In [None]:
import os, sys
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["OMP_NUM_THREADS"] = "1"

# Make src/ importable from notebooks/
sys.path.insert(0, os.path.abspath("../src"))

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

from utils import set_seed, plot_history

# Case Study: MLP vs CNN on CIFAR-10

Compare a flat MLP baseline against a simple CNN on a real image classification task. This notebook covers the full PyTorch workflow: data loading, model definition, training, evaluation, and visual comparison.

The key insight is that **spatial structure matters** — convolutional layers exploit the 2D arrangement of pixels, while an MLP treats each pixel as an independent feature.

In [None]:
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Data Loading

CIFAR-10: 60,000 32x32 RGB images across 10 classes. We normalise with per-channel statistics and hold out 5,000 training samples for validation.

In [3]:

# Normalization constants for CIFAR-10 (per RGB channel)
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD  = (0.2470, 0.2435, 0.2616)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

In [4]:
# Download CIFAR-10 dataset and store them in ./data directory

data_root = "../data"

train_full = datasets.CIFAR10(
    root=data_root,
    train=True,
    download=True,
    transform=transform,
)

test_ds = datasets.CIFAR10(
    root=data_root,
    train=False,
    download=True,
    transform=transform,
)

In [5]:
print("train:", len(train_full))
print("test:", len(test_ds))
print("classes:", train_full.classes)

train: 50000
test: 10000
classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [6]:
# Split training dataset into train and validation sets

val_size = 5000
train_size = len(train_full) - val_size

train_ds, val_ds = random_split(
    train_full,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42),
)

len(train_ds), len(val_ds)

(45000, 5000)

In [7]:
# Create the data loaders
batch_size = 128
num_workers = 0 

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [8]:
xb, yb = next(iter(train_loader))
xb.shape, yb.shape

(torch.Size([128, 3, 32, 32]), torch.Size([128]))

## MLP Baseline

The MLP flattens each 32x32x3 image into a 3,072-dimensional vector and passes it through one hidden layer (512 units, ReLU). It has no notion of spatial locality — neighbouring pixels are treated no differently from distant ones.

In [9]:
class MLPBaseline(nn.Module):
    def __init__(self, hidden_dim=512, num_classes=10):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 32 * 3, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.flatten(x)          # (B, 3, 32, 32) -> (B, 3072)
        x = F.relu(self.fc1(x))      # (B, hidden_dim)
        x = self.fc2(x)              # (B, num_classes) logits
        return x

mlp = MLPBaseline(hidden_dim=512, num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)


### Anatomy of a Training Step

Before writing the full training loop, we walk through a single batch to see each piece: zero gradients, forward pass, loss, backward pass, parameter update. We also inspect gradient magnitudes to confirm that backpropagation is working.

In [None]:
# Zero gradients (prevents accidental accumulation across steps)
optimizer.zero_grad()

# Forward pass
logits = mlp(xb)

# Compute loss
loss = criterion(logits, yb)

# Backward pass (populates .grad on all parameters)
loss.backward()

# Inspect gradients
print("fc1.weight.grad is None?", mlp.fc1.weight.grad is None)
print("fc1.weight.grad shape:", mlp.fc1.weight.grad.shape)
print("mean |grad|:", mlp.fc1.weight.grad.abs().mean().item())

# Update parameters
optimizer.step()

loss.item()

### Training and Evaluation Functions

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    """Run one full pass over the training set, updating parameters after each batch."""
    model.train()

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)

        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * xb.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == yb).sum().item()
        total_samples += xb.size(0)

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples

    return avg_loss, accuracy

In [12]:
train_loss, train_acc = train_one_epoch(
    mlp, train_loader, criterion, optimizer, device
)

train_loss, train_acc

(1.7046472429275512, 0.416)

In [None]:
def evaluate(model, loader, criterion, device):
    """Evaluate model on a dataset without computing gradients."""
    model.eval()

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)

            logits = model(xb)
            loss = criterion(logits, yb)

            total_loss += loss.item() * xb.size(0)
            preds = logits.argmax(dim=1)
            total_correct += (preds == yb).sum().item()
            total_samples += xb.size(0)

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples

    return avg_loss, accuracy

In [14]:
val_loss, val_acc = evaluate(
    mlp, val_loader, criterion, device
)

val_loss, val_acc

(1.538858579826355, 0.4654)

## CNN

Two convolutional layers (3x3 kernels, ReLU, max-pooling) followed by a single linear classifier. Unlike the MLP, convolutions share weights across spatial positions and only look at local patches, giving the model a strong inductive bias for image data.

In [15]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),   # 32x32 -> 16x16

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),   # 16x16 -> 8x8
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [16]:
cnn = SimpleCNN(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)

cnn


SimpleCNN(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=4096, out_features=10, bias=True)
  )
)

In [17]:
xb, yb = next(iter(train_loader))
xb = xb.to(device)
yb = yb.to(device)

logits = cnn(xb)
logits.shape

torch.Size([128, 10])

In [18]:
train_loss_cnn, train_acc_cnn = train_one_epoch(
    cnn, train_loader, criterion, optimizer, device
)

train_loss_cnn, train_acc_cnn

(1.4002999457465277, 0.5075333333333333)

In [19]:
val_loss_cnn, val_acc_cnn = evaluate(
    cnn, val_loader, criterion, device
)

val_loss_cnn, val_acc_cnn

(1.1561440475463867, 0.5928)

## Full Training Comparison

We train both models from scratch for 5 epochs using Adam (lr=1e-3) and cross-entropy loss, then compare their training curves side by side.

In [None]:
def fit(model, train_loader, val_loader, criterion, optimizer, device, epochs=5):
    """Train for multiple epochs, returning a history dict with loss and accuracy curves."""
    history = {
        "train_loss": [],
        "train_acc": [],
        "val_loss": [],
        "val_acc": [],
    }

    for epoch in range(1, epochs + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        va_loss, va_acc = evaluate(model, val_loader, criterion, device)

        history["train_loss"].append(tr_loss)
        history["train_acc"].append(tr_acc)
        history["val_loss"].append(va_loss)
        history["val_acc"].append(va_acc)

        print(
            f"Epoch {epoch:02d} | "
            f"train loss {tr_loss:.4f} acc {tr_acc:.4f} | "
            f"val loss {va_loss:.4f} acc {va_acc:.4f}"
        )

    return history

# (re)initialize model to make the experiment clean
cnn = SimpleCNN(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)

history_cnn = fit(cnn, train_loader, val_loader, criterion, optimizer, device, epochs=5)

In [21]:
mlp = MLPBaseline(hidden_dim=512, num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)

history_mlp = fit(
    mlp,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    epochs=5
)


Epoch 01 | train loss 1.7118 acc 0.4105 | val loss 1.5783 acc 0.4506
Epoch 02 | train loss 1.4697 acc 0.4883 | val loss 1.5082 acc 0.4808
Epoch 03 | train loss 1.3925 acc 0.5187 | val loss 1.4847 acc 0.4968
Epoch 04 | train loss 1.3274 acc 0.5430 | val loss 1.4788 acc 0.4992
Epoch 05 | train loss 1.2640 acc 0.5659 | val loss 1.5314 acc 0.4918


## Results

In [None]:
plot_history(history_mlp, "MLP")
plot_history(history_cnn, "CNN")

# Direct comparison (validation only)
epochs = list(range(1, len(history_cnn["val_acc"]) + 1))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(epochs, history_mlp["val_acc"], label="MLP")
ax1.plot(epochs, history_cnn["val_acc"], label="CNN")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Validation accuracy")
ax1.set_title("Validation accuracy comparison")
ax1.set_xticks(epochs)
ax1.legend()

ax2.plot(epochs, history_mlp["val_loss"], label="MLP")
ax2.plot(epochs, history_cnn["val_loss"], label="CNN")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Validation loss")
ax2.set_title("Validation loss comparison")
ax2.set_xticks(epochs)
ax2.legend()

plt.tight_layout()
plt.show()

## Takeaway

| Model | Val Accuracy | Val Loss |
|-------|-------------|----------|
| MLP   | ~49%        | ~1.53    |
| CNN   | ~70%        | ~0.86    |

- The CNN outperforms the MLP by **~21 percentage points** with the same optimizer, learning rate, and number of epochs.
- Convolutional layers exploit **spatial locality** and **weight sharing**, both crucial inductive biases for vision.
- Both models show signs of overfitting (training loss drops while validation loss plateaus or rises), suggesting that regularisation (dropout, data augmentation) would help as a next step.