# **Knowledge Distillation for deep learning models**


*   Teacher (ResNet50) -> Pretrained weights and then fine tuned to CIFAR10 dataset.
*   Student (ResNet18)



In [1]:
#loading the libraries
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

# **Setting up the GPU**

In [2]:
print(f"PyTorch Version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

PyTorch Version: 2.6.0+cu124
Device: cuda


# **Dataset preparation**

In [3]:
# Convert images to tensors and normalize them using pre-calculated CIFAR-10 mean and std values
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2023, 0.1994, 0.2010])
])

# load full CIFAR-10 train set
train_set = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = CIFAR10(root='./data', train=False, download=True, transform=transform)

# spliting sizes for train and validation sets
total_size = len(train_set)
train_size = int(0.9 * total_size)
val_size = total_size - train_size

# perform split
train_subset, val_subset = random_split(train_set, [train_size, val_size])
print(f"Train samples: {train_size}")
print(f"Validation samples: {val_size}")
print(f"Test samples: {len(test_set)}")

# create DataLoaders
BATCH_SIZE = 128
train_loader = DataLoader(train_subset, batch_size= BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_subset, batch_size= BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_set, batch_size= BATCH_SIZE, shuffle=False)


Train samples: 45000
Validation samples: 5000
Test samples: 10000


In [4]:
from model_helpers import setup_models, extract_teacher_features
from utils import count_params, measure_latency, evaluate_accuracy
from distill_loss import distillation_loss, student_training_step
from teacher_core import train_teacher
from student_core import train_student
import os

PyTorch Version: 2.6.0+cu124
Device: cuda
PyTorch Version: 2.6.0+cu124
Device: cuda
PyTorch Version: 2.6.0+cu124
Device: cuda
PyTorch Version: 2.6.0+cu124
Device: cuda
PyTorch Version: 2.6.0+cu124
Device: cuda


In [5]:
# Setup models first (this returns teacher and student_wrapper)
teacher, student_wrapper = setup_models(device)

# **Teacher model**

In [6]:
# train the teacher on CIFAR-10
teacher = train_teacher(
    teacher=teacher,
    loader= train_loader,
    val_loader= val_loader,
    epochs=20,
    tag="Fine-tuning teacher",
    save_path="tuned_pretrained_resnet50_on_CIFAR10.pth"
)


(Fine-tuning teacher)	Epoch 1: loss=0.8133, Accuracy (validation): 79.80%
(Fine-tuning teacher)	Epoch 2: loss=0.4484, Accuracy (validation): 83.42%
(Fine-tuning teacher)	Epoch 3: loss=0.5439, Accuracy (validation): 84.14%
(Fine-tuning teacher)	Epoch 4: loss=0.3513, Accuracy (validation): 83.96%
(Fine-tuning teacher)	Epoch 5: loss=0.2241, Accuracy (validation): 84.12%
(Fine-tuning teacher)	Epoch 6: loss=0.1254, Accuracy (validation): 82.96%
(Fine-tuning teacher)	Epoch 7: loss=0.2714, Accuracy (validation): 85.48%
(Fine-tuning teacher)	Epoch 8: loss=0.1182, Accuracy (validation): 84.60%
(Fine-tuning teacher)	Epoch 9: loss=0.0782, Accuracy (validation): 85.62%
(Fine-tuning teacher)	Epoch 10: loss=0.1112, Accuracy (validation): 85.40%
(Fine-tuning teacher)	Epoch 11: loss=0.0757, Accuracy (validation): 84.94%
(Fine-tuning teacher)	Epoch 12: loss=0.1008, Accuracy (validation): 85.22%
(Fine-tuning teacher)	Epoch 13: loss=0.1285, Accuracy (validation): 85.66%
(Fine-tuning teacher)	Epoch 14: lo

**Evaluation**

In [7]:
#evaluate the teacher model
teacher_params = count_params(teacher)
teacher_latency = measure_latency(teacher, device=device)
teacher_acc = evaluate_accuracy(teacher, test_loader)

#print the results

print(f"Teacher Params: {teacher_params / 1e6:.2f}M")
print(f"Teacher Latency: {teacher_latency:.2f} ms")
print(f"Teacher Test Accuracy: {teacher_acc * 100:.2f}%")

Teacher Params: 23.53M
Teacher Latency: 8.72 ms
Teacher Test Accuracy: 85.11%


## **Student Model via Knowledge Distillation**

In [8]:
# train student training with fixed Temperature and Alpha
student_fixed = train_student(teacher, student_wrapper, train_loader, val_loader, 20, device, fixed_T=5.0, fixed_alpha=0.7, save_path="student_fixed.pth")


[(Training student)	Epoch 1] Loss = 10.5607 | Val Acc = 62.42%
New best model saved.
[(Training student)	Epoch 2] Loss = 6.6346 | Val Acc = 67.36%
New best model saved.
[(Training student)	Epoch 3] Loss = 4.9954 | Val Acc = 71.88%
New best model saved.
[(Training student)	Epoch 4] Loss = 3.9482 | Val Acc = 71.78%
[(Training student)	Epoch 5] Loss = 3.1448 | Val Acc = 74.40%
New best model saved.
[(Training student)	Epoch 6] Loss = 2.4659 | Val Acc = 74.60%
New best model saved.
[(Training student)	Epoch 7] Loss = 2.0139 | Val Acc = 74.32%
[(Training student)	Epoch 8] Loss = 1.6365 | Val Acc = 75.86%
New best model saved.
[(Training student)	Epoch 9] Loss = 1.3252 | Val Acc = 75.44%
[(Training student)	Epoch 10] Loss = 1.0997 | Val Acc = 76.12%
New best model saved.
[(Training student)	Epoch 11] Loss = 0.9982 | Val Acc = 74.68%
[(Training student)	Epoch 12] Loss = 0.9988 | Val Acc = 75.44%
[(Training student)	Epoch 13] Loss = 0.9205 | Val Acc = 76.50%
New best model saved.
[(Training st

In [9]:
# evaluate size, latency, and accuracy
student_fixed_params = count_params(student_fixed)
student_fixed_latency = measure_latency(student_fixed, device=device)
student_fixed_acc = evaluate_accuracy(student_fixed, test_loader)


print(f"Student Fixed Params: {student_fixed_params / 1e6:.2f}M")
print(f"Student Fixed Latency: {student_fixed_latency:.2f} ms")
print(f"Student Fixed Test Accuracy: {student_fixed_acc * 100:.2f}%")


Student Fixed Params: 11.18M
Student Fixed Latency: 3.93 ms
Student Fixed Test Accuracy: 76.83%


# **Hyperparameter tuning**

**Fixed Alpha and Temperature:**

Alpha (α)
*   High (e.g., 0.8): Focuses on soft labels — useful when the teacher is reliable.

*   Low (e.g., 0.2): Emphasizes hard labels — better if the teacher isn’t perfect.

Temperature (T)
*  Low (1–2): Sharp outputs — can be too strict.

* High (3–5): Softer outputs — easier for student to learn patterns.


---


**Temperature & Alpha Scheduling:**
* Used exponential decay (0.95) for both:

Temperature (0.5 → 0.3):
* Starts high to soften the teacher's outputs, allowing the student to absorb nuanced, generalized knowledge. As it decays, the teacher’s predictions become sharper, helping the student refine its learning and align with confident teacher outputs.

Alpha (0.8 → 0.5):
* Initially gives more weight to the teacher’s soft targets, reinforcing transfer of knowledge. As training progresses, lower alpha encourages the student to rely more on hard labels, improving its own discriminative ability and generalization.



In [6]:
#T value between (5.0, 3.0) and aplha value between (0.8, 0.5 )
student_schedule = train_student(teacher, student_wrapper, train_loader, val_loader, 20, device,  save_path="student_schedule.pth")

[(Training student)	Epoch 1] Loss = 1.4867 | Val Acc = 51.22%
New best model saved.
[(Training student)	Epoch 2] Loss = 1.3632 | Val Acc = 52.94%
New best model saved.
[(Training student)	Epoch 3] Loss = 1.3385 | Val Acc = 66.78%
New best model saved.
[(Training student)	Epoch 4] Loss = 1.3185 | Val Acc = 70.32%
New best model saved.
[(Training student)	Epoch 5] Loss = 1.3099 | Val Acc = 72.86%
New best model saved.
[(Training student)	Epoch 6] Loss = 1.2938 | Val Acc = 72.74%
[(Training student)	Epoch 7] Loss = 1.2786 | Val Acc = 76.02%
New best model saved.
[(Training student)	Epoch 8] Loss = 1.2127 | Val Acc = 77.48%
New best model saved.
[(Training student)	Epoch 9] Loss = 1.1863 | Val Acc = 77.36%
[(Training student)	Epoch 10] Loss = 1.1775 | Val Acc = 76.28%
[(Training student)	Epoch 11] Loss = 1.1603 | Val Acc = 77.64%
New best model saved.
[(Training student)	Epoch 12] Loss = 1.1467 | Val Acc = 76.02%
[(Training student)	Epoch 13] Loss = 1.1414 | Val Acc = 77.98%
New best model

In [7]:
# evaluate size, latency, and accuracy
student_schedule_params = count_params(student_schedule)
student_schedule_latency = measure_latency(student_schedule, device=device)
student_schedule_acc = evaluate_accuracy(student_schedule, test_loader)


print(f"Student_schedule Params: {student_schedule_params / 1e6:.2f}M")
print(f"Student_schedule Latency: {student_schedule_latency:.2f} ms")
print(f"Student_schedule Accuracy: {student_schedule_acc * 100:.2f}%")

Student Fixed Params: 11.18M
Student Fixed Latency: 4.10 ms
Student Fixed Test Accuracy: 78.60%


# **Student Model without Knowledge Distillation**

In [9]:
# define baseline student: ResNet18 training from scratch on its own, re-headed for CIFAR-10
baseline_student = models.resnet18(weights=None)
baseline_student.fc = nn.Linear(512, 10).to(device)
baseline_student = baseline_student.to(device)

# Train the baseline student on CIFAR-10
baseline_student = train_teacher(baseline_student, train_loader, val_loader, epochs=20, tag="baseline-student", save_path="baseline_student.pth")

# Evaluate baseline student
baseline_student_acc = evaluate_accuracy(baseline_student, test_loader)
print(f"\nBaseline Student Test Accuracy: {baseline_student_acc * 100:.2f}%")

Model already trained. Loading from baseline_student.pth

Baseline Student Test Accuracy: 75.06%


In [14]:
# evaluate size, latency, and accuracy
baseline_student_params = count_params(baseline_student)
baseline_student_latency = measure_latency(baseline_student, device=device)
baseline_student_acc = evaluate_accuracy(baseline_student, test_loader)


print(f"baseline_student_Params: {baseline_student_params / 1e6:.2f}M")
print(f"baseline_student Latency: {baseline_student_latency:.2f} ms")
print(f"baseline_student Test Accuracy: {baseline_student_acc * 100:.2f}%")

baseline_student_Params: 11.18M
baseline_student Latency: 4.72 ms
baseline_student Test Accuracy: 75.06%


# **Comparision Table**

In [28]:
import pandas as pd
from IPython.display import display

data = {
    "Model": ["Teacher", "Student_Schedule (distilled best)", "student_fixed (distilled)", "Baseline Student"],
    "Accuracy (%)": [85.11,78.60,76.80,75.06],
    "Parameters (M)": [23.53,11.18,11.18,11.18],
    "Latency (ms)": [8.72,3.9,4.1,4.72],
}

df = pd.DataFrame(data)

# Function to color rows based on model name
def highlight_models(row):
    color = ''
    text_color = 'color: black;'
    if row['Model'] == 'Teacher':
        color = 'background-color: #cc4c4c; ' + text_color
    elif row['Model'] == 'Student_Schedule (distilled best)':
        color = 'background-color: #4caf50; ' + text_color
    elif row['Model'] == 'Baseline Student':
        color = 'background-color: #82b4e3; ' + text_color
    return [color] * len(row)

# Display the styled dataframe
styled_df = df.style.apply(highlight_models, axis=1).set_caption("Model Comparison Table")
display(styled_df)


Unnamed: 0,Model,Accuracy (%),Parameters (M),Latency (ms)
0,Teacher,85.11,23.53,8.72
1,Student_Schedule (distilled best),78.6,11.18,3.9
2,student_fixed (distilled),76.8,11.18,4.1
3,Baseline Student,75.06,11.18,4.72


# **Drafts :** Experimented with tuning T and Alpha to find the best student model

In [7]:
# train student training with fixed Temperature 5.0 and Alpha 0.8
student_fixed1 = train_student(teacher, student_wrapper, train_loader, val_loader, 20, device, fixed_T=5.0, fixed_alpha=0.8, save_path="student_fixed1.pth")

[(Training student)	Epoch 1] Loss = 1.4536 | Val Acc = 52.54%
New best model saved.
[(Training student)	Epoch 2] Loss = 1.3312 | Val Acc = 63.24%
New best model saved.
[(Training student)	Epoch 3] Loss = 1.3026 | Val Acc = 69.54%
New best model saved.
[(Training student)	Epoch 4] Loss = 1.2886 | Val Acc = 71.78%
New best model saved.
[(Training student)	Epoch 5] Loss = 1.2781 | Val Acc = 72.84%
New best model saved.
[(Training student)	Epoch 6] Loss = 1.2663 | Val Acc = 74.38%
New best model saved.
[(Training student)	Epoch 7] Loss = 1.2534 | Val Acc = 71.00%
[(Training student)	Epoch 8] Loss = 1.1908 | Val Acc = 76.52%
New best model saved.
[(Training student)	Epoch 9] Loss = 1.1655 | Val Acc = 78.98%
New best model saved.
[(Training student)	Epoch 10] Loss = 1.1553 | Val Acc = 77.50%
[(Training student)	Epoch 11] Loss = 1.1427 | Val Acc = 78.92%
[(Training student)	Epoch 12] Loss = 1.1260 | Val Acc = 78.06%
[(Training student)	Epoch 13] Loss = 1.1115 | Val Acc = 78.68%
[(Training stu

In [None]:
#train the student with Temperature and Alpha scheduling
student = train_student(teacher, student_wrapper, train_loader, val_loader, 20, device)

[(Training student)	Epoch 1] Loss = 1.6438 | Val Acc = 52.04%
New best model saved.
[(Training student)	Epoch 2] Loss = 1.5064 | Val Acc = 64.16%
New best model saved.
[(Training student)	Epoch 3] Loss = 1.4677 | Val Acc = 67.76%
New best model saved.
[(Training student)	Epoch 4] Loss = 1.4380 | Val Acc = 69.50%
New best model saved.
[(Training student)	Epoch 5] Loss = 1.4155 | Val Acc = 72.76%
New best model saved.
[(Training student)	Epoch 6] Loss = 1.3811 | Val Acc = 74.28%
New best model saved.
[(Training student)	Epoch 7] Loss = 1.3503 | Val Acc = 71.84%
[(Training student)	Epoch 8] Loss = 1.3178 | Val Acc = 76.32%
New best model saved.
[(Training student)	Epoch 9] Loss = 1.2723 | Val Acc = 75.42%
[(Training student)	Epoch 10] Loss = 1.2424 | Val Acc = 75.46%
[(Training student)	Epoch 11] Loss = 1.2093 | Val Acc = 76.66%
New best model saved.
[(Training student)	Epoch 12] Loss = 1.1779 | Val Acc = 76.98%
New best model saved.
[(Training student)	Epoch 13] Loss = 1.1598 | Val Acc =

In [None]:
student_schedule1 = train_student(teacher, student_wrapper, train_loader, val_loader, 20, device,  save_path="student_schedule1.pth")

[(Training student)	Epoch 1] Loss = 1.3132 | Val Acc = 50.66%
New best model saved.
[(Training student)	Epoch 2] Loss = 1.2714 | Val Acc = 66.18%
New best model saved.
[(Training student)	Epoch 3] Loss = 1.2601 | Val Acc = 70.16%
New best model saved.
[(Training student)	Epoch 4] Loss = 1.2461 | Val Acc = 73.44%
New best model saved.
[(Training student)	Epoch 5] Loss = 1.2384 | Val Acc = 74.96%
New best model saved.
[(Training student)	Epoch 6] Loss = 1.2265 | Val Acc = 74.38%
[(Training student)	Epoch 7] Loss = 1.2136 | Val Acc = 77.74%
New best model saved.
[(Training student)	Epoch 8] Loss = 1.1478 | Val Acc = 79.30%
New best model saved.
[(Training student)	Epoch 9] Loss = 1.1313 | Val Acc = 79.08%
[(Training student)	Epoch 10] Loss = 1.1226 | Val Acc = 78.80%
[(Training student)	Epoch 11] Loss = 1.1122 | Val Acc = 79.36%
New best model saved.
[(Training student)	Epoch 12] Loss = 1.1033 | Val Acc = 79.26%
[(Training student)	Epoch 13] Loss = 1.0993 | Val Acc = 78.66%
[(Training stu