
# Training lenet 5 on mnist with pytorch lighnting

* **Author:** Rasmus Johansson
Based on the hello world example from https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/mnist-hello-world.html


Train a pytorch implementation of the [Leenet5 model](https://en.wikipedia.org/wiki/LeNet) on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).
Also shows how to logg/plot loss and accuracy values more frequently than once per epoch

## Setup
This notebook requires some packages besides pytorch-lightning.

In [None]:
! pip install "pandas" "ipython[notebook]" "torchvision" "setuptools==59.5.0" "torch>=1.8" "torchmetrics>=0.7" "seaborn" "pytorch-lightning>=1.4"

In [None]:
import os

import pandas as pd
import seaborn as sn
import torch
from IPython.core.display import display
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

## Lenet5 model

A basic lenet5 implementation.
Extended so you can change the activationfunction.



In [None]:
class Lenet5MNIST(LightningModule):
    def __init__(self, data_dir=PATH_DATASETS,  learning_rate=2e-4,activation_function=nn.Tanh):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir

        self.learning_rate = learning_rate
        
        #In order to plot loss and accuracy vs fractions of epochs finnished we need to keep track of how many epochs are finnished in float format (e.g 0.7 or 3.4)        
        self.epochs_done_as_float=0

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # Define the lenet5 inspired PyTorch model parameters taken from (https://en.wikipedia.org/wiki/File:Comparison_image_neural_networks.svg)
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1,padding=2),
            activation_function(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            activation_function(),
            nn.AvgPool2d(kernel_size=2),
            
            nn.Flatten(),
            #after flattening we now have 400 different values to connect to the 120 hidden units
            nn.Linear(in_features=400, out_features=120),
            activation_function(),
            nn.Linear(in_features=120, out_features=84),
            activation_function(),
            nn.Linear(in_features=84, out_features=self.num_classes),
        )
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1,padding=2)
        self.pool1=nn.AvgPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.pool2=nn.AvgPool2d(kernel_size=2)
        self.flatten =nn.Flatten()
        
        self.linear=nn.Linear(in_features=400, out_features=self.num_classes)
        

        
        #We use torchmetrics accuracy object for keeping track of the accuracies for the different divisions of the MNIST-dataset(https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html)
        self.val_accuracy = Accuracy()
        self.test_accuracy = Accuracy()
        self.train_accuracy = Accuracy()

    def forward(self, x):
       
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        
        #compute accuracy on batch and return result. 
        #We use forward() instead of update() because we want to know how accuray on the trainingset changes for fractions of epochs
        #forward()  updates the accuracy for the complete dataset (in by calling update() but also returns the accuracy for the current batch/input
        #https://torchmetrics.readthedocs.io/en/stable/pages/overview.html
        accuracy= self.train_accuracy.forward(preds,y)
        
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc",self.train_accuracy , prog_bar=False)# 
        
        #log how many epochs we have finnished (e.g 1.25 epochs) so we can use this as x axis when plotting
        self.epochs_done_as_float = self.current_epoch+ (batch_idx/self.trainer.num_training_batches)
        self.log("epochs_as_float",self.epochs_done_as_float , prog_bar=False)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)
        

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=False)
        self.log("val_acc", self.val_accuracy, prog_bar=False)
        #log how many epochs we have finnished (e.g 1.25 epochs) so we can use this as x axis when plotting
        self.log("epochs_as_float",self.epochs_done_as_float , prog_bar=False)
        
        
        
        

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)
        self.log("epochs_as_float",self.epochs_done_as_float , prog_bar=False)
        
        

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

In [None]:
#MNIST is a small dataset where it should be fine to wait for a complete epoch before testing the model on the validationset
#if you switch for a larger dataset you might want to consider validating more frequently.
#this can be done by setting val_check_interval to the fraction of a complete batch that should pass beteeen each validation
#We will validate 4 times for every epoch
val_check_interval =0.25


In [None]:
model = Lenet5MNIST(activation_function=nn.ReLU)
trainer = Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    max_epochs=4,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    logger=CSVLogger(save_dir="logs/"),
    val_check_interval=val_check_interval,
    log_every_n_steps=50
    # limiting how often self.log("log message") should save stuff to disk. 



)



trainer.fit(model)

# Testing

In [None]:
#test the model on the test partition of the MNIST dataset
trainer.test(model)

# Plot values per epoch

In [None]:


#load the csv file to which the logger has saved the losses and accuracies, as a pandas dataframe
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")

#plot values per epoch

del metrics["step"]
del metrics["epochs_as_float"]

metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())

sn.relplot(data=metrics, kind="line")

In [None]:
# plot values per fraction of epoch

In [None]:
#load the csv file to which the logger has saved the losses and accuracies, as a pandas dataframe
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")

#plot values per epoch

del metrics["step"]
del metrics["epoch"]
metrics.set_index("epochs_as_float", inplace=True)
#plot the 
display(metrics.dropna(axis=1, how="all").head())

sn.relplot(data=metrics, kind="line")