[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](YOUR_COLAB_LINK_HERE)

# 00 Building Blocks From Scratch
## Objectives
- Implement a tiny CNN and a tiny ViT step-by-step.
- Track tensor shapes at every stage.
- Visualize filters, feature maps, and attention.


In [None]:
!pip -q install torch torchvision matplotlib

## A) Setup + CIFAR-10 data

In [None]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)

train_subset = Subset(train_ds, range(2000))
test_subset = Subset(test_ds, range(500))

train_loader = DataLoader(train_subset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=64, shuffle=False)


In [None]:
classes = train_ds.classes
images, labels = next(iter(train_loader))
fig, axes = plt.subplots(2, 5, figsize=(8, 4))
for i, ax in enumerate(axes.flat):
    img = images[i].permute(1, 2, 0)
    ax.imshow(img)
    ax.set_title(classes[labels[i]])
    ax.axis('off')
plt.suptitle('CIFAR-10 Batch')
plt.tight_layout()
plt.show()


## B) Build a tiny CNN step-by-step

We start with a single convolution, then add pooling, another block, and a classifier.

In [None]:
import torch.nn as nn

def shape_summary(model, x):
    print('Input:', x.shape)
    for name, layer in model.named_children():
        x = layer(x)
        print(f'{name}: {x.shape}')
    return x


In [None]:
# Step 1: Conv + ReLU
cnn_step1 = nn.Sequential(
    nn.Conv2d(3, 8, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
)
shape_summary(cnn_step1, images[:1])


In [None]:
# Step 2: Add pooling
cnn_step2 = nn.Sequential(
    nn.Conv2d(3, 8, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2),
)
shape_summary(cnn_step2, images[:1])


In [None]:
# Step 3: Add a second conv block
cnn_step3 = nn.Sequential(
    nn.Conv2d(3, 8, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2),
    nn.Conv2d(8, 16, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2),
)
shape_summary(cnn_step3, images[:1])


In [None]:
# Step 4: Add classifier
cnn = nn.Sequential(
    nn.Conv2d(3, 8, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2),
    nn.Conv2d(8, 16, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(16 * 8 * 8, 10),
)
shape_summary(cnn, images[:1])


### Train for 1 epoch

In [None]:
cnn = cnn.to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
train_losses = []

cnn.train()
for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    logits = cnn(images)
    loss = loss_fn(logits, labels)
    loss.backward()
    optimizer.step()
    train_losses.append(loss.item())

plt.figure(figsize=(5, 3))
plt.plot(train_losses)
plt.title('CNN Training Loss (1 epoch)')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.show()


### Visualize filters + feature maps

In [None]:
weights = cnn[0].weight.data.clone()
weights = (weights - weights.min()) / (weights.max() - weights.min() + 1e-8)
fig, axes = plt.subplots(2, 4, figsize=(6, 3))
for i, ax in enumerate(axes.flat):
    ax.imshow(weights[i].permute(1, 2, 0).cpu())
    ax.axis('off')
plt.suptitle('First-layer filters')
plt.tight_layout()
plt.show()


In [None]:
# Feature maps from first conv layer
with torch.no_grad():
    feat_maps = cnn[0](images[:1].to(device)).cpu()
fig, axes = plt.subplots(2, 4, figsize=(6, 3))
for i, ax in enumerate(axes.flat):
    ax.imshow(feat_maps[0, i], cmap='viridis')
    ax.axis('off')
plt.suptitle('Feature maps (layer 1)')
plt.tight_layout()
plt.show()


## C) Build a tiny ViT step-by-step

In [None]:
import math

class Patchify(nn.Module):
    def __init__(self, patch_size=4):
        super().__init__()
        self.patch_size = patch_size

    def forward(self, x):
        b, c, h, w = x.shape
        p = self.patch_size
        x = x.unfold(2, p, p).unfold(3, p, p)
        x = x.permute(0, 2, 3, 1, 4, 5)
        x = x.reshape(b, -1, c * p * p)
        return x

patchify = Patchify(patch_size=4)
patches = patchify(images[:1])
print('Patches shape:', patches.shape)


In [None]:
class TinyMHSA(nn.Module):
    def __init__(self, embed_dim=64, num_heads=4):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        b, n, d = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = attn @ v
        out = out.transpose(1, 2).reshape(b, n, d)
        return self.proj(out), attn


In [None]:
class TinyTransformer(nn.Module):
    def __init__(self, embed_dim=64, num_heads=4, mlp_ratio=2.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = TinyMHSA(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
        )

    def forward(self, x):
        attn_out, attn = self.attn(self.norm1(x))
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x, attn


In [None]:
class ToyViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, embed_dim=64, depth=2, num_heads=4, num_classes=10):
        super().__init__()
        self.patchify = Patchify(patch_size)
        self.proj = nn.Linear(3 * patch_size * patch_size, embed_dim)
        num_patches = (img_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.blocks = nn.ModuleList([TinyTransformer(embed_dim, num_heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        patches = self.patchify(x)
        x = self.proj(patches)
        cls = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        attn_maps = []
        for block in self.blocks:
            x, attn = block(x)
            attn_maps.append(attn)
        x = self.norm(x)
        logits = self.head(x[:, 0])
        return logits, attn_maps


### Train the toy ViT for 1 epoch

In [None]:
vit = ToyViT().to(device)
optimizer = torch.optim.Adam(vit.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()
vit_losses = []
vit.train()
for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    logits, _ = vit(images)
    loss = loss_fn(logits, labels)
    loss.backward()
    optimizer.step()
    vit_losses.append(loss.item())

plt.figure(figsize=(5, 3))
plt.plot(vit_losses)
plt.title('Toy ViT Training Loss (1 epoch)')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.show()


## D) Attention visualization

In [None]:
vit.eval()
with torch.no_grad():
    logits, attn_maps = vit(images[:1].to(device))
last_attn = attn_maps[-1].mean(dim=1)[0]  # average heads
attn_to_patches = last_attn[0, 1:]
side = int(attn_to_patches.numel() ** 0.5)
attn_map = attn_to_patches.reshape(side, side).cpu().numpy()
plt.figure(figsize=(4, 4))
plt.imshow(attn_map, cmap='inferno')
plt.title('Attention over patches')
plt.axis('off')
plt.show()


In [None]:
import torch.nn.functional as F
img = images[0].permute(1, 2, 0).cpu()
attn_tensor = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)
attn_up = F.interpolate(attn_tensor, size=(32, 32), mode='bilinear', align_corners=False)[0, 0]
plt.figure(figsize=(4, 4))
plt.imshow(img)
plt.imshow(attn_up, cmap='inferno', alpha=0.5)
plt.title('Attention overlay')
plt.axis('off')
plt.show()


## E) Wrap-up

### Scale Up
- Train for 10-50 epochs and increase embedding size for better results.
- Add data augmentation like random crop and color jitter.


### Summary
- CNNs bake in locality and translation equivariance.
- ViTs treat images as token sequences with global attention.
- The toy ViT is data-hungry and needs more scale.

### Exercises
1. Change patch size from 4 to 8 and observe shapes and accuracy.
2. Increase the number of heads or depth.
3. Add data augmentation and compare losses.
4. Remove LayerNorm and observe stability.

### Further Reading
- https://arxiv.org/abs/2010.11929 (ViT)
- https://pytorch.org/tutorials/
