# Case Study: Visualizing MNIST with PCA

In this case study, we will once again take a look at the [MNIST dataset](http://yann.lecun.com/exdb/mnist/).  We will use principal component analysis (PCA) to visualize this dataset. Let's get started.

## Setup

In [1]:
import mnist
import altair as alt
import pandas as pd
import numpy as np

alt.data_transformers.disable_max_rows()

training_set = mnist.train_images()
training_labels = mnist.train_labels()

M = training_set.reshape((60000, 28*28), order="C").astype(float)

Similar as before, we get the training set data from the mnist package. It is represented as a numpy array of size 60000x28x28. In preparation for PCA, we need to flatten each 28x28 image into a 1x784 array and store it as a row in the matrix `M`. Thus, `M` is a 60000x784 matrix with each row representing a flattened image.

## PCA computation

We will leverage the [PCA implementation](https://scikit-learn.org/stable/modules/decomposition.html#pca) from scikit-learn, a popular machine learning package in python.

In [2]:
import sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

M_scaled = StandardScaler().fit_transform(M)
pca = PCA(2)
r = pca.fit_transform(M_scaled)

The first thing we do is to [standardize](https://scikit-learn.org/stable/modules/preprocessing.html#standardization-or-mean-removal-and-variance-scaling) the data in each dimension so that it has mean 0 and standard deviation 1. This is a standard data preprocessing step. It is what the `StandardScalar` object does. In our setting, each dimension is a pixel. The values of a given pixel across all images are a column in the matrix `M`. After standardization, each column in the normalized matrix, `M_scaled`, has 0 mean and 1 standard deviation.

Once the data is standardized, we are ready for PCA. PCA computes a set of so-called principal components, which are just vectors on which our dataset exhibits the most amount of variation.  Because we want to visualize the dataset in 2D, we created a PCA object to compute the first two principal components of the normalized data. The data is then projected onto a lower dimensional space defined by the first two principal components. This projection is stored in a 60000x2 matrix named `r`.

## Visualization with PCA

With the projection computed, we can now visualize the data set.

In [3]:
df_pca = pd.DataFrame({
    "x":r[:,0], 
    "y":r[:,1], 
    "label":training_labels,
})
alt.Chart(df_pca).mark_point().encode(x="x:Q", y="y:Q", color="label:N")

In the above visualization, the x axis represents the first principal component, and the y axis represents the second principal component. The projection of each data point (i.e. an image) is represented as a dot. The color of the dots encodes their corresponding labels. As shown above, while PCA is able to provide a low dimension view of the data, data points with different labels are not really well separated in this low dimensional embedding.

We can see a cleaner picture with a faceted view based on the label.

In [4]:
alt.Chart(df_pca).mark_point().encode(x="x:Q", y="y:Q", color="label:N")\
    .properties(width=300, height=300)\
    .facet("label:N", columns=2)

As shown above, data points with the same label do form a cluster.  However, clusters of different labels can significantly overlap with one another. This makes PCA not very effective as a classifier for MNIST dataset.

An interesting fact about PCA is that the principal components have the same dimension as the data point. In our case, since each data point is an image, we can interpret the principal components as images as well. This provides us a way of visualizing the principal components and gaining some understanding of what signal they represent.

In [5]:
df_pca_axis = pd.DataFrame({
    "row": np.repeat(np.arange(28), 28),
    "col": np.tile(np.arange(28), 28),
    "axis_1": pca.components_[0],
    "axis_2": pca.components_[1]
})
alt.Chart(df_pca_axis)\
    .mark_rect()\
    .encode(x="col:O", y="row:O", color="axis_1")\
    .properties(width=300, height=300) |\
alt.Chart(df_pca_axis)\
    .mark_rect()\
    .encode(x="col:O", y="row:O", color="axis_2")\
    .properties(width=300, height=300)

As shown above, the first principal component checks for a circular feature in the image. The second principal component checks for strong signals around the diagonals of the image. One can compute more principal components and visualize their signal in a similar manner.

## Summary

In this case study, we have visualized the MNIST dataset using PCA. We computed the first two principal components of the dataset and visualized the projection of the dataset into these two principal components.  PCA provides us a baseline visualization for bringing high dimensional data into lower dimension. The principal components can also be visualized as images to gain insight to the signal they are matching.