# Vision Transformer (ViT) from Scratch

Building the full ViT architecture in PyTorch — patch embeddings, CLS token, positional encoding, and stacked transformer encoder blocks. Uses the `MultiHeadAttention` built in notebook 01.

Reference: [An Image is Worth 16×16 Words — Dosovitskiy et al., 2020](https://arxiv.org/abs/2010.11929) · [CS231N Lecture 12](https://cs231n.stanford.edu/slides/2024/lecture_12.pdf)

## Overview

**Builds:** `PatchEmbedding` · `TransformerEncoderBlock` · `VisionTransformer`
**Concepts:** patch embeddings · Conv2d trick · CLS token · positional encoding · pre-norm · ViT-B/16
**Requires:** `01_self_attention.ipynb` — uses `MultiHeadAttention` built there
**References:** [Dosovitskiy et al., 2020](https://arxiv.org/abs/2010.11929) · [CS231N Lecture 12](https://cs231n.stanford.edu/slides/2024/lecture_12.pdf)
**Output:** `assets/gradcam/patch_grid.png` · `assets/gradcam/cls_attention_map.png`

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib
from einops import rearrange

matplotlib.rcParams['figure.dpi'] = 120
torch.manual_seed(42)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device: {device}")

## Setup — MultiHeadAttention from Notebook 01

Copied here so this notebook is self-contained. Built from scratch in `01_self_attention.ipynb`.

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    attn_weights = F.softmax(scores, dim=-1)
    return torch.matmul(attn_weights, v), attn_weights


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k       = d_model // num_heads
        self.num_heads = num_heads
        self.d_model   = d_model
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        B, T, _ = x.shape
        return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)

    def forward(self, x, mask=None):
        B, T, _ = x.shape
        q, k, v = self.split_heads(self.W_q(x)), self.split_heads(self.W_k(x)), self.split_heads(self.W_v(x))
        out, attn = scaled_dot_product_attention(q, k, v, mask)
        out = out.transpose(1, 2).reshape(B, T, self.d_model)
        return self.W_o(out), attn

## Part 1 — Patch Embedding

ViT treats an image as a sequence of fixed-size patches. A 224×224 image with 16×16 patches produces **(224/16)² = 196 patch tokens** — the sequence the transformer operates on.

**The Conv2d trick:** Instead of slicing patches manually and flattening, a single `Conv2d` with `kernel_size=patch_size` and `stride=patch_size` does the same thing in one GPU-efficient operation. Each filter learns one dimension of the patch projection.

```
Input:   (B, 3, 224, 224)
Conv2d:  (B, 768, 14, 14)   ← 14 = 224/16
Flatten: (B, 768, 196)
Permute: (B, 196, 768)       ← 196 tokens, each 768-dim
```

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size=224, patch_size=16, in_channels=3, d_model=768):
        super().__init__()
        assert image_size % patch_size == 0, "Image size must be divisible by patch size"
        self.num_patches = (image_size // patch_size) ** 2
        # Conv2d trick: kernel and stride = patch_size → each kernel covers exactly one patch
        self.proj = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)             # (B, d_model, H/P, W/P)
        x = x.flatten(2)             # (B, d_model, num_patches)
        x = x.transpose(1, 2)        # (B, num_patches, d_model)
        return x

In [None]:
# Verify patch embedding shapes
B, C, H, W = 2, 3, 224, 224
patch_size = 16
d_model    = 768

x   = torch.randn(B, C, H, W)
pe  = PatchEmbedding(image_size=H, patch_size=patch_size, d_model=d_model)
out = pe(x)

print(f"Input image:       {x.shape}")
print(f"After Conv2d:      {pe.proj(x).shape}")    # (2, 768, 14, 14)
print(f"Patch embeddings:  {out.shape}")             # (2, 196, 768)
print(f"Number of patches: {pe.num_patches}")        # 196 = 14×14

In [None]:
# Visualize: show the 196 patches on the original image grid
img = torch.rand(3, 224, 224)  # dummy image

# Rearrange image into patches using einops
patches = rearrange(img, 'c (h p1) (w p2) -> (h w) c p1 p2', p1=patch_size, p2=patch_size)

fig, axes = plt.subplots(14, 14, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(patches[i].permute(1, 2, 0).numpy())
    ax.axis('off')

plt.suptitle(f'224×224 image split into 196 patches of 16×16 each', y=1.01)
plt.tight_layout()
plt.savefig('../assets/gradcam/patch_grid.png', bbox_inches='tight', dpi=150)
plt.show()

## Part 2 — CLS Token and Positional Encoding

**CLS token:** A learned vector prepended to the patch sequence. It has no corresponding image region — its role is to aggregate information from all patches through attention. After the final transformer block, only the CLS token's output is passed to the classifier.

**Why not average all patch outputs?** Averaging loses the ability to weight patches differently. The CLS token learns, through attention, to selectively aggregate what matters for classification.

**Positional encoding:** Transformers have no notion of order — attention is permutation-invariant. Without positional encoding, the model treats patch token 0 (top-left) the same as patch token 195 (bottom-right). ViT uses learned positional embeddings, one per position (197 total: 1 CLS + 196 patches).

```
Patches:      (B, 196, 768)
CLS prepend:  (B, 197, 768)
Pos encoding: (B, 197, 768)  ← added, not concatenated
```

In [None]:
# Demonstrate CLS token + positional encoding
num_patches = 196
d_model     = 768
B           = 2

patch_tokens = torch.randn(B, num_patches, d_model)

# CLS token: one learnable vector, expanded across batch
cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
cls_expanded = cls_token.expand(B, -1, -1)             # (B, 1, 768)
tokens = torch.cat([cls_expanded, patch_tokens], dim=1) # (B, 197, 768)

# Positional encoding: one per position, learned
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, d_model))
tokens = tokens + pos_embed                             # (B, 197, 768)

print(f"Patch tokens:       {patch_tokens.shape}")
print(f"After CLS prepend:  {torch.cat([cls_expanded, patch_tokens], dim=1).shape}")
print(f"After pos encoding: {tokens.shape}")
print(f"\nCLS token is position 0. Patch tokens are positions 1-196.")

## Part 3 — Transformer Encoder Block

Each block applies two operations, each with a residual connection:

```
x = x + MultiHeadAttention(LayerNorm(x))   ← attention + residual
x = x + MLP(LayerNorm(x))                  ← feed-forward + residual
```

**Pre-norm vs post-norm:** ViT uses LayerNorm *before* each sub-layer (pre-norm), not after. Pre-norm stabilizes training in deep transformers.

**MLP:** Two linear layers with GELU activation. Hidden dim = `d_model × mlp_ratio` (typically 4×). For ViT-B: 768 → 3072 → 768.

**GELU vs ReLU:** GELU is smoother near zero and empirically works better in transformers.

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1   = nn.LayerNorm(d_model)
        self.attn    = MultiHeadAttention(d_model, num_heads)
        self.norm2   = nn.LayerNorm(d_model)
        self.mlp     = nn.Sequential(
            nn.Linear(d_model, d_model * mlp_ratio),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * mlp_ratio, d_model),
            nn.Dropout(dropout),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Attention sub-layer with pre-norm and residual
        attn_out, attn_weights = self.attn(self.norm1(x))
        x = x + self.dropout(attn_out)

        # MLP sub-layer with pre-norm and residual
        x = x + self.dropout(self.mlp(self.norm2(x)))

        return x, attn_weights

In [None]:
# Verify encoder block shapes
block = TransformerEncoderBlock(d_model=768, num_heads=12)
x     = torch.randn(2, 197, 768)  # (batch, 197 tokens, 768 dim)

out, attn = block(x)
print(f"Input:        {x.shape}")
print(f"Output:       {out.shape}")   # same shape — residual connection preserves it
print(f"Attn weights: {attn.shape}")  # (2, 12 heads, 197, 197)

## Part 4 — Full Vision Transformer

Stacking everything: patch embedding → CLS + positional encoding → N encoder blocks → final LayerNorm → CLS token → classifier.

**ViT-B/16 configuration (what we build):**
- `image_size=224, patch_size=16` → 196 patches
- `d_model=768, num_heads=12, num_layers=12`
- `mlp_ratio=4` → MLP hidden dim = 3072
- ~86M parameters

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size=224, patch_size=16, in_channels=3,
                 num_classes=10, d_model=768, num_heads=12, num_layers=12,
                 mlp_ratio=4, dropout=0.1):
        super().__init__()
        num_patches = (image_size // patch_size) ** 2

        self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, d_model)
        self.cls_token   = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_embed   = nn.Parameter(torch.zeros(1, num_patches + 1, d_model))
        self.dropout     = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(d_model, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.shape[0]

        # 1. Patch embedding
        x = self.patch_embed(x)                          # (B, 196, 768)

        # 2. Prepend CLS token
        cls = self.cls_token.expand(B, -1, -1)           # (B, 1, 768)
        x   = torch.cat([cls, x], dim=1)                 # (B, 197, 768)

        # 3. Add positional encoding
        x = self.dropout(x + self.pos_embed)             # (B, 197, 768)

        # 4. Transformer encoder blocks
        attn_weights_all = []
        for block in self.blocks:
            x, attn = block(x)
            attn_weights_all.append(attn)

        # 5. Final norm
        x = self.norm(x)                                 # (B, 197, 768)

        # 6. CLS token → classifier
        logits = self.head(x[:, 0])                      # (B, num_classes)

        return logits, attn_weights_all

In [None]:
# Forward pass — verify every shape
vit = VisionTransformer(
    image_size=224, patch_size=16, in_channels=3,
    num_classes=10, d_model=768, num_heads=12, num_layers=12
)

x = torch.randn(2, 3, 224, 224)  # batch of 2 images
logits, attn_weights_all = vit(x)

print(f"Input:                     {x.shape}")
print(f"Output logits:             {logits.shape}")                     # (2, 10)
print(f"Number of encoder blocks:  {len(attn_weights_all)}")
print(f"Attn weights per block:    {attn_weights_all[0].shape}")        # (2, 12, 197, 197)
print(f"\nCLS token attends to 196 patches + itself = 197 positions")

In [None]:
# Parameter count
total = sum(p.numel() for p in vit.parameters())
trainable = sum(p.numel() for p in vit.parameters() if p.requires_grad)

print(f"Total parameters:     {total:,}")
print(f"Trainable parameters: {trainable:,}")
print(f"Model size (FP32):    {total * 4 / 1e6:.1f} MB")
print(f"\nViT-B/16 reference:   86M parameters — expected range: 85-87M")

## Shape Summary

| Tensor | Shape | Notes |
|---|---|---|
| Input image | `(B, 3, 224, 224)` | — |
| After Conv2d (patch embed) | `(B, 768, 14, 14)` | 14 = 224 / 16 |
| Patch tokens | `(B, 196, 768)` | 196 = 14×14 patches |
| After CLS prepend | `(B, 197, 768)` | position 0 = CLS token |
| After positional encoding | `(B, 197, 768)` | same shape — values shifted |
| Encoder block output | `(B, 197, 768)` | residuals preserve shape through all 12 blocks |
| Attention weights per block | `(B, 12, 197, 197)` | 12 heads, each token attends to 197 positions |
| Final CLS token | `(B, 768)` | `x[:, 0]` — only this goes to classifier |
| Logits | `(B, 10)` | 10 scene classes (SUN397 subset) |
| Parameters (ViT-B/16) | ~86M | patch_embed + pos_embed + 12 blocks + head |

In [None]:
# Visualize CLS token attention from the last block
# Row 0 = how CLS token attends to all 196 patches
last_attn = attn_weights_all[-1][0].detach()  # (12 heads, 197, 197)
cls_attn  = last_attn[:, 0, 1:]               # (12, 196) — CLS row, patch columns only

fig, axes = plt.subplots(3, 4, figsize=(14, 10))
for h, ax in enumerate(axes.flat):
    attn_map = cls_attn[h].reshape(14, 14).numpy()
    im = ax.imshow(attn_map, cmap='hot')
    ax.set_title(f'Head {h+1} — CLS attention')
    ax.axis('off')
    plt.colorbar(im, ax=ax)

plt.suptitle('CLS token attention to 14×14 patch grid (last block, random weights)', y=1.01)
plt.tight_layout()
plt.savefig('../assets/gradcam/cls_attention_map.png', bbox_inches='tight', dpi=150)
plt.show()

print("Uniform attention expected with random weights.")
print("After training, heads specialize to attend to class-relevant regions.")