Imports

In [None]:
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

Helper functions

In [None]:
def plot_example(X, y_true, y_pred):
    for i, (img, y_t, y_p) in enumerate(zip(X[:5].reshape(5, 28, 28),
                                                  y_true[:5],
                                                  y_pred[:5])):
        plt.subplot(151 + i)
        plt.imshow(img)
        plt.xticks([])
        plt.yticks([])
        plt.title(f"{y_t} -> {y_p}")
    plt.show()

Loading data

In [None]:
mnist = fetch_openml('mnist_784',
                     as_frame=False,
                     cache=False)

print(mnist.keys())
print(mnist.data.shape)

Preprocessing data

In [None]:
X, y = mnist.data, mnist.target
X, y = X.astype(np.float32), y.astype(np.int64)
X /= 255.

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=87)

assert(X_train.shape[0] + X_test.shape[0] == mnist.data.shape[0])

print(X_train.shape, y_train.shape)

Build neural network with Pytorch

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

mnist_dim = X.shape[1]
hidden_dim = int(mnist_dim / 8)
output_dim = len(np.unique(y))
print(mnist_dim, hidden_dim, output_dim)

In [None]:
class ClassifierModule(nn.Module):
    def __init__(self, input_dim=mnist_dim,
                 hidden_dim=hidden_dim, output_dim=output_dim,
                 dropout=0.5):
        super(ClassifierModule, self).__init__()
        
        self.dropout = nn.Dropout(dropout)
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)

    def forward(self, X):
        X = F.relu(self.hidden(X))
        X = self.dropout(X)
        X = self.output(X) # no softmax
        return X

In [None]:
from skorch import NeuralNetClassifier

torch.manual_seed(87)

net = NeuralNetClassifier(
    ClassifierModule,
    optimizer=torch.optim.SGD,
    criterion=nn.CrossEntropyLoss,
    max_epochs=20,
    lr=0.1,
    batch_size=32,
    iterator_train__shuffle=True,
    # device=device
    device='cpu'
)
net.fit(X_train, y_train)


Make predictions

In [None]:
from sklearn.metrics import accuracy_score

y_pred = net.predict(X_test)
print("Accuracy: ", accuracy_score(y_test, y_pred))

# wrong predictions
error_mask = y_pred != y_test
plot_example(X_test[error_mask],
             y_test[error_mask],
              y_pred[error_mask])



Convolutional Neural Network

In [None]:
X_cnn = X.reshape(-1, 1, 28, 28)

X_cnn_train, X_cnn_test, y_cnn_train, y_cnn_test = train_test_split(X_cnn, y, test_size=0.25, random_state=42)


In [None]:
class CNN(nn.Module):
    def __init__(self, dropout=0.5):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.conv2_drop = nn.Dropout2d(p=dropout)
        self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height
        self.fc2 = nn.Linear(100, 10)
        self.fc1_drop = nn.Dropout(p=dropout)

    def forward(self, x):
        x = torch.relu(F.max_pool2d(self.conv1(x), 2))
        x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        
        # flatten over channel, height and width = 1600
        x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
        
        x = torch.relu(self.fc1_drop(self.fc1(x)))
        x = self.fc2(x)
        return x

In [None]:
torch.manual_seed(87)

net = NeuralNetClassifier(
    CNN,
    optimizer=torch.optim.SGD,
    criterion=nn.CrossEntropyLoss,
    max_epochs=10,
    lr=0.1,
    batch_size=32,
    iterator_train__shuffle=True,
    device=device
    # device='cpu'
)
net.fit(X_cnn_train, y_cnn_train)

In [None]:
y_cnn_pred = net.predict(X_cnn_test)
print("Accuracy: ", accuracy_score(y_cnn_test, y_cnn_pred))

# wrong predictions
cnn_error_mask = y_cnn_pred != y_cnn_test
plot_example(X_cnn_test[cnn_error_mask],
             y_cnn_test[cnn_error_mask],
              y_cnn_pred[cnn_error_mask])