In [None]:

# Cell 0: installs & imports
!pip install -q einops

import os
import random
import time
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from einops import rearrange


In [None]:

# Cell 1: config, reproducibility, device
cfg = {
    "seed": 42,
    "epochs": 12,
    "batch_size": 128,
    "lr": 3e-4,
    "weight_decay": 1e-4,
    "patch_size": 4,
    "embed_dim": 128,
    "depth": 6,
    "num_heads": 4,
    "mlp_ratio": 4.0,
    "dropout": 0.1,
    "num_classes": 10
}

random.seed(cfg["seed"])
np.random.seed(cfg["seed"])
torch.manual_seed(cfg["seed"])
torch.cuda.manual_seed_all(cfg["seed"])
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


In [None]:

# Cell 2: CIFAR-10 data loaders
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

data_root = "./data"

train_set = datasets.CIFAR10(root=data_root, train=True, download=True, transform=train_transform)
test_set = datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_set, batch_size=cfg["batch_size"], shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=cfg["batch_size"], shuffle=False, num_workers=2, pin_memory=True)

print("Train size:", len(train_set), "Test size:", len(test_set))


In [None]:

# Cell 3: ViT implementation
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=128):
        super().__init__()
        assert img_size % patch_size == 0
        self.patch_size = patch_size
        self.n_patches_side = img_size // patch_size
        self.n_patches = self.n_patches_side * self.n_patches_side
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm.permute(1,0,2), x_norm.permute(1,0,2), x_norm.permute(1,0,2))
        attn_out = attn_out.permute(1,0,2)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x

class SimpleViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10, embed_dim=128, depth=6, num_heads=4, mlp_ratio=4., dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        n_patches = self.patch_embed.n_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

        mlp_dim = int(embed_dim * mlp_ratio)
        self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls_out = x[:, 0]
        return self.head(cls_out)


In [None]:

# Cell 4: instantiate model + optimizer
model = SimpleViT(img_size=32, patch_size=cfg["patch_size"], embed_dim=cfg["embed_dim"],
                  depth=cfg["depth"], num_heads=cfg["num_heads"], mlp_ratio=cfg["mlp_ratio"],
                  num_classes=cfg["num_classes"]).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])

def accuracy_from_logits(logits, labels):
    preds = logits.argmax(dim=1)
    return (preds == labels).float().mean().item()


In [None]:

# Cell 5: training loop
best_acc = 0.0
best_epoch = -1
save_path = "vit_cifar10_best.pth"

for epoch in range(1, cfg["epochs"] + 1):
    model.train()
    running_loss, running_acc = 0.0, 0.0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']} - train", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        running_acc += accuracy_from_logits(logits, labels) * imgs.size(0)

    epoch_loss = running_loss / len(train_set)
    epoch_acc = running_acc / len(train_set)

    # validation
    model.eval()
    test_loss, test_acc = 0.0, 0.0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            loss = criterion(logits, labels)
            test_loss += loss.item() * imgs.size(0)
            test_acc += accuracy_from_logits(logits, labels) * imgs.size(0)

    test_loss /= len(test_set)
    test_acc /= len(test_set)
    scheduler.step()

    print(f"Epoch {epoch:02d} | Train acc: {epoch_acc*100:.2f}% | Test acc: {test_acc*100:.2f}%")

    if test_acc > best_acc:
        best_acc = test_acc
        best_epoch = epoch
        torch.save(model.state_dict(), save_path)
        print(f"New best model saved (epoch {epoch}) -> {best_acc*100:.2f}%")


In [None]:

# Cell 6: final evaluation
model.load_state_dict(torch.load(save_path, map_location=device))
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
final_acc = 100.0 * correct / total
print(f"Final Test Accuracy: {final_acc:.2f}%")
