imports

In [None]:
import os, random, math
import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms, models
import timm

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_DIR = "/content/results_task1"; os.makedirs(SAVE_DIR, exist_ok=True)

BATCH_SIZE = 256
EPOCHS = 12            # increase to ~20 if time allows
LR = 3e-4
NUM_WORKERS = 2

1) CIFAR-10 Data & Fixed Train/Val/Test

In [None]:
# CIFAR-10 stats
MEAN = (0.4914, 0.4822, 0.4465)
STD  = (0.2470, 0.2435, 0.2616)

# 32x32 pipeline (ResNet)
train_tf_rn = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])
test_tf_rn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

# 224x224 pipeline (ViT)
train_tf_vit = transforms.Compose([
    transforms.Resize(224, antialias=True),
    transforms.RandomCrop(224, padding=8),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])
test_tf_vit = transforms.Compose([
    transforms.Resize(224, antialias=True),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

root = "/content/data"
full_train = datasets.CIFAR10(root, train=True, download=True, transform=None)
test_set   = datasets.CIFAR10(root, train=False, download=True, transform=None)

# Fixed split indices (45k train / 5k val)
num_train = len(full_train)  # 50,000
idx = np.arange(num_train)
rng = np.random.default_rng(SEED)
rng.shuffle(idx)
train_idx, val_idx = idx[:45000], idx[45000:]

# Two views of the same split with different transforms
class TransformView(torch.utils.data.Dataset):
    def __init__(self, base, indices, transform):
        self.base = base
        self.indices = indices
        self.transform = transform
    def __len__(self): return len(self.indices)
    def __getitem__(self, i):
        x, y = self.base[self.indices[i]]
        if self.transform is not None:
            x = self.transform(x)
        return x, y

train_rn = TransformView(full_train, train_idx, train_tf_rn)
val_rn   = TransformView(full_train, val_idx,   test_tf_rn)
test_rn  = TransformView(test_set,  np.arange(len(test_set)), test_tf_rn)

train_vit = TransformView(full_train, train_idx, train_tf_vit)
val_vit   = TransformView(full_train, val_idx,   test_tf_vit)
test_vit  = TransformView(test_set,  np.arange(len(test_set)), test_tf_vit)

train_loader_rn = DataLoader(train_rn, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader_rn   = DataLoader(val_rn,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader_rn  = DataLoader(test_rn,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

train_loader_vit = DataLoader(train_vit, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader_vit   = DataLoader(val_vit,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader_vit  = DataLoader(test_vit,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)


2) Models (ResNet-50 & ViT-S/16, pretrained)

In [None]:
def build_resnet50(num_classes=10, pretrained=True):
    m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

def build_vit_s16(num_classes=10, pretrained=True):
    m = timm.create_model('vit_small_patch16_224', pretrained=pretrained)
    m.head = nn.Linear(m.head.in_features, num_classes)
    return m

resnet = build_resnet50().to(device)
vit    = build_vit_s16().to(device)


3) Train/Eval Utilities

In [None]:
def train_epoch(model, loader, opt):
    model.train()
    total, correct, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        out = model(x)
        loss = F.cross_entropy(out, y)
        loss.backward(); opt.step()
        loss_sum += loss.item() * x.size(0)
        pred = out.argmax(1); total += y.size(0); correct += (pred == y).sum().item()
    return loss_sum/total, correct/total

@torch.no_grad()
def eval_acc(model, loader):
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = F.cross_entropy(out, y)
        loss_sum += loss.item() * x.size(0)
        pred = out.argmax(1); total += y.size(0); correct += (pred == y).sum().item()
    return loss_sum/total, correct/total

def fit(model, train_loader, val_loader, test_loader, tag):
    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    best_val = 0.0
    history = []
    best_path = f"{SAVE_DIR}/{tag}_best.pth"

    for e in range(1, EPOCHS+1):
        tr_loss, tr_acc = train_epoch(model, train_loader, opt)
        va_loss, va_acc = eval_acc(model, val_loader)
        te_loss, te_acc = eval_acc(model, test_loader)
        history.append((e, tr_loss, tr_acc, va_loss, va_acc, te_loss, te_acc))
        print(f"[{tag}] Ep{e:02d} | train {tr_loss:.3f}/{tr_acc:.3f} | val {va_loss:.3f}/{va_acc:.3f} | test {te_loss:.3f}/{te_acc:.3f}")

        if va_acc > best_val:
            best_val = va_acc
            torch.save(model.state_dict(), best_path)

    # Load best-by-val and report final test
    model.load_state_dict(torch.load(best_path, map_location=device))
    final_test_loss, final_test_acc = eval_acc(model, test_loader)
    print(f"[{tag}] Best-on-val checkpoint -> test acc: {final_test_acc:.4f}")
    return history, best_path


In [None]:
hist_rn, ckpt_rn = fit(resnet, train_loader_rn, val_loader_rn, test_loader_rn, tag="resnet50_cifar")
hist_vt, ckpt_vt = fit(vit,    train_loader_vit, val_loader_vit, test_loader_vit, tag="vit_s16_cifar")

print("Saved:", ckpt_rn, ckpt_vt)
