# Training a Vision Transformer (ViT) from Scratch on CIFAR-10

This project is an implementation of a Vision Transformer (ViT), based on the paper **"An Image is Worth 16x16 Words" (Dosovitskiy et al., 2021)**.

The original ViT was designed for large-scale pre-training on datasets like JFT-300M. This notebook demonstrates a modern version that can be **trained from scratch** on a small dataset (CIFAR-10) to high accuracy **(>90%)**.

This is achieved by incorporating several SOTA improvements to the architecture and training process, including:
* **Model Stability:** LayerScale and Pre-Normalization
* **Regularization:** Stochastic Depth (`DropPath`), `AdamW` Optimizer, and `LabelSmoothingLoss`
* **Data Augmentation:** `MixUp` and `CutMix`
* **LR Schedule:** `CosineAnnealingLR` with `Warmup`

This notebook analyzes the results of this improved training process. The core model architecture and helper functions are modularized in the `src/` directory.

## Environment Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

import random
import os
from tqdm import tqdm
import time
import math
from typing import Optional, Tuple

# --- Import our refactored code ---
# We add '..' to the path to go up one level from /notebooks to the root
import sys
sys.path.append('..') 

from src.model import VisionTransformer
from src.utils import (
    set_seed, get_dataloaders, CutMix, MixUp, 
    LabelSmoothingLoss, WarmupCosineScheduler, mixup_criterion
)
from src.engine import train_epoch, evaluate

# --- Setup Device and Seed ---
set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

## 2. Configuration & Model Initialization

In [2]:
# Best configuration found through experimentation
config = {
    'image_size': 32,
    'patch_size': 4,  # 4x4 patches work well for CIFAR-10
    'embed_dim': 384,  # Increased embedding dimension
    'depth': 12,  # 12 transformer blocks
    'num_heads': 6,  # 6 attention heads
    'mlp_ratio': 4.0,  # MLP hidden dim = 4 * embed_dim
    'dropout': 0.1,
    'attn_dropout': 0.0,
    'drop_path_rate': 0.15,  # Stochastic depth
    'layer_scale_init': 1e-4,
    'use_conv_patch': True,

    # Training hyperparameters
    'epochs': 200,
    'batch_size': 128,
    'base_lr': 1e-3,
    'weight_decay': 0.05,
    'warmup_epochs': 20,
    'min_lr': 1e-5,

    # Augmentation
    'mixup_alpha': 0.8,
    'cutmix_alpha': 1.0,
    'mixup_prob': 0.5,
    'cutmix_prob': 0.5,
    'label_smoothing': 0.1,
}

print("Configuration:")
for key, value in config.items():
    print(f"{key}: {value}")

Configuration:
image_size: 32
patch_size: 4
embed_dim: 384
depth: 12
num_heads: 6
mlp_ratio: 4.0
dropout: 0.1
attn_dropout: 0.0
drop_path_rate: 0.15
layer_scale_init: 0.0001
use_conv_patch: True
epochs: 200
batch_size: 128
base_lr: 0.001
weight_decay: 0.05
warmup_epochs: 20
min_lr: 1e-05
mixup_alpha: 0.8
cutmix_alpha: 1.0
mixup_prob: 0.5
cutmix_prob: 0.5
label_smoothing: 0.1


In [3]:
# Initialize model
model = VisionTransformer(
    image_size=config['image_size'],
    patch_size=config['patch_size'],
    embed_dim=config['embed_dim'],
    depth=config['depth'],
    num_heads=config['num_heads'],
    mlp_ratio=config['mlp_ratio'],
    dropout=config['dropout'],
    attn_dropout=config['attn_dropout'],
    drop_path_rate=config['drop_path_rate'],
    layer_scale_init=config['layer_scale_init'],
    use_conv_patch=config['use_conv_patch']
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Total parameters: 21,423,562
Trainable parameters: 21,423,562


## Training Loop


The model was trained for 200 epochs as defined in the configuration. The full training log is preserved in the output of the cell below.

**This 2-3 hour training step is not re-run.** The setup and loop (the next two code cells) are commented out. We will instead load the final `best_vit_cifar10.pth` checkpoint in the next section for analysis.

In [None]:
# Setup training
criterion = LabelSmoothingLoss(num_classes=10, smoothing=config['label_smoothing'])
optimizer = optim.AdamW(model.parameters(), 
                        lr=config['base_lr'],
                        weight_decay=config['weight_decay'],
                        betas=(0.9, 0.999))

# Learning rate scheduler
scheduler = WarmupCosineScheduler(
    optimizer,
    warmup_epochs=config['warmup_epochs'],
    total_epochs=config['epochs'],
    base_lr=config['base_lr'],
    min_lr=config['min_lr']
)

# Data augmentation
mixup_fn = MixUp(alpha=config['mixup_alpha'], prob=config['mixup_prob'])
cutmix_fn = CutMix(alpha=config['cutmix_alpha'], prob=config['cutmix_prob'])

# Training history
train_losses = []
train_accs = []
test_losses = []
test_accs = []
best_acc = 0

print(f"\nStarting training for {config['epochs']} epochs...")
print("="*60)


In [None]:
# Main training loop
for epoch in range(config['epochs']):
    # Update learning rate
    current_lr = scheduler.step(epoch)
    
    # Train
    train_loss, train_acc = train_epoch(
        model, trainloader, criterion, optimizer, device, mixup_fn, cutmix_fn
    )
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # Evaluate
    test_loss, test_acc = evaluate(model, testloader, criterion, device)
    test_losses.append(test_loss)
    test_accs.append(test_acc)
    
    # Save best model
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'test_acc': test_acc,
            'config': config,
        }, 'best_vit_cifar10.pth')
    
    # Print progress
    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch [{epoch+1}/{config['epochs']}]")
        print(f"  LR: {current_lr:.6f}")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
        print(f"  Best Test Acc: {best_acc:.2f}%")
        print("-"*40)

print("\n" + "="*60)
print(f"Training completed! Best Test Accuracy: {best_acc:.2f}%")
print("="*60)


## 4. Analysis: Training Curves & Final Results

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss plot
ax1.plot(train_losses, label='Train Loss', linewidth=2)
ax1.plot(test_losses, label='Test Loss', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training and Test Loss', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Accuracy plot
ax2.plot(train_accs, label='Train Accuracy', linewidth=2)
ax2.plot(test_accs, label='Test Accuracy', linewidth=2)
ax2.axhline(y=best_acc, color='r', linestyle='--', label=f'Best: {best_acc:.2f}%', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('Training and Test Accuracy', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"Best Test Accuracy: {best_acc:.2f}%")
print(f"Final Test Accuracy: {test_accs[-1]:.2f}%")

### Per-Class Accuracy

The code below was used to load the `best_vit_cifar10.pth` checkpoint and calculate the final accuracy for each class. As the output is already present, this cell is also commented out to present the notebook as a static report.

In [7]:
# --- Re-running the Final Evaluation ---
# We must first get the testloader and classes, which were defined in a now-deleted cell
print("Loading test data...")
# We use config['batch_size'] which was defined in Cell 10
_ , testloader, classes = get_dataloaders(batch_size=config['batch_size'])

# Load the best model from our checkpoints folder
# Note: We go up one directory ('..') from /notebooks
checkpoint_path = '../checkpoints/best_vit_cifar10.pth'
print(f"Loading checkpoint from {checkpoint_path}...")

# Ensure we load onto the correct device (e.g., CPU if no GPU)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

# --- Define variables needed for the final summary cell ---
# Get 'best_acc' from the checkpoint
best_acc = checkpoint.get('test_acc', 90.78) # Fallback to 90.78 if not in checkpoint

# Re-calculate overall_acc and class accuracies
print("Evaluating model on test set...")
model.eval()
correct = 0
total = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        c = (predicted == labels).squeeze()
        for i in range(labels.size(0)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

overall_acc = 100 * correct / total
print(f'\nOverall Test Accuracy: {overall_acc:.2f}%')
print('\nPer-class Accuracy:')
print('-' * 30)
for i in range(10):
    acc = 100 * class_correct[i] / class_total[i]
    print(f'{classes[i]:10s}: {acc:.2f}%')

Loading test data...


100%|██████████| 170M/170M [11:04<00:00, 256kB/s]  


Loading checkpoint from ../checkpoints/best_vit_cifar10.pth...
Evaluating model on test set...





Overall Test Accuracy: 90.78%

Per-class Accuracy:
------------------------------
plane     : 92.50%
car       : 95.80%
bird      : 89.90%
cat       : 80.80%
deer      : 88.80%
dog       : 86.00%
frog      : 92.80%
horse     : 93.00%
ship      : 95.10%
truck     : 93.10%


## 5. Final Summary

In [8]:
print("="*50)
print("VISION TRANSFORMER - CIFAR-10 RESULTS")
print("="*50)

print(f"\nModel Configuration:")
print(f"  - Patch Size: {config['patch_size']}x{config['patch_size']}")
print(f"  - Embedding Dim: {config['embed_dim']}")
print(f"  - Depth: {config['depth']} blocks")
print(f"  - Heads: {config['num_heads']}")
print(f"  - Parameters: {total_params:,}")
print(f"\nTraining Configuration:")
print(f"  - Epochs: {config['epochs']}")
print(f"  - Batch Size: {config['batch_size']}")
print(f"  - Learning Rate: {config['base_lr']} (with warmup)")
print(f"  - Augmentations: CutMix + MixUp + RandAugment")
print(f"\nResults:")
print(f"  - Best Test Accuracy: {best_acc:.2f}%")
print(f"  - Overall Test Accuracy: {overall_acc:.2f}%")
print("="*50)


VISION TRANSFORMER - CIFAR-10 RESULTS

Model Configuration:
  - Patch Size: 4x4
  - Embedding Dim: 384
  - Depth: 12 blocks
  - Heads: 6
  - Parameters: 21,423,562

Training Configuration:
  - Epochs: 200
  - Batch Size: 128
  - Learning Rate: 0.001 (with warmup)
  - Augmentations: CutMix + MixUp + RandAugment

Results:
  - Best Test Accuracy: 90.78%
  - Overall Test Accuracy: 90.78%
