## 1) Imports and seeds

In [6]:
# Cell 1 — Imports / Config / Seeds
import os, math, json, copy, random
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt

def set_seed(s=42):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.benchmark = True
set_seed(42)

CFG = dict(
    device = 'cuda' if torch.cuda.is_available() else 'cpu',
    save_dir = './ckpts_task3',
    epochs = 50, batch_size = 256, num_workers = 4,
    lr = 0.1, weight_decay = 5e-4, warmup_epochs = 5,
    ema_decay = 0.999, label_smoothing = 0.1,
    kd_alpha_ce = 0.6, kd_T = 4.0, dkd_tau = 0.5,
    use_randaugment = True, download_if_missing = True
)
os.makedirs(CFG["save_dir"], exist_ok=True)

CIFAR100_MEAN=(0.5071,0.4867,0.4408); CIFAR100_STD=(0.2675,0.2565,0.2761)


## 2) Data: CIFAR-100

In [7]:
# Cell 2 — Data
from pathlib import Path
def tfs(train=True, randaug=True, color=False):
    aug = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
    if color: aug.append(transforms.ColorJitter(0.4,0.4,0.4,0.1))
    if randaug and train:
        try:
            from torchvision.transforms.autoaugment import RandAugment
            aug.append(RandAugment(num_ops=2, magnitude=10))
        except: pass
    base = [transforms.ToTensor(), transforms.Normalize(CIFAR100_MEAN, CIFAR100_STD)]
    return transforms.Compose((aug if train else []) + base)

def get_loaders(train_tf, val_tf, root='./data', download=True):
    tr = datasets.CIFAR100(root, train=True, transform=train_tf, download=download)
    va = datasets.CIFAR100(root, train=False, transform=val_tf, download=download)
    return (DataLoader(tr, batch_size=CFG["batch_size"], shuffle=True, num_workers=CFG["num_workers"], pin_memory=True),
            DataLoader(va, batch_size=CFG["batch_size"], shuffle=False, num_workers=CFG["num_workers"], pin_memory=True))

train_loader, val_loader = get_loaders(tfs(train=True, randaug=CFG["use_randaugment"]), tfs(train=False))


## 3) Models: VGG-16/19 (teacher), VGG-11 (student)

In [8]:
# Cell 3 — Models
def build_vgg11_student(nc=100):
    m = models.vgg11_bn(weights=None)
    m.classifier[-1] = nn.Linear(m.classifier[-1].in_features, nc)
    return m

def build_vgg_teacher(depth='vgg16_bn', nc=100, pretrained=True):
    assert depth in ['vgg16_bn','vgg19_bn']
    w = (models.VGG16_BN_Weights.IMAGENET1K_V1 if depth=='vgg16_bn' else models.VGG19_BN_Weights.IMAGENET1K_V1) if pretrained else None
    m = getattr(models, depth)(weights=w)
    m.classifier[-1] = nn.Linear(m.classifier[-1].in_features, nc)
    return m

@torch.no_grad()
def forward_logits_penult(m, x):
    # returns logits and penultimate (before last Linear)
    feats = m.features(x); feats = torch.flatten(feats, 1)
    pen = None
    h = feats
    for i, layer in enumerate(m.classifier):
        h = layer(h)
        if i == len(m.classifier)-2: pen = h.clone()
    return h, pen


## 4) Losses: LS, LM(KL), DKD, Hints, CRD

In [9]:
# Cell 4 — Losses
class LSCELoss(nn.Module):
    def __init__(self, eps=0.1): super().__init__(); self.eps=eps
    def forward(self, logits, y):
        n = logits.size(-1); logp = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            t = torch.zeros_like(logp).fill_(self.eps/(n-1))
            t.scatter_(1, y.unsqueeze(1), 1-self.eps)
        return torch.mean(torch.sum(-t*logp, dim=-1))

class DistillKLLoss(nn.Module):
    def __init__(self, T=4.0): super().__init__(); self.T=T; self.kl=nn.KLDivLoss(reduction='batchmean')
    def forward(self, s, t):
        return self.kl(F.log_softmax(s/self.T,dim=1), F.softmax(t/self.T,dim=1))*(self.T**2)

class DKDLoss(nn.Module):
    def __init__(self, alpha_ce=0.6, tau=0.5, T=4.0, eps_ls=0.1):
        super().__init__(); self.alpha_ce=alpha_ce; self.tau=tau; self.T=T; self.lsce=LSCELoss(eps_ls); self.kl=nn.KLDivLoss(reduction='batchmean')
    def forward(self, s, t, y):
        loss_ce = self.lsce(s, y)
        ps = F.log_softmax(s/self.T, dim=1); pt = F.softmax(t/self.T, dim=1)
        mask_t = torch.zeros_like(pt).scatter_(1, y.view(-1,1), 1.0); mask_nt = 1-mask_t
        ps_t = (ps*mask_t).sum(1, keepdim=True); pt_t = (pt*mask_t).sum(1, keepdim=True)
        ps_nt = F.log_softmax(ps + (mask_t*(-1e9)), dim=1)
        pt_nt = (pt*mask_nt); pt_nt = pt_nt / pt_nt.sum(1, keepdim=True).clamp_min(1e-12)
        loss_t = F.kl_div(ps_t, pt_t, reduction='batchmean', log_target=False)
        loss_nt = self.kl(ps_nt, pt_nt)
        return self.alpha_ce*loss_ce + (1-self.alpha_ce)*(self.T**2)*(self.tau*loss_t + (1-self.tau)*loss_nt)

class HintRegressor(nn.Module):
    def __init__(self, in_ch, out_ch): super().__init__(); self.proj=nn.Conv2d(in_ch,out_ch,1,bias=False)
    def forward(self, x): return self.proj(x)

class HintsLoss(nn.Module):
    def __init__(self, adapter, w=1.0): super().__init__(); self.adapter=adapter; self.w=w; self.mse=nn.MSELoss()
    def forward(self, feat_s, feat_t): return self.w*self.mse(self.adapter(feat_s), feat_t.detach())

class CRDHead(nn.Module):
    def __init__(self, in_dim, out_dim=128):
        super().__init__(); self.proj=nn.Sequential(nn.Linear(in_dim,512), nn.ReLU(True), nn.Linear(512,out_dim))
    def forward(self, x): return F.normalize(self.proj(x), dim=1)

def crd_loss(zs, zt, temp=0.07):
    logits = zs @ zt.t() / temp
    labels = torch.arange(zs.size(0), device=zs.device)
    return F.cross_entropy(logits, labels)


## 5) Training utilities: accuracy, EMA, warmup+cosine, evaluate

In [10]:
# Cell 5 — Utils (acc/EMA/scheduler/eval)
def accuracy(logits, y, topk=(1,)):
    maxk=max(topk); b=y.size(0)
    _, pred = logits.topk(maxk,1,True,True); pred=pred.t()
    correct = pred.eq(y.view(1,-1).expand_as(pred))
    res=[]
    for k in topk:
        res.append(correct[:k].reshape(-1).float().sum().mul_(100.0/b).item())
    return res

class EMA:
    def __init__(self, m, decay=0.999):
        self.decay=decay; self.shadow={n:p.detach().clone() for n,p in m.named_parameters() if p.requires_grad}
    @torch.no_grad()
    def update(self, m):
        for n,p in m.named_parameters():
            if p.requires_grad: self.shadow[n]=self.decay*self.shadow[n]+(1-self.decay)*p.detach()
    @torch.no_grad()
    def copy_to(self, m):
        for n,p in m.named_parameters():
            if p.requires_grad: p.data.copy_(self.shadow[n])

class WarmupCosine:
    def __init__(self, opt, base_lr, warm, maxe): self.opt=opt; self.base=base_lr; self.warm=max(1,warm); self.maxe=maxe; self.e=0
    def step(self):
        self.e+=1
        for pg in self.opt.param_groups:
            if self.e<=self.warm: lr=self.base*self.e/self.warm
            else:
                t=(self.e-self.warm)/max(1,(self.maxe-self.warm)); lr=0.5*self.base*(1+math.cos(math.pi*t))
            pg['lr']=lr

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval(); ce=nn.CrossEntropyLoss()
    losses=[]; t1=[]; t5=[]
    for x,y in loader:
        x,y=x.to(device),y.to(device)
        out=model(x); losses.append(ce(out,y).item())
        a1,a5=accuracy(out,y,(1,5)); t1.append(a1); t5.append(a5)
    return dict(val_loss=float(np.mean(losses)), top1=float(np.mean(t1)), top5=float(np.mean(t5)))


## 6) Core train loop that supports: SI (LS), LM, DKD, Hints, CRD

In [11]:
# Cell 6 — Train loop
from torch.cuda.amp import GradScaler, autocast

def train_model(mode, student, train_loader, val_loader, device,
                teacher=None, hints_cfg=None, crd_cfg=None, run_name="run"):
    student.to(device)
    if teacher is not None:
        teacher.to(device); teacher.eval()
        for p in teacher.parameters(): p.requires_grad=False

    opt = torch.optim.SGD(student.parameters(), lr=CFG["lr"], momentum=0.9, weight_decay=CFG["weight_decay"], nesterov=True)
    sch = WarmupCosine(opt, CFG["lr"], CFG["warmup_epochs"], CFG["epochs"])
    scaler=GradScaler(); ema=EMA(student, CFG["ema_decay"])

    ls = LSCELoss(CFG["label_smoothing"])
    kd = DistillKLLoss(CFG["kd_T"])
    dkd = DKDLoss(CFG["kd_alpha_ce"], CFG["dkd_tau"], CFG["kd_T"], CFG["label_smoothing"])

    # hints
    hint_loss=None; feat_s=None; feat_t=None; handles=[]
    if mode=="hints":
        # pick layers that match (verify names on your VGGs)
        s_name=hints_cfg["s_layer"]; t_name=hints_cfg["t_layer"]
        s_feats={}; t_feats={}
        def hook_s(m,i,o): s_feats["f"]=o
        def hook_t(m,i,o): t_feats["f"]=o
        handles.append(dict(student.named_modules())[s_name].register_forward_hook(hook_s))
        handles.append(dict(teacher.named_modules())[t_name].register_forward_hook(hook_t))
        hint_loss=HintsLoss(hints_cfg["adapter"], hints_cfg.get("w",1.0))
        hint_loss.adapter.to(device) 
        feat_s=s_feats; feat_t=t_feats

    # crd
    head_s=None; head_t=None
    if mode=="crd":
        head_s=CRDHead(crd_cfg["penult_dim"]).to(device)
        head_t=CRDHead(crd_cfg["penult_dim"]).to(device)
        for p in head_t.parameters(): p.requires_grad=False

    hist={"epoch":[], "train_loss":[], "val_loss":[], "top1":[], "top5":[]}
    best={"metric":-1, "epoch":-1, "sd":None}

    for ep in range(1, CFG["epochs"]+1):
        student.train(); train_losses=[]
        for xb,yb in train_loader:
            xb,yb=xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            with autocast():
                if mode=="si":
                    out = student(xb); loss = ls(out,yb)
                elif mode=="lm":
                    out = student(xb); tout = teacher(xb)
                    loss = CFG["kd_alpha_ce"]*ls(out,yb) + (1-CFG["kd_alpha_ce"])*kd(out,tout)
                elif mode=="dkd":
                    out = student(xb); tout = teacher(xb)
                    loss = dkd(out, tout, yb)
                elif mode=="hints":
                    out = student(xb); _ = teacher(xb)
                    loss = ls(out,yb) + hint_loss(feat_s["f"], feat_t["f"])
                elif mode=="crd":
                    logits_s, pen_s = forward_logits_penult(student, xb)
                    with torch.no_grad():
                        logits_t, pen_t = forward_logits_penult(teacher, xb)
                    loss = (CFG["kd_alpha_ce"]*ls(logits_s,yb) +
                            (1-CFG["kd_alpha_ce"])*kd(logits_s,logits_t) +
                            crd_loss(head_s(pen_s), head_t(pen_t)))
                else:
                    raise ValueError("bad mode")
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
            ema.update(student)
            train_losses.append(loss.item())
        sch.step()

        # eval with EMA
        eval_m = copy.deepcopy(student).to(device); ema.copy_to(eval_m)
        mtr = evaluate(eval_m, val_loader, device)
        hist["epoch"].append(ep); hist["train_loss"].append(float(np.mean(train_losses)))
        hist["val_loss"].append(mtr["val_loss"]); hist["top1"].append(mtr["top1"]); hist["top5"].append(mtr["top5"])
        if mtr["top1"]>best["metric"]:
            best.update(metric=mtr["top1"], epoch=ep, sd=copy.deepcopy(eval_m.state_dict()))
            torch.save(best["sd"], os.path.join(CFG["save_dir"], f"{run_name}_best.pt"))
        if ep in [1,10,20,30,40,50,CFG["epochs"]]:
            print(f"[{run_name}] {ep}/{CFG['epochs']} val_loss={mtr['val_loss']:.3f} top1={mtr['top1']:.2f} top5={mtr['top5']:.2f}")

    with open(os.path.join(CFG["save_dir"], f"{run_name}_history.json"),"w") as f: json.dump(hist,f)
    return hist, best


## 7) Train teachers (VGG-16 std, VGG-19 std, VGG-16-Color)

In [12]:
# Cell 7 — Train teachers
teacher16 = build_vgg_teacher('vgg16_bn', pretrained=True).to(CFG["device"])
teacher19 = build_vgg_teacher('vgg19_bn', pretrained=True).to(CFG["device"])

tl_tr, tl_va = get_loaders(tfs(train=True, randaug=CFG["use_randaugment"]), tfs(False))
hist_T16, best_T16 = train_model("si", teacher16, tl_tr, tl_va, CFG["device"], run_name="T16_finetune")
hist_T19, best_T19 = train_model("si", teacher19, tl_tr, tl_va, CFG["device"], run_name="T19_finetune")

# Color-invariance teacher (VGG-16 with ColorJitter)
tl_tr_color, tl_va_std = get_loaders(tfs(train=True, randaug=False, color=True), tfs(False))
teacher16_color = build_vgg_teacher('vgg16_bn', pretrained=True).to(CFG["device"])
hist_T16C, best_T16C = train_model("si", teacher16_color, tl_tr_color, tl_va_std, CFG["device"], run_name="T16_color_finetune")


Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth
100%|██████████| 528M/528M [00:07<00:00, 72.6MB/s] 
Downloading: "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth" to /root/.cache/torch/hub/checkpoints/vgg19_bn-c79401a0.pth
100%|██████████| 548M/548M [00:06<00:00, 88.5MB/s] 
  scaler=GradScaler(); ema=EMA(student, CFG["ema_decay"])
  with autocast():


[T16_finetune] 1/50 val_loss=4.618 top1=0.98 top5=4.96
[T16_finetune] 10/50 val_loss=4.052 top1=11.26 top5=27.55
[T16_finetune] 20/50 val_loss=1.502 top1=61.74 top5=86.02
[T16_finetune] 30/50 val_loss=1.349 top1=66.32 top5=88.88
[T16_finetune] 40/50 val_loss=1.382 top1=67.42 top5=88.98
[T16_finetune] 50/50 val_loss=1.094 top1=73.33 top5=92.78
[T19_finetune] 1/50 val_loss=4.618 top1=0.98 top5=5.37
[T19_finetune] 10/50 val_loss=4.645 top1=4.72 top5=13.93
[T19_finetune] 20/50 val_loss=1.560 top1=59.06 top5=85.36
[T19_finetune] 30/50 val_loss=1.628 top1=61.45 top5=85.27
[T19_finetune] 40/50 val_loss=1.471 top1=65.39 top5=87.94
[T19_finetune] 50/50 val_loss=1.145 top1=72.12 top5=91.82
[T16_color_finetune] 1/50 val_loss=4.615 top1=1.17 top5=5.15
[T16_color_finetune] 10/50 val_loss=3.967 top1=14.56 top5=37.74
[T16_color_finetune] 20/50 val_loss=1.453 top1=63.57 top5=87.48
[T16_color_finetune] 30/50 val_loss=1.514 top1=64.38 top5=86.68
[T16_color_finetune] 40/50 val_loss=1.536 top1=65.97 top5=

## 8) Train students: SI(LS), LM, DKD, Hints, CRD, LM from VGG-19, CRD from color teacher

In [13]:
# Cell 8 — Train students
SI = build_vgg11_student().to(CFG["device"])
hist_SI, best_SI = train_model("si", SI, train_loader, val_loader, CFG["device"], run_name="S_VGG11_SI")

# LM from VGG-16
S_LM = build_vgg11_student().to(CFG["device"]); T16=build_vgg_teacher('vgg16_bn',pretrained=False)
T16.load_state_dict(torch.load(os.path.join(CFG["save_dir"],"T16_finetune_best.pt"), map_location=CFG["device"]))
hist_LM, best_LM = train_model("lm", S_LM, train_loader, val_loader, CFG["device"], teacher=T16, run_name="S_VGG11_LM_T16")

# DKD from VGG-16
S_DKD = build_vgg11_student().to(CFG["device"])
hist_DKD, best_DKD = train_model("dkd", S_DKD, train_loader, val_loader, CFG["device"], teacher=T16, run_name="S_VGG11_DKD_T16")




  scaler=GradScaler(); ema=EMA(student, CFG["ema_decay"])
  with autocast():


[S_VGG11_SI] 1/50 val_loss=4.611 top1=0.98 top5=4.88
[S_VGG11_SI] 10/50 val_loss=60.434 top1=1.07 top5=5.34
[S_VGG11_SI] 20/50 val_loss=5086.112 top1=1.11 top5=5.25
[S_VGG11_SI] 30/50 val_loss=61.910 top1=0.98 top5=6.50
[S_VGG11_SI] 40/50 val_loss=2.785 top1=30.62 top5=62.83
[S_VGG11_SI] 50/50 val_loss=1.642 top1=56.69 top5=83.39
[S_VGG11_LM_T16] 1/50 val_loss=4.606 top1=0.98 top5=4.72
[S_VGG11_LM_T16] 10/50 val_loss=4.612 top1=0.98 top5=5.63
[S_VGG11_LM_T16] 20/50 val_loss=21.199 top1=2.38 top5=12.51
[S_VGG11_LM_T16] 30/50 val_loss=1.994 top1=54.70 top5=81.88
[S_VGG11_LM_T16] 40/50 val_loss=1.569 top1=64.39 top5=87.88
[S_VGG11_LM_T16] 50/50 val_loss=1.259 top1=69.41 top5=90.70
[S_VGG11_DKD_T16] 1/50 val_loss=4.607 top1=1.12 top5=5.03
[S_VGG11_DKD_T16] 10/50 val_loss=6.137 top1=0.99 top5=5.11
[S_VGG11_DKD_T16] 20/50 val_loss=4.701 top1=0.88 top5=5.05
[S_VGG11_DKD_T16] 30/50 val_loss=4.605 top1=0.86 top5=4.87
[S_VGG11_DKD_T16] 40/50 val_loss=4.605 top1=0.98 top5=4.88
[S_VGG11_DKD_T16] 5

In [14]:
# Hints (choose layers; verify names by printing model.named_modules())
# Common choice: teacher features.28 vs student features.20 (adjust if mismatched)
adapter = HintRegressor(in_ch=512, out_ch=512)
hcfg = dict(s_layer="features.20", t_layer="features.28", adapter=adapter, w=1.0)
S_HINT = build_vgg11_student().to(CFG["device"])
hist_HINT, best_HINT = train_model("hints", S_HINT, train_loader, val_loader, CFG["device"], teacher=T16, hints_cfg=hcfg, run_name="S_VGG11_HINTS_T16")

# CRD from VGG-16 (penultimate dim is 4096 for VGGs with BN)
S_CRD = build_vgg11_student().to(CFG["device"])
hist_CRD, best_CRD = train_model("crd", S_CRD, train_loader, val_loader, CFG["device"], teacher=T16,
                                 crd_cfg=dict(penult_dim=4096), run_name="S_VGG11_CRD_T16")

# LM from VGG-19
T19=build_vgg_teacher('vgg19_bn',pretrained=False)
T19.load_state_dict(torch.load(os.path.join(CFG["save_dir"],"T19_finetune_best.pt"), map_location=CFG["device"]))
S_LM19 = build_vgg11_student().to(CFG["device"])
hist_LM19, best_LM19 = train_model("lm", S_LM19, train_loader, val_loader, CFG["device"], teacher=T19, run_name="S_VGG11_LM_T19")

# CRD from Color teacher
T16C=build_vgg_teacher('vgg16_bn',pretrained=False)
T16C.load_state_dict(torch.load(os.path.join(CFG["save_dir"],"T16_color_finetune_best.pt"), map_location=CFG["device"]))
S_CRD_color = build_vgg11_student().to(CFG["device"])
hist_CRDcol, best_CRDcol = train_model("crd", S_CRD_color, train_loader, val_loader, CFG["device"], teacher=T16C,
                                       crd_cfg=dict(penult_dim=4096), run_name="S_VGG11_CRD_T16Color")

  scaler=GradScaler(); ema=EMA(student, CFG["ema_decay"])
  with autocast():


RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.FloatTensor) should be the same

# 9) Summary table (top-1/top-5/loss @ best)

In [None]:
# Cell 9 — Summary table
import pandas as pd, json, glob
def load_hist(name):
    with open(os.path.join(CFG["save_dir"], f"{name}_history.json")) as f: return json.load(f)
def best_row(tag, hist_path, best_ckpt):
    h = load_hist(hist_path); e = int(torch.load(os.path.join(CFG["save_dir"], f"{best_ckpt}_best.pt"), map_location='cpu').get('epoch', 0) or max(h["epoch"]))
    idx = h["epoch"].index(e)
    return dict(run=tag, best_epoch=e, top1=round(h["top1"][idx],2), top5=round(h["top5"][idx],2), val_loss=round(h["val_loss"][idx],4))

rows = [
    best_row("SI", "S_VGG11_SI", "S_VGG11_SI"),
    best_row("LM_T16", "S_VGG11_LM_T16", "S_VGG11_LM_T16"),
    best_row("DKD_T16", "S_VGG11_DKD_T16", "S_VGG11_DKD_T16"),
    best_row("HINTS_T16", "S_VGG11_HINTS_T16", "S_VGG11_HINTS_T16"),
    best_row("CRD_T16", "S_VGG11_CRD_T16", "S_VGG11_CRD_T16"),
    best_row("LM_T19", "S_VGG11_LM_T19", "S_VGG11_LM_T19"),
    best_row("CRD_T16Color", "S_VGG11_CRD_T16Color", "S_VGG11_CRD_T16Color"),
]
df = pd.DataFrame(rows).sort_values("top1", ascending=False)
print(df.to_string(index=False))


## 10) Curves (loss / top-1 / top-5)

In [None]:
# Cell 10 — Curves
def plot_series(names):
    plt.figure(figsize=(8,5))
    for n in names:
        h = load_hist(n)
        plt.plot(h["epoch"], h["top1"], label=f"{n}-top1")
    plt.xlabel("Epoch"); plt.ylabel("Top-1 %"); plt.grid(True); plt.legend(); plt.show()

    plt.figure(figsize=(8,5))
    for n in names:
        h = load_hist(n)
        plt.plot(h["epoch"], h["val_loss"], label=f"{n}-loss")
    plt.xlabel("Epoch"); plt.ylabel("Val Loss"); plt.grid(True); plt.legend(); plt.show()

plot_series(["S_VGG11_SI","S_VGG11_LM_T16","S_VGG11_DKD_T16","S_VGG11_HINTS_T16","S_VGG11_CRD_T16"])


## 11) Probability-distribution alignment (KL/JS) — T vs SI/SD*

In [None]:
# Cell 11 — Probability distribution divergences
@torch.no_grad()
def dist_div(teacher, students: dict, loader, T=4.0, metric="kl", batches=20, device=CFG["device"]):
    teacher.eval(); [m.eval() for m in students.values()]
    vals={k:[] for k in students}
    c=0
    for x,_ in loader:
        x=x.to(device)
        pt = F.softmax(teacher(x)/T, dim=1)
        for name, s in students.items():
            ps = F.softmax(s(x)/T, dim=1)
            if metric=="kl":
                v = F.kl_div(ps.log(), pt, reduction='batchmean').item()
            else:
                m = 0.5*(pt+ps)
                v = 0.5*F.kl_div(ps.log(), m, reduction='batchmean').item() + 0.5*F.kl_div(pt.log(), m, reduction='batchmean').item()
            vals[name].append(v)
        c+=1
        if c>=batches: break
    return {k: float(np.mean(v)) for k,v in vals.items()}

# Build models from best checkpoints
def load_student(tag):
    m = build_vgg11_student().to(CFG["device"])
    m.load_state_dict(torch.load(os.path.join(CFG["save_dir"], f"{tag}_best.pt"), map_location=CFG["device"]))
    return m.eval()

T16_eval = build_vgg_teacher('vgg16_bn',pretrained=False).to(CFG["device"])
T16_eval.load_state_dict(torch.load(os.path.join(CFG["save_dir"], "T16_finetune_best.pt"), map_location=CFG["device"])); T16_eval.eval()

students = {
    "SI": load_student("S_VGG11_SI"),
    "LM": load_student("S_VGG11_LM_T16"),
    "DKD": load_student("S_VGG11_DKD_T16"),
    "HINTS": load_student("S_VGG11_HINTS_T16"),
    "CRD": load_student("S_VGG11_CRD_T16"),
}
print("KL(T16 || S*):", dist_div(T16_eval, students, val_loader, metric="kl"))
print("JS(T16 || S*):", dist_div(T16_eval, students, val_loader, metric="js"))


## 12) Grad-CAM + similarity to teacher

In [None]:
# Cell 12 — Grad-CAM similarity
class GradCAM:
    def __init__(self, model, layer="features.42"):
        self.m=model.eval(); self.a=None; self.g=None
        mod=dict([*model.named_modules()])[layer]
        mod.register_forward_hook(lambda m,i,o: setattr(self,'a',o.detach()))
        mod.register_full_backward_hook(lambda m,gi,go: setattr(self,'g',go[0].detach()))
    def __call__(self,x, idx=None):
        self.m.zero_grad(set_to_none=True)
        out=self.m(x)
        if idx is None: idx = out.argmax(1)
        out[torch.arange(out.size(0)), idx].sum().backward()
        w=self.g.mean((2,3), keepdim=True)
        cam=(w*self.a).sum(1,keepdim=True); cam=F.relu(cam)
        cam=cam/(cam.amax((2,3),keepdim=True).clamp_min(1e-6))
        return cam

def cam_cos(camA, camB):
    A=camA.flatten(1); B=camB.flatten(1)
    return F.cosine_similarity(F.normalize(A,1), F.normalize(B,1), dim=1).mean().item()

# Build cams
cam_T = GradCAM(T16_eval, layer="features.42")  # adjust if needed
cams_S = {k: GradCAM(v, layer="features.42") for k,v in students.items()}

# sample a small batch
xb,_ = next(iter(val_loader))
xb = xb.to(CFG["device"])[:32]
cam_t = cam_T(xb)
for name, cammer in cams_S.items():
    cam_s = cammer(xb)
    print(name, "CAM cosine vs Teacher:", cam_cos(cam_t, cam_s))


## 13) Color-invariance eval (jittered validation)

In [None]:
# Cell 13 — Color invariance evaluation
val_jitter = transforms.Compose([transforms.ColorJitter(0.4,0.4,0.4,0.1), transforms.ToTensor(), transforms.Normalize(CIFAR100_MEAN,CIFAR100_STD)])
_, val_loader_j = get_loaders(tfs(True), val_jitter)

S_CRD_col = load_student("S_VGG11_CRD_T16Color")
S_SI = load_student("S_VGG11_SI")
def eval_on(loader, model): return evaluate(model, loader, CFG["device"])["top1"]
print("Top-1 on jittered val — SI:", eval_on(val_loader_j, S_SI))
print("Top-1 on jittered val — CRD(Color teacher):", eval_on(val_loader_j, S_CRD_col))


## 14) Confusion matrix + per-class + reliability/ECE (optional but nice in write-up)

In [None]:
# Cell 14 — (Optional) CM / per-class / ECE
from sklearn.metrics import confusion_matrix
@torch.no_grad()
def preds_targets(m, loader):
    m.eval(); P=[]; Y=[]
    for x,y in loader:
        x=x.to(CFG["device"]); logits=m(x)
        P.append(logits.argmax(1).cpu().numpy()); Y.append(y.numpy())
    return np.concatenate(P), np.concatenate(Y)

p,y = preds_targets(load_student("S_VGG11_LM_T16"), val_loader)
cm = confusion_matrix(y,p); cls_acc = (cm.diagonal()/cm.sum(axis=1))
print("Mean per-class acc:", cls_acc.mean()*100)
# Reliability: compute confidences vs accuracy (ECE) if you want.
