In [8]:
from pyspark.mllib.evaluation import BinaryClassificationMetrics, MulticlassMetrics
from pyspark.sql import SparkSession

# Initialize Spark Session (if you haven't already)
spark = SparkSession.builder.appName("MetricsExample").getOrCreate()

# Assuming 'predictions' is your DataFrame after model predictions with columns 'prediction' and 'label'

# Select (prediction, true label) and convert to RDD for metrics calculation
predictionAndLabels = predictions.select("prediction", "label").rdd

# Instantiate MulticlassMetrics for the RDD
metrics = MulticlassMetrics(predictionAndLabels)

# Accuracy
accuracy = metrics.accuracy
print("Accuracy:", accuracy)

# Precision
precision = metrics.weightedPrecision
print("Precision:", precision)

# Recall (Sensitivity)
recall = metrics.weightedRecall
print("Recall:", recall)

# Specificity
confusionMatrix = metrics.confusionMatrix().toArray()
TN = confusionMatrix[0][0]
FP = confusionMatrix[0][1]
specificity = TN / (TN + FP)
print("Specificity:", specificity)

# F1-Score
f1Score = metrics.weightedFMeasure()
print("F1-Score:", f1Score)

# Stop the Spark session (if you're done with other tasks)
spark.stop()

ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "C:\Users\yubar\anaconda3\envs\spark_env\lib\site-packages\py4j\java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "C:\Users\yubar\anaconda3\envs\spark_env\lib\site-packages\py4j\clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "C:\Users\yubar\anaconda3\envs\spark_env\lib\socket.py", line 717, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 

### Another example with ROC and Precision-Recall Curve

In [None]:
from pyspark.sql import SparkSession
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler
from pyspark.mllib.evaluation import BinaryClassificationMetrics, MulticlassMetrics

# Initialize Spark Session
spark = SparkSession.builder.appName("EvaluationExample").getOrCreate()

# Assuming 'df' is your Spark DataFrame with features and a binary target column named 'label'
# Split the data into training and test sets
(train, test) = df.randomSplit([0.7, 0.3], seed=12345)

# VectorAssembler to combine feature columns
assembler = VectorAssembler(inputCols=[column for column in df.columns if column != 'label'], outputCol="features")
train = assembler.transform(train)
test = assembler.transform(test)

# Train a Logistic Regression model (you can replace this with any classifier)
lr = LogisticRegression(featuresCol='features', labelCol='label')
model = lr.fit(train)

# Make predictions
predictions = model.transform(test)

# Select (prediction, true label) and compute metrics
predictionAndLabels = predictions.select("prediction", "label")

# 1. Accuracy, Precision, Recall, F1-Score
evaluatorMulti = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction")
print("Accuracy:", evaluatorMulti.evaluate(predictionAndLabels, {evaluatorMulti.metricName: "accuracy"}))
print("Precision:", evaluatorMulti.evaluate(predictionAndLabels, {evaluatorMulti.metricName: "weightedPrecision"}))
print("Recall:", evaluatorMulti.evaluate(predictionAndLabels, {evaluatorMulti.metricName: "weightedRecall"}))
print("F1-Score:", evaluatorMulti.evaluate(predictionAndLabels, {evaluatorMulti.metricName: "f1"}))

# 2. Confusion Matrix
metrics = MulticlassMetrics(predictionAndLabels.rdd)
print("Confusion Matrix:\n", metrics.confusionMatrix().toArray())

# 3. ROC Curve and AUC
evaluator = BinaryClassificationEvaluator(labelCol="label")
print("AUC:", evaluator.evaluate(predictions, {evaluator.metricName: "areaUnderROC"}))

# 4. Precision-Recall Curve (Not directly supported in PySpark, but you can extract necessary information)
# (This part is a bit more complex in PySpark compared to scikit-learn)


# Stop the Spark session
spark.stop()
