<a href="https://colab.research.google.com/github/tsakailab/prml/blob/master/ipynb/MNIST_GaussianNB_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%matplotlib inline

Gaussian naive Bayes classification of the MNIST digits dataset
=============================================================


Fetch the MNIST digits dataset
------------------------------------
If fail, retry after rebooting the runtime or kernel.

In [None]:
import numpy as np
import pandas as pd
mnist = pd.read_csv('/content/sample_data/mnist_test.csv', header=None)
#mnist = pd.read_csv('https://github.com/tsakailab/prml/raw/master/datasets/mnist_test.csv', header=None)

y = mnist.iloc[:,0].to_numpy()
Ximages = mnist.drop(columns=0).to_numpy().reshape(-1,28,28)
print("(#images, height, width)", Ximages.shape)

X = np.reshape(Ximages, (Ximages.shape[0],-1))
print("X.shape = ", X.shape)

X = X / X.max()
y = np.int64(y)

c = 10
lbl = range(c)

Plot the data: images of digits
-------------------------------


In [None]:
from matplotlib import pyplot as plt
fig = plt.figure(figsize=(6, 6))  # figure size in inches
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
print("64 out of %d images" % len(y))

p = np.random.randint(0, len(y), 64)
for i in range(64):
    ax = fig.add_subplot(8, 8, i + 1, xticks=[], yticks=[])
    ax.imshow(Ximages[p[i]], cmap=plt.cm.gray)
    # label the image with the target value
    ax.text(0, 7, str(y[p[i]]), color='white')

Choose two classes if you enjoy binary classification
-----------------------------------------------------------------
Skip this cell for ten classes.

In [None]:
c = 2
pos = 1 # choose from 0 to 9
neg = 0 # choose from 0 to 9

X = X[np.logical_or(y == pos, y == neg),:]
y = y[np.logical_or(y == pos, y == neg)]
yp, yn = y == pos, y== neg
y[yp] = 1
y[yn] = 0
lbl = [neg, pos]

Split the data into training and test sets
--------------------------------------------------

In [None]:
from sklearn.model_selection import train_test_split

# split the data into training and validation sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
if len(y_test) > 2000:
    X_test = X_test[:2000,:]
    y_test = y_test[:2000]

print("(#training data, dim.)=", X_train.shape)
print("(#test data,)=", X_test.shape)

Run the training
---------------------------------



In [None]:
# Gaussian Naive Bayes
from sklearn.naive_bayes import GaussianNB

clf = GaussianNB()
clf.fit(X_train, y_train)

#number of classes
print("# of classes: ", len(clf.class_count_))

#probability of each class
print("Prior probs: ", clf.class_prior_)

In [None]:
# use the model to predict the labels of the test data
predicted = clf.predict(X_test)
expected = y_test

# Plot the prediction
fig = plt.figure(figsize=(6, 6))  # figure size in inches
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)

# plot the digits
idx64 = np.random.choice(len(y_test), 64, replace=False)
for j in range(64):
    i = idx64[j]
    ax = fig.add_subplot(8, 8, j + 1, xticks=[], yticks=[])
    ax.imshow(X_test.reshape(-1, Ximages.shape[1], Ximages.shape[2])[i], cmap=plt.cm.gray)

    # label the image with the target value
    ax.text(0, 7, str(lbl[expected[i]]), color='white')
    if predicted[i] == expected[i]:
        ax.text(21, 7, str(lbl[predicted[i]]), color='#a0ffa0')
    else:
        ax.text(21, 7, str(lbl[predicted[i]]), color='red')

# the number of correct matches / the total number of data points
matches = (predicted == expected)
score = matches.sum()/float(len(matches))
print("%d / %d = %2.1f %%" % (matches.sum(), len(matches), 100*score))

Quantify the performance detail
------------------------
Print the classification report

In [None]:
from sklearn import metrics
print(metrics.classification_report(expected, predicted))

Print the confusion matrix



In [None]:
import seaborn as sns

# Make predictions on test data
cm = metrics.confusion_matrix(expected, predicted)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

plt.figure(figsize=(9,9))
sns.heatmap(cm_normalized, annot=True, fmt=".3f", linewidths=.5, square = True, cmap = 'Blues_r');
plt.ylabel('Actual label');
plt.xlabel('Predicted label');
all_sample_title = 'Accuracy Score: {:.3f}'.format(score) 
plt.title(all_sample_title, size = 15);