# Feed-Forward Attention Classifier for CIFAR-10

This notebook builds an **attention mechanism from scratch** and uses it inside
a small image classifier for the **CIFAR-10** dataset
(10 classes, 32 × 32 colour images, 3 RGB channels).

> **Goal.** By the end of this notebook you should be able to explain how
> queries, keys, and values work, why we scale by $\sqrt{d_k}$, what a causal
> mask does, and how attention layers can be stacked into a classifier.


### Work to complete this notebook

What you will find is that for limited epochs and a smaller model size this model will likely saturate at 60-70% (validation) accuracy.

Expanding the CNN to provide more tokenized data to the Attention proportion of the model, and increasing the number of layers in the Attention component may both contribute to a better model design but that's left as an open problem to explore to see if you can get better than 80% validation accuracy. (Hint if you explore this, segmenting the input data to provide a map to the attention layers like a VisionTransformer model is likely the best approach.)

As with previous weeks this notebook is supposed to run from top → bottom.

The main goal/takeaway from this is to become familiar with Attention, but this also serves as a nice example of a classifier exercise which is difficult to "solve" with different model architectures.

Again you are to work at solving this by completing the `## FINISH_ME ##` sections of code in the notebook to make this work.
**The first four problems are due as the second part of Homework 3, by 9:30am on Friday 27 February.**

| <p align='left'> Problem  (by Section number)                | <p align='left'> Marks possible |  <p align='left'> Marks awarded |
| ------------------------------------- | --- | --- |
| <p align='left'> 1. Finish the Attention class forward calculations  | <p align='left'> 2 | |
| <p align='left'> 2. Finish assembling the AttentionClassifier model  | <p align='left'> 1 | |
| <p align='left'> 3. Finish loading the input data of CIFAR10 | <p align='left'> 1 | |
| <p align='left'> 4. Complete the train 1 step method and evaluate method | <p align='left'> 1 | |
| <p align='left'> 5.1 Plot the train/validation and calculate the test performance of the model  | <p align='left'> -- | -- |
| <p align='left'> 5.2 Calculate the per-class accuracy, interesting to see where the classifier struggles | <p align='left'> -- | -- |
| <p align='left'> 6. Plot the distribution of attention weights that light-up for a given image  | <p align='left'> -- | -- |
| <p align='left'> **Total** | <p align='left'> max **5** | |


## 0 — Imports

We need **PyTorch** for the model and training, **torchvision** for the CIFAR-10
dataset and image transforms, and **matplotlib** for plotting.

In [None]:
# Reproducibility: set global seeds and deterministic flags
# This makes training runs repeatable across restarts and machines.
import os
import random
import numpy as np

# Technical, but important for CUDA determinism with cuBLAS on CUDA >= 10.2:
# Must be set BEFORE any cuBLAS operations (GEMMs) are invoked.
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")

SEED = 42

def seed_everything(seed: int = 42):
    """
    Seed Python, NumPy, and PyTorch RNGs.
    Enable deterministic algorithms where possible.
    Note: Some GPU ops may still be non-deterministic depending on hardware.
    """
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)

    # This is loaded late in this context because we want to fix the library seeds before they're loaded
    import torch

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Make cuDNN deterministic (may reduce performance slightly)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Strict deterministic mode (falls back if unsupported ops are used)
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

seed_everything(SEED)
print(f"Reproducibility setup complete. Global seed = {SEED}")
print(f"CUBLAS_WORKSPACE_CONFIG = {os.environ.get('CUBLAS_WORKSPACE_CONFIG')}")

# Tip: For fully deterministic DataLoader shuffling across runs,
# pass a seeded generator to DataLoader (e.g., generator=torch.Generator().manual_seed(SEED)).

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import math
import matplotlib.pyplot as plt
from tqdm.auto import tqdm           # progress bars for training loops

# Automatically use the GPU if one is available, otherwise fall back to CPU.
# Training on a GPU is ~10–50× faster for this model.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1 — The Attention Layer (from scratch)

Attention lets every position in a sequence "look at" every other position and
decide how much information to take from each.  We implement
**single-head scaled dot-product attention** in four clear steps.

Given an input sequence $X \in \mathbb{R}^{B \times N \times D_{\text{in}}}$
(a batch of $B$ sequences, each with $N$ tokens of dimension $D_{\text{in}}$):

---

**Step 1 — Project to Q, K, V.**
Three learned weight matrices create *queries*, *keys*, and *values*:

$$Q = XW_Q, \quad K = XW_K, \quad V = XW_V \qquad \bigl(W_\cdot \in \mathbb{R}^{D_{\text{in}} \times d_k}\bigr)$$

- **Query** = "what am I looking for?"
- **Key** = "what do I contain?"
- **Value** = "what information do I provide?"

**Step 2 — Compute raw scores.**
The dot product $QK^\top$ measures similarity between every query–key pair.
We scale by $\sqrt{d_k}$ to keep the variance stable (without this, large
$d_k$ pushes softmax into a near-one-hot regime and gradients vanish):

$$\text{scores} = \frac{Q\,K^\top}{\sqrt{d_k}} \;\in\; \mathbb{R}^{N \times N}$$

**Step 3 — Apply a causal (forward) mask.**
We set "future" entries to $-\infty$ so they become **exactly zero** after
softmax — the model can only attend to earlier (or same) positions:

$$\text{scores}[i,j] \;\leftarrow\; \begin{cases} \text{scores}[i,j] & \text{if } j \le i \\[2pt] -\infty & \text{if } j > i \end{cases}$$

**Step 4 — Softmax.**
Row-wise softmax converts scores into a probability distribution (rows sum to 1).
$$\text{weights}_{attention} = \text{softmax}(\text{scores})$$

**Step 5 — Output is weighted sum of V.**
Each output token is then a *weighted average* of the value vectors it attends to:

$$\text{output} = \text{weights}_{attention} . \;V$$

In [None]:
class Attention(nn.Module):
    """
    Single-head scaled dot-product attention with an explicit causal mask.

    This is the core building block of transformer-style models.
    We implement it step-by-step so every mathematical operation is visible.

    Parameters
    ----------
    d_in    : int – dimension of each input token
    d_k     : int – internal dimension for queries, keys, and values
    seq_len : int – sequence length (needed to pre-build the mask)
    """

    def __init__(self, d_in: int, d_k: int, seq_len: int = 16):
        super().__init__()
        self.d_k = d_k

        # ----- Learned projection matrices (no bias, to keep the maths clean) -----
        # Each is a matrix W ∈ ℝ^{d_in × d_k}.
        # nn.Linear(d_in, d_k, bias=False) multiplies input by W^T internally.
        self.W_q = nn.Linear(d_in, d_k, bias=False)   # Query  projection
        self.W_k = nn.Linear(d_in, d_k, bias=False)   # Key    projection
        self.W_v = nn.Linear(d_in, d_k, bias=False)   # Value  projection

        # ----- Build the causal mask ONCE at initialisation time -----
        #
        #   We want an N×N matrix where:
        #       mask[i, j] =  0     if j ≤ i   (position j is in the "past" — allowed)
        #                  = -∞     if j > i   (position j is in the "future" — blocked)
        #
        #   Adding -∞ to a score before softmax forces that entry to exp(-∞) = 0,
        #   so the model cannot attend to future positions.
        #
        #   torch.triu(matrix, diagonal=1) extracts the strictly upper-triangular
        #   part — exactly the "future" positions.  This is equivalent to:
        #
        #       for i in range(N):
        #           for j in range(N):
        #               if j > i:
        #                   mask[i, j] = -inf
        #
        mask = torch.zeros(seq_len, seq_len)
        mask = mask + torch.triu(
            torch.full((seq_len, seq_len), float('-inf')),   # fill with -∞
            diagonal=1                                        # above main diagonal
        )

        # register_buffer keeps the mask on the same device as the model
        # (it moves to GPU automatically) but it is NOT a learnable parameter.
        # Look up how to use it (and how to properly call the 'mask') in the pytorch documentation.
        self.register_buffer('mask', mask)

    # -----------------------------------------------------------------
    def forward(self, x):
        """
        Forward pass — implements the four steps described in the markdown above.

        x : (batch, seq_len, d_in)   – input token embeddings
        returns : (batch, seq_len, d_k) – attended output
        """

        # ── Step 1: Project input into Q, K, V ──────────────────────
        Q = self.W_q(x)                          # (B, N, d_k)  — queries
        K = ## FINISH_ME ##                      # (B, N, d_k)  — keys
        V = ## FINISH_ME ##                      # (B, N, d_k)  — values

        # ── Step 2: Raw attention scores ─────────────────────────────
        # scores[b, i, j] = how much query i attends to key j
        K_T = K.transpose(-2, -1)                # (B, d_k, N)  — transpose keys for dot product
        scores = ## FINISH_ME ##  matmul step    # (B, N, N) — dot products
        scores = ## FINISH_ME ##  scale down     # scale to stabilise softmax

        # ── Step 3: Apply the causal mask ────────────────────────────
        # Broadcasting: mask is (N, N), scores is (B, N, N) — mask is
        # added identically across every sample in the batch.
        # For a classifier this mask is not strictly necessary, since the model doesn't need to
        # attend to future positions.  However, we include it here for pedagogical clarity and
        # to show how masks are applied in general (e.g. for autoregressive language models).
        scores = ## FINISH_ME ##  add mask to scores (broadcasting)  # (B, N, N) — masked scores

        # ── Step 4: Softmax → attention weights, then weighted sum ───
        weights_attn = torch.softmax(scores, dim=-1)  # (B, N, N)  — rows sum to 1

        # ── Step 5: Weighted sum of values ───────────────────────────────
        # This is the "attended output" — each output vector is a weighted average of the value vectors,
        # where the weights are determined by the attention mechanism.
        output  = ## FINISH_ME ## weighted value out  # (B, N, d_k) — weighted average

        return output

### Quick sanity check

Let's pass a tiny random tensor through the `Attention` layer to make sure the
shapes are correct before we build anything bigger.

We create a batch of 2 sequences, each with 4 tokens of dimension 8, and
project into an internal dimension of 6.

In [None]:
# Tiny test:  batch=2, seq_len=4, d_in=8, d_k=6
attn = Attention(d_in=8, d_k=6, seq_len=4)
x_test = torch.randn(2, 4, 8)         # random input
y_test = attn(x_test)                  # forward pass

print("Input shape :", x_test.shape)   # expect (2, 4, 8)
print("Output shape:", y_test.shape)   # expect (2, 4, 6)  — d_in → d_k

## 2 — The Classifier

Now we wrap the attention layer inside a complete image classifier.

### Why a CNN stem?

Raw pixels are not great tokens — a single pixel has no spatial context,
and attention is **permutation-equivariant** (it has no built-in notion of
"left" or "right").  A small convolutional stem solves two problems:

1. **Local feature extraction** — convolutions detect edges, textures, and
   colour gradients before attention sees the data.
2. **Translation invariance** — a cat near the top-left should be classified
   the same as a cat near the bottom-right; convolutions give us that for free.

The CNN reduces the 32 × 32 image to an 8 × 8 feature map with 64 channels.
We then **reshape** that map into a sequence of **64 tokens** (one per spatial
position), each of dimension **64** — ready for the attention layers.

### Architecture diagram

```
Input image   (3 × 32 × 32)                   ← 3 RGB channels
      │
  Conv2d 3 → 32, 3×3, pad 1  +  ReLU          (32 × 32 × 32)
      │
  MaxPool 2×2                                  (32 × 16 × 16)
      │
  Conv2d 32 → 64, 3×3, pad 1  +  ReLU         (64 × 16 × 16)
      │
  MaxPool 2×2                                  (64 × 8  × 8)
      │
  Reshape → 64 tokens × 64 dims               ← one token per spatial location
      │
  Attention Layer 1  (d_in=64, d_k=64)
      │
  Residual connection + LayerNorm + SiLU
      │
  Attention Layer 2  (d_in=64, d_k=64)
      │
  Residual connection + LayerNorm + SiLU
      │
  Mean-pool over the 64 tokens → 64-dim vector
      │
  Linear 64 → 10  (one logit per class)
```

### Design choices

| Choice | Reason |
|---|---|
| **Residual connections** (`x = x + attn(x)`) | Lets gradients flow directly through skip paths — essential for training deeper networks. |
| **LayerNorm** | Normalises each token independently; stabilises training and speeds up convergence. |
| **SiLU activation** ($x \cdot \sigma(x)$) | Smooth, non-monotonic non-linearity; works well in transformer-style architectures. |
| **Mean-pool** (not max-pool) | Aggregates information from *all* tokens equally before classification. |
| **Weight initialisation** | Kaiming for conv (matches ReLU), Xavier for linear (good general default). |

In [None]:
class AttentionClassifier(nn.Module):
    """
    Image classifier:  CNN stem  →  two Attention layers  →  linear head.

    Parameters
    ----------
    num_classes : int  – number of output classes (10 for CIFAR-10)
    d           : int  – token / channel dimension throughout the model
    """

    def __init__(self, num_classes: int = 10, d: int = 64):
        super().__init__()
        self.d = d

        # ── CNN stem ─────────────────────────────────────────────────
        # Two conv layers with ReLU activations and 2×2 max-pooling.
        # Input: (B, 3, 32, 32)  →  Output: (B, d, 8, 8)
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=5, padding=2),      # 3 → 3 channels
            nn.BatchNorm2d(3),                              # normalise after conv
            nn.Conv2d(3, 32, kernel_size=3, padding=1),     # 3 → 32 channels
            nn.ReLU(),                                      # non-linearity
            nn.MaxPool2d(2),                                # halve spatial dims → 16×16
            nn.Conv2d(32, d, kernel_size=3, padding=1),     # 32 → 64 channels
            nn.BatchNorm2d(d),                              # normalise after conv
            nn.ReLU(),                                      # non-linearity
            nn.MaxPool2d(2),                                # halve again → 8×8
        )

        # The 8×8 feature map will become 64 tokens of dimension d
        self.seq_len = 8 * 8   # = 64 tokens

        # ── Two attention layers ─────────────────────────────────────
        self.attn1 = Attention(d_in=d, d_k=d, seq_len=self.seq_len)
        self.attn2 = Attention(d_in=d, d_k=d, seq_len=self.seq_len)

        # ── Layer normalisation (one per attention block) ────────────
        # LayerNorm normalises across the feature dimension of each token,
        # keeping different tokens independent of each other.
        self.norm1 = nn.LayerNorm(d)
        self.norm2 = nn.LayerNorm(d)

        # ── Classification head ──────────────────────────────────────
        # Maps the 64-dimensional pooled vector to 10 class logits.
        self.classifier = nn.Linear(d, num_classes)

        # Apply explicit weight initialisation
        self._init_weights()

    # -----------------------------------------------------------------
    def _init_weights(self):
        """
        Set weights to sensible starting values.

        Good initialisation prevents early training instability:
          • Conv2d    → Kaiming (He) normal    – matches ReLU's nonlinearity
          • Linear    → Xavier (Glorot) normal – good general-purpose default
          • LayerNorm → weight = 1, bias = 0   – starts as the identity transform
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    # -----------------------------------------------------------------
    def forward(self, x):
        """
        Full forward pass:  image → CNN → tokens → attention → class logits.

        x       : (B, 3, 32, 32) – batch of CIFAR-10 images
        returns : (B, 10)         – raw class scores (logits)
        """
        B = x.shape[0]

        # ── 1. CNN stem ──────────────────────────────────────────────
        x = self.cnn(x)                              # (B, 64, 8, 8)

        # ── 2. Reshape feature map → token sequence ──────────────────
        # We treat each of the 8×8 = 64 spatial positions as a "token"
        # with d = 64 features, giving a sequence the attention layer
        # can process.
        x = ## FINISH_ME ## Flatten CNN output       # (B, 64, 64) — channels first NB: view(..., -1) reduces by 1-dim
        x = x.transpose(1, 2)                        # (B, 64, 64) — tokens first

        # ── 3. Attention block 1 ─────────────────────────────────────
        x = x + self.attn1(x)                        # residual connection
        x = self.norm1(x)                            # normalise
        x = torch.nn.functional.silu(x)              # SiLU non-linearity

        # ── 4. Attention block 2 ─────────────────────────────────────
        x = ## FINISH_ME ##                          # residual connection
        x = ## FINISH_ME ##                          # normalise
        x = torch.nn.functional.silu(x)              # SiLU non-linearity

        # ── 5. Aggregate tokens → single vector ─────────────────────
        x = x.mean(dim=1)                            # (B, 64) — mean over tokens

        # ── 6. Classify ─────────────────────────────────────────────
        logits = self.classifier( ## FINISH_ME ##    # (B, 10)
        return logits

### Sanity check — shapes end-to-end

Before training, we pass a dummy batch through the full model to confirm that
every layer's output shape is correct and count the total number of learnable
parameters.

In [None]:
model = AttentionClassifier().to(device)

# Create a fake batch of 4 CIFAR-10-shaped images (3 channels, 32×32)
dummy = torch.randn(4, 3, 32, 32, device=device)
out   = model(dummy)

print("Model output shape:", out.shape) # expect (4, 10)
print("Total parameters  :", f"{sum(p.numel() for p in model.parameters()):,}")

## 3 — Load CIFAR-10 (train / validation / test)

[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) contains
50 000 training and 10 000 test images (32 × 32 pixels, 3 RGB channels)
across 10 everyday object classes.

### Pre-processing

Each channel is normalised to approximately zero mean and unit standard
deviation using statistics computed over the full training set.  This
keeps all channels on the same scale and helps the optimiser converge
faster.

### Train / validation / test split

We reserve **5 000** of the 50 000 training images as a **validation set**
(for monitoring over-fitting during training).  The test set is only
touched **once**, at the very end, for a fair evaluation.

### Data augmentation (training only)

Augmentation artificially increases the effective size and diversity of
the training set, which reduces over-fitting:

| Transform | Effect |
|---|---|
| `RandomHorizontalFlip()` | Mirror left ↔ right with 50 % probability |
| `RandomRotation(5)` | Rotate by a random angle in [−5°, +5°] |
| `RandomCrop(32, padding=4)` | Pad 4 px on each side, then crop back to 32 × 32 at a random offset |

> **Note:** Augmentation is applied *only* to the training set.
> Validation and test images use only `ToTensor` + `Normalize`.

In [None]:
# ═══════════════════════════════════════════════════════════════════════
# Transforms
# ═══════════════════════════════════════════════════════════════════════

# There are various transforms which we can add to the training pipeline
#
#    transforms.RandomHorizontalFlip(),          # 50 % chance of mirror
#    transforms.RandomRotation(10),               # random ±10° rotation
#    transforms.RandomCrop(32, padding=4),       # random translation (±4 px)
#    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), # random brightness, contrast, saturation, hue
#

# Training pipeline: augment data on load, then convert to tensor.
# Augmentation is applied randomly each time an image is loaded, so the
# model never sees the exact same version of an image twice.
train_transform = transforms.Compose([
    
    ## FINISH_ME ##   Add at least 2 transforms to the dataset when being tained

    transforms.ToTensor()                      # PIL → tensor, scale to [0, 1]
])

# Validation / test pipeline: no augmentation — just convert to tensor.
eval_transform = transforms.Compose([
    transforms.ToTensor()
])

# ═══════════════════════════════════════════════════════════════════════
# Datasets
# ═══════════════════════════════════════════════════════════════════════

# Load the full training set with eval_transform first, because the
# validation split should NOT be augmented.
full_train_set = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=eval_transform)

test_set = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=eval_transform)

# Reproducible 45 000 / 5 000 split (seeded for consistency across runs).
train_subset, val_subset = torch.utils.data.random_split(
    full_train_set, [45000, 5000],
    generator=torch.Generator().manual_seed(SEED))


class AugmentedSubset(torch.utils.data.Dataset):
    """
    Thin wrapper that overrides the transform on a Subset.

    The underlying dataset uses eval_transform (no augmentation).
    This wrapper converts each image back to PIL, then applies the
    augmented training transform instead.
    """
    def __init__(self, subset, transform):
        self.subset    = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        img, label = self.subset[idx]                # img is already a tensor
        img = transforms.ToPILImage()(img)           # tensor → PIL image
        img = self.transform(img)                    # apply augmented pipeline
        return img, label


# Wrap only the training subset with augmentation.
train_set = AugmentedSubset(train_subset, train_transform)
val_set   = val_subset    # no augmentation — uses eval_transform from the parent

# ═══════════════════════════════════════════════════════════════════════
# Data loaders
# ═══════════════════════════════════════════════════════════════════════
# num_workers  : number of processes for loading data in parallel
#                set to 2 if running entirely on a CPU; try 4 if running on GPU
# pin_memory   : faster CPU → GPU transfer on CUDA devices
# persistent_workers : keep worker processes alive between epochs
loader_kwargs = dict(num_workers=2, pin_memory=True, persistent_workers=True)

# Seeded generator → deterministic DataLoader shuffling across runs
g = torch.Generator().manual_seed(SEED)

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=128, shuffle=True, generator=g, **loader_kwargs)

val_loader = torch.utils.data.DataLoader(
    val_set, batch_size=256, shuffle=False, generator=g, **loader_kwargs)

test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=256, shuffle=False, **loader_kwargs)

# Human-readable class names (in label-index order)
classes = ('Plane', 'Car', 'Bird', 'Cat', 'Deer',
           'Dog', 'Frog', 'Horse', 'Ship', 'Truck')

print(f"Training   samples : {len(train_set):,}  (with augmentation)")
print(f"Validation samples : {len(val_set):,}")
print(f"Test       samples : {len(test_set):,}")

### Explore the dataset

Before training, it's good practice to **look at your data**.  We'll check:

1. **Sample images** — one from each class, so we know what we're dealing with.
2. **Class balance** — are all classes equally represented?
3. **Pixel statistics** — verify the per-channel means and standard deviations
   that we use for normalisation.

In [None]:
# ── 1. Show one sample image per class ───────────────────────────────

fig, axes = plt.subplots(2, 5, figsize=(12, 5))
axes = axes.ravel()

# Load the raw dataset (no normalisation) so colours look natural
raw_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=False, transform=transforms.ToTensor())

shown = set()
for img, label in raw_dataset:
    if label not in shown:
        ax = axes[ ## FINISH_ME ## Make sure we get the correct sub-plot
        ax.imshow(img.permute(1, 2, 0).numpy())   # (C,H,W) → (H,W,C) for display
        ax.set_title(## FINISH_ME ## Make sure we get the correct class label, fontsize=11)
        ax.axis("off")
        shown.add(label)
    if len(shown) == 10:
        break

fig.suptitle("One sample per class", fontsize=14, y=1.02)
fig.tight_layout()
plt.show()

# ── 2. Class distribution ───────────────────────────────────────────
# A balanced dataset means the model won't be biased towards any class.

## It's worth answering, how many images of each class are there?
## Print this in a table, plotted in a bar-chart:

## FINISH ME ##



## 4 — Training Loop

### Loss function — Cross-entropy

For a classification problem with $C$ classes, **cross-entropy loss** measures
how far the model's predicted probability distribution is from the true label.
For a single sample with true class $y$:

$$\mathcal{L} = -\log \hat{p}_y$$

where $\hat{p}_y$ is the softmax probability assigned to the correct class.
Lower is better; a perfect prediction gives $\mathcal{L} = 0$.

### Optimiser — Adam

[Adam](https://arxiv.org/abs/1412.6980) maintains per-parameter running
averages of the gradient and its square, giving each weight its own adaptive
learning rate.  It usually converges faster than plain SGD.

### Learning-rate schedule — Cosine annealing

We smoothly decay the learning rate from its initial value down to near zero
over the course of training, following a half-cosine curve.  This lets the
model take large steps early on and fine-tune with small steps later.

### What we track

Each epoch we record:
- **Training loss** — how well the model fits the training data.
- **Validation loss** — how well it generalises to unseen data.
- **Validation accuracy** — the metric we actually care about.

If validation loss starts *rising* while training loss keeps *falling*, the
model is **over-fitting** — memorising training data rather than learning
generalisable patterns.

In [None]:
def train_one_epoch(model, loader, criterion, optimizer):
    """
    Train the model for one full pass over the training set.

    Returns the average loss across all training samples.
    """
    model.train()            # enable dropout / batch-norm training behaviour
    running_loss = 0.0
    total        = 0

    pbar = tqdm(loader, desc="  Training", leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()          # reset gradients from previous step
        logits = model(images)         # forward pass  → (B, 10)
        loss   = criterion(logits, labels)   # compute cross-entropy loss
        ## FINISH_ME ## Go backwards   # backpropagate gradients
        ## FINISH_ME ## Optimizer step # update weights

        # Accumulate batch loss (weighted by batch size for correct average)
        running_loss += loss.item() * images.size(0)
        total        += images.size(0)
        pbar.set_postfix(loss=f"{running_loss / total:.4f}")

    return running_loss / total

In [None]:
@torch.no_grad()   # disable gradient tracking — saves memory, runs faster
def evaluate(model, loader, criterion):
    """
    Evaluate the model on a dataset (validation or test).

    Returns (average_loss, accuracy_percent).
    """
    model.eval()             # disable dropout / set batch-norm to eval mode
    running_loss = 0.0
    correct      = 0
    total        = 0

    pbar = tqdm(loader, desc="  Evaluating", leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)                    # forward pass only
        loss   = criterion(logits, labels)

        running_loss += loss.item() * images.size(0)
        correct      += (logits.argmax(dim=1) == labels).sum().item()
        total        += labels.size(0)
        pbar.set_postfix(loss=f"{running_loss / total:.4f}",
                         acc=f"{100.0 * correct / total:.1f}%")
        
    avg_loss = ## FINISH_ME ## What's the average loss?
    pct_acc = ## FINISH_ME ## What's the average accuracy?

    return avg_loss, pct_acc

### Run training

We train for **20 epochs** with:
- **Adam** optimiser at an initial learning rate of $10^{-3}$
- **Cosine-annealing** LR schedule (smoothly decays lr to near zero)

Each epoch prints the current learning rate, training loss, validation loss,
and validation accuracy.

In [None]:
# ── Create a fresh model, loss function, and optimiser ───────────────
model     = AttentionClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

NUM_EPOCHS = 20

# Cosine-annealing schedule: lr follows a half-cosine from 1e-3 → ~0
# over NUM_EPOCHS.  This lets the model explore early and fine-tune later.
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# Dictionary to record metrics for plotting later
history = {"train_loss": [], "val_loss": [], "val_acc": []}

for epoch in range(1, NUM_EPOCHS + 1):
    current_lr = scheduler.get_last_lr()[0]
    print(f"Epoch {epoch}/{NUM_EPOCHS}  (lr = {current_lr:.6f})")

    train_loss        = train_one_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_acc = evaluate(model, val_loader, criterion)

    scheduler.step()   # update the learning rate for the next epoch

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    print(f"  → train loss = {train_loss:.4f}   "
          f"val loss = {val_loss:.4f}   "
          f"val acc  = {val_acc:.2f}%\n")

## 5 — Results

### 5.1 Learning curves

Plotting training and validation loss side by side is the best way to diagnose
how training went:

- **Both curves falling together** → the model is learning and generalising well.
- **Training loss falling, validation loss rising** → **over-fitting** — the model
  is memorising training data.
- **Both curves plateauing** → the model has converged; more epochs won't help much.

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))
epochs = range(1, NUM_EPOCHS + 1)

# ── Left panel: loss curves ──────────────────────────────────────────
ax1.plot(epochs, ## FINISH_ME ## plot the 'train_loss', marker='o', markersize=4, label='Train')
ax1.plot(epochs, ## FINISH_ME ## plot the validation loss against this,   marker='s', markersize=4, label='Validation')
ax1.set_title("Loss  (train vs. validation)")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Cross-Entropy Loss")
ax1.legend()
ax1.grid(True, alpha=0.3)

# ── Right panel: validation accuracy ─────────────────────────────────
ax2.plot(epochs, ## FINISH_ME ## plot the validation accuracy, marker='o', markersize=4, color='green')
ax2.set_title("Validation Accuracy")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy (%)")
ax2.grid(True, alpha=0.3)

fig.tight_layout()
plt.show()

### Final test evaluation

The test set has been **completely untouched** during training — neither for
weight updates nor for hyper-parameter choices.  We evaluate on it exactly
**once** to get an honest, unbiased estimate of how well the model generalises
to images it has never seen.

In [None]:
test_loss, test_acc = evaluate( ## FINISH_ME ## evaluate the model on the test set using the evaluate function we defined above for testing in the training loop)
print(f"Test loss     : {test_loss:.4f}")
print(f"Test accuracy : {test_acc:.2f}%")

### 5.2 Per-class accuracy

Overall accuracy can hide large class-level differences.  Some classes
(e.g. *Car*, *Ship*) have distinctive shapes and are easier to classify,
while others (e.g. *Cat* vs *Dog*) share similar silhouettes and textures.

In [None]:
@torch.no_grad()
def per_class_accuracy(model, loader, class_names):
    """Compute and print accuracy separately for each class."""
    model.eval()
    correct = [0] * len(class_names)
    total   = [0] * len(class_names)

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        preds = model(images).argmax(dim=1)        # predicted class indices
        for pred, label in zip(preds, labels):
            total[label]   += 1
            correct[label] += ## FINISH_ME ## 1. if pred == label else 0.

    # Print a neat table
    print(f"  {'Class':>8s}   Accuracy")
    print(f"  {'─'*8}   {'─'*8}")   
    for i, name in enumerate(class_names):
        acc = ## FINISH_ME ## using num correct and num total for this class (i) calculate accuracy as a percentage 
        print(f"  {name:>8s}   {acc:5.1f}%")

per_class_accuracy(model, test_loader, classes)

## 6 — Visualise Attention Weights

The attention weight matrix $W \in \mathbb{R}^{N \times N}$ (where $N = 64$
tokens) tells us **which spatial positions the model is attending to**.

- $W[i, j]$ is the weight that token $i$ (query) places on token $j$ (key).
- Each row sums to 1 (because of softmax).
- The causal mask means the upper triangle is all zeros — each token can only
  look at positions ≤ itself.

We extract the weight matrices from **both** attention layers and display them
as heat-maps alongside the original image.  Bright cells mean strong attention.

In [None]:
@torch.no_grad()
def get_attention_weights(attn_layer, x):
    """
    Re-run the attention computation and return the N×N weight matrix.

    We need to duplicate the forward logic (rather than just call
    attn_layer(x)) because the standard forward() only returns the
    output, not the intermediate weights.
    """
    Q = attn_layer.W_q(x)
    K = attn_layer.W_k(x)

    scores = Q @ K.transpose(-2, -1) / math.sqrt(attn_layer.d_k)
    scores = scores + attn_layer.mask                      # apply causal mask

    weights = torch.softmax(scores, dim=-1)                # (B, N, N)
    return weights


# ── Pick one test image ─────────────────────────────────────────────
images, labels = next(iter(test_loader))
img   = images[0:1].to(device)               # single image: (1, 3, 32, 32)
label = labels[0].item()

# ── Run through the CNN stem to get the token sequence ──────────────
B = 1
x = model.cnn(img)                            # (1, 64, 8, 8)
x = x.view(B, model.d, -1).transpose(1, 2)   # (1, 64, 64)  — 64 tokens

# ── Layer 1 attention weights ───────────────────────────────────────
w1 = get_attention_weights(model.attn1, x)[0].cpu()

# ── Pass through layer 1 to get input for layer 2 ──────────────────
# (Must mirror the forward pass: residual → norm → activation)
x = x + model.attn1(x)
x = model.norm1(x)
x = torch.nn.functional.silu(x)              # SiLU, matching the classifier

# ── Layer 2 attention weights ───────────────────────────────────────
w2 = get_attention_weights(model.attn2, x)[0].cpu()

# ── Plot ────────────────────────────────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Left: original image (un-normalise for display)
# Image is permuted to (H, W, C) for display, and rescaled to [0, 1] for correct colours.
raw = images[0].permute(1, 2, 0).numpy()              # (32, 32, 3)
raw = (raw - raw.min()) / ## FINISH_ME ## divide by the range of values for the raw     # rescale to [0, 1]
axes[0].imshow(raw)
axes[0].set_title(f"Input image\n(class: {classes[label]})")
axes[0].axis("off")

# Centre: layer 1 weights
## Now we need to convert from the weights tensor to a NumPy array for plotting,
## we also need to make sure we're plotting the correct layer's weights.
## The weights are stored in the variable `w1`, which we obtained from the `get_attention_weights` function
## for the first attention layer.
## We can use this variable directly in the `imshow` function to visualize the attention weights for layer 1.
im1 = axes[1].imshow(## FINISH_ME ##, cmap="viridis", aspect="auto")
axes[1].set_title("Attention weights — Layer 1")
axes[1].set_xlabel("Key position  (j)")
axes[1].set_ylabel("Query position  (i)")
fig.colorbar(im1, ax=axes[1], fraction=0.046)

# Right: layer 2 weights
im2 = axes[2].imshow(## FINISH_ME ##, cmap="viridis", aspect="auto")
axes[2].set_title("Attention weights — Layer 2")
axes[2].set_xlabel("Key position  (j)")
axes[2].set_ylabel("Query position  (i)")
fig.colorbar(im2, ax=axes[2], fraction=0.046)

fig.tight_layout()
plt.show()

## 7 — Key Take-aways

| Idea | What it does | Why it matters |
|---|---|---|
| **Q, K, V projections** | Three learned linear maps turn each token into a query, a key, and a value. | Separating "what am I looking for?" from "what do I contain?" from "what do I provide?" lets the model learn flexible interactions. |
| **Scaled dot-product** | $QK^\top / \sqrt{d_k}$ computes a pairwise similarity score matrix. | Without the $\sqrt{d_k}$ scaling, large dimensions push softmax into a near-one-hot regime and gradients vanish. |
| **Causal mask** | Setting future positions to $-\infty$ before softmax forces them to exactly zero. | The model can only attend to earlier (or same) positions — this is what makes it a *causal* / *autoregressive* mechanism. |
| **softmax(scores) · V** | Each output token is a *weighted average* of value vectors, where the weights are learned attention scores. | This is the core of attention — dynamically routing information between positions based on content similarity. |
| **Residual connections** | `x = x + attn(x)` adds the attention output back to the input. | Gives gradients a direct path through the network, making deeper models trainable. |
| **LayerNorm** | Normalises each token to zero mean / unit variance across features. | Stabilises the distribution of activations, leading to faster and more reliable training. |
| **CNN stem** | Two conv + pool layers before attention. | Extracts local features (edges, textures) and provides translation invariance that attention alone lacks. |
| **Cosine LR schedule** | Smoothly decays the learning rate over training. | Large steps early for fast exploration; small steps late for fine-tuning — often gives a few % accuracy boost. |

### What to try next

- **Remove the causal mask** — for image classification we don't need causal ordering.  How does bidirectional attention compare?
- **Multi-head attention** — split Q, K, V into $h$ parallel heads so the model can attend to different patterns simultaneously.
- **Positional encodings** — add information about *where* each token is in the sequence (attention is otherwise position-agnostic).
- **Deeper CNN stem or more attention layers** — trade compute for accuracy.
- **Learning rate warm-up** — gradually increase lr for the first few epochs before cosine decay.