In [1]:
# Libraries + Config
import time
import math
import random
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as T

from tqdm.auto import tqdm

@dataclass
class Config:
    # Data
    img_size: int = 32
    patch_size: int = 4
    in_chans: int = 3
    num_classes: int = 10

    # ViT
    d_model: int = 64
    depth: int = 6
    num_heads: int = 4
    mlp_dim: int = 256       # = 4 * d_model
    dropout: float = 0.1

    # Train
    batch_size: int = 128
    epochs: int = 10
    lr: float = 3e-4
    weight_decay: float = 0.05
    seed: int = 0
    num_workers: int = 2

cfg = Config()

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # deterministic
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(cfg.seed)

print(cfg)


Config(img_size=32, patch_size=4, in_chans=3, num_classes=10, d_model=64, depth=6, num_heads=4, mlp_dim=256, dropout=0.1, batch_size=128, epochs=10, lr=0.0003, weight_decay=0.05, seed=0, num_workers=2)


In [2]:
# Data + Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))

# CIFAR-10 is already 32x32 RGB
train_tfms = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                (0.2470, 0.2435, 0.2616)),
])

test_tfms = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                (0.2470, 0.2435, 0.2616)),
])

train_ds = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=train_tfms
)

test_ds = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=test_tfms
)

train_loader = DataLoader(
    train_ds,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=cfg.num_workers,
    pin_memory=(device.type == "cuda"),
)

test_loader = DataLoader(
    test_ds,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=cfg.num_workers,
    pin_memory=(device.type == "cuda"),
)

# quick sanity check
x, y = next(iter(train_loader))
print("Batch x:", x.shape, "Batch y:", y.shape)
print("x dtype:", x.dtype, "y[:10]:", y[:10].tolist())

# tokens check: 32x32 with patch 4 => 8x8=64 patches (+cls)
num_patches = (cfg.img_size // cfg.patch_size) * (cfg.img_size // cfg.patch_size)
print("Num patches:", num_patches, "Tokens with CLS:", num_patches + 1)

Device: cuda
GPU: Tesla T4


100%|██████████| 170M/170M [00:04<00:00, 35.0MB/s]


Batch x: torch.Size([128, 3, 32, 32]) Batch y: torch.Size([128])
x dtype: torch.float32 y[:10]: [4, 3, 9, 0, 1, 7, 2, 0, 7, 7]
Num patches: 64 Tokens with CLS: 65


In [3]:
# ViT Model
class PatchEmbed(nn.Module):
    """
    Convert image to patch embeddings using Conv2d:
    (B, C, H, W) -> (B, N, D)
    """
    def __init__(self, img_size: int, patch_size: int, in_chans: int, embed_dim: int):
        super().__init__()
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) * (img_size // patch_size)

        self.proj = nn.Conv2d(
            in_channels=in_chans,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)                 # (B, D, H/P, W/P)
        x = x.flatten(2).transpose(1, 2) # (B, N, D)
        return x


class MLP(nn.Module):
    def __init__(self, d_model: int, mlp_dim: int, dropout: float):
        super().__init__()
        self.fc1 = nn.Linear(d_model, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class EncoderBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, mlp_dim: int, dropout: float):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True,
        )
        self.drop_path = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, mlp_dim, dropout)

    def forward(self, x):
        # Pre-LN Attention
        h = self.norm1(x)
        attn_out, _ = self.attn(h, h, h, need_weights=False)
        x = x + self.drop_path(attn_out)

        # Pre-LN MLP
        h = self.norm2(x)
        x = x + self.drop_path(self.mlp(h))
        return x


class ViT(nn.Module):
    def __init__(
        self,
        img_size: int,
        patch_size: int,
        in_chans: int,
        num_classes: int,
        d_model: int,
        depth: int,
        num_heads: int,
        mlp_dim: int,
        dropout: float,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, d_model)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, d_model))
        self.pos_drop = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            EncoderBlock(d_model, num_heads, mlp_dim, dropout) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.zeros_(self.head.bias)

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.patch_embed(x)  # (B, N, D)
        B, N, D = x.shape

        cls = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        x = torch.cat([cls, x], dim=1)          # (B, N+1, D)
        x = x + self.pos_embed[:, :N+1, :]
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls_out = x[:, 0]        # CLS token
        logits = self.head(cls_out)
        return logits


model = ViT(
    img_size=cfg.img_size,
    patch_size=cfg.patch_size,
    in_chans=cfg.in_chans,
    num_classes=cfg.num_classes,
    d_model=cfg.d_model,
    depth=cfg.depth,
    num_heads=cfg.num_heads,
    mlp_dim=cfg.mlp_dim,
    dropout=cfg.dropout,
).to(device)

print(model.__class__.__name__)
print("Params:", sum(p.numel() for p in model.parameters())/1e6, "M")


ViT
Params: 0.308042 M


In [4]:
# Train
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    t0 = time.time()

    pbar = tqdm(loader, desc="Train", leave=False)
    for x, y in pbar:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        pbar.set_postfix(loss=loss.item())

    avg_loss = running_loss / len(loader.dataset)
    epoch_time = time.time() - t0
    return avg_loss, epoch_time


@torch.no_grad()
def evaluate_accuracy(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / total


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

total_train_start = time.time()

for epoch in range(1, cfg.epochs + 1):
    # tqdm
    epoch_bar = tqdm(total=len(train_loader), desc=f"Epoch {epoch}/{cfg.epochs}", leave=True)

    model.train()
    running_loss = 0.0
    epoch_start = time.time()

    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        epoch_bar.update(1)

    epoch_bar.close()

    avg_loss = running_loss / len(train_loader.dataset)
    epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch}/{cfg.epochs}, Average Loss: {avg_loss:.4f}, Time Taken: {epoch_time:.2f}s")

total_train_time = time.time() - total_train_start
print(f"\nTotal training time: {total_train_time:.2f}s")


Epoch 1/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 1/10, Average Loss: 1.8293, Time Taken: 18.45s


Epoch 2/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 2/10, Average Loss: 1.5478, Time Taken: 17.30s


Epoch 3/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 3/10, Average Loss: 1.3895, Time Taken: 17.69s


Epoch 4/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 4/10, Average Loss: 1.2942, Time Taken: 18.81s


Epoch 5/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 5/10, Average Loss: 1.2282, Time Taken: 18.10s


Epoch 6/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 6/10, Average Loss: 1.1738, Time Taken: 16.89s


Epoch 7/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 7/10, Average Loss: 1.1308, Time Taken: 16.96s


Epoch 8/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 8/10, Average Loss: 1.0854, Time Taken: 18.23s


Epoch 9/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 9/10, Average Loss: 1.0514, Time Taken: 17.76s


Epoch 10/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 10/10, Average Loss: 1.0211, Time Taken: 18.18s

Total training time: 178.48s


In [5]:
# Evaluate: Accuracy + Train time + Inference time
@torch.no_grad()
def measure_inference_time(model, loader, device, warmup_batches=10):
    model.eval()

    # warmup
    if device.type == "cuda":
        for i, (x, _) in enumerate(loader):
            if i >= warmup_batches:
                break
            x = x.to(device, non_blocking=True)
            _ = model(x)
        torch.cuda.synchronize()

    t0 = time.time()
    total_samples = 0

    for x, _ in loader:
        x = x.to(device, non_blocking=True)
        _ = model(x)
        total_samples += x.size(0)

    if device.type == "cuda":
        torch.cuda.synchronize()

    t = time.time() - t0
    return t, total_samples

test_acc = evaluate_accuracy(model, test_loader, device)
infer_time, infer_samples = measure_inference_time(model, test_loader, device)

print(f"Test Accuracy: {test_acc*100:.2f}%")
print(f"Train time (total): {total_train_time:.2f}s")
print(f"Inference time (test set): {infer_time:.4f}s for {infer_samples} samples "
      f"({(infer_time/infer_samples)*1000:.4f} ms/sample)")

Test Accuracy: 62.79%
Train time (total): 178.48s
Inference time (test set): 2.2482s for 10000 samples (0.2248 ms/sample)
