# 4. Supervised Classification: LeNet-5 with PyTorch

While zero-shot classification with CLIP is powerful, building and training a model from scratch is a fundamental skill. In this notebook, we'll implement **LeNet-5**, one of the earliest and most influential Convolutional Neural Networks (CNNs). We will train it on the MNIST training dataset using PyTorch.

**Key concepts covered:**
*   LeNet-5 architecture
*   Defining a custom `nn.Module` in PyTorch
*   Splitting FiftyOne data for training and validation
*   Creating custom PyTorch Datasets from FiftyOne views
*   Data normalization through mean/standard deviation
*   Using PyTorch DataLoaders
*   Implementing a training loop and model checkpointing

## Setup
As always, we start with imports and helper functions.

In [None]:
import os
import random
import cv2
from PIL import Image
import numpy as np
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torchvision.transforms.v2 as transforms
from torch.utils.data import Dataset
from torch.optim import Adam

import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.utils.random as four
import albumentations as A

### Reproducibility
To ensure our training experiments are reproducible, we'll define a function to set random seeds for all relevant libraries.

In [None]:
def set_seeds(seed=51):
    """Sets seeds for reproducibility."""
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    cv2.setRNGSeed(seed)
    try:
        A.seed_everything(seed)
    except AttributeError:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except RuntimeError:
        print("Warning: Some operations may not be deterministic")

def create_deterministic_training_dataloader(dataset, batch_size, shuffle=True, **kwargs):
    """Creates a DataLoader with deterministic behavior."""
    generator = torch.Generator()
    generator.manual_seed(51)
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        generator=generator if shuffle else None,
        **kwargs
    )

set_seeds(51)

## Defining the LeNet-5 Architecture

![](https://raw.githubusercontent.com/andandandand/practical-computer-vision/refs/heads/main/images/lenet5-architecture.png)

We'll define a modernized version of LeNet-5 using `ReLU` activations and `MaxPooling`, which generally perform better than the original `tanh` and `AveragePooling`.

In [None]:
class ModernLeNet5(nn.Module):
    """
    Modernized version of LeNet-5 with ReLU activations and max pooling.
    """

    def __init__(self, num_classes=10):
        super(ModernLeNet5, self).__init__()

        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.conv3 = nn.Conv2d(16, 120, kernel_size=4)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, num_classes)

        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(Fun.relu(self.conv1(x)))
        x = self.pool(Fun.relu(self.conv2(x)))
        x = Fun.relu(self.conv3(x))

        x = x.view(x.size(0), -1)
        x = Fun.relu(self.fc1(x))
        x = self.dropout(x)  # Add dropout for regularization
        x = self.fc2(x)

        return x

## Preparing the Data

We'll now load the `train` split of MNIST and divide it into a training set (85%) and a validation set (15%). The validation set is crucial for monitoring overfitting and for saving the best version of our model during training.

In [None]:
train_val_dataset = foz.load_zoo_dataset("mnist",
                                         split='train',
                                         dataset_name="mnist-train-val",
                                         persistent=True)

# Ensure tags from previous runs are cleared
train_val_dataset.untag_samples(["train", "validation"])

set_seeds(51)
four.random_split(train_val_dataset,
                  {"train": 0.85, "validation": 0.15},
                  seed=51)

train_dataset = train_val_dataset.match_tags("train").clone(name="mnist-training-set", persistent=True)
val_dataset = train_val_dataset.match_tags("validation").clone(name="mnist-validation-set", persistent=True)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

### Creating a Custom PyTorch Dataset

To use our FiftyOne datasets with PyTorch, we create a custom `Dataset` class. This acts as a bridge, allowing PyTorch's `DataLoader` to efficiently load images and labels while we still benefit from FiftyOne's powerful data management features.

In [None]:
class CustomTorchImageDataset(torch.utils.data.Dataset):
    def __init__(self, fiftyone_dataset,
                 image_transforms=None,
                 label_map=None,
                 gt_field="ground_truth"):
        self.fiftyone_dataset = fiftyone_dataset
        self.image_paths = self.fiftyone_dataset.values("filepath")
        self.str_labels = self.fiftyone_dataset.values(f"{gt_field}.label")
        self.image_transforms = image_transforms

        if label_map is None:
            self.label_map = {str(i): i for i in range(10)}  # "0"->0, "1"->1, etc.
        else:
            self.label_map = label_map

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('L')

        if self.image_transforms:
            image = self.image_transforms(image)

        label_str = self.str_labels[idx]
        label_idx = self.label_map.get(label_str, -1)
        return image, torch.tensor(label_idx, dtype=torch.long)

### Data Normalization

Normalizing input data to have a mean of 0 and a standard deviation of 1 is a critical preprocessing step. It helps stabilize training and allows the model to converge faster. We'll compute these statistics on our training set and apply the same normalization to all splits.

In [None]:
def compute_stats_fiftyone(fiftyone_view):
    filepaths = fiftyone_view.values("filepath")
    all_pixels = []
    for filepath in tqdm(filepaths, desc="Computing Stats"):
        image = Image.open(filepath).convert('L')
        pixels = np.array(image, dtype=np.float32) / 255.0
        all_pixels.append(pixels.flatten())
    all_pixels = np.concatenate(all_pixels)
    return np.mean(all_pixels), np.std(all_pixels)

mean_intensity, std_intensity = compute_stats_fiftyone(train_dataset)
print(f"Mean: {mean_intensity:.4f}, Std: {std_intensity:.4f}")

# Define transforms with normalization
image_transforms = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize((mean_intensity,), (std_intensity,))
])

# Create label map and torch datasets
dataset_classes = sorted(train_val_dataset.distinct("ground_truth.label"))
label_map = {label: i for i, label in enumerate(dataset_classes)}

torch_train_set = CustomTorchImageDataset(train_dataset, label_map=label_map, image_transforms=image_transforms)
torch_val_set = CustomTorchImageDataset(val_dataset, label_map=label_map, image_transforms=image_transforms)

### PyTorch DataLoaders

`DataLoaders` wrap our datasets and handle batching, shuffling, and parallel data loading, which are essential for efficient training.

In [None]:
batch_size = 64
num_workers = os.cpu_count()

train_loader = create_deterministic_training_dataloader(
    torch_train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    torch_val_set,
    batch_size=batch_size,
    shuffle=False, # No need to shuffle validation data
    num_workers=num_workers,
    pin_memory=True
)
print("DataLoaders created.")

## Training the Model

We'll now set up the training loop. This involves defining a loss function, an optimizer, and functions to handle one epoch of training and validation.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ModernLeNet5().to(device)
optimizer = Adam(model.parameters(), lr=0.003)
ce_loss = nn.CrossEntropyLoss()

def train_epoch(model, train_loader, optimizer, ce_loss):
    model.train()
    batch_losses = []
    for images, labels in tqdm(train_loader, desc="Training"): 
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(images)
        loss = ce_loss(logits, labels)
        loss.backward()
        optimizer.step()
        batch_losses.append(loss.item())
    return np.mean(batch_losses)

def val_epoch(model, val_loader, ce_loss):
    model.eval()
    batch_losses = []
    with torch.inference_mode():
        for images, labels in tqdm(val_loader, desc="Validation"): 
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            loss = ce_loss(logits, labels)
            batch_losses.append(loss.item())
    return np.mean(batch_losses)

### The Training Loop with Checkpointing

We'll train for several epochs. Crucially, we will monitor the validation loss after each epoch and save the model's weights only when the validation loss improves. This practice, known as **checkpointing**, ensures we keep the model that generalizes best, protecting us from overfitting.

In [None]:
set_seeds(51)
num_epochs = 10
train_losses, val_losses = [], []
best_val_loss = float('inf')
model_save_path = Path(os.getcwd()) / 'best_lenet.pth'

for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, ce_loss)
    val_loss = val_epoch(model, val_loader, ce_loss)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), model_save_path)
        print('✓ Found and saved better model weights.')

### Visualizing Training Progress

Plotting the training and validation losses helps us diagnose the training process. In an ideal scenario, both losses decrease, and the validation loss remains close to the training loss. If the validation loss starts to increase while the training loss continues to decrease, the model is overfitting.

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
best_epoch = np.argmin(val_losses) + 1
plt.axvline(x=best_epoch-1, color='r', linestyle='--', label=f'Best Model @ Epoch {best_epoch}')
plt.legend()
plt.show()

## Next Steps

We have successfully trained a LeNet-5 model and saved its best-performing weights.

Now, let's evaluate this model on the unseen test set to see how well it performs.

Proceed to `5_lenet_evaluation.ipynb`.