
# [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/skojaku/applied-soft-comp/blob/master/notebooks/lenet.ipynb)


# LeNet-1 for MNIST

In this notebook, we will implement LeNet-1 for MNIST dataset. We will first train the model on the MNIST dataset and then create an interactive digit recognizer using the trained model.


![](https://production-media.paperswithcode.com/datasets/MNIST-0000000001-2e09631a_09liOmx.jpg)


## Install libraries


In [1]:
# Uncomment the following line to install the libraries
# !pip install ipywidgets pillow ipycanvas pytorch_lightning


In [2]:
# If you are using Google Colab, uncomment the following line to enable the custom widget manager
#from google.colab import output
#output.enable_custom_widget_manager()

## LeNet-1
Let us first define the LeNet-1 model using PyTorch Lightning. We note that this is not a faithful implementation of LeNet-1. We will use some modern techniques such as the Adam optimizer to speed up the training process.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchmetrics import Accuracy


class LeNet1(pl.LightningModule):
    """
    PyTorch Lightning implementation of LeNet-1
    Includes training, validation, and test functionality
    """

    def __init__(self, learning_rate=1e-3):
        super(LeNet1, self).__init__()
        self.save_hyperparameters()

        # Metrics
        self.train_accuracy = Accuracy(task="multiclass", num_classes=10)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=10)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=10)

        # First convolutional layer (1x28x28 -> 4x24x24)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5, stride=1)

        # Average pooling layer (4x24x24 -> 4x12x12)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

        # Second convolutional layer (4x12x12 -> 12x8x8)
        self.conv2 = nn.Conv2d(in_channels=4, out_channels=12, kernel_size=5, stride=1)

        # Fully connected layer (12*4*4=192 -> 10)
        self.fc = nn.Linear(12 * 4 * 4, 10)

        # Initialize weights
        self._init_weights()

        # Initialize validation losses
        self.val_losses = []
        self.train_losses = []

    def _init_weights(self):
        """Initialize weights using Xavier initialization"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        # First conv block
        x = self.conv1(x)
        x = torch.tanh(
            x
        )  # while the original paper does not mention the activation function, we use tanh here
        x = self.pool(x)

        # Second conv block
        x = self.conv2(x)
        x = torch.tanh(
            x
        )  # while the original paper does not mention the activation function, we use tanh here
        x = self.pool(x)

        # Flatten and fully connected
        x = x.view(-1, 12 * 4 * 4)
        x = self.fc(x)
        return x

    def configure_optimizers(self):
        """Define optimizers and LR schedulers"""
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.1, patience=10, verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"},
        }

    def training_step(self, batch, batch_idx):
        """Training step"""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)

        # Log metrics
        acc = self.train_accuracy(logits, y)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)

        self.train_losses.append({"loss": loss.item(), "acc": acc.item()})

        return loss

    def validation_step(self, batch, batch_idx):
        """Validation step"""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)

        # Log metrics
        acc = self.val_accuracy(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        self.val_losses.append({"loss": loss.item(), "acc": acc.item()})

    def test_step(self, batch, batch_idx):
        """Test step"""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)

        # Log metrics
        acc = self.test_accuracy(logits, y)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True)

In [4]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import Button, HBox, VBox, HTML
from IPython.display import display
import io
import PIL.Image
import PIL.ImageDraw
from torchvision import transforms
import ipywidgets as widgets
from ipycanvas import Canvas

class DigitRecognizer:
    def __init__(self, model_path):
        # Load the trained model
        self.model = LeNet1.load_from_checkpoint(model_path)
        self.model.eval()

        # Initialize the drawing canvas
        self.canvas_size = 280  # 28x28 pixels * 10 for better drawing

        # Create canvas with sync_image_data enabled
        self.canvas = Canvas(width=self.canvas_size, height=self.canvas_size, sync_image_data=True)

        # Set initial canvas properties
        self.canvas.fill_style = "white"
        self.canvas.fill_rect(0, 0, self.canvas_size, self.canvas_size)
        self.canvas.line_cap = "round"
        self.canvas.line_join = "round"

        # Create buttons
        self.clear_button = Button(description='Clear')
        self.predict_button = Button(description='Predict')
        self.result_label = HTML(value='<h3>Draw a digit and click Predict</h3>')

        # Setup button callbacks
        self.clear_button.on_click(self.clear_canvas)
        self.predict_button.on_click(self.make_prediction)

        # Setup drawing state
        self.drawing = False
        self.last_x = None
        self.last_y = None

        # Setup mouse event handlers
        self.canvas.on_mouse_down(self.start_drawing)
        self.canvas.on_mouse_move(self.draw)
        self.canvas.on_mouse_up(self.stop_drawing)
        self.canvas.on_mouse_out(self.stop_drawing)

        # Display the UI
        display(VBox([
            self.canvas,
            HBox([self.clear_button, self.predict_button]),
            self.result_label
        ]))

    def clear_canvas(self, b=None):
        """Clear the canvas"""
        self.canvas.fill_style = "white"
        self.canvas.fill_rect(0, 0, self.canvas_size, self.canvas_size)
        self.result_label.value = '<h3>Draw a digit and click Predict</h3>'

    def start_drawing(self, x, y):
        self.drawing = True
        self.last_x = x
        self.last_y = y

        # Option 1: Remove initial point drawing entirely
        # Just initialize the position without drawing anything

        # OR Option 2: Draw a much smaller initial point
        self.canvas.begin_path()
        self.canvas.arc(x, y, 2, 0, 2 * np.pi)  # Reduced radius from 10 to 2
        self.canvas.fill_style = 'black'
        self.canvas.fill()
        self.canvas.close_path()

    def draw(self, x, y):
        if self.drawing and self.last_x is not None and self.last_y is not None:
            self.canvas.begin_path()
            self.canvas.move_to(self.last_x, self.last_y)
            self.canvas.line_to(x, y)
            self.canvas.line_width = 20
            self.canvas.stroke_style = 'black'
            self.canvas.stroke()
            self.canvas.close_path()

            self.last_x = x
            self.last_y = y

    def stop_drawing(self, x, y):
        self.drawing = False
        self.last_x = None
        self.last_y = None

    def preprocess_image(self):
        # Get image data and convert to PIL Image
        image_data = self.canvas.get_image_data()
        img = PIL.Image.frombytes('RGBA', (self.canvas_size, self.canvas_size), image_data)

        # Convert to grayscale
        img = img.convert('L')

        # Add thresholding to make strokes more distinct
        img = img.point(lambda x: 0 if x > 128 else 255)

        # Center the digit in the image
        bbox = img.getbbox()
        if bbox:
            img = img.crop(bbox)
            # Add padding to maintain aspect ratio
            padded = PIL.Image.new('L', (max(img.size), max(img.size)), 255)
            padded.paste(img, ((max(img.size)-img.size[0])//2, (max(img.size)-img.size[1])//2))
            img = padded

        # Resize to MNIST dimensions
        img = img.resize((28, 28), PIL.Image.Resampling.LANCZOS)

        # Convert to tensor and normalize
        transform = transforms.Compose(
            [
                transforms.ToTensor(),  # convert to tensor
                transforms.Normalize(
                    (0,), (1,)
                ),  # normalize the data such that the mean is 0 and the standard deviation is 1
            ]
        )
        # Add batch dimension
        tensor = transform(img).unsqueeze(0)
        return tensor

    def make_prediction(self, b=None):
        """Make a prediction on the drawn digit"""
        # Preprocess the image
        tensor = self.preprocess_image()

        # Make prediction
        with torch.no_grad():
            output = self.model(tensor)
            probabilities = torch.nn.functional.softmax(output, dim=1)
            prediction = output.argmax(dim=1).item()
            confidence = probabilities[0][prediction].item() * 100

        # Update result display
        self.result_label.value = f'<h3>Prediction: {prediction} (Confidence: {confidence:.2f}%)</h3>'



We then define the MNISTDataModule class to load the MNIST dataset.

In [5]:
class MNISTDataModule(pl.LightningDataModule):
    """
    PyTorch Lightning data module for MNIST dataset
    """

    def __init__(self, data_dir: str = "./data", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        # Define transforms
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),  # convert to tensor
                transforms.Normalize(
                    (0,), (1,)
                ),  # normalize the data such that the mean is 0 and the standard deviation is 1
            ]
        )

    def prepare_data(self):
        """Download data if needed"""
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        """Setup train, val, and test datasets"""
        if stage == "fit" or stage is None:
            mnist_full = datasets.MNIST(
                self.data_dir, train=True, transform=self.transform
            )
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        if stage == "test" or stage is None:
            self.mnist_test = datasets.MNIST(
                self.data_dir, train=False, transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=1
        )

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=1)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=1)

See the lecture note for the break down of the code above. 

## Train the model

In [None]:
# First, train the model or load a pre-trained model
model = LeNet1(learning_rate=1e-3)
data_module = MNISTDataModule(batch_size=512)

# Initialize trainer
trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",  # Uses GPU if available
    devices=1,
)

# Train and test
trainer.fit(model, data_module)
trainer.save_checkpoint('lenet1.ckpt')

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | train_accuracy | MulticlassAccuracy | 0      | train
1 | val_accuracy   | MulticlassAccuracy | 0      | train
2 | test_accuracy  | MulticlassAccuracy | 0      | train
3 | conv1          | Conv2d             | 104    | train
4 | pool           | AvgPool2d          | 0      | train
5 | conv2          | Conv2d             | 1.2 K  | train
6 | fc             | Linear             | 1.9 K  | train
--------------------------------------------------------------
3.2 K     Trainable params
0         Non-trainable params
3.2 K     Total params
0.013     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/skojaku-admin/miniforge3/envs/advnetsci/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
/Users/skojaku-admin/miniforge3/envs/advnetsci/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

# Interactive session

In [7]:
# Then create the interactive recognizer
recognizer = DigitRecognizer('lenet1.ckpt')

VBox(children=(Canvas(height=280, sync_image_data=True, width=280), HBox(children=(Button(description='Clear',…