# Knowledge distillation
Based on this awesome paper [https://arxiv.org/abs/1503.02531](Distilling the Knowledge in a Neural Network, by Geoffrey Hinton, Oriol Vinyals, Jeff Dean). Made in a rysh, escuse any brevity

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize

In [None]:
# Load MNIST
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = np.expand_dims(x_train, -1) #Before: (N, 28, 28) After: (N, 28, 28, 1)
x_test = np.expand_dims(x_test, -1)

In [None]:
sample_image = x_test[0]
sample_label = y_test[0]

# Plot the image
plt.imshow(sample_image, cmap="gray")
plt.title(f"Label: {sample_label}")
plt.axis("off")
plt.show()

print("Label:", sample_label)

In [None]:
#We will follow the original input and flatten it to use a DNN:
# Flatten the MNIST inputs
x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)

In [None]:
# Create early stopping callback
from tensorflow.keras.callbacks import EarlyStopping

early_stopping = EarlyStopping(
    monitor='val_loss',   # Monitor validation loss
    patience=10,          # Stop if no improvement after 10 epochs
    restore_best_weights=True  # Restore model weights from best epoch
)
#Create teacher
def create_teacher():
    model = keras.Sequential([
        layers.Input(shape=(784,)),                 # Flattened MNIST input
        layers.Dense(1200, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(1200, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(10)                            # Output logits
    ])
    return model

# Compile and train teacher
teacher = create_teacher()
teacher.compile(
    optimizer='adam',
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
teacher.fit(x_train, y_train, epochs=50, batch_size=64, validation_split=0.1,callbacks=[early_stopping])
teacher.evaluate(x_test, y_test)

In [None]:
# Create Hinton-style student (1 hidden layer, 300 units)
from tensorflow import keras
from tensorflow.keras import layers

# Create Hinton-style student (1 hidden layer, 300 units)
def create_student():
    model = keras.Sequential([
        layers.Input(shape=(784,)),          # Flattened input
        layers.Dense(300, activation='relu'),
        layers.Dense(10)                     # Output logits
    ])
    return model

student = create_student()


### Knowledge Distillation with KL Divergence

In distillation, we train a **student model** to mimic the **softened output** of a **teacher model**.  
We use **KL Divergence** to measure how much the student’s predicted distribution diverges from the teacher’s.

$$
\text{KL}(P || Q) = \sum_i P(i) \log\left(\frac{P(i)}{Q(i)}\right)
$$

Where:
- \( P \): teacher's softmax output (soft labels)
- \( Q \): student’s softmax output

> KL divergence encourages the student to match the **class probabilities**, not just the hard label.

![KL Divergence Intuition](https://upload.wikimedia.org/wikipedia/commons/thumb/8/8e/Kullback%E2%80%93Leibler_distributions_example_1.svg/2560px-Kullback%E2%80%93Leibler_distributions_example_1.svg.png)

See [here](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)


## Softmax Temperature in Knowledge Distillation

This plot shows how the temperature parameter \(T\) affects the softmax output distribution in knowledge distillation:

![Softmax Temperature Scaling](https://miro.medium.com/v2/resize:fit:1400/format:webp/0*7xj72SjtNHvCMQlV.jpeg)

- **T = 1**: Standard softmax with sharp, confident predictions.
- **T > 1**: Softens the probability distribution, revealing more information about less likely classes.
- **T → ∞**: Produces a uniform distribution over classes.

Higher temperature values help the student model learn better from the teacher's softened outputs by capturing relative similarities between classes.

In [None]:
# Custom model class for knowledge distillation
# This wraps a student and teacher model and defines training logic
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()  # Initialize base keras.Model
        self.student = student  # Student model to be trained
        self.teacher = teacher  # Pre-trained teacher model

    def compile(self, optimizer, metrics,
                student_loss_fn, distillation_loss_fn,
                alpha=0.1, temperature=10):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        x, y = data  # Unpack the data (features and labels)

        # Get teacher predictions (frozen, inference mode)
        teacher_logits = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Get student predictions
            student_logits = self.student(x, training=True)

            # Hard label loss (e.g. true labels)
            student_loss = self.student_loss_fn(y, student_logits)

            # Soft label loss (between softened predictions)
            distill_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_logits / self.temperature, axis=1),
                tf.nn.softmax(student_logits / self.temperature, axis=1),
            )

            # Combine losses
            loss = self.alpha * student_loss + (1 - self.alpha) * distill_loss

        # Compute gradients and apply them
        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))

        # Update and return metrics
        self.compiled_metrics.update_state(y, student_logits)
        results = {m.name: m.result() for m in self.metrics}
        results.update({
            "student_loss": student_loss,
            "distill_loss": distill_loss,
            "total_loss": loss,
        })
        return results

    def call(self, inputs, training=False, mask=None):
 
        return self.student(inputs, training=training)


In [None]:
# Train the student using knowledge distillation with Hinton's original parameters
distiller = Distiller(student=student, teacher=teacher)

# Compile the distiller
distiller.compile(
    optimizer=tf.keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

In [None]:
history = distiller.fit(x_train, y_train, epochs=50, batch_size=64, validation_split=0.1,callbacks=[early_stopping])
distiller.evaluate(x_test, y_test)

In [None]:
import matplotlib.pyplot as plt

plt.plot(history.history['student_loss'], label='Student Loss')
plt.plot(history.history['distill_loss'], label='Distillation Loss')
plt.plot(history.history['total_loss'], label='Combined Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# After distillation training
student.compile(
    optimizer="adam",
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

# Now this will work:
student.evaluate(x_test, y_test)

In [None]:
# We also need to comnpare with a student trained from skratch:
student_scratch = create_student()
student_scratch.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

# Train from scratch
student_scratch.fit(x_train, y_train, epochs=50, batch_size=64, validation_split=0.1,callbacks=[early_stopping])

In [None]:
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import numpy as np
import mplhep
# %pip install mplhep if neccessary
# Apply a style (e.g. CMS or ATLAS)
mplhep.style.use("CMS")

# Binarize the test labels (one-hot for ROC)
y_test_bin = label_binarize(y_test, classes=range(10))

# Select the digit class to analyze (class 8)
target_class = 8

# Get probabilities
teacher_probs = tf.nn.softmax(teacher(x_test), axis=1).numpy()
student_probs = tf.nn.softmax(student(x_test), axis=1).numpy()
student_scratch_probs = tf.nn.softmax(student_scratch(x_test), axis=1).numpy()

# Compute ROC curves and AUCs
fpr_t, tpr_t, _ = roc_curve(y_test_bin[:, target_class], teacher_probs[:, target_class])
auc_t = auc(fpr_t, tpr_t)

fpr_d, tpr_d, _ = roc_curve(y_test_bin[:, target_class], student_probs[:, target_class])
auc_d = auc(fpr_d, tpr_d)

fpr_s, tpr_s, _ = roc_curve(y_test_bin[:, target_class], student_scratch_probs[:, target_class])
auc_s = auc(fpr_s, tpr_s)

# Plot
plt.figure(figsize=(8, 6))
plt.semilogx(fpr_t, tpr_t, label=f"Teacher (AUC = {auc_t:.4f})", linewidth=2)
plt.semilogx(fpr_d, tpr_d, label=f"Distilled Student (AUC = {auc_d:.4f})", linestyle='--', linewidth=2)
plt.semilogx(fpr_s, tpr_s, label=f"Student from Scratch (AUC = {auc_s:.4f})", linestyle=':', linewidth=2)

plt.plot([0, 1], [0, 1], 'k--', linewidth=1)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(loc="lower right")
plt.show()


In [None]:
# Plot logits just to get a feeling where they differ
# import matplotlib.pyplot as plt
import numpy as np

# Get logits from student model (shape: [num_samples, 10])
logits = student(x_test).numpy()

# Plot only first two dimensions of logits
plt.figure(figsize=(8, 6))
scatter = plt.scatter(logits[:, 0], logits[:, 1], c=y_test, cmap='tab10', s=10, alpha=0.7)
plt.xlabel("Logit dimension 0")
plt.ylabel("Logit dimension 1")
plt.colorbar(scatter, ticks=range(10), label='True Label')

plt.show()


In [None]:
# Plot logits just to get a feeling where they differ
# import matplotlib.pyplot as plt
import numpy as np

# Get logits from student model (shape: [num_samples, 10])
logits = teacher(x_test).numpy()

# Plot only first two dimensions of logits
plt.figure(figsize=(8, 6))
scatter = plt.scatter(logits[:, 0], logits[:, 1], c=y_test, cmap='tab10', s=10, alpha=0.7)
plt.xlabel("Logit dimension 0")
plt.ylabel("Logit dimension 1")
plt.colorbar(scatter, ticks=range(10), label='True Label')

plt.show()