<a href="https://colab.research.google.com/github/vlassner/DSML_4220_Deep_Learning/blob/main/Lab5_lenet_w_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Lab 5: LeNet from Scratch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sgeinitz/DSML4220/blob/main/lab5_lenet_w_mnist.ipynb)

[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/sgeinitz/DSML4220/blob/main/lab5_lenet_w_mnist.ipynb)

In this notebook we'll build one of [LeNet CNNs](https://en.wikipedia.org/wiki/LeNet) from scratch. In addition to that, there are a few new Python modules/tools we'll use. These are:
* [torchvision](https://pytorch.org/vision/stable/index.html) - a supplementary PyTorch module with popular image datasets, including [MNIST](https://en.wikipedia.org/wiki/MNIST_database)
* [captum](https://captum.ai) - a tool that provides some utilities for model explainability and interpretability
* [tensorboard](https://www.tensorflow.org/tensorboard/get_started) - developed for TensorFlow, TensorBoard provides model experimentation tools similar to Weights and Biases ([wandb.ai](https://wandb.ai)), and works with PyTorch too


### Lab 5 Assignment/Task

This lab has only three questions.  As before, you'll submit the link to your completed notebook with your answers to the questions.

# Dataset Preparation

## Dowload the dataset

In [None]:
from torchvision import datasets

In [None]:
train_val_dataset = datasets.MNIST(root="~/data/lenet/", train=True, download=True)
test_dataset = datasets.MNIST(root="~/data/lenet/", train=False, download=True)

## Analyze the dataset

In [None]:
train_val_dataset, test_dataset

In [None]:
class_names = train_val_dataset.classes
class_names[0]

In [None]:
class_index = train_val_dataset.class_to_idx
class_index

In [None]:
img, label = train_val_dataset[0]
img, label

In [None]:
import matplotlib.pyplot as plt

plt.imshow(img, cmap='gray')
plt.title(f"{class_names[label]}");

In [None]:
import numpy as np

def img_pixel_superimpose(img, ax):
    w, h = img.shape
    color_map = plt.cm.get_cmap('gray_r')  # gray_reversed
    ax.imshow(img, cmap='gray')
    for x in range(w):
        for y in range(h):
            color = color_map(img[x][y])
            ax.annotate(str(img[x][y]), xy=(y,x), horizontalalignment='center', verticalalignment='center',
                        color=color)
            plt.axis(False)


In [None]:
import torch

torch.manual_seed(42)

fig = plt.figure(figsize = (12,12))
ax0 = fig.add_subplot(1, 1, 1)

rand_ind = torch.randint(0, len(train_val_dataset), size=[1]).item()

img0 = train_val_dataset.data[rand_ind]
img0 = img0.numpy()
img_pixel_superimpose(img0, ax0)

In [None]:
import torch
torch.manual_seed(42)     # Search on the internet why '42' is special or
                          # even '42 * 2' = 84

fig = plt.figure(figsize=(16, 4))
rows, cols = 2, 10

for i in range(1, (rows*cols) + 1):
    rand_ind = torch.randint(0, len(train_val_dataset), size=[1]).item()
    img, lab = train_val_dataset[rand_ind]
    fig.add_subplot(rows, cols, i)
    plt.imshow(img, cmap='gray')
    plt.title(f"{class_names[lab]}")
    plt.axis(False)
    plt.tight_layout()

## Transform the dataset

In [None]:
from torchvision import transforms

### ToTensor()

In [None]:
train_val_dataset = datasets.MNIST(root="~/data/lenet/", train=True, download=False, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root="~/data/lenet/", train=False, download=False, transform=transforms.ToTensor())
train_val_dataset, test_dataset

In [None]:
img, lab = train_val_dataset[0]
img, lab

In [None]:
img.min(), img.max()  # The ToTensor() transformation scaled down 0-255 --> 0-1

In [None]:
plt.imshow(img.squeeze_(), cmap='gray')
plt.title(f"{class_names[lab]}");

### Normalization

In [None]:
# Calculate mean and std

imgs = torch.stack([img for img, _ in train_val_dataset], dim=0)
print(imgs.shape)

In [None]:
mean = imgs.view(1, -1).mean(dim=1)    # or imgs.mean()
std = imgs.view(1, -1).std(dim=1)     # or imgs.std()
mean, std

In [None]:
mnist_transforms = transforms.Compose([transforms.ToTensor(),
                                       transforms.Normalize(mean=mean, std=std)])
mnist_transforms

In [None]:
train_val_dataset = datasets.MNIST(root="~/data/lenet/", train=True, download=False, transform=mnist_transforms)
test_dataset = datasets.MNIST(root="~/data/lenet/", train=False, download=False, transform=mnist_transforms)
train_val_dataset, test_dataset

In [None]:
img, label = train_val_dataset[0]
img, label

In [None]:
img.min(), img.max()

In [None]:
plt.imshow(img.squeeze_(), cmap='gray')
plt.title(f"{class_names[label]}");

## Split dataset into Train/Val/Test

In [None]:
train_size = int(0.9 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset=train_val_dataset, lengths=[train_size, val_size])
train_dataset, val_dataset

In [None]:
len(train_dataset), len(val_dataset), len(test_dataset)

In [None]:
len(train_dataset), len(train_dataset.dataset)  # Remember train_dataset.dataset access parent train_val_dataset

In [None]:
# Validate train dataset is working fine
img, label = train_dataset[0]
plt.imshow(img.squeeze_(), cmap='gray')
plt.title(f"{class_names[label]}");

## Dataloader preparation

In [None]:
from torch.utils.data import DataLoader

In [None]:
BATCH_SIZE = 32

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_dataloader, val_dataloader, test_dataloader

In [None]:
img, label = train_dataloader.dataset[0]
img.shape, label

In [None]:
# Validate the dataloader is working fine
plt.imshow(img.squeeze_(), cmap='gray')
plt.title(f"{class_names[label]}");

In [None]:
no_train_batches = len(train_dataloader.dataset) / train_dataloader.batch_size
no_val_batches = len(val_dataloader.dataset) / val_dataloader.batch_size
no_test_batches = len(test_dataloader.dataset) / test_dataloader.batch_size

# Let's see no of batches that we have now with the current batch-size
no_train_batches, no_val_batches, no_test_batches

In [None]:
len(train_dataloader), len(val_dataloader), len(test_dataloader)   # Actual lengths show wrapping at the end

# Model Architecture, Construct Training & Evaluation Pipeline

In [None]:
from torch import nn

In [None]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature = nn.Sequential(
            #1
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),   # 28*28->32*32-->28*28
            nn.ReLU(),  # nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # 14*14

            #2
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),  # 10*10
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # 5*5

        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=16*5*5, out_features=120),
            nn.BatchNorm1d(120), # not part of original LeNet
            nn.Sigmoid(),
            nn.Linear(in_features=120, out_features=84),
            nn.Sigmoid(),
            nn.Linear(in_features=84, out_features=10),
        )

    def forward(self, x):
        return self.classifier(self.feature(x))


In [None]:
lenetmodel = LeNet()
lenetmodel

### Model summary

In [None]:
 !pip install torchinfo

from torchinfo import summary

In [None]:
summary(model=lenetmodel, input_size=(1, 1, 28, 28), col_width=20, col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0)

### Loss, Optimizer, Metrics

In [None]:
 !pip install torchmetrics

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=lenetmodel.parameters(), lr=0.001)

from torchmetrics import Accuracy
accuracy = Accuracy(task='multiclass', num_classes=10)

### Train model

In [None]:
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter

from datetime import datetime
import os

# Experiment tracking
timestamp = datetime.now().strftime("%Y-%m-%d")
experiment_name = "MNIST"
model_name = "LeNet"
log_dir = os.path.join("runs", timestamp, experiment_name, model_name)
writer = SummaryWriter(log_dir)

# device-agnostic setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
accuracy = accuracy.to(device)
lenetmodel = lenetmodel.to(device)

EPOCHS = 12

for epoch in tqdm(range(EPOCHS)):
    # Training loop
    train_loss, train_acc = 0.0, 0.0
    for X, y in train_dataloader:
        X, y = X.to(device), y.to(device)

        lenetmodel.train()

        y_pred = lenetmodel(X)

        loss = loss_fn(y_pred, y)
        train_loss += loss.item()

        acc = accuracy(y_pred, y)
        train_acc += acc

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)

    # Validation loop
    val_loss, val_acc = 0.0, 0.0
    lenetmodel.eval()
    with torch.inference_mode():
        for X, y in val_dataloader:
            X, y = X.to(device), y.to(device)

            y_pred = lenetmodel(X)

            loss = loss_fn(y_pred, y)
            val_loss += loss.item()

            acc = accuracy(y_pred, y)
            val_acc += acc

        val_loss /= len(val_dataloader)
        val_acc /= len(val_dataloader)

    writer.add_scalars(main_tag="Loss", tag_scalar_dict={"train/loss": train_loss, "val/loss": val_loss}, global_step=epoch)
    writer.add_scalars(main_tag="Accuracy", tag_scalar_dict={"train/acc": train_acc, "val/acc": val_acc}, global_step=epoch)

    print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")


### Save model

In [None]:
from pathlib import Path

MODEL_PATH = Path("~/models/lenet/")
MODEL_PATH.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "lenet.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME


In [None]:
print(f"Saving the model: {MODEL_SAVE_PATH}")
torch.save(obj=lenetmodel.state_dict(), f=MODEL_SAVE_PATH)

### Load model

In [None]:
lenetmodel_loaded = LeNet()
lenetmodel_loaded.load_state_dict(torch.load(MODEL_SAVE_PATH))
lenetmodel_loaded

### Evaluate model

In [None]:
test_loss, test_acc = 0, 0

lenetmodel_loaded.to(device)

lenetmodel_loaded.eval()
with torch.inference_mode():
    for X, y in test_dataloader:
        X, y = X.to(device), y.to(device)
        y_pred = lenetmodel_loaded(X)

        test_loss += loss_fn(y_pred, y)
        test_acc += accuracy(y_pred, y)

    test_loss /= len(test_dataloader)
    test_acc /= len(test_dataloader)

print(f"Test loss: {test_loss: .5f}| Test acc: {test_acc: .5f}")

In [None]:
# See random images with their labels
torch.manual_seed(42)  # setting random seed
import matplotlib.pyplot as plt
%matplotlib inline
fig = plt.figure(figsize=(12, 4))

rows, cols = 2, 6
for i in range(1, (rows * cols) + 1):
    random_idx = torch.randint(0, len(test_dataset), size=[1]).item()
    img, label_gt = test_dataset[random_idx]
    img_temp = img.unsqueeze(dim=0).to(device)
    # print(img.shape)
    label_pred = torch.argmax(lenetmodel_loaded(img_temp))
    fig.add_subplot(rows, cols, i)
    img = img.permute(1, 2, 0)    # CWH --> WHC
    plt.imshow(img, cmap='gray')
    if label_pred == label_gt:
        plt.title(class_names[label_pred], color='g')
    else:
        plt.title(class_names[label_pred], color='r')
    plt.axis(False)
    plt.tight_layout();

### Model Interpretability/Explainability

Let's now use `Captum` to try to understand more about why the model makes the predictions that it does. Specifically, we want to try to understand what it is about one observation that causes the model to make the prediction that it does.

In [None]:
!pip install captum
from captum.attr import IntegratedGradients, Occlusion

This first approach uses the parameter gradients but calculated for one specific observation. This may sound odd, particularly since we are not training, but calculating the gradient will give us idea of how the 'surface' for this one observation differs from the average 'surface' that the model was trained to find a minimum for. This will give us an idea of which features of this image (which pixels) make it most unique.

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(109)#108)

random_idx = torch.randint(0, len(test_dataset), size=[1]).item()
img, label = test_dataset[random_idx]
img = img.to(device)

# Instantiate an IntegratedGradients object for the model
ig = IntegratedGradients(lenetmodel_loaded)

# Compute the attribution scores for the random image
attr, delta = ig.attribute(img.unsqueeze(0), target=label, return_convergence_delta=True)

# Visualize the attribution scores
fig, ax = plt.subplots(1, 2)

ax[0].imshow(img.permute(1, 2, 0).to('cpu'), cmap='gray')
ax[1].imshow(attr[0][0].detach().to('cpu').numpy(), cmap='gray')

Next, let's try a different tool that is based on hiding part of the input and seeing how much prediction differs. Similar to a convolutional kernel, this `Feature Occlusion`, of _feature hiding_ approach will use a small sliding window over the input. When an area is hidden, and then yields a much different prediction than it would if it were not hidden, are considered more influential.

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(109)

random_idx = torch.randint(0, len(test_dataset), size=[1]).item()
img, label = test_dataset[random_idx]
img = img.to(device)

# Instantiate an IntegratedGradients object for the model
oc = Occlusion(lenetmodel_loaded)

# Compute the attribution scores for the random image
attr = oc.attribute(img.unsqueeze(0), target=label, strides=(1, 3, 1), sliding_window_shapes=(1, 3, 3), baselines=0)

# Visualize the attribution scores
fig, ax = plt.subplots(1, 2)

ax[0].imshow(img.permute(1, 2, 0).to('cpu'), cmap='gray')
ax[1].imshow(attr[0][0].detach().to('cpu').numpy(), cmap='gray')

---

Q1: The stride parameter is a 3-tuple with values for channel, horizontal, and vertical shifts. We want to leave the channel parameter at 1 since the inputs are only 1-channel, but...

* Q1a: Try modifying the horizontal stride parameter though (e.g. `stride=(1,3,1)`), how does the output image differ?

The image is slightly sharper with the stride being (1,3,1), but overall they are still very similar.

* Q1b: Now change the horizontal parameter back to 1 and try changing the vertical stride parameter so that `stride=(1,1,3)`. Now how does the output immage differ?

It is slightly more clearer compared to (1,1,1) and (1,3,1). The shape of the four is definitly more identifiable.

---

### Visualize model metrics

Similar to Weights and Biases, the module TensorBoard is a machine learning experimentation tool for tracking and visualizing metrics from training runs.

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir "runs/"

---

#### Q2: Change the sigmoid activation functions to use ReLU instead. How does this change the results? Specifically, what the Validation accuracy when using Sigmoid? Then, what was it when using ReLU instead of Sigmoid?

The validation accuracy for sigmoid is 0.98172 and for ReLU it was 0.98321 which was slightly higher.

---

---

#### Q3: Add Batch normalization to the first linear layer in the model.

* Q3a: Does using Batch Normalization improve the performance (accuracy) of the model? By how much? (_recall that batch normalization was not discovered/used until around 2015, so that is why it was not used in LeNet_)

It did increase the accuracy by 0.004 with a 0.98753 score.

* Q3b: Does using Batch Normalization increase the number of parameters in the model? If so, how many more parameters are there? Can you explain why?

The number of parameters for batch normalization was 61,946. Without BN, there are  61,706. BN adds two new parameters: stretch and shift, which allows it more flexibility to learn better. This causes a slight increase in the number of parameters.

---