In [None]:
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf, column
from pyspark.ml import Pipeline, PipelineModel
from sparktorch import PysparkPipelineWrapper
from pyspark.sql.types import StructType
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import json
import numpy as np
import matplotlib.pyplot as plt

In [None]:
#setting configuration
sparkConf = SparkConf()
sparkConf.setMaster("yarn")
sparkConf.setAppName("MNIST_TEST")
sparkConf.set("spark.hadoop.yarn.resourcemanager.address", "127.0.0.1:8032")
sparkConf.set("spark.driver.memory","2g").set("spark.executor.memory", '2g').set('spark.executor.cores', 2)
#sparkConf.setMaster("local[2]").set("spark.driver.memory","8g").set("spark.executor.memory", '8g').set('spark.executor.cores', 1)

spark = SparkSession.builder.config(conf=sparkConf).getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

In [None]:
with open("schema.json") as f:
    schema = StructType.fromJson(json.load(f))

In [None]:
df = spark.read \
    .schema(schema) \
    .option("header", "true") \
    .csv('/user/hduser/input/mnist_test.csv') \
    .withColumnRenamed("_c0", "labels") \
    .coalesce(2)

In [None]:
# Load ml pipeline from hdfs
p = PysparkPipelineWrapper.unwrap(PipelineModel.load("/user/hduser/models/mnist"))

In [None]:
predictions = p.transform(df).persist()

In [None]:
evaluator = MulticlassClassificationEvaluator(labelCol="labels", predictionCol="predictions", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Train accuracy = %g" % accuracy)

In [None]:
# Filter all images where the prediction was wrong
compare = np.array(predictions.select(col("labels"), col("predictions")).collect()).reshape(10000,2)
# Insert index column for getting image data
compare = np.insert(compare, 0, np.array(list(range(10000))), axis=1)
wrongPredictions = []
for item in compare:
    if (item[1] != item[2]):
        wrongPredictions.append([int(item[0]), item[1], item[2]])

In [None]:
len(wrongPredictions)

In [None]:
# Plot wrong identified images
num_col = 10

num_row = int((len(wrongPredictions) - (len(wrongPredictions) % num_col)) / num_col) + 1
images = np.array(df.drop(col("labels")).collect()).reshape(10000,28,28)

fig, axes = plt.subplots(num_row, num_col, figsize=(1.5*num_col,2*num_row))

for i in range(num_row * num_col):
    ax = axes[i//num_col, i%num_col]
    ax.set_axis_off()
    
for i in range(len(wrongPredictions)):
    ax = axes[i//num_col, i%num_col]
    ax.set_axis_on()
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(images[wrongPredictions[i][0]], cmap='gray_r')
    ax.set_title('Prediction: {}'.format(int(wrongPredictions[i][2])))
plt.tight_layout()
plt.show()

In [None]:
spark.sparkContext.stop()

In [None]:
#END