In [None]:
# ==== CELL 1: Notebook header / short notes =====
"""
ViT on CIFAR-10 (final ready-to-run).
Implements patchify, learnable positional embeddings, CLS token, Transformer encoder blocks (MHSA + MLP + residual + norm).
Uses AdamW + cosine LR schedule (with warmup), RandAugment, MixUp support.
At the end prints BEST_TEST_ACCURACY: XX.XX% and writes /content/best_accuracy.txt for easy copy.
Edit CONFIG for quicker runs (epochs, batch_size).
"""

In [None]:
#===== CELL 2: Install + Imports =====
!pip -q install timm
import os, math, random, time
from pathlib import Path
from tqdm.notebook import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets
import timm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

In [None]:
#===== CELL 3: CONFIG =====
CONFIG = {
"seed": 42,
"batch_size": 128,
"epochs": 80, # reduce to 20-30 for quick test
"lr": 3e-4,
"weight_decay": 0.05,
"image_size": 32,
"patch_size": 4,
"embed_dim": 256,
"depth": 8,
"num_heads": 8,
"mlp_ratio": 4.0,
"drop_rate": 0.0,
"drop_path_rate": 0.1,
"mixup_alpha": 0.8,
"label_smoothing": 0.1,
"num_workers": 4,
"use_mixup": True,
"randaugment": True,
"gradient_clip": 1.0,
"out_dir": "/content/vit_cifar10",
"save_every": 10
}
os.makedirs(CONFIG["out_dir"], exist_ok=True)

In [None]:
#===== CELL 4: Reproducibility =====
def set_seed(seed=42):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  set_seed(CONFIG["seed"])

In [None]:
#===== CELL 5: Datasets & Transforms =====
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)

train_transforms = []

if CONFIG["randaugment"]:
  train_transforms += [
      transforms.RandomHorizontalFlip(),
      transforms.RandomCrop(CONFIG["image_size"], padding=4),
      transforms.RandAugment(num_ops=2, magnitude=9)
      ]
else:
  train_transforms += [
      transforms.RandomHorizontalFlip(),
      transforms.RandomCrop(CONFIG["image_size"], padding=4)
      ]

train_transforms += [transforms.ToTensor(), transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)]
train_transform = transforms.Compose(train_transforms)
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)
])

train_dataset = datasets.CIFAR10(root="/content/data", train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root="/content/data", train=False, download=True, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=CONFIG["num_workers"], pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=CONFIG["num_workers"], pin_memory=True)

In [None]:
#===== CELL 6: MixUp + Loss =====
def mixup_data(x, y, alpha=1.0):
    if alpha <= 0:
        return x, y, 1.0, None
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.0):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, preds, target):
        log_probs = F.log_softmax(preds, dim=-1)
        n_classes = preds.size(-1)
        if self.smoothing > 0:
            with torch.no_grad():
                true_dist = torch.zeros_like(log_probs)
                true_dist.fill_(self.smoothing / (n_classes - 1))
                true_dist.scatter_(1, target.data.unsqueeze(1), 1.0 - self.smoothing)
            return torch.mean(torch.sum(-true_dist * log_probs, dim=-1))
        else:
            return F.cross_entropy(preds, target)


In [None]:
#===== CELL 7: ViT model (from-scratch) =====
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        assert img_size % patch_size == 0
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.0):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, drop=0.0, attn_drop=0.0, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = nn.Identity() if drop_path == 0.0 else nn.Dropout(drop_path)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class ViTLike(nn.Module):
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_chans=3,
        num_classes=10,
        embed_dim=256,
        depth=8,
        num_heads=8,
        mlp_ratio=4.0,
        drop_rate=0.0,
        drop_path_rate=0.1,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList(
            [
                Block(embed_dim, num_heads, mlp_ratio, drop_rate, attn_drop=0.0, drop_path=dpr[i])
                for i in range(depth)
            ]
        )
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls = x[:, 0]
        logits = self.head(cls)
        return logits


In [None]:
#=====Build model,optimizer,scheduler,criterion =====

model=ViTLike(
    img_size=CONFIG["image_size"],
    patch_size=CONFIG["patch_size"],
    embed_dim=CONFIG["embed_dim"],
    depth=CONFIG["depth"],
    num_heads=CONFIG["num_heads"],
    mlp_ratio=CONFIG["mlp_ratio"],
    drop_rate=CONFIG["drop_rate"],
    drop_path_rate=CONFIG["drop_path_rate"]
).to(device)

print(model)


def count_parameters(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)


print("Trainable params:",count_parameters(model))

opt=torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG["lr"],
    weight_decay=CONFIG["weight_decay"]
)


def get_scheduler(optimizer,epochs,warmup=5):
    def lr_lambda(step):
        if step < warmup:
            return float(step)/float(max(1,warmup))
        else:
            progress=float(step-warmup)/float(max(1,epochs-warmup))
            return 0.5 * (1.0 + math.cos(math.pi*progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda)


scheduler=get_scheduler(opt,CONFIG["epochs"],warmup=5)
criterion=LabelSmoothingCrossEntropy(smoothing=CONFIG["label_smoothing"])


In [None]:
#=====Train/Evaluation functions =====

def evaluate(model,loader):
    model.eval()
    correct=0
    total=0
    losses=[]
    with torch.no_grad():
        for x,y in loader:
            x=x.to(device)
            y=y.to(device)
            logits=model(x)
            loss=criterion(logits, y)
            losses.append(loss.item())
            preds=logits.argmax(dim=-1)
            correct+=(preds==y).sum().item()
            total+=y.size(0)

    return np.mean(losses),correct/total


def train_one_epoch(model,loader,optimizer,epoch,cfg):
    model.train()
    running_loss=0.0
    total=0
    correct=0
    pbar=tqdm(loader)

    for xb,yb in pbar:
        xb=xb.to(device)
        yb=yb.to(device)

        if cfg["use_mixup"]:
            xb,y_a,y_b,lam=mixup_data(xb,yb,alpha=cfg["mixup_alpha"])
            logits=model(xb)
            loss=lam*criterion(logits,y_a)+(1-lam)*criterion(logits,y_b)
        else:
            logits=model(xb)
            loss=criterion(logits, yb)

        optimizer.zero_grad()
        loss.backward()

        if cfg["gradient_clip"] is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["gradient_clip"])

        optimizer.step()

        running_loss+=loss.item() * xb.size(0)
        preds=logits.argmax(dim=-1)

        correct+=(preds==yb).sum().item()
        total+=xb.size(0)

        pbar.set_description(f"Epoch {epoch} loss: {running_loss/total:.4f} acc: {correct/total:.4f}")

    return running_loss/total


In [None]:
#=====Training loop=====
best_acc=0.0
save_path=os.path.join(CONFIG["out_dir"], "best_vit.pth")
history={"train_loss": [], "val_loss": [], "val_acc": []}
start_time=time.time()

for epoch in range(1,CONFIG["epochs"]+1):
    train_loss=train_one_epoch(model,train_loader,opt,epoch,CONFIG)
    val_loss,val_acc=evaluate(model,test_loader)
    scheduler.step()

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    print(f"Epoch {epoch}: train_loss {train_loss:.4f} val_loss {val_loss:.4f} val_acc {val_acc:.4f}")

    if val_acc>best_acc:
        best_acc=val_acc
        torch.save(
            {
                "model_state":model.state_dict(),
                "opt_state": opt.state_dict(),
                "epoch": epoch,
                "val_acc": val_acc
            },
            save_path
        )
        print("Saved best model:",save_path)

    if epoch % CONFIG["save_every"]==0:
        torch.save(
            {
                "model_state": model.state_dict(),
                "opt_state": opt.state_dict(),
                "epoch": epoch
            },
            os.path.join(CONFIG["out_dir"],f"checkpoint_{epoch}.pth")
        )

elapsed=time.time()-start_time
print("Training complete. Best val acc:",best_acc,"Elapsed (s):",int(elapsed))


In [None]:
#=====Report & save best accuracy =====
best_acc_pct = best_acc * 100 if best_acc <= 1 else best_acc
print(f"BEST_TEST_ACCURACY: {best_acc_pct:.2f}%")

with open("/content/best_accuracy.txt", "w") as f:
    f.write(f"{best_acc_pct:.2f}%\n")

print("Best accuracy written to /content/best_accuracy.txt")


In [None]:
#=====Plot quick curves=====
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.plot(history["train_loss"],label="train_loss")
plt.plot(history["val_loss"],label="val_loss")
plt.legend()
plt.title("Loss")

plt.subplot(1,2,2)
plt.plot(history["val_acc"],label="val_acc")
plt.legend()
plt.title("Validation Accuracy")

plt.show()
