# Decision Trees for handwritten digit recognition

This notebook demonstrates learning a [Decision Tree](https://en.wikipedia.org/wiki/Decision_tree_learning) using Spark's distributed implementation.  It gives the reader a better understanding of some critical [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_optimization) for the tree learning algorithm, using examples to demonstrate how tuning the hyperparameters can improve accuracy.

**Background**: To learn more about Decision Trees, check out the resources at the end of this notebook.  [The visual description of ML and Decision Trees](http://www.r2d3.us/visual-intro-to-machine-learning-part-1/) provides nice intuition helpful to understand this notebook, and [Wikipedia](https://en.wikipedia.org/wiki/Decision_tree_learning) gives lots of details.

**Data**: We use the classic MNIST handwritten digit recognition dataset.

**Goal**: Our goal for our data is to learn how to recognize digits (0 - 9) from images of handwriting.  However, we will focus on understanding trees, not on this particular learning problem.

**Takeaways**: Decision Trees take several hyperparameters which can affect the accuracy of the learned model.  There is no one "best" setting for these for all datasets.  To get the optimal accuracy, we need to tune these hyperparameters based on our data.

## Load MNIST training and test datasets

Our datasets are vectors of pixels representing images of handwritten digits.  For example:

![Image of a digit](http://training.databricks.com/databricks_guide/digit.png)
![Image of all 10 digits](http://training.databricks.com/databricks_guide/MNIST-small.png)

These datasets are stored in the popular LibSVM dataset format.  We will load them using MLlib's LibSVM dataset reader utility.

In [4]:
training = spark.read.format("libsvm").load("/databricks-datasets/mnist-digits/data-001/mnist-digits-train.txt")
test = spark.read.format("libsvm").load("/databricks-datasets/mnist-digits/data-001/mnist-digits-test.txt")

training.cache()
test.cache()

In [5]:
from pyspark.ml.classification import DecisionTreeClassifier, DecisionTreeClassificationModel
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer 

print(training.count())

## Task 1

Your task is to use the links provided in the Piazza post to learn about Spark, ML Pipelines and Decision Trees and move on to implementing those in this notebook.

**Make sure you use the Spark ML Pipeline to carry out your workflow - otherwise using Spark would be no different from using SKLearn**

## Task 2 

Answer the following questions

##### 1. What do the parameters of the Decision Tree do? How do they change the complexity of the tree (i.e. how can you tune them if the model is overfitting/underfitting)?

Ans:

##### 2. How does the Spark ML Pipeline speed-up the computation? Run the same model and evaluation / validation techniques on your own machine and compare the computation speed. Use SKLearn for the 5-fold cross-validator for hyper-parameter tuning on your machine, and compare its speed with the cross-validator on Databricks. 

Ans: