# ATML Assignment 3 — Task 3: Knowledge Distillation (Colab, T4 GPU)

Lean, single-notebook pipeline for CIFAR-100 with VGG-16/19 (teacher) and VGG-11 (student).

**Experiments covered**
1. Independent Student (SI)
2. Logit Matching (LM)
3. Label Smoothing (LS) baseline
4. Decoupled KD (DKD)
5. Hints KD (intermediate feature matching)
6. Contrastive Representation Distillation (CRD)
7. Color-invariance transfer (CRD with jittered teacher)
8. Larger teacher check (VGG-16 vs VGG-19 → LM)

**Artifacts saved** (CSV/PNG/CKPT):
- `ckpts/teacher_vgg16.pt`, `teacher_vgg19.pt`, `student_*.pt`
- `results/metrics_overall.csv`, `results/kl_alignment.csv`, `results/color_invariance.csv`
- `figures/acc_bar.png`, `figures/kl_bar.png`, `figures/gradcam_grid.png`, `figures/color_shift.png`
- `gradcam/IMGxxxx_METHOD.png`

**Notes**
- Keep epochs modest on Colab T4. Use AMP. Fix seeds.
- For Grad-CAM, we use `pytorch-grad-cam` if available; otherwise, skip gracefully.


In [None]:
# %%capture
# If running in Colab, uncomment installs as needed.
# !pip -q install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
# !pip -q install pytorch-grad-cam==1.4.8
# !pip -q install pandas matplotlib scikit-learn
print("If running in Colab, enable GPU (T4) and uncomment installs above if needed.")

In [None]:
import os, random, json, math, time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from datetime import datetime

USE_AMP = True
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ROOT = Path('/content/ATML_A3') if 'google.colab' in str(getattr(sys.modules.get('google'), '__file__', '')) else Path.cwd() / 'ATML_A3'
CKPTS = ROOT / 'ckpts'
FIGS  = ROOT / 'figures'
GRADS = ROOT / 'gradcam'
LOGS  = ROOT / 'logs'
RES   = ROOT / 'results'
for p in [CKPTS, FIGS, GRADS, LOGS, RES]:
    p.mkdir(parents=True, exist_ok=True)
print(f"Project root: {ROOT}")
print(f"Device: {DEVICE}")

In [None]:
def set_seed(seed=1337):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(1337)
print("Seeds fixed.")

## Data: CIFAR-100 (train/val loaders)

In [None]:
def get_cifar100(batch_size=128, num_workers=2, color_jitter=False):
    mean = (0.5071, 0.4867, 0.4408)
    std  = (0.2675, 0.2565, 0.2761)
    aug = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
    if color_jitter:
        aug.append(transforms.ColorJitter(0.4,0.4,0.4,0.1))
    train_tf = transforms.Compose(aug + [transforms.ToTensor(), transforms.Normalize(mean, std)])
    test_tf  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
    train_ds = datasets.CIFAR100(root=str(ROOT/'data'), train=True,  download=True, transform=train_tf)
    test_ds  = datasets.CIFAR100(root=str(ROOT/'data'), train=False, download=True, transform=test_tf)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

train_loader, val_loader = get_cifar100()
print("Data loaders ready.")

## Models: VGG-16/19 (Teacher), VGG-11 (Student)

In [None]:
def make_vgg(name='vgg16', num_classes=100, pretrained=False):
    if name == 'vgg16':
        net = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1 if pretrained else None)
    elif name == 'vgg19':
        net = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1 if pretrained else None)
    elif name == 'vgg11':
        net = models.vgg11(weights=None)
    else:
        raise ValueError('Unsupported VGG: ' + name)
    # replace classifier for CIFAR-100
    in_feats = net.classifier[-1].in_features
    net.classifier[-1] = nn.Linear(in_feats, num_classes)
    return net

teacher16 = make_vgg('vgg16', pretrained=True).to(DEVICE)
student11 = make_vgg('vgg11', pretrained=False).to(DEVICE)
print("Teacher VGG-16 & Student VGG-11 created.")

## Losses: CE, Label Smoothing, KD (LM), DKD (stub), Hints, CRD (minimal)

In [None]:
class LabelSmoothingCE(nn.Module):
    def __init__(self, eps=0.1):
        super().__init__()
        self.eps = eps
        self.log_softmax = nn.LogSoftmax(dim=1)
    def forward(self, logits, targets):
        n_classes = logits.size(1)
        logprobs = self.log_softmax(logits)
        with torch.no_grad():
            true_dist = torch.zeros_like(logprobs)
            true_dist.fill_(self.eps / (n_classes - 1))
            true_dist.scatter_(1, targets.unsqueeze(1), 1 - self.eps)
        return torch.mean(torch.sum(-true_dist * logprobs, dim=1))

def kd_loss_logits(student_logits, teacher_logits, T=4.0):
    # KL(T||S) with temperature scaling (standard LM)
    p = F.log_softmax(student_logits / T, dim=1)
    q = F.softmax(teacher_logits / T, dim=1)
    return F.kl_div(p, q, reduction='batchmean') * (T * T)

class DKDLoss(nn.Module):
    # Minimal DKD stub — you may refine alpha/beta schedules as needed
    def __init__(self, alpha=1.0, beta=8.0, T=4.0):
        super().__init__()
        self.alpha=alpha; self.beta=beta; self.T=T
    def forward(self, s_logits, t_logits, targets):
        # Decouple target vs non-target components (very compact form)
        T = self.T
        s = F.log_softmax(s_logits/T, dim=1)
        t = F.softmax(t_logits/T, dim=1)
        # target class
        one_hot = F.one_hot(targets, num_classes=s_logits.size(1)).float()
        pos_loss = F.kl_div((s*one_hot).sum(1, keepdim=True), (t*one_hot).sum(1, keepdim=True), reduction='batchmean')
        # non-target classes
        neg_loss = F.kl_div((s*(1-one_hot)), (t*(1-one_hot)), reduction='batchmean')
        return (self.alpha*pos_loss + self.beta*neg_loss) * (T*T)

class HintLoss(nn.Module):
    def __init__(self, proj_s: nn.Module, weight=1e-2):
        super().__init__()
        self.proj_s = proj_s
        self.weight = weight
        self.mse = nn.MSELoss()
    def forward(self, feat_s, feat_t):
        return self.weight * self.mse(self.proj_s(feat_s), feat_t)

class CRDLoss(nn.Module):
    # Lightweight contrastive loss over batch (no memory bank for simplicity)
    def __init__(self, dim=128, T=0.07):
        super().__init__()
        self.T=T
        self.proj_t = nn.Linear(dim, dim, bias=False)
        self.proj_s = nn.Linear(dim, dim, bias=False)
    def forward(self, z_s, z_t):
        z_s = F.normalize(self.proj_s(z_s), dim=1)
        z_t = F.normalize(self.proj_t(z_t), dim=1)
        logits = (z_s @ z_t.t()) / self.T
        labels = torch.arange(z_s.size(0), device=z_s.device)
        return F.cross_entropy(logits, labels)


## Train/Eval helpers

In [None]:
def accuracy_topk(logits, targets, topk=(1,)):
    maxk = max(topk)
    batch_size = targets.size(0)
    _, pred = logits.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(targets.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def evaluate(model, loader):
    model.eval()
    loss_sum=0.0; n=0; top1=0.0; top5=0.0
    ce = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x,y in loader:
            x=x.to(DEVICE); y=y.to(DEVICE)
            logits = model(x)
            loss = ce(logits, y)
            a1,a5 = accuracy_topk(logits, y, topk=(1,5))
            bs = x.size(0)
            loss_sum += loss.item()*bs
            n += bs
            top1 += a1.item()*bs/100.0
            top5 += a5.item()*bs/100.0
    return loss_sum/n, 100*top1/n, 100*top5/n

def train_ce(model, train_loader, val_loader, epochs=60, lr=0.1, weight_decay=5e-4):
    model.to(DEVICE)
    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
    ce = nn.CrossEntropyLoss()
    best=(1e9,0,0)
    for ep in range(1, epochs+1):
        model.train()
        for x,y in train_loader:
            x=x.to(DEVICE); y=y.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                logits = model(x)
                loss = ce(logits, y)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
        sched.step()
        vl, a1, a5 = evaluate(model, val_loader)
        if a1>best[1]: best=(vl,a1,a5)
        if ep%10==0 or ep==1:
            print(f"[CE] Epoch {ep}/{epochs} | val_loss={vl:.3f} top1={a1:.2f} top5={a5:.2f}")
    return best

def train_kd_lm(student, teacher, train_loader, val_loader, epochs=60, lr=0.1, alpha=0.5, T=4.0):
    student.to(DEVICE)
    teacher.eval().to(DEVICE)
    opt = torch.optim.SGD(student.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
    ce = nn.CrossEntropyLoss()
    best=(1e9,0,0)
    for ep in range(1, epochs+1):
        student.train()
        for x,y in train_loader:
            x=x.to(DEVICE); y=y.to(DEVICE)
            with torch.no_grad():
                t_logits = teacher(x)
            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                s_logits = student(x)
                loss = alpha*kd_loss_logits(s_logits, t_logits, T=T) + (1-alpha)*ce(s_logits, y)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
        sched.step()
        vl, a1, a5 = evaluate(student, val_loader)
        if a1>best[1]: best=(vl,a1,a5)
        if ep%10==0 or ep==1:
            print(f"[KD-LM] Epoch {ep}/{epochs} | val_loss={vl:.3f} top1={a1:.2f} top5={a5:.2f}")
    return best


## Probability alignment (mean KL(T∥S))

In [None]:
def mean_kl_teacher_student(teacher, student, loader, T=4.0, max_batches=50):
    teacher.eval(); student.eval()
    kl_sum=0.0; n=0
    with torch.no_grad():
        for b,(x,y) in enumerate(loader):
            if b>=max_batches: break
            x=x.to(DEVICE)
            t_logits = teacher(x)
            s_logits = student(x)
            kl = kd_loss_logits(s_logits, t_logits, T=T).item()
            kl_sum += kl
            n += 1
    return kl_sum/max(1,n)


## Hints KD (feature hooks)

In [None]:
def _get_conv_feature_module_vgg(model, stage_idx=3):
    # Grab the stage module (rough heuristic: features block slice)
    return model.features

def collect_features(module, input, output, storage: dict, key: str):
    storage[key] = output

def train_kd_hints(student, teacher, train_loader, val_loader, epochs=60, lr=0.1, hint_weight=1e-2):
    student.to(DEVICE); teacher.to(DEVICE).eval()
    # register simple hooks on the tail of features
    feats_s, feats_t = {}, {}
    h_s = _get_conv_feature_module_vgg(student).register_forward_hook(lambda m,i,o: collect_features(m,i,o,feats_s,'s'))
    h_t = _get_conv_feature_module_vgg(teacher).register_forward_hook(lambda m,i,o: collect_features(m,i,o,feats_t,'t'))
    # projection 1x1 (conv) to match channels if needed
    proj = nn.Conv2d(student.features[-1].out_channels if hasattr(student.features[-1],'out_channels') else 512, 512, kernel_size=1).to(DEVICE)
    hint_crit = HintLoss(proj, weight=hint_weight)
    ce = nn.CrossEntropyLoss()
    opt = torch.optim.SGD(list(student.parameters())+list(proj.parameters()), lr=lr, momentum=0.9, weight_decay=5e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
    best=(1e9,0,0)
    for ep in range(1, epochs+1):
        student.train()
        for x,y in train_loader:
            x=x.to(DEVICE); y=y.to(DEVICE)
            with torch.no_grad():
                _ = teacher(x)
                feat_t = feats_t.get('t')
            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                logits = student(x)
                feat_s = feats_s.get('s')
                h_loss = hint_crit(feat_s, feat_t)
                loss = ce(logits, y) + h_loss
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
        sched.step()
        vl, a1, a5 = evaluate(student, val_loader)
        if a1>best[1]: best=(vl,a1,a5)
        if ep%10==0 or ep==1:
            print(f"[Hints] Epoch {ep}/{epochs} | val_loss={vl:.3f} top1={a1:.2f} top5={a5:.2f}")
    h_s.remove(); h_t.remove()
    return best


## CRD (batch-only minimal contrastive)

In [None]:
class GlobalAvgPool(nn.Module):
    def forward(self, x):
        return F.adaptive_avg_pool2d(x, (1,1)).flatten(1)

def train_kd_crd(student, teacher, train_loader, val_loader, epochs=60, lr=0.1, rep_dim=128):
    student.to(DEVICE); teacher.to(DEVICE).eval()
    # Extract penultimate features via simple head (GAP over last conv)
    gap = GlobalAvgPool()
    crd = CRDLoss(dim=512 if 'vgg' in teacher.__class__.__name__.lower() else rep_dim)
    ce = nn.CrossEntropyLoss()
    opt = torch.optim.SGD(student.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
    best=(1e9,0,0)
    for ep in range(1, epochs+1):
        student.train()
        for x,y in train_loader:
            x=x.to(DEVICE); y=y.to(DEVICE)
            with torch.no_grad():
                t_logits = teacher(x)
                t_feat = gap(teacher.features(x)) if hasattr(teacher,'features') else t_logits
            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                s_logits = student(x)
                s_feat = gap(student.features(x)) if hasattr(student,'features') else s_logits
                loss = ce(s_logits, y) + crd(s_feat, t_feat)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
        sched.step()
        vl, a1, a5 = evaluate(student, val_loader)
        if a1>best[1]: best=(vl,a1,a5)
        if ep%10==0 or ep==1:
            print(f"[CRD] Epoch {ep}/{epochs} | val_loss={vl:.3f} top1={a1:.2f} top5={a5:.2f}")
    return best


## Training scripts: Teacher, SI, KD variants

In [None]:
def save_ckpt(model, path):
    torch.save({'state_dict': model.state_dict()}, path)

def load_ckpt(model, path):
    sd = torch.load(path, map_location=DEVICE)
    model.load_state_dict(sd['state_dict'])

def train_teacher_vgg16(epochs=60):
    model = make_vgg('vgg16', pretrained=True).to(DEVICE)
    best = train_ce(model, train_loader, val_loader, epochs=epochs, lr=0.1)
    save_ckpt(model, str(CKPTS/'teacher_vgg16.pt'))
    print("Saved teacher_vgg16.pt | best:", best)
    return model

def train_teacher_vgg19(epochs=60):
    model = make_vgg('vgg19', pretrained=True).to(DEVICE)
    best = train_ce(model, train_loader, val_loader, epochs=epochs, lr=0.1)
    save_ckpt(model, str(CKPTS/'teacher_vgg19.pt'))
    print("Saved teacher_vgg19.pt | best:", best)
    return model

def train_student_si(epochs=60):
    model = make_vgg('vgg11', pretrained=False).to(DEVICE)
    best = train_ce(model, train_loader, val_loader, epochs=epochs, lr=0.1)
    save_ckpt(model, str(CKPTS/'student_SI.pt'))
    print("Saved student_SI.pt | best:", best)
    return model

def distill_lm(teacher_path=None, epochs=60, alpha=0.5, T=4.0):
    t = make_vgg('vgg16', pretrained=False).to(DEVICE)
    if teacher_path is None: teacher_path = CKPTS/'teacher_vgg16.pt'
    load_ckpt(t, str(teacher_path))
    s = make_vgg('vgg11', pretrained=False).to(DEVICE)
    best = train_kd_lm(s, t, train_loader, val_loader, epochs=epochs, lr=0.1, alpha=alpha, T=T)
    save_ckpt(s, str(CKPTS/'student_LM.pt'))
    print("Saved student_LM.pt | best:", best)
    return s

def distill_hints(teacher_path=None, epochs=60):
    t = make_vgg('vgg16', pretrained=False).to(DEVICE)
    if teacher_path is None: teacher_path = CKPTS/'teacher_vgg16.pt'
    load_ckpt(t, str(teacher_path))
    s = make_vgg('vgg11', pretrained=False).to(DEVICE)
    best = train_kd_hints(s, t, train_loader, val_loader, epochs=epochs, lr=0.1)
    save_ckpt(s, str(CKPTS/'student_HINTS.pt'))
    print("Saved student_HINTS.pt | best:", best)
    return s

def distill_crd(teacher_path=None, epochs=60):
    t = make_vgg('vgg16', pretrained=False).to(DEVICE)
    if teacher_path is None: teacher_path = CKPTS/'teacher_vgg16.pt'
    load_ckpt(t, str(teacher_path))
    s = make_vgg('vgg11', pretrained=False).to(DEVICE)
    best = train_kd_crd(s, t, train_loader, val_loader, epochs=epochs, lr=0.1)
    save_ckpt(s, str(CKPTS/'student_CRD.pt'))
    print("Saved student_CRD.pt | best:", best)
    return s


## KL alignment, Color-invariance, Larger teacher

In [None]:
import pandas as pd

def kl_alignment_table(teacher_path, student_paths, names, T=4.0):
    t = make_vgg('vgg16', pretrained=False).to(DEVICE)
    load_ckpt(t, str(teacher_path))
    rows=[]
    for p,nm in zip(student_paths, names):
        s = make_vgg('vgg11', pretrained=False).to(DEVICE)
        load_ckpt(s, str(p))
        kl = mean_kl_teacher_student(t, s, val_loader, T=T, max_batches=50)
        rows.append({'method':nm, 'mean_kl':kl})
    df = pd.DataFrame(rows)
    path = RES/'kl_alignment.csv'
    df.to_csv(path, index=False)
    print('Saved', path)
    return df

def color_invariance_experiment(epochs_teacher_ft=10, epochs_student=30):
    # 1) finetune teacher16 with color jitter
    t = make_vgg('vgg16', pretrained=False).to(DEVICE)
    load_ckpt(t, str(CKPTS/'teacher_vgg16.pt'))
    jitter_loader, _ = get_cifar100(color_jitter=True)
    print('Finetuning teacher with color jitter...')
    train_ce(t, jitter_loader, val_loader, epochs=epochs_teacher_ft, lr=0.01)
    save_ckpt(t, str(CKPTS/'teacher_vgg16_colorjitter.pt'))
    # 2) CRD distill to student
    s = make_vgg('vgg11', pretrained=False).to(DEVICE)
    print('Distilling CRD from jittered teacher...')
    train_kd_crd(s, t, train_loader, val_loader, epochs=epochs_student, lr=0.1)
    save_ckpt(s, str(CKPTS/'student_CRD_color.pt'))
    # 3) Evaluate on clean vs jittered val
    _, val_loader_j = get_cifar100(color_jitter=True)
    clean = evaluate(s, val_loader)
    jitter= evaluate(s, val_loader_j)
    df = pd.DataFrame([{'method':'CRD_color', 'clean_top1':clean[1], 'jitter_top1':jitter[1], 'delta': jitter[1]-clean[1]}])
    path = RES/'color_invariance.csv'
    df.to_csv(path, index=False)
    print('Saved', path)
    return df

def larger_teacher_experiment(epochs=60):
    # Distill LM from VGG-16 vs VGG-19
    # Assumes teacher_vgg16.pt and teacher_vgg19.pt exist
    s16 = distill_lm(teacher_path=CKPTS/'teacher_vgg16.pt', epochs=epochs)
    s19 = distill_lm(teacher_path=CKPTS/'teacher_vgg19.pt', epochs=epochs)
    l1 = evaluate(s16, val_loader)
    l2 = evaluate(s19, val_loader)
    df = pd.DataFrame([
        {'teacher':'vgg16', 'student':'vgg11', 'top1': l1[1], 'top5': l1[2]},
        {'teacher':'vgg19', 'student':'vgg11', 'top1': l2[1], 'top5': l2[2]},
    ])
    path = RES/'larger_teacher.csv'
    df.to_csv(path, index=False)
    print('Saved', path)
    return df


## Grad-CAM utilities (optional; uses `pytorch-grad-cam` if present)

In [None]:
def gradcam_grid(models_dict, loader, n_images=3, save_path=FIGS/'gradcam_grid.png'):
    try:
        from pytorch_grad_cam import GradCAM
        from pytorch_grad_cam.utils.image import show_cam_on_image
        import matplotlib.pyplot as plt
    except Exception as e:
        print("pytorch-grad-cam not available; skipping Grad-CAM.")
        return
    # pick targets on the last conv layer for VGG
    def last_conv(model):
        layers = [m for m in model.features if isinstance(m, nn.Conv2d)]
        return layers[-1]
    model_targets = {name: last_conv(m) for name,m in models_dict.items()}
    # sample a few images
    xs=[]; ys=[]
    for i,(x,y) in enumerate(loader):
        xs.append(x); ys.append(y)
        if len(xs)*x.size(0)>=n_images: break
    x = torch.cat(xs, dim=0)[:n_images].to(DEVICE)
    # naive unnormalize to [0,1] for visualization
    img_np = x.detach().cpu().float().clamp(0,1).permute(0,2,3,1).numpy()
    import matplotlib.pyplot as plt
    cols = len(models_dict)
    fig, ax = plt.subplots(n_images, cols, figsize=(3*cols, 3*n_images))
    if n_images==1: ax = np.expand_dims(ax, 0)
    for c,(name,model) in enumerate(models_dict.items()):
        target_layer = model_targets[name]
        cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=torch.cuda.is_available())
        grayscale_cams = cam(input_tensor=x)
        for r in range(n_images):
            vis = show_cam_on_image(img_np[r], grayscale_cams[r], use_rgb=True)
            ax[r,c].imshow(vis); ax[r,c].axis('off')
            if r==0: ax[r,c].set_title(name)
    plt.tight_layout(); plt.savefig(save_path, dpi=160); plt.close()
    print('Saved', save_path)


## Plotting (accuracy bars, KL bars)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

def plot_acc_bar(csv_path=RES/'metrics_overall.csv', out=FIGS/'acc_bar.png'):
    if not Path(csv_path).exists():
        print('metrics_overall.csv not found; skipping.')
        return
    df = pd.read_csv(csv_path)
    methods = df['method']
    acc = df['top1_acc']
    plt.figure(figsize=(6,4))
    plt.bar(methods, acc)
    plt.xticks(rotation=20)
    plt.ylabel('Top-1 Accuracy')
    plt.title('KD Methods — CIFAR-100')
    plt.tight_layout(); plt.savefig(out, dpi=160); plt.close()
    print('Saved', out)

def plot_kl_bar(csv_path=RES/'kl_alignment.csv', out=FIGS/'kl_bar.png'):
    if not Path(csv_path).exists():
        print('kl_alignment.csv not found; skipping.')
        return
    df = pd.read_csv(csv_path)
    methods = df['method']
    vals = df['mean_kl']
    plt.figure(figsize=(6,4))
    plt.bar(methods, vals)
    plt.xticks(rotation=20)
    plt.ylabel('Mean KL(T||S)')
    plt.title('Distribution Alignment (lower is better)')
    plt.tight_layout(); plt.savefig(out, dpi=160); plt.close()
    print('Saved', out)


## Orchestration (run cells in order in Colab)

Use these *after* training/finetuning steps to produce CSVs/PNGs:
- `kl_alignment_table()`
- `plot_acc_bar()`, `plot_kl_bar()`
- `gradcam_grid(...)`
- `color_invariance_experiment()`
- `larger_teacher_experiment()`