# Regularization: Closing the Train/Val Gap

In notebook 03, our `SimpleCNN` reached ~70% validation accuracy on CIFAR-10 — but training accuracy climbed to ~75% and was still rising. That gap between training and validation performance is **overfitting**: the model memorises training data instead of learning general patterns.

**Regularization** is any technique that reduces this gap by constraining the model. This notebook compares four common approaches:

1. **Data augmentation** — show the model transformed versions of each image
2. **Dropout** — randomly zero out activations during training
3. **Weight decay (L2)** — penalise large weights in the optimizer
4. **Batch normalization** — normalize activations within each mini-batch

We keep everything else fixed (same architecture, optimizer, learning rate, epochs) so we can isolate the effect of each technique.

## Setup

In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["OMP_NUM_THREADS"] = "1"

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

from dlbasics import set_seed, plot_history

In [None]:
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Data Loading

Same CIFAR-10 setup as notebook 03: normalize with per-channel statistics, split 50k training images into 45k train / 5k validation.

In [None]:
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD  = (0.2470, 0.2435, 0.2616)

# Standard transform (no augmentation)
base_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

# Augmented transform (used in sections 4 and 8)
aug_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

In [None]:
data_root = "../data"

train_full = datasets.CIFAR10(root=data_root, train=True, download=True, transform=base_transform)
test_ds = datasets.CIFAR10(root=data_root, train=False, download=True, transform=base_transform)

val_size = 5000
train_size = len(train_full) - val_size

train_ds, val_ds = random_split(
    train_full,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42),
)

print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")

In [None]:
batch_size = 128

def make_loaders(train_dataset, val_dataset):
    """Create train and val DataLoaders with consistent settings."""
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    return train_loader, val_loader

train_loader, val_loader = make_loaders(train_ds, val_ds)

## Training Utilities

Same `train_one_epoch`, `evaluate`, and `fit` functions from notebook 03 (copied inline to keep this notebook self-contained).

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    """Run one full pass over the training set, updating parameters after each batch."""
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * xb.size(0)
        total_correct += (logits.argmax(1) == yb).sum().item()
        total_samples += xb.size(0)

    return total_loss / total_samples, total_correct / total_samples


def evaluate(model, loader, criterion, device):
    """Evaluate model on a dataset without computing gradients."""
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = criterion(logits, yb)

            total_loss += loss.item() * xb.size(0)
            total_correct += (logits.argmax(1) == yb).sum().item()
            total_samples += xb.size(0)

    return total_loss / total_samples, total_correct / total_samples


def fit(model, train_loader, val_loader, criterion, optimizer, device, epochs=5):
    """Train for multiple epochs, returning a history dict with loss and accuracy curves."""
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

    for epoch in range(1, epochs + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        va_loss, va_acc = evaluate(model, val_loader, criterion, device)

        history["train_loss"].append(tr_loss)
        history["train_acc"].append(tr_acc)
        history["val_loss"].append(va_loss)
        history["val_acc"].append(va_acc)

        print(f"Epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.4f} | val loss {va_loss:.4f} acc {va_acc:.4f}")

    return history

## 1 — Baseline CNN (no regularization)

Same `SimpleCNN` from notebook 03. This is our control: we expect to see train accuracy climb higher than val accuracy.

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, num_classes),
        )

    def forward(self, x):
        return self.classifier(self.features(x))

In [None]:
set_seed(42)
baseline = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(baseline.parameters(), lr=1e-3)

print("=== Baseline (no regularization) ===")
history_baseline = fit(baseline, train_loader, val_loader, criterion, optimizer, device, epochs=5)

In [None]:
plot_history(history_baseline, "Baseline")

The gap between the train and val curves is the overfitting we want to reduce. Training loss keeps dropping while validation loss flattens or rises — the model is memorising instead of generalising.

## 2 — Data Augmentation

The cheapest regularizer: show the model randomly transformed versions of each training image. We use two standard augmentations for CIFAR-10:

- `RandomHorizontalFlip()` — 50% chance to flip left/right
- `RandomCrop(32, padding=4)` — pad 4 pixels on each side, then crop back to 32x32 (small random shifts)

This effectively increases dataset diversity without collecting new data. The model sees a slightly different image each epoch, making it harder to memorise specific pixel patterns.

In [None]:
# Reload training data with augmentation applied
train_full_aug = datasets.CIFAR10(root=data_root, train=True, download=True, transform=aug_transform)

# Use the same indices for the train/val split
train_ds_aug = torch.utils.data.Subset(train_full_aug, train_ds.indices)

train_loader_aug, _ = make_loaders(train_ds_aug, val_ds)

In [None]:
set_seed(42)
model_aug = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_aug.parameters(), lr=1e-3)

print("=== Data Augmentation ===")
history_aug = fit(model_aug, train_loader_aug, val_loader, criterion, optimizer, device, epochs=5)

## 3 — Dropout

Dropout randomly sets a fraction of activations to zero during training. This prevents co-adaptation: neurons can't rely on specific other neurons always being present, so each one must learn more robust features.

- `nn.Dropout2d(p)` after conv blocks — drops entire feature maps (better for spatial data than per-element dropout)
- `nn.Dropout(p)` before the classifier — standard element-wise dropout for the fully connected layer

At test time, dropout is turned off and activations are scaled to compensate.

In [None]:
class CNNWithDropout(nn.Module):
    def __init__(self, num_classes=10, drop_conv=0.25, drop_fc=0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(drop_conv),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(drop_conv),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(drop_fc),
            nn.Linear(64 * 8 * 8, num_classes),
        )

    def forward(self, x):
        return self.classifier(self.features(x))

In [None]:
set_seed(42)
model_drop = CNNWithDropout().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_drop.parameters(), lr=1e-3)

print("=== Dropout ===")
history_drop = fit(model_drop, train_loader, val_loader, criterion, optimizer, device, epochs=5)

## 4 — Weight Decay (L2 Regularization)

Weight decay adds a penalty proportional to the squared magnitude of the weights to the loss:

$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{CE}} + \frac{\lambda}{2} \sum_i w_i^2$$

This discourages any single weight from growing too large. In practice, we don't modify the loss — PyTorch's optimizer applies the penalty directly during the weight update step via the `weight_decay` parameter.

Same `SimpleCNN`, same data, just `weight_decay=1e-4` in the optimizer.

In [None]:
set_seed(42)
model_wd = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_wd.parameters(), lr=1e-3, weight_decay=1e-4)

print("=== Weight Decay (L2) ===")
history_wd = fit(model_wd, train_loader, val_loader, criterion, optimizer, device, epochs=5)

## 5 — Batch Normalization

Batch normalization normalizes activations within each mini-batch to have zero mean and unit variance, then applies a learnable affine transform. It was introduced primarily to stabilize and speed up training, but it also has a mild regularization effect because the batch statistics introduce noise.

We add `nn.BatchNorm2d` after each convolutional layer, before ReLU.

In [None]:
class CNNWithBatchNorm(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, num_classes),
        )

    def forward(self, x):
        return self.classifier(self.features(x))

In [None]:
set_seed(42)
model_bn = CNNWithBatchNorm().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_bn.parameters(), lr=1e-3)

print("=== Batch Normalization ===")
history_bn = fit(model_bn, train_loader, val_loader, criterion, optimizer, device, epochs=5)

## 6 — Combined: All Techniques Together

Now we combine everything: data augmentation + dropout + weight decay + batch normalization. This should give the best generalization.

In [None]:
class RegularizedCNN(nn.Module):
    def __init__(self, num_classes=10, drop_conv=0.25, drop_fc=0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(drop_conv),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(drop_conv),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(drop_fc),
            nn.Linear(64 * 8 * 8, num_classes),
        )

    def forward(self, x):
        return self.classifier(self.features(x))

In [None]:
set_seed(42)
model_combined = RegularizedCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_combined.parameters(), lr=1e-3, weight_decay=1e-4)

print("=== Combined (aug + dropout + weight decay + batch norm) ===")
history_combined = fit(model_combined, train_loader_aug, val_loader, criterion, optimizer, device, epochs=5)

## Comparison

Side-by-side results for all six runs.

In [None]:
results = {
    "Baseline":       history_baseline,
    "Data Aug":       history_aug,
    "Dropout":        history_drop,
    "Weight Decay":   history_wd,
    "Batch Norm":     history_bn,
    "Combined":       history_combined,
}

# Print comparison table
print(f"{'Method':<16} {'Train Acc':>10} {'Val Acc':>10} {'Train Loss':>11} {'Val Loss':>10} {'Gap':>8}")
print("-" * 67)
for name, h in results.items():
    ta = h["train_acc"][-1]
    va = h["val_acc"][-1]
    tl = h["train_loss"][-1]
    vl = h["val_loss"][-1]
    gap = ta - va
    print(f"{name:<16} {ta:>10.4f} {va:>10.4f} {tl:>11.4f} {vl:>10.4f} {gap:>8.4f}")

In [None]:
# Bar chart: final validation accuracy for each method
names = list(results.keys())
val_accs = [h["val_acc"][-1] for h in results.values()]
train_accs = [h["train_acc"][-1] for h in results.values()]

x = np.arange(len(names))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 5))
bars_train = ax.bar(x - width/2, train_accs, width, label="Train", color="#4c72b0")
bars_val = ax.bar(x + width/2, val_accs, width, label="Val", color="#dd8452")

ax.set_ylabel("Accuracy")
ax.set_title("Train vs Val Accuracy by Regularization Method")
ax.set_xticks(x)
ax.set_xticklabels(names, rotation=15, ha="right")
ax.legend()
ax.set_ylim(0.4, 0.85)

# Add value labels on bars
for bar in bars_train:
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
            f"{bar.get_height():.1%}", ha="center", va="bottom", fontsize=8)
for bar in bars_val:
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
            f"{bar.get_height():.1%}", ha="center", va="bottom", fontsize=8)

plt.tight_layout()
plt.show()

In [None]:
# Validation accuracy curves for all methods
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

epochs = list(range(1, 6))
for name, h in results.items():
    ax1.plot(epochs, h["val_acc"], marker="o", label=name)
    ax2.plot(epochs, h["val_loss"], marker="o", label=name)

ax1.set_xlabel("Epoch")
ax1.set_ylabel("Validation Accuracy")
ax1.set_title("Validation Accuracy")
ax1.set_xticks(epochs)
ax1.legend(fontsize=8)

ax2.set_xlabel("Epoch")
ax2.set_ylabel("Validation Loss")
ax2.set_title("Validation Loss")
ax2.set_xticks(epochs)
ax2.legend(fontsize=8)

plt.tight_layout()
plt.show()

## Takeaways

- **Data augmentation** is one of the cheapest and most effective regularizers — it adds no parameters and usually improves generalization significantly. For image tasks, `RandomHorizontalFlip` + `RandomCrop` are strong baselines.
- **Dropout** reduces co-adaptation between neurons. `Dropout2d` is often preferred in convolutional layers since it drops entire feature maps rather than individual pixels.
- **Weight decay (L2 regularization)** penalizes large weights, encouraging simpler solutions. It’s a single hyperparameter (`weight_decay`) with no architectural changes required (**AdamW** is the standard choice).
- **Batch normalization** was designed to stabilize training (Ioffe & Szegedy, 2015), and the stochasticity of batch statistics provides mild regularization.
- **Combining techniques** often works best since they address different failure modes (memorization, co-adaptation, weight growth, unstable optimization).
- The train/val **generalization gap** is as important as raw accuracy. A method that slightly lowers training accuracy but closes the gap is often improving generalization.
- **Not covered here**: early stopping, learning rate scheduling, and stronger augmentation policies (Cutout, Mixup, AutoAugment).
