# KD vs RKD vs QRKD (mini MNIST)

This walkthrough trains small teacher/student CNNs on a **10k subset of MNIST** for a few epochs to contrast classical distillation flavors:

- **KD**: match teacher logits (soft targets).
- **RKD**: match relational geometry (pairwise distances/angles) of teacher features.
- **QRKD**: extend RKD with a quantum-inspired fidelity kernel on normalized features (see `QRKD.txt` and the paper *Quantum Relational Knowledge Distillation*, arXiv:2508.13054, for the conceptual background: map features to a Hilbert space and align quantum kernels \|\u27e8\u03c6(x\_i)\|\u03c6(x\_j)\u27e9\|^2 to transfer richer relations).

We keep everything classical here; the "quantum" part is the fidelity kernel regularizer over feature vectors. Runs are short (3 epochs) for quick comparison.

In [None]:
# Add project to path when running from repo root
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

ROOT = Path(".").resolve()
if str(ROOT / "QRKD") not in sys.path:
    sys.path.insert(0, str(ROOT / "QRKD"))

In [None]:
from lib.losses import DistillationLoss
from lib.models import StudentCNN, TeacherCNN
from lib.train import TrainConfig, train_student, train_teacher

In [None]:
# Hyperparameters for the quick demo
SEED = 1337
SUBSET_SIZE = 10_000  # use 10k samples from MNIST train split
BATCH_SIZE = 128
EPOCHS = 3
LR = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
DEVICE

In [None]:
# Data: 10k MNIST subset for training, full 10k test for evaluation
tfm = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

train_full = datasets.MNIST("data", train=True, download=True, transform=tfm)
indices = torch.randperm(len(train_full))[:SUBSET_SIZE]
train_subset = Subset(train_full, indices)
test_set = datasets.MNIST("data", train=False, download=True, transform=tfm)

train_loader = DataLoader(
    train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0
)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

len(train_subset), len(test_set)

In [None]:
# Helper: train teacher
def run_teacher():
    teacher = TeacherCNN().to(DEVICE)
    cfg = TrainConfig(epochs=EPOCHS, lr=LR, device=str(DEVICE), verbose=True)
    teacher, hist = train_teacher(teacher, train_loader, cfg, test_loader)
    return teacher, hist


teacher, hist_teacher = run_teacher()
hist_teacher

In [None]:
# Student variants
def run_student(student_name: str, weights: DistillationLoss, teacher_model=None):
    student = StudentCNN().to(DEVICE)
    cfg = TrainConfig(epochs=EPOCHS, lr=LR, device=str(DEVICE), verbose=True)
    result = train_student(
        student,
        teacher_model,
        train_loader,
        test_loader,
        cfg,
        weights,
        student_name=student_name,
    )
    return result["history"], result["test_acc"]


histories = {}
test_accs = {}

# Scratch
scratch_hist, scratch_acc = run_student(
    "student_scratch",
    DistillationLoss(kd=0.0, dr=0.0, ar=0.0, qk=0.0),
    teacher_model=None,
)
histories["Scratch"] = scratch_hist
test_accs["Scratch"] = scratch_acc

# KD
kd_weights = DistillationLoss(
    kd=0.5, dr=0.0, ar=0.0, qk=0.0, temperature=4.0, kd_alpha=0.5
)
kd_hist, kd_acc = run_student("student_kd", kd_weights, teacher_model=teacher)
histories["KD"] = kd_hist
test_accs["KD"] = kd_acc

# RKD
rkd_weights = DistillationLoss(
    kd=0.0, dr=0.1, ar=0.1, qk=0.0, temperature=4.0, kd_alpha=0.5
)
rkd_hist, rkd_acc = run_student("student_rkd", rkd_weights, teacher_model=teacher)
histories["RKD"] = rkd_hist
test_accs["RKD"] = rkd_acc

# QRKD (simple fidelity kernel on features)
qrkd_weights = DistillationLoss(
    kd=0.5,
    dr=0.1,
    ar=0.1,
    qk=0.1,
    qk_backend="merlin",  # use Merl√Øn fidelity kernel backend for the "quantum" relational term
    qk_n_modes=10,
    qk_n_photons=5,
    temperature=4.0,
    kd_alpha=0.5,
)
qrkd_hist, qrkd_acc = run_student("student_qrkd", qrkd_weights, teacher_model=teacher)
histories["QRKD"] = qrkd_hist
test_accs["QRKD"] = qrkd_acc

test_accs

In [None]:
# Summarize final accuracy
from pprint import pprint

summary = {
    "Teacher train": hist_teacher["train_acc"][-1],
    "Teacher test": hist_teacher["test_acc"][-1],
}
for name in ["Scratch", "KD", "RKD", "QRKD"]:
    h = histories[name]
    summary[f"{name} train"] = h["train_acc"][-1]
    summary[f"{name} test"] = h["test_acc"][-1]

pprint(summary)

In [None]:
# Plot accuracy curves
plt.figure(figsize=(8, 5))
epochs = range(1, EPOCHS + 1)
plt.plot(epochs, hist_teacher["train_acc"], label="Teacher train", marker="o")
plt.plot(
    epochs, hist_teacher["test_acc"], label="Teacher test", marker="x", linestyle="--"
)
for name, style in zip(["Scratch", "KD", "RKD", "QRKD"], ["-", "--", "-.", ":"]):
    train = histories[name]["train_acc"]
    test = histories[name]["test_acc"]
    plt.plot(epochs, train, label=f"{name} train", marker="o", linestyle=style)
    plt.plot(epochs, test, label=f"{name} test", marker="x", linestyle=style)

plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title("KD vs RKD vs QRKD on 10k MNIST subset")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()