In [2]:

!pip install timm -q

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm
import numpy as np
import random
import time
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# For reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


os.makedirs("outputs", exist_ok=True)


Using device: cuda


In [3]:
# ==========================
# CIFAR-10 DATALOADERS
# ==========================

# CIFAR-10 images are 32x32, but ViT expects 224x224
# so we resize them.
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2470, 0.2435, 0.2616]),
])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2470, 0.2435, 0.2616]),
])

train_dataset = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=train_transform
)

test_dataset = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=test_transform
)

batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=2)

print("Train samples:", len(train_dataset))
print("Test samples :", len(test_dataset))


100%|██████████| 170M/170M [01:06<00:00, 2.58MB/s]


Train samples: 50000
Test samples : 10000


In [4]:
# ==========================
# Define ViT model
# ==========================

class ViT_CIFAR10(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # This roughly matches the ViT-16 model family used in the paper
        self.backbone = timm.create_model(
            "vit_base_patch16_224",
            pretrained=False,         # paper trains from scratch on small data
            num_classes=num_classes
        )

    def forward(self, x):
        return self.backbone(x)

model = ViT_CIFAR10().to(device)
print("Number of parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")


Number of parameters: 85.806346 M


In [5]:
# ==========================
# Training & Evaluation Utilities
# ==========================

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total   += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc  = correct / total
    return epoch_loss, epoch_acc

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total   += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc  = correct / total
    return epoch_loss, epoch_acc


In [6]:
# ==========================
# FAST Train Loop
# ==========================

num_epochs = 3

best_acc = 0.0
history = []

start = time.time()

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, test_loader, criterion, device)

    history.append({
        "epoch": epoch+1,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc
    })

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "outputs/best_vit_cifar10.pth")

    print(f"[{epoch+1}/{num_epochs}] Train Acc: {train_acc*100:.2f}% | Test Acc: {val_acc*100:.2f}%")

print("Done in", (time.time() - start)/60, "minutes")
print("Best Test Accuracy:", best_acc*100, "%")


[1/3] Train Acc: 36.33% | Test Acc: 42.01%
[2/3] Train Acc: 48.75% | Test Acc: 50.60%
[3/3] Train Acc: 53.88% | Test Acc: 55.65%
Done in 83.48183061679204 minutes
Best Test Accuracy: 55.65 %


In [None]:
# ==========================
# Save training history
# ==========================
import json

with open("outputs/history_vit_cifar10.json", "w") as f:
    json.dump(history, f, indent=2)

print("Saved model and history to 'outputs/' folder.")
