<a href="https://colab.research.google.com/github/s4908819/Colab-for-COMP3710/blob/main/3.3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 查看 GPU & 预装的 torch/torchvision 版本（Colab 通常已预装）
import torch, platform, os
print("Python:", platform.python_version())
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))
# 打开 TF32、cudnn benchmark（和你 slurm 脚本一致的“加速开关”）
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision("high")
except:
    pass


Python: 3.12.11
Torch: 2.8.0+cu126
CUDA available: True
CUDA device: NVIDIA A100-SXM4-40GB


In [None]:
%%writefile fast_cifar10.py
# -*- coding: utf-8 -*-
import os, time, argparse
from pathlib import Path
from collections import defaultdict
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
        self.bn2   = nn.BatchNorm2d(planes)
        self.down  = None
        if stride != 1 or in_planes != planes:
            self.down = nn.Sequential(
                nn.Conv2d(in_planes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes)
            )
    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)), inplace=True)
        y = self.bn2(self.conv2(y))
        if self.down is not None:
            x = self.down(x)
        return F.relu(x + y, inplace=True)

class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(64)
        self.layer1 = self._make(64,  2, 1)
        self.layer2 = self._make(128, 2, 2)
        self.layer3 = self._make(256, 2, 2)
        self.layer4 = self._make(512, 2, 2)
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.fc  = nn.Linear(512, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
    def _make(self, planes, blocks, stride):
        layers = [BasicBlock(self.in_planes, planes, stride)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.in_planes, planes, 1))
        return nn.Sequential(*layers)
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)), inplace=True)
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        x = self.avg(x).flatten(1)
        return self.fc(x)

def tensor_stats(t):
    return dict(dtype=str(t.dtype), shape=list(t.shape),
                min=float(t.min().item()) if t.numel() else None,
                max=float(t.max().item()) if t.numel() else None,
                unique=int(t.unique().numel()) if t.numel()<10000 else None)

@torch.no_grad()
def param_snapshot(model):
    picks, snap = [], {}
    for n, p in model.named_parameters():
        if p.requires_grad and p.ndim >= 2:
            picks.append(n)
        if len(picks) >= 3:
            break
    for n in picks:
        snap[n] = model.state_dict()[n].float().norm().item()
    return picks, snap

def param_delta(model, picks, before):
    moved = {}
    for n in picks:
        v = model.state_dict()[n].float().norm().item()
        moved[n] = abs(v - before[n])
    return moved

class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = float(decay)
        self.ema = ResNet18(num_classes=model.fc.out_features).to(next(model.parameters()).device)
        self.ema.load_state_dict(model.state_dict())
        for p in self.ema.parameters():
            p.requires_grad_(False)
    @torch.no_grad()
    def update(self, model):
        msd = model.state_dict(); esd = self.ema.state_dict()
        d = self.decay; i = 1.0 - d
        for k, v in esd.items():
            src = msd[k]
            if torch.is_floating_point(v):
                v.mul_(d).add_(src, alpha=i)
            else:
                esd[k].copy_(src)

def get_loaders(root, batch_size, num_workers, autoaugment=True):
    mean, std = (0.4914,0.4822,0.4465), (0.2023,0.1994,0.2010)
    train_tfms = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
    if autoaugment:
        train_tfms.append(transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10))
    train_tfms += [transforms.ToTensor(), transforms.Normalize(mean, std)]
    test_tfms  = [transforms.ToTensor(), transforms.Normalize(mean, std)]
    try:
        train = datasets.CIFAR10(root=root, train=True,  download=False, transform=transforms.Compose(train_tfms))
        test  = datasets.CIFAR10(root=root, train=False, download=False, transform=transforms.Compose(test_tfms))
    except Exception:
        train = datasets.CIFAR10(root=root, train=True,  download=True, transform=transforms.Compose(train_tfms))
        test  = datasets.CIFAR10(root=root, train=False, download=True, transform=transforms.Compose(test_tfms))
    pin = True
    persistent = num_workers > 0
    prefetch = 4 if num_workers > 0 else None
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,  num_workers=num_workers,
                              pin_memory=pin, persistent_workers=persistent,
                              prefetch_factor=prefetch if prefetch else 2)
    test_loader  = DataLoader(test,  batch_size=1024,       shuffle=False, num_workers=num_workers,
                              pin_memory=pin, persistent_workers=persistent,
                              prefetch_factor=prefetch if prefetch else 2)
    return train_loader, test_loader

@torch.no_grad()
def evaluate(model, loader, device, amp=True, channels_last=False):
    model.eval(); correct=total=0
    if amp and device.type == "cuda":
        ctx = torch.amp.autocast('cuda')
    else:
        ctx = torch.amp.autocast('cpu')
    with ctx:
        for x,y in loader:
            x = x.to(device, non_blocking=True,
                     memory_format=torch.channels_last if channels_last else torch.contiguous_format)
            y = y.to(device, non_blocking=True)
            pred = model(x).argmax(1)
            correct += (pred==y).sum().item(); total += y.size(0)
    return 100.0*correct/total

def debug_one_batch(train_loader, model, device, criterion, optimizer, use_amp=False):
    model.train()
    xb, yb = next(iter(train_loader))
    xb, yb = xb.to(device), yb.to(device)
    print("[DEBUG] y:", tensor_stats(yb))
    if yb.dtype != torch.long:
        print("!! CrossEntropyLoss 需要 int64 索引标签（非 one-hot）。")
    if yb.ndim != 1:
        print("!! 目标形状应为 [N]。当前可能是 one-hot 或多维。")
    logits = model(xb)
    with torch.no_grad():
        probs = torch.softmax(logits, dim=1).mean(0)
        print("[DEBUG] mean class probs:", probs.detach().cpu().numpy().round(3).tolist())
    loss = criterion(logits, yb)
    print("[DEBUG] loss(before step):", float(loss.item()))
    picks, snap = param_snapshot(model)
    optimizer.zero_grad(set_to_none=True)
    if use_amp and device.type == "cuda":
        with torch.amp.autocast('cuda'):
            loss2 = criterion(model(xb), yb)
        scaler = torch.amp.GradScaler('cuda')
        scaler.scale(loss2).backward()
        gstats = []
        for n, p in model.named_parameters():
            if p.grad is not None and p.requires_grad:
                gstats.append((n, p.grad.detach().float().norm().item()))
                if len(gstats) >= 3: break
        print("[DEBUG] grad norms:", [(n, round(v, 6)) for n, v in gstats] or "all None")
        scaler.step(optimizer); scaler.update()
    else:
        loss2 = criterion(model(xb), yb)
        loss2.backward()
        gstats = []
        for n, p in model.named_parameters():
            if p.grad is not None and p.requires_grad:
                gstats.append((n, p.grad.detach().float().norm().item()))
                if len(gstats) >= 3: break
        print("[DEBUG] grad norms:", [(n, round(v, 6)) for n, v in gstats] or "all None")
        optimizer.step()
    moved = param_delta(model, picks, snap)
    print("[DEBUG] param moved (|Δ||W||):", {k: round(v, 6) for k, v in moved.items()})
    if all(v == 0.0 for v in moved.values()):
        print("!! 参数没有变化：可能优化器没 step、梯度为 0/None、或被 AMP/NaN/clip 抑制。")
    with torch.no_grad():
        loss_after = criterion(model(xb), yb).item()
    print("[DEBUG] loss(after step):", float(loss_after))

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data", type=str, default=str(Path.home()/ "datasets/cifar10"))
    ap.add_argument("--epochs", type=int, default=200)
    ap.add_argument("--batch-size", type=int, default=1024)
    ap.add_argument("--lr", type=float, default=0.4)
    ap.add_argument("--weight-decay", type=float, default=5e-4)
    ap.add_argument("--momentum", type=float, default=0.9)
    ap.add_argument("--label-smoothing", type=float, default=0.1)
    ap.add_argument("--cosine", action="store_true")
    ap.add_argument("--amp", action="store_true")
    ap.add_argument("--channels-last", action="store_true")
    ap.add_argument("--ema", type=float, default=0.999)
    ap.add_argument("--compile", action="store_true")
    ap.add_argument("--target-acc", type=float, default=93.0)
    ap.add_argument("--evaluate", action="store_true")
    ap.add_argument("--debug-one-batch", action="store_true")
    ap.add_argument("--tiny-overfit", action="store_true")
    args = ap.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass

    # Colab 上 num_workers 取 2~4 更稳（persistent_workers=True 需要 >0）
    num_workers = 2
    train_loader, test_loader = get_loaders(args.data, args.batch_size, num_workers, autoaugment=True)

    model = ResNet18().to(device)
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)
    if args.compile:
        try:
            model = torch.compile(model, mode="max-autotune")
        except Exception as e:
            print(f"[warn] torch.compile failed: {e}")

    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing).to(device)
    opt = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs) if args.cosine \
          else torch.optim.lr_scheduler.MultiStepLR(opt, [100,150], 0.1)
    ema = EMA(model, decay=args.ema) if args.ema>0 else None

    if args.tiny_overfit:
        subset_n = 512
        from torch.utils.data import Subset
        train_idx = list(range(min(subset_n, len(train_loader.dataset))))
        train_loader = DataLoader(Subset(train_loader.dataset, train_idx),
                                  batch_size=128, shuffle=True, num_workers=2,
                                  pin_memory=True, persistent_workers=True, prefetch_factor=4)
        base_lr = 0.4 * (128 / 1024)
        for g in opt.param_groups:
            g['lr'] = base_lr
            g['weight_decay'] = 0.0
        print(f"[TinyOverfit] subset={len(train_idx)}, batch=128, lr={base_lr:.4f}, wd=0.0")

    if args.debug_one_batch:
        debug_one_batch(train_loader, model, device, criterion, opt, use_amp=args.amp and device.type=="cuda")
        return

    t0 = time.monotonic()
    base = evaluate(model, test_loader, device, amp=args.amp, channels_last=args.channels_last)
    print(f"[Eval@Start] Acc={base:.2f}%  Device={device}  CUDA={torch.version.cuda}")

    if args.evaluate:
        print(f"[Evaluate only] Elapsed {time.monotonic()-t0:.2f}s"); return

    scaler = torch.amp.GradScaler('cuda') if (args.amp and device.type=="cuda") else None
    best = base
    for ep in range(args.epochs):
        model.train(); loss_sum=0.0
        for x,y in train_loader:
            x = x.to(device, non_blocking=True,
                     memory_format=torch.channels_last if args.channels_last else torch.contiguous_format)
            y = y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            if scaler is not None:
                with torch.amp.autocast('cuda'):
                    logits = model(x); loss = criterion(logits, y)
                scaler.scale(loss).backward()
                scaler.step(opt); scaler.update()
            else:
                logits = model(x); loss = criterion(logits, y)
                loss.backward(); opt.step()
            if ema: ema.update(model)
            loss_sum += loss.item()
        sch.step()
        eval_model = ema.ema if ema else model
        acc = evaluate(eval_model, test_loader, device, amp=bool(scaler), channels_last=args.channels_last)
        best = max(best, acc); elapsed = time.monotonic()-t0
        print(f"[Epoch {ep+1}/{args.epochs}] loss={loss_sum/len(train_loader):.4f}  acc={acc:.2f}%  best={best:.2f}%  time={elapsed:.1f}s")
        if acc >= args.target_acc:
            print(f"[Reached {args.target_acc:.1f}%] Time-to-accuracy: {elapsed:.2f} seconds")
            break
    eval_model = ema.ema if ema else model
    final = evaluate(eval_model, test_loader, device, amp=bool(scaler), channels_last=args.channels_last)
    print(f"[Final] Acc={final:.2f}%  Total Time={time.monotonic()-t0:.2f}s")

if __name__ == "__main__":
    main()



Writing fast_cifar10.py


In [None]:
# === Fast CIFAR-10 训练（无 label smoothing / 无 EMA，自动 batch，达成 93% 自动停） ===
from pathlib import Path
import torch, os, platform

# 1) 环境 & 数据目录准备（Colab 默认在 /content）
data_root = "/content"

# 兼容你之前手动放的数据：把 cifar-10-python 改成 torchvision 期望的名字
if os.path.isdir(f"{data_root}/cifar-10-python") and not os.path.isdir(f"{data_root}/cifar-10-batches-py"):
    os.rename(f"{data_root}/cifar-10-python", f"{data_root}/cifar-10-batches-py")

# 2) GPU 与加速开关
print("Python:", platform.python_version())
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision("high")
    except: pass

# 3) 自动选择 batch size 与学习率（更稳的组合）
if torch.cuda.is_available():
    name = torch.cuda.get_device_name(0).lower()
    if "t4" in name or "p100" in name:   # 显存较小
        batch = 512;  lr = 0.15
    elif "v100" in name or "l4" in name or "a10" in name:
        batch = 768;  lr = 0.20
    else:                                 # A100/更强
        batch = 1024; lr = 0.20
else:
    batch = 128;  lr = 0.10

print(f"Chosen batch size: {batch}, lr: {lr}")

# 4) 开始训练：无 label smoothing / 无 EMA，AMP + channels_last + cosine 调度
#    提前停止阈值 target-acc=93.0
data_root_var = data_root  # 供 shell 展开
batch_var = batch
lr_var = lr

# 用 IPython 变量替换（$var 语法）
!python fast_cifar10.py \
  --data "$data_root_var" \
  --epochs 400 \
  --target-acc 93.0 \
  --batch-size $batch_var \
  --lr $lr_var \
  --weight-decay 5e-4 \
  --label-smoothing 0.0 \
  --ema 0.0 \
  --cosine \
  --amp \
  --channels-last





Python: 3.12.11
Torch: 2.8.0+cu126
CUDA available: True
CUDA device: NVIDIA A100-SXM4-40GB
Chosen batch size: 768, lr: 0.2
100% 170M/170M [00:10<00:00, 16.3MB/s]
[Eval@Start] Acc=10.00%  Device=cuda  CUDA=12.6
[Epoch 1/400] loss=3.1838  acc=12.94%  best=12.94%  time=55.9s
[Epoch 2/400] loss=2.2582  acc=17.50%  best=17.50%  time=75.7s
[Epoch 3/400] loss=2.1624  acc=23.60%  best=23.60%  time=95.5s
[Epoch 4/400] loss=2.0581  acc=25.33%  best=25.33%  time=115.4s
[Epoch 5/400] loss=1.9460  acc=32.66%  best=32.66%  time=135.5s
[Epoch 6/400] loss=1.8478  acc=34.18%  best=34.18%  time=155.5s
[Epoch 7/400] loss=1.7380  acc=35.39%  best=35.39%  time=175.4s
[Epoch 8/400] loss=1.5996  acc=42.75%  best=42.75%  time=195.2s
[Epoch 9/400] loss=1.4765  acc=43.85%  best=43.85%  time=215.2s
[Epoch 10/400] loss=1.3422  acc=54.86%  best=54.86%  time=235.2s
[Epoch 11/400] loss=1.2194  acc=59.29%  best=59.29%  time=255.5s
[Epoch 12/400] loss=1.1050  acc=60.55%  best=60.55%  time=275.4s
[Epoch 13/400] loss=1.