In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summary
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm

# -------------------------------
# 1) Vision Transformer Modules
# -------------------------------
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, 
                              kernel_size=patch_size, 
                              stride=patch_size)
        # learnable CLS token + positional encoding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.n_patches, embed_dim))

    def forward(self, x):
        B = x.shape[0]
        x = self.proj(x)                  # (B, E, H/ps, W/ps)
        x = x.flatten(2).transpose(1, 2)  # (B, n_patches, E)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B,1,E)
        x = torch.cat((cls_tokens, x), dim=1)          # (B,1+n_patches,E)
        x = x + self.pos_embed
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_hidden_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn  = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp   = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x: (B, N, E) -> transpose for MHAttention: (N, B, E)
        x2 = self.norm1(x)
        attn_out, _ = self.attn(x2.transpose(0,1), x2.transpose(0,1), x2.transpose(0,1))
        x = x + attn_out.transpose(0,1)
        x = x + self.mlp(self.norm2(x))
        return x

class ViT(nn.Module):
    def __init__(self, 
                 img_size=32, patch_size=4, in_chans=3, 
                 num_classes=100, embed_dim=256, depth=4, 
                 num_heads=4, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        mlp_hidden = embed_dim * mlp_ratio
        self.encoder = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, mlp_hidden, dropout)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)   # (B, 1+n, E)
        for blk in self.encoder:
            x = blk(x)
        x = self.norm(x)
        cls_token_final = x[:, 0]
        return self.head(cls_token_final)

# -------------------------------
# 2) Data Loaders
# -------------------------------
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071,0.4867,0.4408), (0.2675,0.2565,0.2761))
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071,0.4867,0.4408), (0.2675,0.2565,0.2761))
])

train_ds = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
test_ds  = datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True,  num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# -------------------------------
# 3) Training & Evaluation
# -------------------------------
def train_one_epoch(model, optimizer, criterion, loader):
    model.train()
    running_loss = 0.0
    start = time.time()
    for imgs, labels in tqdm(loader, leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    elapsed = time.time() - start
    return running_loss / len(loader.dataset), elapsed / len(loader)

@torch.no_grad()
def test(model, loader):
    model.eval()
    correct = 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        out = model(imgs)
        pred = out.argmax(dim=1)
        correct += (pred == labels).sum().item()
    return correct / len(loader.dataset)

# -------------------------------
# 4) Experiment Loop
# -------------------------------
vit_configs = []
for ps in [4,8]:
  for ed in [256,512]:
    for dp in [4,8]:
      for nh in [2,4]:
        for mr in [2,4]:
          vit_configs.append({
            'patch_size': ps,
            'embed_dim': ed,
            'depth': dp,
            'num_heads': nh,
            'mlp_ratio': mr
          })

results = []

# ViT runs (20 epochs each)
for cfg in vit_configs:
    name = f"ViT_ps{cfg['patch_size']}_ed{cfg['embed_dim']}_d{cfg['depth']}_h{cfg['num_heads']}_mr{cfg['mlp_ratio']}"
    print(f"\n===== TRAINING {name} =====")
    model = ViT(patch_size=cfg['patch_size'],
                embed_dim=cfg['embed_dim'],
                depth=cfg['depth'],
                num_heads=cfg['num_heads'],
                mlp_ratio=cfg['mlp_ratio']
               ).to(device)
    opt   = optim.Adam(model.parameters(), lr=1e-3)
    crit  = nn.CrossEntropyLoss()

    # summary for params & FLOPs
    sumry = summary(model, input_size=(1,3,32,32), verbose=0)
    n_params = sum(p.numel() for p in model.parameters())
    flops    = sumry.total_mult_adds  # approx FLOPs

    # train
    total_time = 0.
    for epoch in range(20):
        loss, tpe = train_one_epoch(model, opt, crit, train_loader)
        total_time += tpe
    avg_time = total_time / 20
    acc = test(model, test_loader)

    results.append({
        'model': name,
        'params': n_params,
        'flops': flops,
        'time/epoch(s)': avg_time,
        'test_acc': acc
    })

# ResNet-18 baseline (10 epochs)
print("\n===== TRAINING ResNet-18 CIFAR100 =====")
resnet = models.resnet18(pretrained=False)
# adapt for CIFAR:
resnet.conv1 = nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1,bias=False)
resnet.maxpool = nn.Identity()
resnet.fc = nn.Linear(resnet.fc.in_features, 100)
resnet = resnet.to(device)

opt = optim.Adam(resnet.parameters(), lr=1e-3)
crit = nn.CrossEntropyLoss()
sumry = summary(resnet, input_size=(1,3,32,32), verbose=0)
n_params = sum(p.numel() for p in resnet.parameters())
flops    = sumry.total_mult_adds

total_time = 0.
for epoch in range(10):
    loss, tpe = train_one_epoch(resnet, opt, crit, train_loader)
    total_time += tpe
avg_time = total_time / 10
acc = test(resnet, test_loader)

results.append({
    'model': 'ResNet-18',
    'params': n_params,
    'flops': flops,
    'time/epoch(s)': avg_time,
    'test_acc': acc
})

# -------------------------------
# 5) Summary
# -------------------------------
print("\n=== FINAL RESULTS ===")
print(f"{'Model':40} | {'Params':12} | {'FLOPs':12} | {'Time/Ep(s)':10} | {'Test Acc':8}")
print("-"*90)
for r in results:
    print(f"{r['model']:40} | {r['params']:12,d} | {int(r['flops']):12,d} | {r['time/epoch(s)']:10.3f} | {r['test_acc']*100:7.2f}%")


Files already downloaded and verified
Files already downloaded and verified

===== TRAINING ViT_ps4_ed256_d4_h2_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed256_d4_h2_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed256_d4_h4_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed256_d4_h4_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed256_d8_h2_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed256_d8_h2_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed256_d8_h4_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed256_d8_h4_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed512_d4_h2_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed512_d4_h2_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed512_d4_h4_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed512_d4_h4_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed512_d8_h2_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed512_d8_h2_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed512_d8_h4_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps4_ed512_d8_h4_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed256_d4_h2_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed256_d4_h2_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed256_d4_h4_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed256_d4_h4_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed256_d8_h2_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed256_d8_h2_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed256_d8_h4_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed256_d8_h4_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed512_d4_h2_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed512_d4_h2_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed512_d4_h4_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed512_d4_h4_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed512_d8_h2_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed512_d8_h2_mr4 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed512_d8_h4_mr2 =====


                                                                                                                       


===== TRAINING ViT_ps8_ed512_d8_h4_mr4 =====





===== TRAINING ResNet-18 CIFAR100 =====


                                                                                                                       


=== FINAL RESULTS ===
Model                                    | Params       | FLOPs        | Time/Ep(s) | Test Acc
------------------------------------------------------------------------------------------
ViT_ps4_ed256_d4_h2_mr2                  |    2,164,068 |    1,884,772 |      0.019 |   35.72%
ViT_ps4_ed256_d4_h2_mr4                  |    3,214,692 |    2,935,396 |      0.020 |   39.71%
ViT_ps4_ed256_d4_h4_mr2                  |    2,164,068 |    1,884,772 |      0.019 |   10.86%
ViT_ps4_ed256_d4_h4_mr4                  |    3,214,692 |    2,935,396 |      0.020 |   35.24%
ViT_ps4_ed256_d8_h2_mr2                  |    4,272,484 |    2,940,516 |      0.025 |    2.69%
ViT_ps4_ed256_d8_h2_mr4                  |    6,373,732 |    5,041,764 |      0.026 |   10.11%
ViT_ps4_ed256_d8_h4_mr2                  |    4,272,484 |    2,940,516 |      0.025 |   13.56%
ViT_ps4_ed256_d8_h4_mr4                  |    6,373,732 |    5,041,764 |      0.026 |   18.86%
ViT_ps4_ed512_d4_h2_mr2        