In [None]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd

iris_data = load_iris()
X = iris_data.data
y = iris_data.target

spark = SparkSession.builder.getOrCreate()

iris_pandas = pd.DataFrame(
    data=X,
    columns=iris_data.feature_names)
iris_pandas["label"] = y

iris_df = spark.createDataFrame(iris_pandas)

feature_cols = iris_data.feature_names
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
iris_df = assembler.transform(iris_df)

X_train, X_test = iris_df.randomSplit([0.8, 0.2], seed=87)

spark.conf.set("spark.seed", 87)
rf = RandomForestClassifier(featuresCol="features",
                            labelCol="label")

param_grid = (ParamGridBuilder()
              .addGrid(rf.maxDepth, [2, 5, 7, 10])
              .addGrid(rf.numTrees, [20, 50, 75, 100])
              .addGrid(rf.minInstancesPerNode, [2, 3, 5,  10])
              .addGrid(rf.minInfoGain, [0.0, 0.01, 0.02, 0.03])
              .build())

evaluator = MulticlassClassificationEvaluator(metricName="accuracy")

cv = CrossValidator(estimator=rf,
                    estimatorParamMaps=param_grid,
                    evaluator=evaluator,
                    numFolds=5)

cv_model = cv.fit(X_train)

best_model = cv_model.bestModel
best_max_depth = best_model.getMaxDepth()
best_num_trees = best_model.getNumTrees()
best_min_samples_split = best_model.getMinInstancesPerNode()
best_min_info_gain = best_model.getMinInfoGain()

predictions = best_model.transform(X_test)

accuracy = evaluator.evaluate(predictions)

print(f"Test Accuracy: {accuracy}")
print("Best max_depth: ", best_max_depth)
print("Best n_estimators: ", best_num_trees)
print("Best min_samples_split: ", best_min_samples_split)
print("Best min_info_gain: ", best_min_info_gain)

24/10/26 13:30:06 WARN Utils: Your hostname, Pepijn-Air.local resolves to a loopback address: 127.0.0.1; using 192.168.2.5 instead (on interface en0)
24/10/26 13:30:06 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/10/26 13:30:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
                                                                                