In [3]:
import os
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import wandb
import hashlib
import numpy as np
import random

In [4]:
def find_root():
    cur = os.path.abspath(os.curdir)
    while os.path.basename(cur) != "cifar-week3":
        parent = os.path.dirname(cur)
        if parent == cur:
            raise RuntimeError("Folder must be named 'cifar-week3'")
        cur = parent
    return cur

ROOT = find_root()
DATA_DIR = os.path.join(ROOT, "data")
ARTIFACTS_DIR = os.path.join(ROOT, "artifacts")
os.makedirs(ARTIFACTS_DIR, exist_ok=True)

In [5]:
random.seed(42); np.random.seed(42); torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33musansrita[0m ([33musansrita-kathmandu-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [7]:
wandb.init(
    project="cifar10-week3",
    group="Day2-FromScratch-CNN",
    name="cnn9-amp-onecycle",
    config={
        "model": "CNN9 (from scratch)",
        "epochs": 50,
        "batch_size": 128,
        "optimizer": "SGD + OneCycleLR",
        "lr_max": 0.1,
        "weight_decay": 5e-4,
        "amp": True,
        "augmentation": "basic + RandAugment",
    }
)

In [8]:
def md5(p): 
    if not os.path.exists(p): return "MISSING"
    h = hashlib.md5()
    with open(p, "rb") as f:
        for c in iter(lambda: f.read(4096), b""): h.update(c)
    return h.hexdigest()

print("MD5 checksums:")
for f in ["data_batch_1", "test_batch"]:
    p = os.path.join(DATA_DIR, "cifar-10-batches-py", f)
    print(f"  {p}: {md5(p)}")

MD5 checksums:
  c:\cifar-week3\data\cifar-10-batches-py\data_batch_1: c99cafc152244af753f735de768cd75f
  c:\cifar-week3\data\cifar-10-batches-py\test_batch: 40351d587109b95175f43aff81a1287e


In [9]:
MEAN = [0.4914, 0.4822, 0.4465]
STD  = [0.2470, 0.2430, 0.2610]

train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])
test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

full_train = datasets.CIFAR10(root=DATA_DIR, train=True,  download=True, transform=train_tf)
test_ds    = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=test_tf)

train_ds, val_ds = random_split(full_train, [45000, 5000], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True,  num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=4, pin_memory=True)

In [10]:
try:
    import pandera.pandas as pa
    labels = [int(train_ds[i][1]) for i in range(500)]
    pa.DataFrameSchema({"label": pa.Column(int, pa.Check(lambda s: s.between(0,9).all()))})(pd.DataFrame({"label": labels}))
    print("Pandera validation PASSED")
except: pass

In [11]:
class CNN9(nn.Module):
    def conv(self, i, o): 
        return nn.Sequential(nn.Conv2d(i,o,3,1,1,bias=False), nn.BatchNorm2d(o), nn.ReLU(inplace=True))
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            self.conv(3,64),   self.conv(64,128),   nn.MaxPool2d(2),
            self.conv(128,128), self.conv(128,256),  nn.MaxPool2d(2),
            self.conv(256,512), self.conv(512,512),  nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512,10)
        )
    def forward(self,x): return self.net(x)

model = CNN9().to(device)
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")

Params: 4,065,098


In [12]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

In [13]:
total_steps = 50 * len(train_loader)

In [14]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.1,
    total_steps=total_steps, 
    pct_start=0.2,
    anneal_strategy='cos',
    cycle_momentum=True,
    base_momentum=0.85,
    max_momentum=0.95,
)

In [15]:
criterion = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler('cuda')

In [16]:
best_val = 0.0
global_step = 0  

for epoch in range(1, 51):
    model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            loss = criterion(model(x), y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if global_step < total_steps:
            scheduler.step()
        global_step += 1

    # Validation
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            with torch.amp.autocast('cuda'):
                correct += (model(x).argmax(1) == y).sum().item()
            total += y.size(0)
    val_acc = correct / total

    wandb.log({
        "val_accuracy": val_acc,
        "learning_rate": optimizer.param_groups[0]['lr'],
        "epoch": epoch
    })

    if val_acc > best_val:
        best_val = val_acc

    print(f"Epoch {epoch:02d} → Val: {val_acc:.4f} (Best: {best_val:.4f})")

Epoch 01 → Val: 0.4952 (Best: 0.4952)
Epoch 02 → Val: 0.4178 (Best: 0.4952)
Epoch 03 → Val: 0.5954 (Best: 0.5954)
Epoch 04 → Val: 0.6270 (Best: 0.6270)
Epoch 05 → Val: 0.6950 (Best: 0.6950)
Epoch 06 → Val: 0.6280 (Best: 0.6950)
Epoch 07 → Val: 0.6672 (Best: 0.6950)
Epoch 08 → Val: 0.6866 (Best: 0.6950)
Epoch 09 → Val: 0.7024 (Best: 0.7024)
Epoch 10 → Val: 0.7442 (Best: 0.7442)
Epoch 11 → Val: 0.7852 (Best: 0.7852)
Epoch 12 → Val: 0.6918 (Best: 0.7852)
Epoch 13 → Val: 0.7194 (Best: 0.7852)
Epoch 14 → Val: 0.7318 (Best: 0.7852)
Epoch 15 → Val: 0.7836 (Best: 0.7852)
Epoch 16 → Val: 0.7690 (Best: 0.7852)
Epoch 17 → Val: 0.7392 (Best: 0.7852)
Epoch 18 → Val: 0.7496 (Best: 0.7852)
Epoch 19 → Val: 0.7514 (Best: 0.7852)
Epoch 20 → Val: 0.7964 (Best: 0.7964)
Epoch 21 → Val: 0.8066 (Best: 0.8066)
Epoch 22 → Val: 0.7730 (Best: 0.8066)
Epoch 23 → Val: 0.7550 (Best: 0.8066)
Epoch 24 → Val: 0.7870 (Best: 0.8066)
Epoch 25 → Val: 0.7624 (Best: 0.8066)
Epoch 26 → Val: 0.8334 (Best: 0.8334)
Epoch 27 → V

In [17]:
model.eval()
correct = total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        correct += (model(x).argmax(1) == y).sum().item()
        total += y.size(0)
test_acc = correct / total

wandb.log({"test_accuracy_final": test_acc, "best_val_accuracy": best_val})
print(f"\nFINAL TEST ACCURACY: {test_acc:.4f}")


FINAL TEST ACCURACY: 0.9349


In [18]:
path = os.path.join(ARTIFACTS_DIR, "day2_cnn9_final.pth")
torch.save(model.state_dict(), path)

artifact = wandb.Artifact("day2-cnn9-model", type="model")
artifact.add_file(path)
wandb.log_artifact(artifact)

wandb.finish()

0,1
best_val_accuracy,▁
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
learning_rate,▁▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁
test_accuracy_final,▁
val_accuracy,▂▁▄▄▅▅▅▆▆▅▅▆▆▆▆▆▇▆▆▆▇▆▆▆▇▆▇▇▆▇▇▇▇▇██████

0,1
best_val_accuracy,0.9112
epoch,50.0
learning_rate,0.0
test_accuracy_final,0.9349
val_accuracy,0.9078
