In [10]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import numpy as np
from sklearn.datasets import load_digits

spark = SparkSession.builder.appName("MLP").getOrCreate()

dataset = load_digits()

In [None]:
data_array = np.concatenate((dataset.data, dataset.target.reshape(-1, 1)), axis=1)
columns = dataset.feature_names + ['label']
spark_df = spark.createDataFrame(data_array, columns)
spark_df.show(5)

train, test = spark_df.randomSplit([0.8, 0.2], seed=87)

input_features = spark_df.columns[:-1]

assembler = VectorAssembler(
    inputCols=input_features, outputCol="features")
train_df = assembler.transform(train)
test_df = assembler.transform(test)

layers = [len(input_features), 64, 64, len(dataset.target_names)]
mlp = MultilayerPerceptronClassifier(
    layers=layers, featuresCol='features', labelCol='label', seed=87)

model = mlp.fit(train_df)
predictions = model.transform(test_df)

evaluator = MulticlassClassificationEvaluator(
    predictionCol="prediction", labelCol="label", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

evaluator_precision = MulticlassClassificationEvaluator(
    predictionCol="prediction", labelCol="label", metricName="weightedPrecision")
precision = evaluator_precision.evaluate(predictions)

evaluator_recall = MulticlassClassificationEvaluator(
    predictionCol="prediction", labelCol="label", metricName="weightedRecall")
recall = evaluator_recall.evaluate(predictions)

evaluator_f1 = MulticlassClassificationEvaluator(
    predictionCol="prediction", labelCol="label", metricName="f1")
f1 = evaluator_f1.evaluate(predictions)

print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1: {f1}")

confusion_matrix = (predictions
                    .groupBy("label")
                    .pivot("prediction")
                    .count()
                    .fillna(0)
                    .orderBy("label")
)
confusion_matrix.show(truncate=False)

+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+-----+
|pixel_0_0|pixel_0_1|pixel_0_2|pixel_0_3|pixel_0_4|pixel_0_5|pixel_0_6|pixel_0_7|pixel_1_0|pixel_1_1|pixel_1_2|pixel_1_3|pixel_1_4|pixel_1_5|pixel_1_6|pixel_1_7|pixel_2_0|pixel_2_1|pixel_2_2|pixel_2_3|pixel_2_4|pixel_2_5|pixel_2_6|pixel_2_7|pixel_3_0|pixel_3_1|pixel_3_2|pixel_3_3|pixel_3_4|pixel_3_5|pixel_3_6|pixel_3_7|pixel_4_0|pixel_4_1|pixel_4_2|p

24/11/15 00:32:41 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS


Accuracy: 0.9435897435897436
Precision: 0.9429757436753393
Recall: 0.9435897435897437
F1: 0.9428243681109911
+-----+---+---+---+---+---+---+---+---+---+---+
|label|0.0|1.0|2.0|3.0|4.0|5.0|6.0|7.0|8.0|9.0|
+-----+---+---+---+---+---+---+---+---+---+---+
|0.0  |46 |0  |0  |0  |1  |0  |0  |0  |0  |0  |
|1.0  |0  |42 |0  |0  |0  |0  |0  |0  |0  |0  |
|2.0  |0  |0  |39 |1  |0  |0  |0  |0  |0  |0  |
|3.0  |0  |0  |2  |40 |0  |0  |0  |0  |3  |1  |
|4.0  |0  |0  |0  |0  |34 |0  |0  |1  |0  |0  |
|5.0  |0  |0  |0  |1  |0  |42 |0  |0  |1  |0  |
|6.0  |0  |0  |0  |0  |0  |0  |26 |0  |1  |0  |
|7.0  |0  |0  |0  |0  |0  |0  |0  |30 |0  |0  |
|8.0  |0  |3  |0  |1  |0  |1  |1  |0  |31 |2  |
|9.0  |1  |0  |0  |0  |1  |0  |0  |0  |0  |38 |
+-----+---+---+---+---+---+---+---+---+---+---+

