In [1]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

import os

In [2]:
# Create Spark Session
spark = SparkSession.builder.appName("Predict").getOrCreate()

# Define data path
BASE_PATH = os.getcwd()
DATA_PATH = os.path.join(BASE_PATH, 'data')
TEST_DATASET_PATH = os.path.join(DATA_PATH, 'training')
MODEL_PATH = os.path.join(BASE_PATH, 'model')

In [3]:
def predict(sample, model_load_path: str):
    # Loads the model from model path
    # Applies Feature Transformations and Prediction on the samples via the model
    model = PipelineModel.load(MODEL_PATH)
    prediction = model.transform(sample)
    return prediction

In [4]:
# Load test dataset
test_data = spark.read.parquet(TEST_DATASET_PATH)

# Extract batch samples
sample = test_data.sample(True, 0.1)

# Load model and provide predictions
predictions = predict(sample, MODEL_PATH)

# Evaluate Predictions (Given labels from test data) and provide accuracy
evaluator = MulticlassClassificationEvaluator(
        labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

print(f'Prediction accuracy: {accuracy}')
predictions.show()

Prediction accuracy: 1.0
+-----+--------------------+------------+--------------------+-------------+-----------+----------+--------------+
|label|            features|indexedLabel|     indexedFeatures|rawPrediction|probability|prediction|predictedLabel|
+-----+--------------------+------------+--------------------+-------------+-----------+----------+--------------+
|  0.0|(692,[122,123,124...|         1.0|(692,[122,123,124...|   [0.0,15.0]|  [0.0,1.0]|       1.0|           0.0|
|  0.0|(692,[127,128,129...|         1.0|(692,[127,128,129...|   [0.0,15.0]|  [0.0,1.0]|       1.0|           0.0|
|  0.0|(692,[153,154,155...|         1.0|(692,[153,154,155...|   [0.0,15.0]|  [0.0,1.0]|       1.0|           0.0|
|  1.0|(692,[151,152,153...|         0.0|(692,[151,152,153...|   [15.0,0.0]|  [1.0,0.0]|       0.0|           1.0|
|  1.0|(692,[155,156,157...|         0.0|(692,[155,156,157...|   [15.0,0.0]|  [1.0,0.0]|       0.0|           1.0|
+-----+--------------------+------------+--------------