<a href="https://colab.research.google.com/github/slucey-cs-cmu-edu/RVSS26/blob/main/Classification_MLP_simple.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image classification with a multi-layer perceptron (MLP)

This notebook is intentionally **minimal**. The goal is to make the *flow* obvious:

1. Load data
2. Define model (architecture)
3. Define loss + optimiser
4. Train with backprop (`loss.backward()`)
5. Evaluate with accuracy

You should be able to point to each of those steps in the code.

## 1. Imports + device

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm.auto import tqdm

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# Simple reproducibility (optional)
torch.manual_seed(0)

## 2. Data (MNIST)

We split the original 60k training set into **train (50k)** and **val (10k)**. Test is the standard **10k** test set.

We keep `num_workers=0` for Colab stability.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean/std
])

full_train = datasets.MNIST(root="data", train=True, download=True, transform=transform)
test_set   = datasets.MNIST(root="data", train=False, download=True, transform=transform)

train_set, val_set = random_split(full_train, [50_000, 10_000], generator=torch.Generator().manual_seed(0))

batch_size = 128
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, num_workers=0)

print("train:", len(train_set), "val:", len(val_set), "test:", len(test_set))

### Visualise a small batch (optional)

In [None]:
import torchvision

x, y = next(iter(train_loader))

# Un-normalise roughly for display
x_disp = x * 0.3081 + 0.1307
grid = torchvision.utils.make_grid(x_disp[:32], nrow=8)

plt.figure(figsize=(8,8))
plt.imshow(grid.permute(1,2,0).numpy())
plt.axis("off")
plt.show()

print("labels:", y[:32].numpy())

## 3. Model (architecture)

Architecture is set here. This MLP is:

`784 → hidden → hidden → 10`

Hidden layers use ReLU. Output is **logits** (no softmax here).

In [None]:
class MLP(nn.Module):
    def __init__(self, hidden=256):
        super().__init__()
        self.fc0 = nn.Linear(28*28, hidden)
        self.fc1 = nn.Linear(hidden, hidden)
        self.fc2 = nn.Linear(hidden, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)      # flatten: (B, 1, 28, 28) -> (B, 784)
        x = F.relu(self.fc0(x))
        x = F.relu(self.fc1(x))
        logits = self.fc2(x)
        return logits

net = MLP(hidden=256).to(device)
print(net)

## 4. Loss and optimiser

- Loss: `CrossEntropyLoss` (takes logits + labels)
- Optimiser: AdamW (good modern default)

Backprop is triggered by `loss.backward()` in the next section.

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-2)

## 5. Training loop

This is the main flow:

- forward pass
- compute loss
- `loss.backward()` (backprop)
- `optimizer.step()` (update weights)

We report validation accuracy each epoch.

In [None]:
def accuracy_from_logits(logits, y):
    pred = logits.argmax(dim=1)
    return (pred == y).float().mean().item()

@torch.no_grad()
def evaluate(loader):
    net.eval()
    total_loss = 0.0
    total_acc = 0.0
    n = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = net(x)
        loss = criterion(logits, y)

        bs = x.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy_from_logits(logits, y) * bs
        n += bs

    return total_loss / n, total_acc / n

epochs = 5

for epoch in range(1, epochs + 1):
    net.train()
    running_loss = 0.0
    n = 0
    t0 = time.time()

    for x, y in tqdm(train_loader, desc=f"epoch {epoch}/{epochs}", leave=False):
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()

        # forward
        logits = net(x)
        loss = criterion(logits, y)

        # backprop
        loss.backward()

        # update
        optimizer.step()

        bs = x.size(0)
        running_loss += loss.item() * bs
        n += bs

    train_loss = running_loss / n
    val_loss, val_acc = evaluate(val_loader)

    print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc*100:.2f}% | {time.time()-t0:.1f}s")

print("Finished training")

## 6. Test set performance

In [None]:
test_loss, test_acc = evaluate(test_loader)
print(f"TEST | loss={test_loss:.4f} | acc={test_acc*100:.2f}%")

## 7. Softmax (optional)

Training does **not** need softmax here. If you want probabilities, apply softmax to logits at inference time.

In [None]:
@torch.no_grad()
def predict_proba(x):
    net.eval()
    logits = net(x.to(device))
    return torch.softmax(logits, dim=1).cpu()

x, y = next(iter(test_loader))
probs = predict_proba(x[:8])
pred = probs.argmax(dim=1)

print("pred:", pred.numpy())
print("true:", y[:8].numpy())
print("probs[0]:", probs[0].numpy())

## 8. Experiments

Try these (one at a time):

1. Change `hidden=256` to 128, 64, 32. When does accuracy start to drop?
2. Change optimiser to SGD:
   - `optim.SGD(..., lr=0.01, momentum=0.9)`
3. Change activation to `tanh` by replacing `F.relu` with `torch.tanh`.
4. Change epochs to 1, 3, 10 and compare.

## 9. Pixel permutations (key conceptual experiment)

Fully connected networks **do not use spatial structure**.
They treat the input as a vector, not as an image.

Does **any fixed permutation of the pixels** give essentially the *same performance*?

- ensure the permutation is applied consistently
- the same permutation is used for train, validation, and test

This is very different from tokenized networks like CNNs and Transformers, which *do* rely on
local spatial structure.

### Step 1: Create a fixed random pixel permutation

We create a single random permutation of the 784 input dimensions and
keep it fixed for the entire experiment.

In [None]:
# Create a fixed random permutation of pixel indices
perm = torch.randperm(28 * 28)

def permute_pixels(x):
    # x: (B, 1, 28, 28)
    x = x.view(x.size(0), -1)      # (B, 784)
    x = x[:, perm]                 # apply permutation
    x = x.view(-1, 1, 28, 28)      # reshape back to image form
    return x

### Step 2: Visualise original vs permuted images

The permuted images no longer look like digits to *us*,
but to an MLP they are just vectors.

In [None]:
# Visualise original vs permuted images
x, y = next(iter(train_loader))

x_perm = permute_pixels(x)

def show_side_by_side(x1, x2, labels, n=8):
    import torchvision
    x1_disp = x1 * 0.3081 + 0.1307
    x2_disp = x2 * 0.3081 + 0.1307

    grid1 = torchvision.utils.make_grid(x1_disp[:n], nrow=n)
    grid2 = torchvision.utils.make_grid(x2_disp[:n], nrow=n)

    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.title("Original")
    plt.imshow(grid1.permute(1,2,0).numpy())
    plt.axis("off")

    plt.subplot(1,2,2)
    plt.title("Permuted pixels")
    plt.imshow(grid2.permute(1,2,0).numpy())
    plt.axis("off")

    plt.show()
    print("labels:", labels[:n].numpy())

show_side_by_side(x, x_perm, y)

### Step 3: Challenge

Modify the training code so that **all inputs are permuted** before being
passed to the network.

Hints:
- Apply `permute_pixels(x)` inside the training loop
- Apply the same permutation in `evaluate(...)`
- Do *not* change the model architecture

Questions to think about:
1. Does training accuracy change?
2. Does validation/test accuracy change?
3. Why does this work for MLPs but fail badly for CNNs?
4. What does this say about inductive bias?

**Takeaway**

For fully connected networks, pixel order is arbitrary.
All spatial meaning comes from the *data representation*, not the model.

This experiment is one of the cleanest ways to see the difference between
*representation* and *architecture*.