# 7. Data Augmentation and Retraining

In this final notebook, we'll use the insights gained from our analysis to improve our model. We will identify challenging samples from the training set—those that were misclassified, are highly unique, or had low prediction confidence. We will then apply targeted **data augmentation** to these samples and fine-tune our LeNet model on this enriched dataset.

**Key concepts covered:**
*   Identifying problematic samples for augmentation
*   Defining effective data augmentation strategies for MNIST
*   Creating a combined dataset of original and augmented data
*   Fine-tuning a pre-trained model
*   Comparing performance before and after fine-tuning

## Setup
Let's begin by setting up our environment, including all necessary imports and helper functions.

In [None]:
import os
import random
import cv2
from PIL import Image
import numpy as np
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torchvision.transforms.v2 as transforms
from torch.utils.data import Dataset, ConcatDataset
from torch.optim import Adam

import fiftyone as fo
from fiftyone import ViewField as F
import albumentations as A

# Redefine model and dataset classes
class ModernLeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(ModernLeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.conv3 = nn.Conv2d(16, 120, kernel_size=4)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, num_classes)
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.pool(Fun.relu(self.conv1(x)))
        x = self.pool(Fun.relu(self.conv2(x)))
        x = Fun.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = Fun.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

def create_deterministic_training_dataloader(dataset, batch_size, shuffle=True, **kwargs):
    generator = torch.Generator().manual_seed(51)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, generator=generator if shuffle else None, **kwargs)

def set_seeds(seed=51):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

## Identifying Samples for Augmentation

First, we need to get our model's predictions on the training set to identify which samples it misclassified. We'll need to run inference on the training set, similar to how we did for the test set.

In [None]:
# Load datasets and model
device = "cuda" if torch.cuda.is_available() else "cpu"
train_dataset = fo.load_dataset("mnist-training-set")
test_dataset = fo.load_dataset("mnist-test-set")
model_save_path = Path(os.getcwd()) / 'best_lenet.pth'
loaded_model = ModernLeNet5().to(device)
loaded_model.load_state_dict(torch.load(model_save_path, map_location=device))

# Recreate transforms and dataloaders
mean_intensity, std_intensity = 0.1307, 0.3081 # Pre-computed
image_transforms = transforms.Compose([
    transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize((mean_intensity,), (std_intensity,))
])
dataset_classes = sorted(train_dataset.distinct("ground_truth.label"))
label_map = {label: i for i, label in enumerate(dataset_classes)}

from torch.utils.data import Dataset
class CustomTorchImageDataset(Dataset):
    def __init__(self, fo_dset, xforms, l_map):
        self.fo_dset, self.xforms, self.l_map = fo_dset, xforms, l_map
        self.img_paths = self.fo_dset.values("filepath")
        self.labels = self.fo_dset.values("ground_truth.label")
    def __len__(self): return len(self.img_paths)
    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert('L')
        return self.xforms(img), self.l_map[self.labels[idx]]

torch_train_set = CustomTorchImageDataset(train_dataset, image_transforms, label_map)
train_inference_loader = torch.utils.data.DataLoader(torch_train_set, batch_size=64, shuffle=False)

# Add predictions to training set if they don't exist
if "lenet_train_classification" not in train_dataset.get_field_schema():
    train_preds, train_logits = [], []
    with torch.inference_mode():
        for imgs, _ in tqdm(train_inference_loader, desc="Getting train preds"):
            logits = loaded_model(imgs.to(device))
            train_logits.append(logits.cpu().numpy())
            train_preds.extend(torch.max(logits, 1)[1].cpu().numpy())
    train_logits = np.concatenate(train_logits)
    
    classifications = []
    for i in range(len(train_dataset)):
        pred_idx = train_preds[i]
        logits = train_logits[i]
        conf = float(Fun.softmax(torch.tensor(logits), dim=0)[pred_idx])
        classifications.append(fo.Classification(label=dataset_classes[pred_idx], confidence=conf, logits=logits.tolist()))
    train_dataset.set_values("lenet_train_classification", classifications)
    train_dataset.save()

# Create view of misclassified training samples
mislabeled_train_images_view = train_dataset.match(F("lenet_train_classification.label") != F("ground_truth.label"))
print(f"Found {len(mislabeled_train_images_view)} misclassified training samples.")

## Defining Augmentations

Effective augmentation for MNIST involves creating realistic variations that a model might encounter. We'll use small geometric transformations (rotation, translation, scaling) and elastic deformations to simulate natural handwriting styles. We will use the `albumentations` library for this.

In [None]:
set_seeds(51)
mnist_augmentations = A.Compose([
    A.Affine(
        translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
        scale=(0.9, 1.1),
        rotate=(-10, 10),
        p=0.8
    ),
    A.ElasticTransform(
        alpha=20, sigma=5, border_mode=cv2.BORDER_CONSTANT, p=0.6
    ),
    A.GridDistortion(num_steps=3, distort_limit=0.1, p=0.3),
])

### Creating an Augmented Dataset

We'll define a new PyTorch `Dataset` class that takes our misclassified samples and applies these augmentations on the fly. For each misclassified sample, it will generate multiple augmented versions.

In [None]:
class AugmentedMNISTDataset(Dataset):
    def __init__(self, fiftyone_view, label_map, base_transforms, augmentations, augment_factor=5):
        self.image_paths = fiftyone_view.values("filepath")
        self.str_labels = fiftyone_view.values("ground_truth.label")
        self.label_map = label_map
        self.base_transforms = base_transforms
        self.augmentations = augmentations
        self.augment_factor = augment_factor

    def __len__(self):
        return len(self.image_paths) * self.augment_factor

    def __getitem__(self, idx):
        base_idx = idx // self.augment_factor
        image = Image.open(self.image_paths[base_idx]).convert('L')
        image_np = np.array(image, dtype=np.uint8)
        augmented = self.augmentations(image=image_np)['image']
        image = Image.fromarray(augmented, mode='L')
        if self.base_transforms: image = self.base_transforms(image)
        label_idx = self.label_map.get(self.str_labels[base_idx], -1)
        return image, torch.tensor(label_idx, dtype=torch.long)

In [None]:
torch_augmented_dataset = AugmentedMNISTDataset(
    mislabeled_train_images_view,
    label_map=label_map,
    base_transforms=image_transforms,
    augmentations=mnist_augmentations,
    augment_factor=10
)

# Combine original training set with the new augmented samples
combined_dataset = ConcatDataset([torch_train_set, torch_augmented_dataset])
print(f"Original training set size: {len(torch_train_set)}")
print(f"Augmented samples added: {len(torch_augmented_dataset)}")
print(f"Combined dataset size: {len(combined_dataset)}")

# Create a new DataLoader for fine-tuning
combined_train_loader = create_deterministic_training_dataloader(
    combined_dataset, batch_size=64, shuffle=True, num_workers=os.cpu_count()
)

## Fine-Tuning the Model

We'll now fine-tune our model. We start with the best weights from our initial training and train for a few more epochs on the combined dataset. We use a **lower learning rate** for fine-tuning to make small, careful adjustments to the already-learned weights.

In [None]:
set_seeds(51)
# Load the best model to start fine-tuning
retrain_model = ModernLeNet5().to(device)
retrain_model.load_state_dict(torch.load(model_save_path, map_location=device))

# Use a smaller learning rate for fine-tuning
retrain_optimizer = Adam(retrain_model.parameters(), lr=0.0001)
ce_loss = nn.CrossEntropyLoss()

# Reload the validation loader
val_dataset = fo.load_dataset("mnist-validation-set")
torch_val_set = CustomTorchImageDataset(val_dataset, image_transforms, label_map)
val_loader = torch.utils.data.DataLoader(torch_val_set, batch_size=64)

# Training loop
retrain_epochs = 15
best_retrain_val_loss = float('inf')
retrain_model_save_path = Path(os.getcwd()) / 'retrained_lenet.pth'

for epoch in range(retrain_epochs):
    retrain_model.train()
    # Simplified training and validation epoch functions for brevity
    for images, labels in tqdm(combined_train_loader, desc=f"Retraining Epoch {epoch+1}"):
        images, labels = images.to(device), labels.to(device)
        retrain_optimizer.zero_grad()
        logits = retrain_model(images)
        loss = ce_loss(logits, labels)
        loss.backward()
        retrain_optimizer.step()

    retrain_model.eval()
    val_losses = []
    with torch.inference_mode():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            logits = retrain_model(images)
            val_losses.append(ce_loss(logits, labels).item())
    
    avg_val_loss = np.mean(val_losses)
    print(f"Epoch {epoch+1} - Validation Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_retrain_val_loss:
        best_retrain_val_loss = avg_val_loss
        torch.save(retrain_model.state_dict(), retrain_model_save_path)
        print("✓ Saved improved retrained model")

## Final Evaluation

Finally, let's evaluate our newly fine-tuned model on the test set and compare its performance to the original model.

In [None]:
# Load the best retrained model
final_model = ModernLeNet5().to(device)
final_model.load_state_dict(torch.load(retrain_model_save_path, map_location=device))
final_model.eval()

# Create test loader
torch_test_set = CustomTorchImageDataset(test_dataset, image_transforms, label_map)
test_loader = torch.utils.data.DataLoader(torch_test_set, batch_size=64)

# Run inference with the retrained model
retrained_predictions, retrained_logits = [], []
with torch.inference_mode():
    for images, _ in tqdm(test_loader, desc="Evaluating retrained model"):
        logits = final_model(images.to(device))
        retrained_logits.append(logits.cpu().numpy())
        retrained_predictions.extend(torch.max(logits.data, 1)[1].cpu().numpy())

retrained_logits = np.concatenate(retrained_logits, axis=0)

# Store retrained predictions in FiftyOne
for i, sample in enumerate(test_dataset):
    pred_idx = retrained_predictions[i]
    logits = retrained_logits[i]
    conf = float(Fun.softmax(torch.tensor(logits), dim=0)[pred_idx])
    sample["retrained_lenet_classification"] = fo.Classification(
        label=dataset_classes[pred_idx], confidence=conf, logits=logits.tolist()
    )
    sample.save()

print("Retrained predictions stored.")

In [None]:
# Evaluate original and retrained models
original_eval = test_dataset.evaluate_classifications("lenet_classification", eval_key="original_eval")
retrained_eval = test_dataset.evaluate_classifications("retrained_lenet_classification", eval_key="retrained_eval")

print("\n--- Original Model Performance ---")
original_eval.print_report()

print("\n--- Retrained Model Performance ---")
retrained_eval.print_report()

# Compare performance
orig_accuracy = original_eval.metrics()['accuracy']
retrain_accuracy = retrained_eval.metrics()['accuracy']
print(f"\nAccuracy Improvement: {retrain_accuracy - orig_accuracy:+.4f}")

### Analysis of Changes

Let's see exactly which samples were fixed by retraining.

In [None]:
originally_wrong = test_dataset.match(F("lenet_classification.label") != F("ground_truth.label"))
now_correct = originally_wrong.match(F("retrained_lenet_classification.label") == F("ground_truth.label"))

now_wrong = test_dataset.match(
    (F("lenet_classification.label") == F("ground_truth.label")) & 
    (F("retrained_lenet_classification.label") != F("ground_truth.label"))
)

print(f"Samples fixed by retraining: {len(now_correct)}")
print(f"Samples broken by retraining: {len(now_wrong)}")
print(f"Net improvement in correct predictions: {len(now_correct) - len(now_wrong)}")

session = fo.launch_app(test_dataset)
session.view = now_correct
print(f"\nView the fixed samples in the App: {session.url}")

## Conclusion

Congratulations! You have completed the entire workflow from data exploration to model training, analysis, and targeted improvement. You've seen how a generalist model like CLIP provides a strong baseline, and how a specialized, supervised model can achieve superior performance. Most importantly, you've learned how to use model predictions and embeddings to analyze your dataset, find problematic samples, and use that information to make your model even better.

Please see `summary.md` for a full recap and suggested next steps.