# Q1 — Vision Transformer on CIFAR-10 (PyTorch)

## Goal
Implement a Vision Transformer (ViT) and train it on **CIFAR-10 (10 classes)**.  
Your objective is to achieve the **highest possible test accuracy**.  

You are free to experiment with improvements and tricks to push performance further.  
**Note:** You must use **Google Colab** for implementation.  

📄 Reference Paper:  
*Dosovitskiy et al., "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", ICLR 2021*  
[Paper Link](https://arxiv.org/abs/2010.11929)

---

## Requirements
- Patchify images  
- Add learnable positional embeddings  
- Prepend a **CLS token**  
- Stack Transformer encoder blocks:  
  - Multi-Head Self Attention (MHSA)  
  - MLP with residual connections + normalization  
- Classify from the **CLS token**  

---


## Bonus (Optional Analysis)
You can earn bonus marks by including a **concise analysis**. Keep it **short and crisp**.  
Some examples of analysis directions:
- Choice of patch size  
- Depth vs. width trade-offs  
- Effect of data augmentation  
- Optimizer and learning schedule variants  
- Overlapping vs. non-overlapping patches  

This analysis should also be part of your `README.md`.  

---


In [1]:

# --- SETUP ---
!pip install timm torchsummary

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import timm
from timm.loss import LabelSmoothingCrossEntropy
from timm.utils import ModelEmaV2




In [2]:

# --- CONFIG ---
class CFG:
    model_name = "vit_base_patch16_224"  # pretrained ViT
    img_size = 224
    batch_size = 64
    epochs = 30
    lr = 5e-5
    weight_decay = 0.05
    num_classes = 10
    smoothing = 0.1
    device = "cuda" if torch.cuda.is_available() else "cpu"

cfg = CFG()

In [3]:
# --- DATASET & AUGMENTATIONS ---
train_transform = transforms.Compose([
    transforms.Resize((cfg.img_size, cfg.img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [4]:
val_transform = transforms.Compose([
    transforms.Resize((cfg.img_size, cfg.img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])


In [5]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = DataLoader(trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=2)


100%|██████████| 170M/170M [00:24<00:00, 7.00MB/s]


In [6]:
valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform)
valloader = DataLoader(valset, batch_size=cfg.batch_size, shuffle=False, num_workers=2)


In [7]:
# --- MODEL ---
model = timm.create_model(cfg.model_name, pretrained=True, num_classes=cfg.num_classes)
model.to(cfg.device)


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False

In [8]:

# --- LOSS & OPTIMIZER ---
criterion = LabelSmoothingCrossEntropy(smoothing=cfg.smoothing)
optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs)


In [9]:
# EMA for stability
ema = ModelEmaV2(model, decay=0.999)


In [10]:
# --- TRAIN & EVAL ---
def train_one_epoch(model, loader, optimizer, criterion, device, ema):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for i, (images, targets) in enumerate(loader):
        images, targets = images.to(device), targets.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        ema.update(model)

        total_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if (i + 1) % 100 == 0: # Print every 100 batches
            print(f"  Batch {i+1}/{len(loader)} | Loss: {loss.item():.4f} | Acc: {100.*correct/total:.2f}%")

    return total_loss/total, 100.*correct/total

In [17]:
import time

In [18]:
# --- TRAIN & EVAL ---
def train_one_epoch(model, loader, optimizer, criterion, device, ema, log_interval=50):
    model.train()
    total_loss, correct, total = 0, 0, 0

    start_epoch_time = time.time()
    for batch_idx, (images, targets) in enumerate(loader):
        batch_start = time.time()

        # Data transfer
        data_start = time.time()
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        data_time = time.time() - data_start

        optimizer.zero_grad()

        # Forward pass
        fwd_start = time.time()
        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, targets)
        fwd_time = time.time() - fwd_start

        # Backward pass
        bwd_start = time.time()
        loss.backward()
        optimizer.step()
        ema.update(model)
        bwd_time = time.time() - bwd_start

        # Metrics
        total_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        batch_time = time.time() - batch_start

        # Logging every few batches
        if (batch_idx + 1) % log_interval == 0:
            print(f"Batch {batch_idx+1}/{len(loader)} | "
                  f"Loss: {loss.item():.4f} | "
                  f"Data: {data_time:.3f}s | "
                  f"Fwd: {fwd_time:.3f}s | "
                  f"Bwd: {bwd_time:.3f}s | "
                  f"Batch: {batch_time:.3f}s | "
                  f"Acc so far: {100.*correct/total:.2f}%")

    epoch_time = time.time() - start_epoch_time
    print(f"Epoch finished in {epoch_time:.2f}s")
    return total_loss/total, 100.*correct/total


def evaluate(model, loader, criterion, device, log_interval=50):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    start_eval_time = time.time()
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(loader):
            images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            outputs = model(images)
            loss = criterion(outputs, targets)

            total_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if (batch_idx + 1) % log_interval == 0:
                print(f"[Eval] Batch {batch_idx+1}/{len(loader)} | "
                      f"Loss: {loss.item():.4f} | "
                      f"Acc so far: {100.*correct/total:.2f}%")

    eval_time = time.time() - start_eval_time
    print(f"Evaluation finished in {eval_time:.2f}s")
    return total_loss/total, 100.*correct/total


In [None]:

# --- TRAINING LOOP ---
for epoch in range(cfg.epochs):
    print(f"\n--- Epoch {epoch+1}/{cfg.epochs} ---")
    train_loss, train_acc = train_one_epoch(model, trainloader, optimizer, criterion, cfg.device, ema)
    val_loss, val_acc = evaluate(model, valloader, criterion, cfg.device)
    scheduler.step()
    print(f"Epoch {epoch+1}/{cfg.epochs} | "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")


--- Epoch 1/30 ---


  with torch.cuda.amp.autocast():


Batch 50/782 | Loss: 0.5728 | Data: 0.012s | Fwd: 0.016s | Bwd: 0.202s | Batch: 0.485s | Acc so far: 97.34%
Batch 100/782 | Loss: 0.5183 | Data: 0.012s | Fwd: 0.015s | Bwd: 0.197s | Batch: 0.469s | Acc so far: 97.05%
Batch 150/782 | Loss: 0.5460 | Data: 0.011s | Fwd: 0.015s | Bwd: 0.200s | Batch: 0.471s | Acc so far: 96.74%
Batch 200/782 | Loss: 0.5825 | Data: 0.012s | Fwd: 0.015s | Bwd: 0.210s | Batch: 0.483s | Acc so far: 96.73%
Batch 250/782 | Loss: 0.5986 | Data: 0.013s | Fwd: 0.018s | Bwd: 0.198s | Batch: 0.478s | Acc so far: 96.62%
Batch 300/782 | Loss: 0.6176 | Data: 0.012s | Fwd: 0.016s | Bwd: 0.199s | Batch: 0.472s | Acc so far: 96.60%
Batch 350/782 | Loss: 0.5934 | Data: 0.012s | Fwd: 0.016s | Bwd: 0.199s | Batch: 0.473s | Acc so far: 96.49%
Batch 400/782 | Loss: 0.6353 | Data: 0.020s | Fwd: 0.014s | Bwd: 0.203s | Batch: 0.482s | Acc so far: 96.48%
Batch 450/782 | Loss: 0.5934 | Data: 0.023s | Fwd: 0.026s | Bwd: 0.191s | Batch: 0.485s | Acc so far: 96.43%
Batch 500/782 | Loss

In [None]:

# Save best model
torch.save(model.state_dict(), "vit_cifar10.pth")

In [14]:
                                                                                                                                         # ### Colab Notebook: Vision Transformer (ViT) Fine-Tuning on CIFAR-10
# # Goal: Highest accuracy possible with minimal compute (using pretrained ViT)

# # --- SETUP ---
# !pip install timm torchsummary

# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import DataLoader
# import torchvision
# import torchvision.transforms as transforms
# import timm
# from timm.loss import LabelSmoothingCrossEntropy
# from timm.utils import ModelEmaV2

# # --- CONFIG ---
# class CFG:
#     model_name = "vit_base_patch16_224"  # pretrained ViT
#     img_size = 224
#     batch_size = 64
#     epochs = 30
#     lr = 5e-5
#     weight_decay = 0.05
#     num_classes = 10
#     smoothing = 0.1
#     device = "cuda" if torch.cuda.is_available() else "cpu"

# cfg = CFG()

# # --- DATASET & AUGMENTATIONS ---
# train_transform = transforms.Compose([
#     transforms.Resize((cfg.img_size, cfg.img_size)),
#     transforms.RandomHorizontalFlip(),
#     transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ])

# val_transform = transforms.Compose([
#     transforms.Resize((cfg.img_size, cfg.img_size)),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ])

# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
# trainloader = DataLoader(trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=2)

# valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform)
# valloader = DataLoader(valset, batch_size=cfg.batch_size, shuffle=False, num_workers=2)

# # --- MODEL ---
# model = timm.create_model(cfg.model_name, pretrained=True, num_classes=cfg.num_classes)
# model.to(cfg.device)

# # --- LOSS & OPTIMIZER ---
# criterion = LabelSmoothingCrossEntropy(smoothing=cfg.smoothing)
# optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs)

# # EMA for stability
# ema = ModelEmaV2(model, decay=0.999)

# # --- TRAIN & EVAL ---
# def train_one_epoch(model, loader, optimizer, criterion, device, ema):
#     model.train()
#     total_loss, correct, total = 0, 0, 0
#     for images, targets in loader:
#         images, targets = images.to(device), targets.to(device)
#         optimizer.zero_grad()
#         with torch.cuda.amp.autocast():
#             outputs = model(images)
#             loss = criterion(outputs, targets)
#         loss.backward()
#         optimizer.step()
#         ema.update(model)

#         total_loss += loss.item() * images.size(0)
#         _, predicted = outputs.max(1)
#         total += targets.size(0)
#         correct += predicted.eq(targets).sum().item()
#     return total_loss/total, 100.*correct/total


# def evaluate(model, loader, criterion, device):
#     model.eval()
#     total_loss, correct, total = 0, 0, 0
#     with torch.no_grad():
#         for images, targets in loader:
#             images, targets = images.to(device), targets.to(device)
#             outputs = model(images)
#             loss = criterion(outputs, targets)

#             total_loss += loss.item() * images.size(0)
#             _, predicted = outputs.max(1)
#             total += targets.size(0)
#             correct += predicted.eq(targets).sum().item()
#     return total_loss/total, 100.*correct/total

# # --- TRAINING LOOP ---
# for epoch in range(cfg.epochs):
#     train_loss, train_acc = train_one_epoch(model, trainloader, optimizer, criterion, cfg.device, ema)
#     val_loss, val_acc = evaluate(model, valloader, criterion, cfg.device)
#     scheduler.step()
#     print(f"Epoch {epoch+1}/{cfg.epochs} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

# # Save best model
# torch.save(model.state_dict(), "vit_cifar10.pth")




  with torch.cuda.amp.autocast():


Epoch 1/30 | Train Acc: 93.84% | Val Acc: 96.98%


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


KeyboardInterrupt: 