# MNIST Handwritten Digit Recognition

This notebook demostrates the use of Apache Spark in learning a decision Tree. It also uses MLflow to track the learning process and give better understanding of some critical hyperparameters for tree learning algorithms. I have attached the output (csv file) of MLflow in the repository.

Data: MNIST Handwritten Digit Recognition.

Goal: Learn to recognize digits (0-9) from images of handwriting.

## Part 1: Setup and Loading MNIST Datasets

Before loading the dataset, I created a Spark ML cluster and attached the notebook to the cluster.

In [4]:
# Check location of the data files
%fs ls /databricks-datasets/mnist-digits/data-001

path,name,size
dbfs:/databricks-datasets/mnist-digits/data-001/mnist-digits-test.txt,mnist-digits-test.txt,11671108
dbfs:/databricks-datasets/mnist-digits/data-001/mnist-digits-train.txt,mnist-digits-train.txt,69430283


In [5]:
# Load datasets and cache them. It is very crucial that we cache the datasets for faster processing.
data_train = spark.read.format("libsvm").option("header", "true").option("inferSchema", "true").load("/databricks-datasets/mnist-digits/data-001/mnist-digits-train.txt");
data_test = spark.read.format("libsvm").option("header", "true").option("inferSchema", "true").load("/databricks-datasets/mnist-digits/data-001/mnist-digits-test.txt");

data_train.cache();
data_test.cache();

## Part 2: Explore using Spark

In [7]:
print("There are {} training images and {} test images.".format(data_train.count(), data_test.count()))

In [8]:
# Checking the structure of the dataset
data_train.printSchema()

Displaying the data. Each image has the true label (the label column) and a vector of features that represent pixel intensities.

In [10]:
display(data_train.limit(2))

label,features
5.0,"List(0, 780, List(152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 260, 261, 262, 263, 264, 265, 266, 268, 269, 289, 290, 291, 292, 293, 319, 320, 321, 322, 347, 348, 349, 350, 376, 377, 378, 379, 380, 381, 405, 406, 407, 408, 409, 410, 434, 435, 436, 437, 438, 439, 463, 464, 465, 466, 467, 493, 494, 495, 496, 518, 519, 520, 521, 522, 523, 524, 544, 545, 546, 547, 548, 549, 550, 551, 570, 571, 572, 573, 574, 575, 576, 577, 578, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 676, 677, 678, 679, 680, 681, 682, 683), List(3.0, 18.0, 18.0, 18.0, 126.0, 136.0, 175.0, 26.0, 166.0, 255.0, 247.0, 127.0, 30.0, 36.0, 94.0, 154.0, 170.0, 253.0, 253.0, 253.0, 253.0, 253.0, 225.0, 172.0, 253.0, 242.0, 195.0, 64.0, 49.0, 238.0, 253.0, 253.0, 253.0, 253.0, 253.0, 253.0, 253.0, 253.0, 251.0, 93.0, 82.0, 82.0, 56.0, 39.0, 18.0, 219.0, 253.0, 253.0, 253.0, 253.0, 253.0, 198.0, 182.0, 247.0, 241.0, 80.0, 156.0, 107.0, 253.0, 253.0, 205.0, 11.0, 43.0, 154.0, 14.0, 1.0, 154.0, 253.0, 90.0, 139.0, 253.0, 190.0, 2.0, 11.0, 190.0, 253.0, 70.0, 35.0, 241.0, 225.0, 160.0, 108.0, 1.0, 81.0, 240.0, 253.0, 253.0, 119.0, 25.0, 45.0, 186.0, 253.0, 253.0, 150.0, 27.0, 16.0, 93.0, 252.0, 253.0, 187.0, 249.0, 253.0, 249.0, 64.0, 46.0, 130.0, 183.0, 253.0, 253.0, 207.0, 2.0, 39.0, 148.0, 229.0, 253.0, 253.0, 253.0, 250.0, 182.0, 24.0, 114.0, 221.0, 253.0, 253.0, 253.0, 253.0, 201.0, 78.0, 23.0, 66.0, 213.0, 253.0, 253.0, 253.0, 253.0, 198.0, 81.0, 2.0, 18.0, 171.0, 219.0, 253.0, 253.0, 253.0, 253.0, 195.0, 80.0, 9.0, 55.0, 172.0, 226.0, 253.0, 253.0, 253.0, 253.0, 244.0, 133.0, 11.0, 136.0, 253.0, 253.0, 253.0, 212.0, 135.0, 132.0, 16.0))"
0.0,"List(0, 780, List(127, 128, 129, 130, 131, 154, 155, 156, 157, 158, 159, 181, 182, 183, 184, 185, 186, 187, 188, 189, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 289, 290, 291, 292, 293, 294, 295, 296, 297, 300, 301, 302, 316, 317, 318, 319, 320, 321, 328, 329, 330, 343, 344, 345, 346, 347, 348, 349, 356, 357, 358, 371, 372, 373, 374, 384, 385, 386, 399, 400, 401, 412, 413, 414, 426, 427, 428, 429, 440, 441, 442, 454, 455, 456, 457, 466, 467, 468, 469, 470, 482, 483, 484, 493, 494, 495, 496, 497, 510, 511, 512, 520, 521, 522, 523, 538, 539, 540, 547, 548, 549, 550, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 622, 623, 624, 625, 626, 627, 628, 629, 630, 651, 652, 653, 654, 655, 656, 657), List(51.0, 159.0, 253.0, 159.0, 50.0, 48.0, 238.0, 252.0, 252.0, 252.0, 237.0, 54.0, 227.0, 253.0, 252.0, 239.0, 233.0, 252.0, 57.0, 6.0, 10.0, 60.0, 224.0, 252.0, 253.0, 252.0, 202.0, 84.0, 252.0, 253.0, 122.0, 163.0, 252.0, 252.0, 252.0, 253.0, 252.0, 252.0, 96.0, 189.0, 253.0, 167.0, 51.0, 238.0, 253.0, 253.0, 190.0, 114.0, 253.0, 228.0, 47.0, 79.0, 255.0, 168.0, 48.0, 238.0, 252.0, 252.0, 179.0, 12.0, 75.0, 121.0, 21.0, 253.0, 243.0, 50.0, 38.0, 165.0, 253.0, 233.0, 208.0, 84.0, 253.0, 252.0, 165.0, 7.0, 178.0, 252.0, 240.0, 71.0, 19.0, 28.0, 253.0, 252.0, 195.0, 57.0, 252.0, 252.0, 63.0, 253.0, 252.0, 195.0, 198.0, 253.0, 190.0, 255.0, 253.0, 196.0, 76.0, 246.0, 252.0, 112.0, 253.0, 252.0, 148.0, 85.0, 252.0, 230.0, 25.0, 7.0, 135.0, 253.0, 186.0, 12.0, 85.0, 252.0, 223.0, 7.0, 131.0, 252.0, 225.0, 71.0, 85.0, 252.0, 145.0, 48.0, 165.0, 252.0, 173.0, 86.0, 253.0, 225.0, 114.0, 238.0, 253.0, 162.0, 85.0, 252.0, 249.0, 146.0, 48.0, 29.0, 85.0, 178.0, 225.0, 253.0, 223.0, 167.0, 56.0, 85.0, 252.0, 252.0, 252.0, 229.0, 215.0, 252.0, 252.0, 252.0, 196.0, 130.0, 28.0, 199.0, 252.0, 252.0, 253.0, 252.0, 252.0, 233.0, 145.0, 25.0, 128.0, 252.0, 253.0, 252.0, 141.0, 37.0))"


In [11]:
import numpy as np
import matplotlib.image as mpimg
import math
from matplotlib import pyplot as plt

# Function to render images stored in the dataset.
def show_images(data):
  # Each image is supposed to be tuple.
  # the first element of the tuple is a 780 sparse vector, corresponding to features in the MNIST dataset.
  # the second element is an integer, corresponding to the label or predicted digit
  # in the following, we display the list of pictures in four-picture rows along with their labels.
  # don't show too many pictures with this function.
  # e.g. show_images([(r.features, r.label) for r in df.take(4)])
 
  fig = plt.figure()
  columns = 4
  rows = math.ceil(len(data)/4) # determine how many rows we need
  
  # ax enables access to manipulate each of subplots
  ax = []

  for i in range(len(data)):
      # the image is an array of 28x28 (=784) gray scale pixels. but the data is an 780 array. We need to pad it, 
      # convert it to float values, and reshape it to 28x28 matrices.
      img = np.array(np.pad(data[i][0],(0,784-len(data[i][0])),'constant',constant_values=(0,0)), dtype='float').reshape((28, 28))
      # create subplot and append to ax
      ax.append(fig.add_subplot(rows, 4, i+1) )
      ax[-1].set_title(str(int(data[i][1])))  # set title for the last image to its label.
      plt.imshow(img,cmap='gray') # render the image
      plt.axis('off') # turn off axies
  
  # display the images
  display(fig)

In [12]:
show_images([(r.features, r.label) for r in data_train.take(4)])

## Part 3: Training a base Decision Tree Classifier

In [14]:
# Import the ML classification, evaluator, indexer, and pipeline classes 
from pyspark.ml.classification import DecisionTreeClassifier, DecisionTreeClassificationModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

In [15]:
# StringIndexer: Read input column "label" (digits) and annotate them as categorical values.
indexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
# DecisionTreeClassifier: Learn to predict column "indexedLabel" using the "features" column.
dtc = DecisionTreeClassifier(labelCol="indexedLabel")
# Chain indexer + dtc together into a single ML Pipeline.
pipeline = Pipeline(stages=[indexer, dtc])

In [16]:
# First have a look at the accuracy of base model
pipelineModel = pipeline.fit(data_train)
predictions = pipelineModel.transform(data_test)
evaluator = MulticlassClassificationEvaluator(labelCol="indexedLabel", predictionCol="prediction", metricName="weightedPrecision")
accuracy = evaluator.evaluate(predictions)

In [17]:
print("Accuracy = %g" % (accuracy))

Using a simple decision tree classifier, we get an accuracy of 70%. Next, let's visualize these results before training a cross validated model using MLflow.

### Visualize Results

In [20]:
correct_prediction = predictions.filter(predictions['label'] == predictions['prediction'])
show_images([(r.features, r.prediction) for r in correct_prediction.sample(False, 0.1, 0).take(8)])

In [21]:
incorrect_prediction = predictions.filter(predictions.label != predictions.prediction)
show_images([(r.features, r.prediction) for r in incorrect_prediction.sample(False, 0.1, 0).take(8)])

## Step 4: Training a Decision Tree Classifier with Automated MLflow Tracking for CrossValidator model tuning

In [23]:
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

grid = ParamGridBuilder() \
  .addGrid(dtc.maxDepth, [2, 3, 4, 5, 6, 7, 8]) \
  .addGrid(dtc.maxBins, [2, 4, 8]) \
  .build()

cv = CrossValidator(estimator=pipeline, evaluator=evaluator, estimatorParamMaps=grid, numFolds=3)

In [24]:
import mlflow
import mlflow.mleap
with mlflow.start_run():
  cvModel = cv.fit(data_train)
  mlflow.set_tag('owner_team', 'UX Data Science') # Logs user-defined tags
  test_metric = evaluator.evaluate(cvModel.transform(data_test))
  mlflow.log_metric('test_' + evaluator.getMetricName(), test_metric) # Logs additional metrics
  mlflow.mleap.log_model(spark_model=cvModel.bestModel, sample_input=data_test, artifact_path='best-model') # Logs the best model via mleap

<b> Conclusion </b>

Using the MLflow, I found the following parameters for the best model:
- maxDepth: 8
- maxBins: 4

I have attached a csv file showing the results of all 21 models.

Future steps to improve model:
We can create new images by rotating the current images by 180 degress and double our training data. This would help in better training of the model and provide better accuracy.