# Decision Tree Classification for MNIST with Cross-Validation

In this notebook, a demonstration is provided of how PySpark's Decision Tree classifier can be utilized with automated hyperparameter tuning via cross-validation to classify handwritten digits from the MNIST dataset.

## Setup and Data Loading

The MNIST dataset in LibSVM format is first downloaded and loaded into the Spark environment:

1. The training and test datasets are downloaded
2. The files are decompressed
3. They are loaded into HDFS for Spark processing
4. The data is read using Spark's libsvm reader

In [1]:
!wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2 https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.t.bz2
!bzip2 -d mnist.bz2 mnist.t.bz2

--2025-03-23 00:54:35--  https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2
Resolving www.csie.ntu.edu.tw (www.csie.ntu.edu.tw)... 140.112.30.26
Connecting to www.csie.ntu.edu.tw (www.csie.ntu.edu.tw)|140.112.30.26|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 15179306 (14M) [application/x-bzip2]
Saving to: ‘mnist.bz2’


2025-03-23 00:54:38 (7.30 MB/s) - ‘mnist.bz2’ saved [15179306/15179306]

--2025-03-23 00:54:38--  https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.t.bz2
Reusing existing connection to www.csie.ntu.edu.tw:443.
HTTP request sent, awaiting response... 200 OK
Length: 2508388 (2.4M) [application/x-bzip2]
Saving to: ‘mnist.t.bz2’


2025-03-23 00:54:38 (101 MB/s) - ‘mnist.t.bz2’ saved [2508388/2508388]

FINISHED --2025-03-23 00:54:38--
Total wall clock time: 3.3s
Downloaded: 2 files, 17M in 2.0s (8.41 MB/s)


In [2]:
!hdfs dfs -put -f mnist 
!hdfs dfs -put -f mnist.t

In [3]:
import pyspark
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import StringIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

In [4]:
# Load MNIST training and test datasets
# These datasets are stored in the LibSVM format
training = spark.read.format("libsvm").load("mnist")
test = spark.read.format("libsvm").load("mnist.t")

# Cache data for multiple uses
training.cache()
test.cache()

print(f"There are {training.count()} training images and {test.count()} test images.")

25/03/23 00:56:04 WARN LibSVMFileFormat: 'numFeatures' option not specified, determining the number of features by going though the input. If you know the number in advance, please specify it via 'numFeatures' option to avoid the extra scan.
25/03/23 00:56:23 WARN LibSVMFileFormat: 'numFeatures' option not specified, determining the number of features by going though the input. If you know the number in advance, please specify it via 'numFeatures' option to avoid the extra scan.

There are 60000 training images and 10000 test images.


                                                                                

In [5]:
# Display the data
display(training)
training.show(n=10)

DataFrame[label: double, features: vector]

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  5.0|(780,[152,153,154...|
|  0.0|(780,[127,128,129...|
|  4.0|(780,[160,161,162...|
|  1.0|(780,[158,159,160...|
|  9.0|(780,[208,209,210...|
|  2.0|(780,[155,156,157...|
|  1.0|(780,[124,125,126...|
|  3.0|(780,[151,152,153...|
|  1.0|(780,[152,153,154...|
|  4.0|(780,[134,135,161...|
+-----+--------------------+
only showing top 10 rows



## Pipeline Construction

A machine learning pipeline is built consisting of:
- A `StringIndexer` to convert the label column to a format suitable for classification
- A `DecisionTreeClassifier` as the model

In [6]:
# Set up the pipeline components
# StringIndexer: Read input column "label" (digits) and annotate them as categorical values
indexer = StringIndexer(inputCol="label", outputCol="indexedLabel")

# DecisionTreeClassifier: Learn to predict column "indexedLabel" using the "features" column
dtc = DecisionTreeClassifier(labelCol="indexedLabel")

# Chain indexer + dtc together into a single ML Pipeline
pipeline = Pipeline(stages=[indexer, dtc])

## Hyperparameter Tuning

To find the optimal model hyperparameters, cross-validation is implemented:
- A parameter grid is created for different tree depths (0-7) and bin sizes (2, 4, 8, 16, 32)
- 3-fold cross-validation is employed to evaluate each parameter combination
- The model with the highest weighted precision is selected

In [7]:
# Define an evaluation metric
evaluator = MulticlassClassificationEvaluator(
    labelCol="indexedLabel",
    predictionCol="prediction",
    metricName="weightedPrecision"
)

In [8]:
# Build parameter grid for CrossValidator
paramGrid = ParamGridBuilder() \
    .addGrid(dtc.maxDepth, range(0,8)) \
    .addGrid(dtc.maxBins, [2, 4, 8, 16, 32]) \
    .build()

In [9]:
# Create the CrossValidator
cv = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=3
)

In [10]:
# Run cross-validation and select the best model
print("Training models with cross-validation...")
cvModel = cv.fit(training)

Training models with cross-validation...


                                                                                

## Model Evaluation

After training:
1. The best model and its parameters are extracted
2. Performance is evaluated on both training and test datasets
3. The best model configuration and its weighted precision metrics are reported

In [11]:
# Get the best model
bestModel = cvModel.bestModel
bestPipelineModel = bestModel

# Extract the decision tree model (the last stage of the pipeline)
bestTreeModel = bestPipelineModel.stages[-1]

display(bestTreeModel)

# Print the best model parameters
print("Best model parameters:")
print(f"maxDepth: {bestTreeModel.getMaxDepth()}")
print(f"maxBins: {bestTreeModel.getMaxBins()}")

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_4d05f4924530, depth=7, numNodes=245, numClasses=10, numFeatures=780

Best model parameters:
maxDepth: 7
maxBins: 8


In [12]:
# Make predictions on training data using the best model
predictions_train = bestPipelineModel.transform(training)

# Evaluate the model on the training dataset
weighted_precision_train = evaluator.evaluate(predictions_train)
print(f"Training weighted precision: {weighted_precision_train}")

Training weighted precision: 0.7915121381227309


In [13]:
# Make predictions on test data using the best model
predictions = bestPipelineModel.transform(test)

# Evaluate the model on the test dataset
weighted_precision_test = evaluator.evaluate(predictions)
print(f"Test weighted precision: {weighted_precision_test}")

Test weighted precision: 0.7946790596293032
