## CNN (convolutional neural network) MNIST examples

Wei Li

In [None]:
import os
import torch
from torch import nn
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt

import time
import random

# random_seed = 123
# os.environ["PL_GLOBAL_SEED"] = str(random_seed)
# random.seed(random_seed)
# np.random.seed(random_seed)

In [None]:
# %pip install watermark
%load_ext watermark
%watermark -a "Wei Li" -u -t -d -v -p numpy,torch,torchvision

In [None]:
# From local helper files
from utils_evaluation import set_all_seeds, set_deterministic, evaluate_epoch_loss, evaluate_epoch_metrics, get_predictions
from utils_plotting import plot_accuracy, plot_loss, show_images, plot_confusion_matrix
from utils_data import get_dataloaders_mnist

In [None]:
## Setting 
RANDOM_SEED = 2022
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
set_all_seeds(RANDOM_SEED)
set_deterministic()

In [None]:
##########################
### MNIST DATASET 
##########################

BATCH_SIZE = 512

# Compose a series of image transformations.
transformCompose = transforms.Compose(
    [
        transforms.Resize((32, 32)), # Resize the input image to 32x32 pixels.
        transforms.ToTensor(),        # Convert the image to a PyTorch tensor, in range [0. 255] to [0.0, 1.0]
        transforms.Normalize((0.5,), (0.5,)), # Normalize the tensor image with mean 0.5 and standard deviation 0.5.
    ]
)
# after transformation, pixel value is centered at 0 and range [-1, 1].

train_loader, valid_loader, test_loader = get_dataloaders_mnist(
    batch_size=BATCH_SIZE,
    validation_fraction=0.2,
    train_transforms=transformCompose,
    test_transforms=transformCompose)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)  # NCHW
    print('Image label dimensions:', labels.shape)
    print('Class labels of 10 examples:', labels[:10])
    break

# len(train_loader) # 93 minibatches

In [None]:
print("Train batches        : ", train_loader.__len__())
print("Val batches          : ", valid_loader.__len__())
print("Test batches        : ", test_loader.__len__())

In [None]:
def extract_data_from_dataloader(dataloader, num_images=10):
    """
    Extract a specified number of images and labels from a DataLoader.

    Args:
        dataloader (DataLoader): The DataLoader to extract data from.
        num_images (int): Number of images to extract.

    Returns:
        Tuple of numpy arrays: (images, labels)
    """
    images, labels = next(iter(dataloader))
    return images[:num_images].numpy(), labels[:num_images].numpy()

train_images, train_labels = extract_data_from_dataloader(train_loader, num_images=10)


In [None]:
show_images(train_images, train_labels, num_images=10, normalize=True, mean=(0.5), std=(0.5))

## Model


In [None]:
# Define the LeNet-5 model
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.convnet = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.fc = nn.Sequential(
            nn.Linear(in_features=16*5*5, out_features=120),
            nn.Tanh(),
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=10)
        )

    def forward(self, x):
        x = self.convnet(x) # NCHW
        # x = x.view(x.size(0), -1)  # Flatten the tensor
        x = torch.flatten(x, 1) #N by CHW        
        x = self.fc(x)
        return x

In [None]:
# Hyperparameters
LR = 0.1
NUM_EPOCHS = 15

# Model, Loss and Optimizer
model = LeNet5().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)


In [None]:
start_time = time.time()
minibatch_loss_list, avg_loss_list,train_loss_list, val_loss_list, train_acc_list, val_acc_list = [],[], [], [], [], []

for epoch in range(NUM_EPOCHS):
    
    model.train()
    for batch_idx, (features, targets) in enumerate(train_loader):

        features = features.to(device)
        targets = targets.to(device)

        # ## FORWARD AND BACK PROP
        y_pred = model(features)
        loss = criterion(y_pred, targets)
        optimizer.zero_grad()

        loss.backward()

        # ## UPDATE MODEL PARAMETERS
        optimizer.step()

        # ## LOGGING
        minibatch_loss_list.append(loss.item())
        if not batch_idx % 50:
            print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} '
                  f'| Batch {batch_idx:04d}/{len(train_loader):04d} '
                  f'| Loss: {loss:.4f}')
            
    model.eval()
    with torch.no_grad():  # save memory during inference
        ### logging for running average of loss over all traversed minibatches
        # avg_loss = torch.mean(torch.FloatTensor(minibatch_loss_list)) 
        # or the average over most recent minibatches in last epoch
        avg_loss = torch.mean(torch.FloatTensor(minibatch_loss_list[-len(train_loader):])) 
        avg_loss_list.append(avg_loss)

        train_loss = evaluate_epoch_loss(model, train_loader, device=device, criterion = nn.CrossEntropyLoss())[1]
        val_loss = evaluate_epoch_loss(model, valid_loader, device=device,criterion = nn.CrossEntropyLoss())[1]

        train_acc = evaluate_epoch_metrics(model, train_loader, device=device)
        val_acc = evaluate_epoch_metrics(model, valid_loader, device=device)

        # Print train and test loss along with accuracy
        print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} '
            f'| Train Loss: {train_loss:.4f} '
            f'| Valid Loss: {val_loss:.4f} '
            f'| Train Acc: {train_acc:.2f}% '
            f'| Valid Acc: {val_acc:.2f}% ')
        
        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)        
        train_acc_list.append(train_acc)
        val_acc_list.append(val_acc)

    elapsed = (time.time() - start_time)/60
    print(f'Time elapsed: {elapsed:.2f} min')

elapsed = (time.time() - start_time)/60
print(f'Total Training Time: {elapsed:.2f} min')

print()
test_acc = evaluate_epoch_metrics(model, test_loader, device=device)
print(f'Test accuracy {test_acc :.2f}')

# benchmark:
# Epoch: 015/015 | Train Loss: 0.0516 | Valid Loss: 0.0589 | Train Acc: 98.59% | Valid Acc: 98.27% 
# Test accuracy 98.34

In [None]:
# Evaluation
plot_loss(train_loss_list=train_loss_list,
              valid_loss_list=val_loss_list,
              results_dir=None)
plt.show()

plot_accuracy(train_acc_list=train_acc_list,
              valid_acc_list=val_acc_list,
              results_dir=None)
plt.show()

#### Predictions


In [None]:
# Get predictions
test_images, test_labels, test_predictions = get_predictions(model, test_loader, device)

In [None]:
# Visualize the first 10 test images with predictions
show_images(
    test_images,
    test_labels,
    test_predictions,
    num_images=20,
    normalize=True,
    mean=(0.5),
    std=(0.5),
)

#### Confusion matrix

In [None]:
# label_dict for MNIST (optional, as it's a direct 0-9 mapping)
mnist_label_dict = {i: str(i) for i in range(10)}

plot_confusion_matrix(test_labels, test_predictions, mnist_label_dict)