In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange

import numpy as np

import torch
import torch.nn as nn
from torchvision import datasets, transforms

In [2]:
class MNIST_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            # conv layer 1
            # a n by n image and f by f filter/kernel -> (n-f+1) by (n-f+1) output if stride is 1
            # kernel_size sets kernel height = kernel width, common choices are 3 or 5
            # padding: adding layers of zeros to input images
            nn.Conv2d(1, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            # Max pooling
            nn.MaxPool2d(kernel_size=2),

            # conv layer 2
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            # fc layer 1
            nn.Flatten(),
            nn.Linear(7*7*64, 256),
            nn.ReLU(),

            # fc layer 2
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.layers(x)

In [3]:
## Load data - as usual
# in principle, we need a validation dataset to tune the hyperparameters for early stop
# here, I simple use test dataset as validation dataset for demonstration purpose
# one should avoid doing this since it can lead to bias in practise
mnist_train = datasets.MNIST(root="./datasets", train=True, transform=transforms.ToTensor(), download=True)
mnist_test = datasets.MNIST(root="./datasets", train=False, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=100, shuffle=False)

## Training
model = MNIST_CNN()
criterion = nn.CrossEntropyLoss()
# use Adam, instead of SGD, as it does well for neural networks
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

best = 0
count = 0
for epoch in trange(10):
    for images, labels in tqdm(train_loader):
        optimizer.zero_grad()
        loss = criterion(model(images), labels)
        loss.backward()
        optimizer.step()
    
    ## Testing - as usual
    correct = 0
    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            predictions = torch.argmax(model(images), dim=1)
            correct += torch.sum((predictions == labels).int())
    
    print(f'Test accuracy: {correct/len(mnist_test)}')

    if correct > best:
        count = 0
        best = correct
        torch.save(model.state_dict(),'./models/CNN_Early-Stop.pt')
    else:
        count += 1
        if count >= 2:
            print(f'Early stop @ epoch #{epoch}')
            break

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


Test accuracy: 0.9868000149726868


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


Test accuracy: 0.9861000180244446


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


Test accuracy: 0.989300012588501


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


Test accuracy: 0.991100013256073


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


Test accuracy: 0.9904000163078308


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


Test accuracy: 0.9930999875068665


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


Test accuracy: 0.9904000163078308


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


Test accuracy: 0.9901999831199646


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=600.0), HTML(value='')))





KeyboardInterrupt: 