In [1]:
import pyspark
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.linalg import VectorUDT

import os

In [2]:
# Create Spark Session
spark = SparkSession.builder.appName("Predict").getOrCreate()

# Define data path
BASE_PATH = os.getcwd()
DATA_PATH = os.path.join(BASE_PATH, 'data')
TEST_DATASET_PATH = os.path.join(DATA_PATH, 'test_json')
MODEL_PATH = os.path.join(BASE_PATH, 'model')

In [3]:
def get_sample_from_json_buffer(json_file: str, num_samples_to_return: int=1):
    schema = StructType().add("label", FloatType()).add("features", VectorUDT())
    json_file = spark.read.schema(schema).json(json_file)
    sample = json_file.sample(True, 1.0).limit(num_samples_to_return)
    return sample

In [4]:
def predict(sample, model_load_path: str):
    # Loads the model from model path
    # Applies Feature Transformations and Prediction on the samples via the model
    model = PipelineModel.load(MODEL_PATH)
    prediction = model.transform(sample)
    return prediction

In [5]:
# Load test dataset
available_json_files = [f for f in os.listdir(TEST_DATASET_PATH) if f.endswith('.json')]
file_path = os.path.join(TEST_DATASET_PATH, available_json_files[0])
test_data = get_sample_from_json_buffer(file_path, 1)
test_data.printSchema()

# Load model and provide predictions
predictions = predict(test_data, MODEL_PATH)

# Evaluate Predictions (Given labels from test data) and provide accuracy
evaluator = MulticlassClassificationEvaluator(
        labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

print(f'Prediction accuracy: {accuracy}')
predictions.show()

root
 |-- label: float (nullable = true)
 |-- features: vector (nullable = true)

Prediction accuracy: 1.0
+-----+--------------------+------------+--------------------+-------------+-----------+----------+--------------+
|label|            features|indexedLabel|     indexedFeatures|rawPrediction|probability|prediction|predictedLabel|
+-----+--------------------+------------+--------------------+-------------+-----------+----------+--------------+
|  0.0|(692,[98,99,100,1...|         1.0|(692,[98,99,100,1...|   [0.0,15.0]|  [0.0,1.0]|       1.0|           0.0|
+-----+--------------------+------------+--------------------+-------------+-----------+----------+--------------+

