# Classification demo
SPIE Short course on Machine Learning for Image Restoration.  
Author: Jesse Wilson (jesse.wilson@colostate.edu).

Walk through training and evaluation of a convolutional network for handwritten digits classification. This code is provided for educational purposes.


# Preliminaries

In [None]:
# import libraries and set up GPU device for training
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from IPython.display import clear_output
import numpy as np
from random import randint
from torch import fft
from IPython.display import clear_output

# get available GPU 
# supports NVIDIA (CUDA), Intel (XPU), and Apple (MPS)
# (CAUTION: AI-generated code -- NOT validated on all systems!)
if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif hasattr(torch,"xpu") and torch.xpu_is_available():
    device = torch.device("xpu:0")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Selected device: {device}.")

In [None]:
# Load a dataset
batch_size=32

# load datasets, and automatically transform images to pytorch tensor format
transform = transforms.ToTensor()
dataset_train = datasets.MNIST(root='data',train=True,download=True,transform=transform)
dataset_val = datasets.MNIST(root='data',train=False,download=True,transform=transform)

# Set up dataloaders, which produce batches of data for training and validation
dataloader_train = DataLoader(dataset_train, batch_size, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size, shuffle=False)

# sanity check -- plot the first image from the training dataloader
img = next(iter(dataloader_train))[0][1].squeeze()
plt.imshow(img)
plt.colorbar()
plt.show()

# Neural network definition and quick passthrough test

In [None]:
# Simple CNN classifier neural network 
# Based roughly on Stevens, Antiga & Viehmann Deep Learning with Pytorch CH 8 example
#
# This sets up a new class, inheriting from the pytorch nn.Module class.
# At bare minimum we need to implement __init__() and forward() functions
class Net(nn.Module):
    def __init__(self):
        super().__init__() # this initializes the parent nn.Module

        # define network elements
        
        # convolutional front end (feature extraction)
        self.conv1 = nn.Conv2d(1,16,kernel_size=3,padding=1)
        self.act1 = nn.ReLU()
        self.pool1=nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16,16,kernel_size=3,padding=1)
        self.act2=nn.ReLU()
        self.pool2=nn.MaxPool2d(2)

        # fully connected back end (classifier)
        self.fc1=nn.Linear(7*7*16,32)
        self.act3=nn.ReLU()
        self.fc2=nn.Linear(32,10)

        # initialize learnable parameters
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                nn.init.constant_(m.bias,0)
        
    def forward(self,x):
        # first, pass through convolutional front end and extract features
        out = self.pool1(self.act1(self.conv1(x)))
        out = self.pool2(self.act2(self.conv2(out)))

        # flatten to a vector and pass through the fully connected classifier
        out = out.view(-1,7*7*16)
        out = self.act3(self.fc1(out))
        out = F.softmax(self.fc2(out),dim=1)
        
        return out

# instantiate our new class and assign it to the hardware device
net = Net().to(device)

In [None]:
# load a saved model
# DO NOT RUN THIS THE FIRST TIME AROUND
net.load_state_dict(torch.load("classification_demo.pth", weights_only=True))
net.to(device)

In [None]:
# quick check running data through the network
data = next(iter(dataloader_val))
x = data[0].to(device) # image
y = data[1] # label (target)

net.eval()
with torch.no_grad():
    yhat = net(x) # estimated label

# yhat is a vector of probabilities for each class
print(yhat[0])

# the predicted class is the one with the maximum probability
maxval, maxind = yhat.max(dim=1)
predicted = maxind

# show a random selection of images and labels
for plotInd in range(9):
    ind = randint(0,len(x)-1)
    
    plt.subplot(3,3,plotInd+1)
    plt.imshow(x[ind].cpu().squeeze())
    plt.title(f'predicted: {predicted[ind]}')
    plt.axis('off')

plt.show()


# Train the neural network

In [None]:
# training loop
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
loss_fn = nn.BCELoss() # binary cross-entropy loss

loss_train_vec = [] # history of training dataset loss
loss_val_vec = [] # history of validation set loss

n_epochs = 100

for epoch in range(n_epochs):
    net.train() # put network in training mode
    loss_this_epoch = 0
    n_samples = 0

    # iterate through the training dataset
    for data, label in dataloader_train:
        # prep data
        x = data.to(device)
        y = F.one_hot(label,10).to(torch.float).to(device) # convert label to one-hot coding

        # pass through network and evaluate loss function
        yhat = net(x)
        loss = loss_fn(yhat, y)

        # gradient descent step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # keep track of loss function values over time
        loss_this_epoch += loss.sum()
        n_samples += len(x)

    loss_this_epoch = loss_this_epoch / n_samples
    loss_train_vec += [loss_this_epoch.item()]

    
    net.eval() # put network in evalution mode
    loss_this_epoch = 0
    n_samples = 0
    # iterate through validation set
    for data, label in dataloader_val:
        with torch.no_grad(): # DON'T CALCULATE GRADIENTS!
            x = data.to(device)
            y = F.one_hot(label,10).to(torch.float).to(device) # convert label to one-hot coding

            # pass through network and evaluate loss function
            yhat = net(x)
            loss = loss_fn(yhat, y)
    
            loss_this_epoch += loss.sum()
            n_samples += len(x)
    
    loss_this_epoch = loss_this_epoch / n_samples
    loss_val_vec += [loss_this_epoch.item()]

    # save the model if it achieved a minimum validation loss
    if loss_val_vec[-1] == min(loss_val_vec):
        torch.save(net.state_dict(), "classification_demo.pth")

    # plot losses
    clear_output(wait=True)
    plt.plot(loss_train_vec)
    plt.plot(loss_val_vec)
    plt.title('losses')
    plt.xlabel('epoch')
    plt.legend(['training','validation'])
    plt.grid()
    plt.show()
        

In [None]:
# your turn: change one thing above and run it again. A few ideas
# - change batch size
# - change learning rate
# - change from Adam to SGD optimizer
# - number of channels per convolutional filter
# - number of layers in the network
# - number of neurons in FC1
# - change ReLU to LeakyReLU or Sigmoid activation function
# - change binary cross entropy to L1 or MSE loss functions

# Extras

In [None]:
# try something out of distribution
batch_size=32

transform = transforms.ToTensor()
dataset_ood = datasets.FashionMNIST(root='data',train=False,download=True,transform=transform)

dataloader_ood = DataLoader(dataset_ood, batch_size, shuffle=True)

img = next(iter(dataloader_ood))[0][1].squeeze()
plt.imshow(img)
plt.colorbar()
plt.show()

In [None]:
# quick check running data through the network
data = next(iter(dataloader_ood))
x = data[0].to(device) # image
y = data[1] # label (target)

net.eval()
with torch.no_grad():
    yhat = net(x) # estimated label
    
maxval, maxind = yhat.max(dim=1)
predicted = maxind
confidence = maxval

# show a random selection of images and labels
for plotInd in range(9):
    ind = randint(0,len(x)-1)
    
    plt.subplot(3,3,plotInd+1)
    plt.imshow(x[ind].cpu().squeeze())
    plt.title(f'pred: {predicted[ind]} ({maxval[ind]*100 :.2f}%)')
    plt.axis('off')

plt.show()