# Q1 — Vision Transformer on CIFAR-10 (PyTorch)
This notebook implements a Vision Transformer (ViT) from scratch (PyTorch) and trains it on CIFAR-10.

**Usage (Colab):**
1. Runtime > Change runtime type > GPU
2. Run all cells. The install cell will install required packages.

The notebook includes:
- Patch embedding, learnable positional embeddings, CLS token
- Transformer encoder blocks (MHSA + MLP + residual + LayerNorm)
- Training loop, evaluation and simple augmentation

Adjust hyperparameters in the `Config` cell for better accuracy.

In [None]:
# Install dependencies (Colab)
!pip install -q torch torchvision einops tqdm timm
import torch
import torchvision
print('torch', torch.__version__)

In [None]:
# Config
from types import SimpleNamespace
cfg = SimpleNamespace(
    img_size=32,
    patch_size=4,
    in_channels=3,
    num_classes=10,
    emb_dim=192,
    depth=6,
    num_heads=3,
    mlp_ratio=4,
    dropout=0.1,
    lr=3e-4,
    batch_size=128,
    epochs=20,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)
cfg

In [None]:
# Model implementation (compact)
import torch.nn as nn
from einops import rearrange
import math

class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        x = self.proj(x)  # (B, E, H/ps, W/ps)
        x = x.flatten(2).transpose(1,2)  # (B, N, E)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, dropout=0.):
        super().__init__()
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_dropout=0., proj_dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_dropout)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1,2).reshape(B,N,C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., dropout=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, attn_dropout=0., proj_dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim*mlp_ratio), dropout=dropout)
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class ViT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.patch_embed = PatchEmbed(cfg.img_size, cfg.patch_size, cfg.in_channels, cfg.emb_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1,1,cfg.emb_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, cfg.emb_dim))
        self.pos_drop = nn.Dropout(cfg.dropout)
        self.blocks = nn.ModuleList([Block(cfg.emb_dim, cfg.num_heads, cfg.mlp_ratio, cfg.dropout) for _ in range(cfg.depth)])
        self.norm = nn.LayerNorm(cfg.emb_dim)
        self.head = nn.Linear(cfg.emb_dim, cfg.num_classes)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return self.head(x[:,0])


In [None]:
# Data, transforms, dataloaders
import torchvision.transforms as T
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

transform_train = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))
])
transform_test = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))
])

trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = CIFAR10(root='./data', train=False, download=True, transform=transform_test)
trainloader = DataLoader(trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=2)


In [None]:
# Training utilities
import torch.optim as optim
from tqdm import tqdm

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds==y).sum().item()
            total += y.size(0)
    return 100*correct/total

device = cfg.device
model = ViT(cfg).to(device)
opt = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=0.05)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)
criterion = nn.CrossEntropyLoss()
best_acc = 0.0

for epoch in range(cfg.epochs):
    model.train()
    pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{cfg.epochs}')
    running_loss = 0.0
    for xb, yb in pbar:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        opt.step()
        running_loss += loss.item()
        pbar.set_postfix(loss=running_loss/ (pbar.n+1))
    scheduler.step()
    acc = evaluate(model, testloader, device)
    print(f'Validation accuracy: {acc:.2f}%')
    if acc>best_acc:
        best_acc = acc
        torch.save(model.state_dict(), 'best_vit_cifar10.pth')
    print(f'Best so far: {best_acc:.2f}%')

print('Training finished. Best accuracy:', best_acc)
