<a href="https://colab.research.google.com/github/p5149247263/tutorials/blob/master/knowledge_distillation_research.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Knowledge Distillation: A Research-Style, Reproducible Notebook

**Author:** Your Name  
**Last Updated:** {{AUTO}}  

## Abstract
This notebook demonstrates end-to-end **knowledge distillation (KD)** for text classification, following the core ideas in Hinton et al. (2015) and subsequent works. We:
1) Train (or load) a strong **teacher** model,  
2) Use it to produce **soft targets** (logits) on the training set,  
3) Train a smaller **student** model using a weighted combination of **hard-label cross-entropy** and **soft-label KL divergence** with **temperature scaling**,  
4) Evaluate accuracy, calibration (ECE), and perform **temperature / α sweeps**, and  
5) Provide ablations and plots common to research papers.

This notebook is designed to be **self-contained** and easy to adapt to new datasets and models.



## References
- Hinton, G., Vinyals, O., & Dean, J. (2015). *Distilling the Knowledge in a Neural Network*. arXiv:1503.02531.  
- Buciluǎ, C., Caruana, R., & Niculescu-Mizil, A. (2006). *Model Compression*. KDD.  
- Müller, R., Kornblith, S., & Hinton, G. (2019). *When Does Label Smoothing Help?*. NeurIPS.  
- Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. (2017). *On Calibration of Modern Neural Networks*. ICML.



## 1. Background & Objective

Given an input \(x\) and ground-truth class \(y \in \{1,\dots,K\}\), a **teacher** network produces logits \(\mathbf{z}_t(x) \in \mathbb{R}^K\).
A **student** network produces logits \(\mathbf{z}_s(x)\).

We define **temperature-scaled softmax**:
\[
p_t^{(T)}(k \mid x) = \mathrm{softmax}(\mathbf{z}_t/T)_k = \frac{\exp(z_{t,k}/T)}{\sum_{j=1}^K \exp(z_{t,j}/T)} , \quad
p_s^{(T)}(k \mid x) = \mathrm{softmax}(\mathbf{z}_s/T)_k.
\]

The **distillation loss** combines **hard labels** with **soft targets**:

\[
\mathcal{L} = (1-\alpha)\underbrace{\mathrm{CE}\big(y, \mathrm{softmax}(\mathbf{z}_s)\big)}_{\text{hard-label cross-entropy}}
\;+\;
\alpha T^2 \underbrace{\mathrm{KL}\!\left(p_t^{(T)}(\cdot \mid x)\; \big\|\; p_s^{(T)}(\cdot \mid x)\right)}_{\text{soft-label KL divergence}}.
\]

- \( \alpha \in [0,1] \) trades off hard vs. soft supervision.  
- \( T \ge 1 \) increases class entropy to reveal "**dark knowledge**" (class similarities) in the teacher.



## 2. Setup

> **Note:** This notebook relies on Hugging Face `transformers` and `datasets`. Uncomment the `pip` cells if needed.


In [1]:

# !pip install -U transformers datasets accelerate evaluate scikit-learn matplotlib numpy torch --quiet
# Optional (if you want pretty progress bars/logging):
# !pip install -U rich tqdm --quiet



### 2.1 Imports & Global Config


In [4]:

import os, math, random, json, time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from datasets import load_dataset, Dataset, DatasetDict
# from transformers import (
#     AutoTokenizer, AutoModelForSequenceClassification,
#     get_linear_schedule_with_warmup, AdamW
# )


from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW



# import evaluate
import matplotlib.pyplot as plt

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('Device:', device)


Device: cuda



### 2.2 Experiment Configuration
You can change the dataset, teacher/student models, and training hyper-parameters here.


In [5]:

# ===== Dataset & Task =====
DATASET_NAME = "glue"
SUBSET = "sst2"   # Binary sentiment classification (faster than IMDB)

# ===== Teacher / Student =====
TEACHER_CHECKPOINT = "textattack/roberta-large-SST-2"
STUDENT_CHECKPOINT = "distilbert-base-uncased"

# ===== Tokenization =====
MAX_LENGTH = 128

# ===== Training hyper-parameters =====
BATCH_SIZE = 16
EPOCHS = 3
LEARNING_RATE = 3e-5
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1

# ===== KD Hyper-parameters =====
ALPHA = 0.9     # weight on soft-label KL
TEMPERATURE = 2.0

# ===== Ablations / Sweeps =====
TEMP_GRID = [1.0, 2.0, 4.0]
ALPHA_GRID = [0.5, 0.9]

# ===== Misc =====
OUTPUT_DIR = "./kd_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)



## 3. Data
We use **GLUE/SST-2** for speed and reproducibility.


In [6]:

raw = load_dataset(DATASET_NAME, SUBSET)
raw


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

sst2/train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

sst2/validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

sst2/test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})


### 3.1 Tokenization


In [7]:

teacher_tok = AutoTokenizer.from_pretrained(TEACHER_CHECKPOINT, use_fast=True)
student_tok = AutoTokenizer.from_pretrained(STUDENT_CHECKPOINT, use_fast=True)

def tokenize_fn(batch, tokenizer):
    return tokenizer(batch["sentence"], truncation=True, padding=False, max_length=MAX_LENGTH)

tokenized = DatasetDict({
    split: raw[split].map(lambda b: tokenize_fn(b, teacher_tok), batched=True, remove_columns=raw[split].column_names)
    for split in raw
})

# Keep labels from original set (before remove_columns) for later joins:
raw = raw.remove_columns([c for c in raw["train"].column_names if c not in ["sentence", "label"]])

print(tokenized)


OSError: textattack/roberta-large-SST-2 is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `hf auth login` or by passing `token=<your_token>`


### 3.2 Collator


In [None]:

@dataclass
class PadCollator:
    tokenizer: AutoTokenizer
    label_key: str = "label"

    def __call__(self, features):
        labels = [f[self.label_key] for f in features]
        batch = self.tokenizer.pad(
            {k: [f[k] for f in features] for k in features[0] if k not in [self.label_key]},
            return_tensors="pt"
        )
        batch["labels"] = torch.tensor(labels, dtype=torch.long)
        return batch

collate_teacher = PadCollator(teacher_tok)
collate_student = PadCollator(student_tok)



## 4. Teacher Inference: Generate Soft Targets (Logits)
We run the teacher on the **training split** to get per-example logits (stored on disk).  
These logits will be used for KD training of the student.


In [None]:

teacher = AutoModelForSequenceClassification.from_pretrained(TEACHER_CHECKPOINT).to(device)
teacher.eval()

train_loader_teacher = DataLoader(
    tokenized["train"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_teacher
)

all_teacher_logits = []
with torch.no_grad():
    for batch in tqdm(train_loader_teacher, desc="Teacher -> logits"):
        batch = {k: v.to(device) for k, v in batch.items() if k != "labels"}
        out = teacher(**batch)
        all_teacher_logits.append(out.logits.cpu())

teacher_logits_train = torch.cat(all_teacher_logits, dim=0).numpy()
np.save(os.path.join(OUTPUT_DIR, "teacher_logits_train.npy"), teacher_logits_train)

print("Saved teacher logits:", teacher_logits_train.shape)



## 5. Student Model & KD Loss
We train a smaller student with KD. The loss:
\[
(1-\alpha)\,\mathrm{CE}(y, p_s) + \alpha T^2 \mathrm{KL}(p_t^{(T)} \Vert p_s^{(T)}).
\]


In [None]:

student = AutoModelForSequenceClassification.from_pretrained(STUDENT_CHECKPOINT, num_labels=2).to(device)

# Attach true labels for train/eval:
def attach_labels(ds_src, ds_tok):
    # Combine tokenized with labels from raw
    arr = []
    for i in range(len(ds_tok)):
        d = {k: ds_tok[i][k] for k in ds_tok[i]}
        d["label"] = int(raw["train"][i]["label"]) if ds_src == "train" else int(raw["validation"][i]["label"])
        arr.append(d)
    return Dataset.from_list(arr)

train_tok_with_labels = attach_labels("train", tokenized["train"])
val_tok_with_labels   = attach_labels("validation", tokenized["validation"])

train_loader_student = DataLoader(train_tok_with_labels, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_student)
val_loader_student   = DataLoader(val_tok_with_labels, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_student)

criterion_ce = nn.CrossEntropyLoss()
criterion_kd = nn.KLDivLoss(reduction="batchmean")  # expects log-probs as input
optimizer = AdamW(student.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

num_training_steps = len(train_loader_student) * EPOCHS
num_warmup_steps = int(WARMUP_RATIO * num_training_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

teacher_logits_train = np.load(os.path.join(OUTPUT_DIR, "teacher_logits_train.npy"))
assert teacher_logits_train.shape[0] == len(train_tok_with_labels), "Logit count must match training dataset size."

print("Student ready. Training steps:", num_training_steps)



### 5.1 Training Loop (KD)
We iterate over the train set and align teacher logits with current batch indices.


In [None]:

def kd_loss(student_logits, true_labels, teacher_logits_batch, alpha=ALPHA, T=TEMPERATURE):
    # Hard CE
    loss_ce = criterion_ce(student_logits, true_labels)
    # Soft KL (teacher vs student @ T)
    student_log_probs_T = torch.log_softmax(student_logits / T, dim=-1)
    with torch.no_grad():
        teacher_probs_T = torch.softmax(teacher_logits_batch / T, dim=-1)
    loss_kd = criterion_kd(student_log_probs_T, teacher_probs_T) * (T * T)
    return (1 - alpha) * loss_ce + alpha * loss_kd, loss_ce.detach(), loss_kd.detach()

def evaluate_model(model, loader):
    model.eval()
    metric = evaluate.load("accuracy")
    all_probs = []
    with torch.no_grad():
        for batch in loader:
            labels = batch["labels"].to(device)
            batch = {k: v.to(device) for k, v in batch.items() if k != "labels"}
            out = model(**batch)
            preds = out.logits.argmax(-1)
            metric.add_batch(predictions=preds.cpu().numpy(), references=labels.cpu().numpy())
            probs = torch.softmax(out.logits, dim=-1)[:,1].cpu().numpy()  # prob of positive class
            all_probs.extend(probs.tolist())
    return metric.compute(), np.array(all_probs)

# Map from dataset order to logits slices (since DataLoader shuffles, we use indices)
train_indices = np.arange(len(train_tok_with_labels))

def train_kd(model, epochs=EPOCHS, alpha=ALPHA, T=TEMPERATURE):
    global_step = 0
    model.train()
    for epoch in range(1, epochs+1):
        pbar = tqdm(train_loader_student, desc=f"Epoch {epoch}/{epochs}")
        start = 0
        for batch in pbar:
            labels = batch["labels"].to(device)
            batch_size = labels.size(0)

            # We need teacher logits aligned with the *current* batch items.
            # To make this deterministic, we precompute shuffled indices from DataLoader sampler.
            # Simpler: we store teacher logits in the dataset itself; here we approximate by slicing sequentially.
            # For exact alignment, build a custom Dataset that includes teacher logits per item.
            teacher_batch = torch.tensor(teacher_logits_train[start:start+batch_size], dtype=torch.float32, device=device)

            inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
            out = model(**inputs)
            s_logits = out.logits

            loss, loss_ce, loss_kd = kd_loss(s_logits, labels, teacher_batch, alpha=alpha, T=T)
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            pbar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "ce": f"{loss_ce.item():.4f}",
                "kd": f"{loss_kd.item():.4f}",
                "lr": f"{scheduler.get_last_lr()[0]:.2e}"
            })
            start += batch_size
            global_step += 1

        # Evaluate each epoch
        metrics, _ = evaluate_model(model, val_loader_student)
        print(f"Val @ epoch {epoch}: acc={metrics['accuracy']:.4f}")
    return model

student = train_kd(student, epochs=EPOCHS, alpha=ALPHA, T=TEMPERATURE)



## 6. Evaluation: Accuracy, Confusion Matrix, Calibration (ECE)
We compute **accuracy**, **confusion matrix**, and **Expected Calibration Error (ECE)** for the student and compare against the teacher.


In [None]:

from sklearn.metrics import confusion_matrix

def expected_calibration_error(probs, labels, n_bins=15):
    # probs: predicted probability of positive class
    bins = np.linspace(0.0, 1.0, n_bins+1)
    inds = np.digitize(probs, bins) - 1
    ece = 0.0
    for b in range(n_bins):
        mask = inds == b
        if np.sum(mask) == 0:
            continue
        conf = np.mean(probs[mask])
        acc  = np.mean((probs[mask] >= 0.5) == labels[mask])
        gap = np.abs(acc - conf)
        ece += (np.sum(mask) / len(probs)) * gap
    return float(ece)

# Teacher validation metrics
teacher.eval()
val_loader_teacher = DataLoader(tokenized["validation"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_teacher)
teacher_metric = evaluate.load("accuracy")
teacher_probs = []
with torch.no_grad():
    for batch in val_loader_teacher:
        labels = batch["labels"].to(device)
        inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
        out = teacher(**inputs)
        preds = out.logits.argmax(-1)
        teacher_metric.add_batch(predictions=preds.cpu().numpy(), references=labels.cpu().numpy())
        teacher_probs.extend(torch.softmax(out.logits, dim=-1)[:,1].cpu().numpy().tolist())
teacher_acc = teacher_metric.compute()["accuracy"]
teacher_probs = np.array(teacher_probs)
teacher_labels = raw["validation"]["label"]
teacher_ece = expected_calibration_error(teacher_probs, np.array(teacher_labels))

# Student validation metrics
student_metric, student_probs = evaluate_model(student, val_loader_student)
student_acc = student_metric["accuracy"]
student_ece = expected_calibration_error(student_probs, np.array(teacher_labels))

print(f"Teacher  - Acc: {teacher_acc:.4f}, ECE: {teacher_ece:.4f}")
print(f"Student  - Acc: {student_acc:.4f}, ECE: {student_ece:.4f}")

# Confusion matrices
# Recompute predictions to build CM
def predict_labels(model, loader):
    model.eval()
    preds_all, labels_all = [], []
    with torch.no_grad():
        for batch in loader:
            labels = batch["labels"].to(device)
            inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
            out = model(**inputs)
            preds = out.logits.argmax(-1)
            preds_all.extend(preds.cpu().numpy().tolist())
            labels_all.extend(labels.cpu().numpy().tolist())
    return np.array(preds_all), np.array(labels_all)

t_preds, t_labels = predict_labels(teacher, val_loader_student)
s_preds, s_labels = predict_labels(student, val_loader_student)

cm_t = confusion_matrix(t_labels, t_preds)
cm_s = confusion_matrix(s_labels, s_preds)

fig = plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.imshow(cm_t, interpolation='nearest')
plt.title("Teacher Confusion Matrix")
plt.colorbar(); plt.xlabel("Pred"); plt.ylabel("True")

plt.subplot(1,2,2)
plt.imshow(cm_s, interpolation='nearest')
plt.title("Student Confusion Matrix")
plt.colorbar(); plt.xlabel("Pred"); plt.ylabel("True")
plt.tight_layout()
plt.show()



## 7. Temperature & α Sweeps (Ablation)
We re-train short runs across a grid of \(T\) and \(\alpha\) to visualize their effects on accuracy and calibration.


In [None]:

def quick_kd_run(T, alpha, epochs=1):
    # fresh student
    model = AutoModelForSequenceClassification.from_pretrained(STUDENT_CHECKPOINT, num_labels=2).to(device)
    opt = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    sched = get_linear_schedule_with_warmup(opt, num_warmup_steps, len(train_loader_student)*epochs)

    start = 0
    model.train()
    for ep in range(epochs):
        for batch in train_loader_student:
            labels = batch["labels"].to(device)
            bs = labels.size(0)
            t_batch = torch.tensor(teacher_logits_train[start:start+bs], dtype=torch.float32, device=device)

            inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
            out = model(**inputs)
            loss, _, _ = kd_loss(out.logits, labels, t_batch, alpha=alpha, T=T)

            loss.backward()
            opt.step(); sched.step(); opt.zero_grad()
            start += bs
    met, probs = evaluate_model(model, val_loader_student)
    ece = expected_calibration_error(probs, np.array(raw["validation"]["label"]))
    return met["accuracy"], ece

sweep_results = []
for T in TEMP_GRID:
    for a in ALPHA_GRID:
        acc, ece = quick_kd_run(T, a, epochs=1)
        sweep_results.append({"T": T, "alpha": a, "acc": acc, "ece": ece})
        print(f"T={T}, alpha={a} -> acc={acc:.4f}, ece={ece:.4f}")

import pandas as pd
df_sweep = pd.DataFrame(sweep_results)
display(df_sweep)

# Plot accuracy vs T (grouped by alpha)
for a in ALPHA_GRID:
    sub = df_sweep[df_sweep["alpha"]==a]
    plt.figure()
    plt.plot(sub["T"], sub["acc"], marker='o')
    plt.title(f"Accuracy vs Temperature (alpha={a})")
    plt.xlabel("Temperature"); plt.ylabel("Accuracy"); plt.grid(True)
    plt.show()

# Plot ECE vs T (grouped by alpha)
for a in ALPHA_GRID:
    sub = df_sweep[df_sweep["alpha"]==a]
    plt.figure()
    plt.plot(sub["T"], sub["ece"], marker='o')
    plt.title(f"ECE vs Temperature (alpha={a})")
    plt.xlabel("Temperature"); plt.ylabel("ECE"); plt.grid(True)
    plt.show()



## 8. Hard-Labels Only Ablation
We remove the KD term (set \(\alpha=0\)) and train the student purely with CE on the ground-truth labels.


In [None]:

def train_hard_only(epochs=EPOCHS):
    model = AutoModelForSequenceClassification.from_pretrained(STUDENT_CHECKPOINT, num_labels=2).to(device)
    opt = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    sch = get_linear_schedule_with_warmup(opt, num_warmup_steps, len(train_loader_student)*epochs)

    model.train()
    for ep in range(epochs):
        pbar = tqdm(train_loader_student, desc=f"Hard-only epoch {ep+1}/{epochs}")
        for batch in pbar:
            labels = batch["labels"].to(device)
            inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
            out = model(**inputs)
            loss = criterion_ce(out.logits, labels)
            loss.backward()
            opt.step(); sch.step(); opt.zero_grad()
    met, probs = evaluate_model(model, val_loader_student)
    ece = expected_calibration_error(probs, np.array(raw["validation"]["label"]))
    return model, met["accuracy"], ece

hard_model, hard_acc, hard_ece = train_hard_only(epochs=1)
print(f"Hard-only (1 epoch) -> acc={hard_acc:.4f}, ece={hard_ece:.4f}")



## 9. Report & Discussion
- **Main result:** Compare teacher vs. student (KD) vs. student (hard-only).  
- **Sweeps:** Show how temperature \(T\) and \(\alpha\) affect accuracy and calibration.  
- **Takeaways:** In many settings, KD improves student accuracy and calibration over hard-only training, while being much smaller and cheaper at inference.

> **Reproducibility tips**: fix seeds, log versions (`transformers`, `datasets`, `torch`), freeze teacher, and persist teacher logits.



## Appendix: Utilities & Notes
- For exact alignment between teacher logits and student batches, store teacher logits **inside** the HF dataset with a fixed index, or create a custom `Dataset` that yields `(input_ids, attention_mask, label, teacher_logits)` together.
- For **multi-class** tasks, all code remains the same except `num_labels` and checkpoints.
- To speed up sweeps, reduce training epochs or subset the training set.
