In [5]:
# Cell 1 ── imports & global config
import time, random, torch, torch.nn as nn, torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import numpy as np

DEVICE = "mps" if torch.backends.mps.is_available() else "cuda"
BS      = 128
LR      = 0.05
torch.manual_seed(0); random.seed(0); np.random.seed(0)
torch.backends.cudnn.benchmark = True
print("Device:", DEVICE)


Device: mps


In [6]:
# Cell 2 ── single-worker loaders for every resolution
def make_loader(size, train):
    tf = transforms.Compose([
        transforms.RandomResizedCrop(size, scale=(0.6,1.0)) if train else transforms.Resize(size),
        transforms.RandomHorizontalFlip()                   if train else transforms.Lambda(lambda x:x),
        transforms.ToTensor()
    ])
    split = "train" if train else "test"
    ds = datasets.STL10("data", split=split, download=True, transform=tf)
    return DataLoader(ds, batch_size=BS, shuffle=train, num_workers=0)

SIZES   = [32, 40, 56, 72, 96]
train_ld= {s: make_loader(s, True)  for s in SIZES}
val_ld  = {s: make_loader(s, False) for s in SIZES}
print("Loaders ready.")


Loaders ready.


In [7]:
# Cell 3 ── helpers
loss_CE = nn.CrossEntropyLoss(label_smoothing=0.1)

def accuracy(net, loader):
    net.eval(); c=t=0
    with torch.no_grad():
        for x,y in loader:
            c += (net(x.to(DEVICE)).argmax(1) == y.to(DEVICE)).sum().item()
            t += y.size(0)
    return 100 * c / t

def train_one_epoch(net, loader, opt):
    net.train()
    for x,y in loader:
        x,y = x.to(DEVICE), y.to(DEVICE)
        opt.zero_grad(set_to_none=True)
        loss_CE(net(x), y).backward(); opt.step()

def resnet18():
    m = models.resnet18(num_classes=10)
    m.conv1  = nn.Conv2d(3,64,3,1,1,bias=False)  # 3×3 stem
    m.maxpool= nn.Identity()
    return m.to(DEVICE)


In [10]:
# Cell X – seeds loop (on-demand loaders + optional BN-freeze)
import time, math, random, numpy as np, torch
from torch import nn

# ─── 1.  Experiment knobs ────────────────────────────────────────────────
SEEDS        = [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
ladder       = [32, 48, 64, 80, 96]
epochs_step  = [4, 3, 3, 2, 12]        # 24 epochs total
FREEZE_BN    = True                    # flip to False for ablation
baseline_times, ladder_times = [], []

# ─── 2.  On-demand DataLoader cache ──────────────────────────────────────
train_ld, val_ld = {}, {}              # start empty!

def ensure_loader(sz: int):
    """Create STL-10 DataLoaders for this resolution if missing."""
    if sz not in train_ld:
        train_ld[sz] = make_loader(sz, train=True)
        val_ld  [sz] = make_loader(sz, train=False)

# ─── 3.  Utility to freeze BN layers once curriculum leaves first stage ──
def freeze_bn_layers(net: nn.Module):
    for m in net.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()
            m.weight.requires_grad_(False)
            m.bias  .requires_grad_(False)

# ─── 4.  Baseline: full-resolution only ──────────────────────────────────
def run_baseline(seed: int) -> float:
    torch.manual_seed(seed); random.seed(seed); np.random.seed(seed)

    ensure_loader(96)                        # make sure 96-px data is ready
    net = resnet18()
    opt = torch.optim.SGD(net.parameters(), lr=LR,
                          momentum=0.9, weight_decay=1e-4)

    t0 = time.perf_counter()
    for e in range(40):
        train_one_epoch(net, train_ld[96], opt)
        acc  = accuracy(net, val_ld[96])
        mins = (time.perf_counter() - t0) / 60
        print(f"[seed {seed}] BASE e{e:02d} | val {acc:5.1f}% | {mins:5.2f} m")
        if acc >= 60:
            break
    return (time.perf_counter() - t0) / 60

# ─── 5.  Ladder curriculum with on-demand loaders & BN freeze ────────────
def run_ladder(seed: int) -> float:
    torch.manual_seed(seed); random.seed(seed); np.random.seed(seed)

    net = resnet18()
    opt = torch.optim.SGD(net.parameters(), lr=LR,
                          momentum=0.9, weight_decay=1e-4)

    t0         = time.perf_counter()
    bn_frozen  = False

    for sz, n_ep in zip(ladder, epochs_step):
        ensure_loader(sz)

        if FREEZE_BN and sz > ladder[0] and not bn_frozen:
            freeze_bn_layers(net)
            bn_frozen = True

        for ep in range(n_ep):
            train_one_epoch(net, train_ld[sz], opt)
            acc  = accuracy(net, val_ld[sz])
            mins = (time.perf_counter() - t0) / 60
            print(f"[seed {seed}] LAD {sz} e{ep:02d} | "
                  f"val {acc:5.1f}% | {mins:5.2f} m")
            if sz == 96 and acc >= 60:
                break
        if sz == 96 and acc >= 60:
            break

    return (time.perf_counter() - t0) / 60

# ─── 6.  Paired benchmark loop ───────────────────────────────────────────
def paired_t(delta: np.ndarray):
    n   = len(delta)
    std = delta.std(ddof=1)
    mean= delta.mean()
    t   = mean / (std / math.sqrt(n))
    ci  = 1.96 * std / math.sqrt(n)
    p   = 2 * (1 - 0.5 * (1 + math.erf(abs(t) / math.sqrt(2)))) if n > 30 else None
    return mean, ci, p, t

for s in SEEDS:
    print(f"\n========= Seed {s} baseline ===========================")
    b = run_baseline(s)
    print(f"→ baseline done in {b:.2f} min\n")

    print(f"========= Seed {s} ladder (BN-freeze={FREEZE_BN}) ======")
    l = run_ladder(s)
    print(f"→ ladder   done in {l:.2f} min\n")

    baseline_times.append(b)
    ladder_times  .append(l)

# ─── 7.  Stats summary ───────────────────────────────────────────────────
base  = np.array(baseline_times)
lad   = np.array(ladder_times)
delta = base - lad

m, ci, p, t_stat = paired_t(delta)

print("-----------------------------------------------------------")
print("Seeds:", SEEDS)
for s, b, l in zip(SEEDS, base, lad):
    print(f"seed {s}: Δ = {b - l:4.2f} min   (base {b:5.2f} | lad {l:5.2f})")
print(f"\nMean Δ-time : {m:4.2f} ±{ci:4.2f} (95% CI)")
print(f"Paired t-stat {t_stat:4.2f}   p ≈ {p if p else 'n<30'}")
print(f"Speed-up    : {base.mean() / lad.mean():4.1f}×")
print("-----------------------------------------------------------")



[seed 20] BASE e00 | val  19.8% |  0.28 m
[seed 20] BASE e01 | val  26.8% |  0.56 m
[seed 20] BASE e02 | val  31.0% |  0.85 m
[seed 20] BASE e03 | val  34.3% |  1.13 m
[seed 20] BASE e04 | val  34.4% |  1.41 m
[seed 20] BASE e05 | val  38.1% |  1.69 m
[seed 20] BASE e06 | val  35.6% |  1.97 m
[seed 20] BASE e07 | val  41.2% |  2.25 m
[seed 20] BASE e08 | val  44.1% |  2.53 m
[seed 20] BASE e09 | val  34.0% |  2.82 m
[seed 20] BASE e10 | val  49.3% |  3.10 m
[seed 20] BASE e11 | val  45.4% |  3.38 m
[seed 20] BASE e12 | val  40.6% |  3.66 m
[seed 20] BASE e13 | val  51.4% |  3.95 m
[seed 20] BASE e14 | val  49.5% |  4.23 m
[seed 20] BASE e15 | val  44.5% |  4.51 m
[seed 20] BASE e16 | val  56.7% |  4.79 m
[seed 20] BASE e17 | val  55.4% |  5.07 m
[seed 20] BASE e18 | val  56.8% |  5.36 m
[seed 20] BASE e19 | val  51.7% |  5.64 m
[seed 20] BASE e20 | val  59.3% |  5.92 m
[seed 20] BASE e21 | val  57.8% |  6.20 m
[seed 20] BASE e22 | val  44.6% |  6.48 m
[seed 20] BASE e23 | val  59.2% |

KeyboardInterrupt: 