# NMF Implementation - Pattern Detection in Digits
Instead of using the MNIST data as an example of supervised learning and training a model to predict the digit given an image, I am instead going to use non-negative matrix factorisation (NMF) to extract patterns from the digit images. This is a method of dimensionality reduction so we are looking to turn the 784 pixel variables for the 28 x 28 images into a smaller subset of features that can be used alongside the model's components to maintain information about the original images and 'rebuild' the samples if necessary using matrix dot product. 

In [None]:
# Imports 
import numpy as np 
import pandas as pd 

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
import matplotlib.patches as mpatches

from sklearn.preprocessing import normalize
from sklearn.decomposition import NMF


# Helper functions for displaying the digits
def show_as_image(sample, shape):
    # shape = tuple e.g. (13,8)
    bitmap = sample.reshape(shape)
    plt.figure()
    plt.imshow(bitmap, cmap='gray', interpolation='nearest')
    plt.colorbar()
    plt.show()
    
# Useful to be able to pass an axes, as subplots enable view of many images
def plot_on_axes(sample, shape, ax):
    bitmap = sample.reshape(shape)
    ax.imshow(bitmap, cmap='gray', interpolation='nearest')

As we don't require test and train sets, I'm going to combine the train and test files into one larger structure that we can work with. 

In [None]:
train = pd.read_csv('/kaggle/input/mnist-in-csv/mnist_train.csv')
test = pd.read_csv('/kaggle/input/mnist-in-csv/mnist_test.csv')
data = pd.concat([train, test])
data.head()

In [None]:
# Create feature (X) and label (y) numpy arrays
X = np.array(data.loc[:, '1x1':])
y = np.array(data.loc[:, 'label'])

## Examples of digits
Let's take a look at some of the samples for each digit. 

In [None]:
n_examples = 10
fig, axes = plt.subplots(10, n_examples, figsize=(n_examples, 10))

for digit in np.arange(10):
    digit_indexes = np.where(y == digit)[0][:n_examples]
    
    for i, index in enumerate(digit_indexes): 
        plot_on_axes(X[index], (28, 28), axes[digit][i])
        axes[digit][i].get_xaxis().set_visible(False)
        axes[digit][i].get_yaxis().set_visible(False)
        
plt.show()

NMF will create both a set of components (n defined upon model creation) and features for input data. These features will, essentially, act as additional variables; this is where the dimensionality reduction comes in as you can compress the information stored in the 784 pixel columns into `n_components` columns.  

You can take the dot product between the features and the components and use that information to 'rebuild' the original samples. That is the way in which NMF works as a dimensionality reduction. It is like PCA but the returned components are interpretable to the images and pick out common patterns that are found across the input data.  

In [None]:
model = NMF(n_components=15, init='nndsvd', random_state=2021, max_iter=1000)
model.fit(X)
nmf_features = model.transform(X)

Note, I am using the `init='nndsvd'` here as reading the [documentation](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.NMF.html) suggests that this is better for sparse matrices, of which our samples are.

In [None]:
def show_components(model, display_on_ax_func, images_per_row, figsize_tuple):
    n_components = model.n_components_
    n_rows = (n_components // images_per_row) 
    n_rows = n_rows if (n_components // images_per_row) == 0 else n_rows + 1
    fig, axes = plt.subplots(n_rows, images_per_row, figsize=figsize_tuple)
    
    for i, ax in enumerate(axes.flatten()):
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
        if i < n_components:
            display_on_ax_func(model.components_[i], (28, 28), ax)
        else:
            ax.axis('off')
            
    plt.show()

In [None]:
show_components(model, plot_on_axes, 5, (10, 6))

While these components aren't the clear clustering of digits themselves (as you would expected in a more predictive approach), we can see that they are clearly parts of the digits and represent the common patterns that are seen throughout the data. With 15 components we are getting a mixture of small dashes and sweeping strokes that are more circular i.e. zeros and eights.  

We might consider increasing the number of components even further to try and break up the components into smaller patterns and smaller strokes.

In [None]:
nmf_features.shape

In [None]:
model.components_.shape

In [None]:
nmf_features[0]

The components of the model represent the patterns that are seen in the images, with their shape being weightings against each of the original variables (pixels in this case); that is why we are able to visualise them in the same shape as the images that we are using via the function `show_as_image` that is defined at the top of this script.   

The features are the new variables that contain a weighting for the corresponding component. We can, therefore, take an example and find out which components are the most influential. 

In [None]:
example_iloc = 0
plt.figure(figsize=(10, 3))
plt.bar(x=np.arange(len(nmf_features[example_iloc])) + 1, height=nmf_features[example_iloc])
plt.xticks(np.arange(len(nmf_features[example_iloc])) + 1)
plt.xlabel('NMF Component')
plt.title('Feature importance for given sample')
plt.show()

We can isolate and check the general importance at a digit-level, rather than a single example. 

In [None]:
five_features = nmf_features[np.where(y == 5)[0]]
seven_features = nmf_features[np.where(y == 7)[0]]

fig, axes = plt.subplots(2, 1, figsize=(10, 10))
axes[0].boxplot(five_features)
axes[0].set_title('Features for "5"s')
axes[1].boxplot(seven_features)
axes[1].set_title('Features for "7"s')
plt.show()

Here we can see the general trend differences in which patterns are most important for each digit. Take the first component, for example, this plays a larger part in "5"s than in "7"s which makes a lot of sense when you look at the shape of that component: 

In [None]:
show_as_image(model.components_[0], (28, 28))

Contrast this with the 14th component, that is more influential in "7"s than in "5"s where it scores much lower. 

In [None]:
show_as_image(model.components_[13], (28, 28))

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(10, 10))

axes[0].bar(x=np.arange(15), height=np.mean(five_features, axis=0))
axes[1].bar(x=np.arange(15), height=np.mean(seven_features, axis=0))
axes[0].set_title('Average importance of features for "5"s')
axes[1].set_title('Average importance of features for "7"s')
plt.show()

Let's extract the components that, on average, have the highest influence for the digit 5. We can then print those out and get a feel for if they look like patterns that commonly appear when writing fives. 

In [None]:
top_four_features_five = np.mean(five_features, axis=0).argsort()[-4:][::-1]
top_four_features_seven = np.mean(seven_features, axis=0).argsort()[-4:][::-1]

# Top 4 components, on average, for the digit 5
print(top_four_features_five)
np.mean(five_features, axis=0)[top_four_features_five]

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(12, 3.5))
fig.suptitle('Key components in "5"s')
for i, feature in enumerate(top_four_features_five):
    plot_on_axes(model.components_[feature], (28, 28), axs.flatten()[i])
    axs.flatten()[i].set_title(f'Component {feature + 1}: {round(np.mean(five_features, axis=0)[feature], 2)}')
    axs.flatten()[i].get_xaxis().set_visible(False)
    axs.flatten()[i].get_yaxis().set_visible(False)
plt.show() 

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(12, 3.5))
fig.suptitle('Key components in "7"s')
for i, feature in enumerate(top_four_features_seven):
    plot_on_axes(model.components_[feature], (28, 28), axs.flatten()[i])
    axs.flatten()[i].set_title(f'Component {feature + 1}: {round(np.mean(seven_features, axis=0)[feature], 2)}')
    axs.flatten()[i].get_xaxis().set_visible(False)
    axs.flatten()[i].get_yaxis().set_visible(False)
plt.show() 

## Reconstructing samples using components and features 
Fundamentally NMF is performing dimensionality reduction, trying to store the information held in 784 variables in just 15 instead (defined via `n_components`). We have already looked at the components that have been created, but how well do these represent the original data? We can reconstruct the sample images from the features and components of the NMF model and compare them to see!  

We can reconstruct the image using the features and components by multiplying the matrices together using `np.dot`. Their shapes are compatible and shows how NMF finds the patterns (or components) and then creates an feature matrix that combines with those features to rebuild the original sample. This can also be done using a method on the model itself, `inverse_transform()`; see the [docs](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.NMF.html#sklearn.decomposition.NMF.inverse_transform) for more information.

In [None]:
sample_index = 0
sample_features = nmf_features[sample_index]
components = model.components_
manually_reconstructed_image = np.dot(sample_features, components)

fig, axes = plt.subplots(1, 3, figsize=(10, 4))
plot_on_axes(manually_reconstructed_image, (28, 28), axes[0])
axes[0].set_title(f'Manually reconstructed image')
axes[0].get_xaxis().set_visible(False)
axes[0].get_yaxis().set_visible(False)

plot_on_axes(X[sample_index], (28, 28), axes[1])
axes[1].set_title(f'Original sample')
axes[1].get_xaxis().set_visible(False)
axes[1].get_yaxis().set_visible(False)

plot_on_axes(model.inverse_transform(nmf_features[sample_index]), (28, 28), axes[2])
axes[2].set_title(f'Using `inverse_transform()`')
axes[2].get_xaxis().set_visible(False)
axes[2].get_yaxis().set_visible(False)

plt.show()

As you can see above, you can achieve the same reconstruction using the `.inverse_transform()` method on the `model` instance itself. Below you can see that the results are the same when we compare the resultant arrays.

In [None]:
# All pixel values of the reconstruction match the inverse_transform method
sum(model.inverse_transform(nmf_features[0]) == manually_reconstructed_image)

The overall reconstruction isn't _that_ impressive here, but you begin to understand what is happening and what NMF is achieving here through creative two matrices in far fewer dimensions that the original data set.  

I'm going to score the differences between the reconstructed and original samples and find those examples that score poorly (the difference between the original and reconstructed is high) and those that reconstruct well (differences are minimal). 

In [None]:
# Rebuild all samples 
X_rebuilt = model.inverse_transform(nmf_features)

# Compare the original image to the normalized version of the rebuild
X_diff = X - X_rebuilt

# Find the differences between the rebuilt and original by taking the norm of the differences row-wise
norm_diffs = pd.Series(np.linalg.norm(X_diff, axis=1))

most_different = np.array(pd.Series(norm_diffs).nlargest(20).index)
most_similar = np.array(pd.Series(norm_diffs).nsmallest(20).index)

In [None]:
sample_index = most_different[0]
sample_features = nmf_features[sample_index]
components = model.components_
manually_reconstructed_image = np.dot(sample_features, components)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
fig.suptitle('Sample with poor reconstruction')
plot_on_axes(manually_reconstructed_image, (28, 28), axes[0])
axes[0].set_title(f'Reconstructed image')
axes[0].get_xaxis().set_visible(False)
axes[0].get_yaxis().set_visible(False)

plot_on_axes(X[sample_index], (28, 28), axes[1])
axes[1].set_title(f'Original sample')
axes[1].get_xaxis().set_visible(False)
axes[1].get_yaxis().set_visible(False)

plt.show()

In [None]:
y[most_different]

We can see that there are a range of the different digits here under those that are underperforming when the model rebuilds the records with 15 components available. 

In [None]:
sample_index = most_similar[2]
sample_features = nmf_features[sample_index]
components = model.components_
manually_reconstructed_image = np.dot(sample_features, components)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
fig.suptitle('Sample with better reconstruction')
plot_on_axes(manually_reconstructed_image, (28, 28), axes[0])
axes[0].set_title(f'Manually reconstructed image')
axes[0].get_xaxis().set_visible(False)
axes[0].get_yaxis().set_visible(False)

plot_on_axes(X[sample_index], (28, 28), axes[1])
axes[1].set_title(f'Original sample')
axes[1].get_xaxis().set_visible(False)
axes[1].get_yaxis().set_visible(False)

plt.show()

In [None]:
y[most_similar]

Whereas, we can see above that the reconstruction works well with the digit 1, probably due to the simplicity of the original image. It would be interesting to see the performance of the model with additional components and some hyperparameter tuning to improve the resultant inversely transformed records. 

## Combining top components to reconstruct classes
I have thought that we can use top components for a given class to reconstruct a generic representation of that class. By combining those top components in such a way that we create an image that shows the major structure of a class by combining the most influential parts.  

With components illustrating the features of those types of images, it would make sense that there is a way to combine them back together to get an 'average' digit of each type, represented by the most influential components that comprise that digit.  

Let's take 7s again as an example of this idea. 

In [None]:
seven_features = nmf_features[np.where(y == 7)[0]]
top_four_features_seven = np.mean(seven_features, axis=0).argsort()[-4:][::-1]

fig, axs = plt.subplots(1, 4, figsize=(12, 3.5))
fig.suptitle('Key components in "7"s')
for i, feature in enumerate(top_four_features_seven):
    plot_on_axes(model.components_[feature], (28, 28), axs.flatten()[i])
    axs.flatten()[i].set_title(f'Component {feature + 1}: {round(np.mean(seven_features, axis=0)[feature], 2)}')
    axs.flatten()[i].get_xaxis().set_visible(False)
    axs.flatten()[i].get_yaxis().set_visible(False)
plt.show() 

In [None]:
top_four_seven_components = model.components_[top_four_features_seven]

# Aggregate at pixel level to reconstruct the class
mean = np.mean(top_four_seven_components, axis=0)
median = np.median(top_four_seven_components, axis=0)
minimum = np.min(top_four_seven_components, axis=0)
maximum = np.max(top_four_seven_components, axis=0)

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(12, 3.5))
fig.suptitle('Reconstruction methods using top 4 components')
metrics = ['Mean', 'Median', 'Minimum', 'Maximum']

for i, metric in enumerate(metrics):
    image_arr = eval(metric.lower())
    plot_on_axes(image_arr, (28, 28), axs.flatten()[i])
    axs.flatten()[i].set_title(f'{metric}')
    axs.flatten()[i].get_xaxis().set_visible(False)
    axs.flatten()[i].get_yaxis().set_visible(False)
plt.show()

While this doesn't necessarily provide actionable insight, it does show some intuition behind the way in which NMF extracts these intepretable components that represent patterns in the original samples. With 7s we can see that the information held in the top 4 components is enough for us to reconstruct a recognisable digit that generalises to the class that we know; this alludes to the effectiveness of the components that we are able to reduce the dimensions and still maintain the ability to build those samples back generally. However, we did see that on an individual basis, with just 15 components, that this wasn't as effective. 

# More components, better interpretability? 
Right, time to increase the number of components that the model creates, and see whether that makes a difference to the ease that we can interpret the components visually. 

In [None]:
model2 = NMF(n_components=45, init='nndsvd', random_state=2021, max_iter=2000)
model2.fit(X)
nmf_features2 = model2.transform(X)
show_components(model2, plot_on_axes, 5, (10, 18))

The components are much 'smaller' now and represent certain strokes and dots in places across the whole image. Combining these in a similar way to before should yield more accurate reconstructions because the inverse transform is creating an image from smaller pieces. 

In [None]:
sample_index = 0

fig, axes = plt.subplots(1, 2, figsize=(10, 10))

plot_on_axes(X[sample_index], (28, 28), axes[0])
axes[0].set_title(f'Original sample')
axes[0].get_xaxis().set_visible(False)
axes[0].get_yaxis().set_visible(False)

plot_on_axes(model2.inverse_transform(nmf_features2[sample_index]), (28, 28), axes[1])
axes[1].set_title(f'Reconstructed image')
axes[1].get_xaxis().set_visible(False)
axes[1].get_yaxis().set_visible(False)

plt.show()

Let's have a look at the top components for the digit 3 now and see whether the 'smaller' components have made a difference. 

In [None]:
three_features = nmf_features2[np.where(y == 3)[0]]
features_display = np.mean(three_features, axis=0).argsort()[-10:][::-1]

fig, axs = plt.subplots(2, 5, figsize=(12, 5.5))
fig.suptitle('Key components in "3"s')
for i, feature in enumerate(features_display):
    plot_on_axes(model2.components_[feature], (28, 28), axs.flatten()[i])
    axs.flatten()[i].set_title(f'Component {feature + 1}: {round(np.mean(three_features, axis=0)[feature], 2)}')
    axs.flatten()[i].get_xaxis().set_visible(False)
    axs.flatten()[i].get_yaxis().set_visible(False)
plt.show() 

As the patterns are now smaller strokes, it is more difficult to see how they might come together to create a digit. 

Let's take some components and plot them on top of one another to illustrate the way they represent different parts of the digit. When applied to images this is intuitive as the image 'completes' to represent a general digit class, with numeric data it is the same idea but is perhaps more difficult to visualise; the components represent patterns that are seen, and if you combine them you can use that to classify generally those samples that are similar to one another. 

In [None]:
colors = ["#8e44ad", "#2ecc71", "#3498db", "#f39c12", "#2c3e50", "#e74c3c", "#1abc9c", "#f1c40f", "#e84393", "#63cdda"]
plt.figure(figsize=(6, 6))
plt.axis('off')

handles = []

for i, feature in enumerate(features_display[:len(colors)]):
    # Create custom colourmap from white to color[N]
    cmap1 = LinearSegmentedColormap.from_list("mycmap", ["white", colors[i]])
    sample = model2.components_[feature]
    bitmap = sample.reshape((28, 28))
    patch = mpatches.Patch(color=colors[i], label=f'Component {feature + 1}')
    handles.append(patch)
    
    
    # Mask those nearer white so that layering them works 
    bitmap = np.ma.masked_where(bitmap <0.01, bitmap)
    
    # Plot the sample in image space using custom cmap1
    plt.imshow(bitmap, cmap=cmap1, interpolation='nearest')
    
plt.legend(handles=handles, loc='center', bbox_to_anchor=(1, 0.5, 0.25, 0))
plt.title(f'Coloured representation of {len(colors)} components')
plt.show()

I think we can see the rough outline of a '3' here, but it would be good to test this out on all digits in the future. 

- - - 

Thanks for taking a look over this pattern detection using NMF notebook. I will keep adding and improving this, and I want to work on some more visualisations to try and bring those components together in a more approachable way.  

Please leave a comment with any feedback and suggestions, or corrections for that matter, I would love to hear what people think! If you've enjoyed this then please consider **upvoting** this notebook, it would be greatly appreciated. 
