# Knowledge Distillation

Knowledge Distillation, often abbreviated as KD, is a powerful technique in the field of machine learning and artificial intelligence. It's a method that allows us to compress and transfer the knowledge learned by a large, complex model to a smaller and more efficient one. The core idea behind Knowledge Distillation is to take the extensive expertise of a large "teacher" model and distill it into a more compact "student" model.

**Teacher-Student Paradigm**: Knowledge Distillation leverages a teacher-student paradigm. The "teacher" model is a large, well-performing model, while the "student" model is a smaller, more lightweight model. You can 'distill' the large and complex network into a much smaller network which does as good a job of approximating the original function learned by the deep network.

**Goal**: The primary goal of knowledge distillation is to train the student model to not only predict the target labels but also to mimic the knowledge contained within the teacher model.

### Formal Definition:

Neural networks use a softmax function to generate the logits $z_i$ to class probabilities $p(z_i, T) = \frac{exp(z_i/T)}{∑_jexp(z_j/T)}$. Here, i, j = 0, 1, 2, ..., C-1, where C is the number of classes. T is the temperature, which is normally set to 1.



Alternately, the **goal** of knowledge distillation is to align the class probability distributions from teacher and student networks.

# Knowledge

In the context of knowledge distillation, a vanilla knowledge distillation primarily leverages the logits of a large deep model as the source of teacher knowledge. The activations, neurons, or features of intermediate layers can be used as the knowledge to guide the learning of the student model. The relationships between different activations, neurons or pairs of samples contain rich information learned by the teacher model. Furthermore, the parameters of the teacher model (or the connections between layers) also contain another knowledge. We will discuss different forms of knowledge such as: response-based knowledge, feature-based knowledge,and relation-based knowledge.

### 1. Response-based knowledge

Response-based knowledge is centered around the final output layer of the teacher model. It aims to train the student model to replicate the teacher's predictions. A loss function, known as the distillation loss, quantifies the difference between the student's logits and the teacher's logits. As the distillation loss is minimized during training, the student model learns to make similar predictions as the teacher model.

<!-- ![response_based_knowledge](https://drive.google.com/uc?id=1vybYI95A6-BJ3ZA2Jz18_NohlmUwO427) -->
![response_based_knowledge](https://i.postimg.cc/J4wSWc7X/image.png)

### 2. Feature-based knowledge

Feature-based knowledge involves the intermediate layers of the teacher model, where valuable data features are captured. This type of knowledge is relevant, particularly in deep neural networks. Feature-based knowledge distillation involves training the student model to learn the same feature activations as the teacher model. The distillation loss function is used to minimize the disparity between the feature activations of the teacher and student models.

<!-- ![feature_based_knowledge](https://drive.google.com/uc?id=1kWU2DXTTchQaPj_aPvqW1ZXAJDUdiB2b) -->
![feature_based_knowledge](https://i.postimg.cc/6qxcFcV5/image.png)

### 3. Relation-based knowledge

In addition to knowledge in output and intermediate layers, neural networks can contain knowledge about relationships between feature maps or data representations. This knowledge, referred to as relation-based knowledge, captures correlations, graphs, similarity matrices, feature embeddings, or probabilistic distributions based on feature representations. Relation-based knowledge is used to train a student model by modeling the relationships between feature maps or data representations. The distillation process incorporates these relationships to guide the student model's learning.

<!-- ![relation_based_knowledge](https://drive.google.com/uc?id=14bSslYG4QRzq4XOSMYWjkJl5auKYtigj) -->
![relation_based_knowledge](https://i.postimg.cc/L85tCGd3/image.png)

# Training/Distillation Schemes

Based on whether the teacher model is
updated simultaneously with the student model or
not, the learning schemes of knowledge distillation
can be divided into three main categories:
**offline distillation**, **online distillation** and **self-distillation**.


<!-- ![dist_scheme](https://drive.google.com/uc?id=1yldXqGD-Ze9elPOCINhKm3w067GgfQKl) -->
![dist_scheme](https://i.postimg.cc/MG8GMjYf/image.png)

### 1. Offline Distillation

In offline distillation, the teacher model is pre-trained and remains fixed during the distillation process. There is no further training or fine-tuning of the teacher model.

**Training Process:**
The training process is divided into two stages:
- The large teacher model is first trained independently on a set of training samples before the distillation process begins.
- The teacher model is then used to extract knowledge, such as logits or intermediate features, which are used to guide the training of the student model during distillation.

**Focus:** Offline distillation methods mainly focus on the design of knowledge transfer, including knowledge representation and loss functions for matching features or distributions.

**Advantages:** Offline methods are simple and easy to implement. They work well when the teacher model is pre-defined and known in advance.

### 2. Online Distillation

A large pre-trained model as the teacher may not always be available as Offline Distillation methods assume. Thus, in such scenarios, the teacher and student networks are trained simultaneously — which is called Online Distillation. The entire distillation process is end-to-end trainable.

**Training Process:**
- Various online knowledge distillation methods have been proposed in recent years, including deep mutual learning, ensemble learning, and feature fusion.
- The teacher-student interaction is more dynamic, allowing for improved adaptation and performance of the student model.

**Advantages:** Online distillation is a one-phase end-to-end training scheme with efficient parallel computing. It is particularly useful when a high-capacity teacher model is not available, and it allows the student model to adapt to new data and domains.

### 3. Self Distillation

 In self-distillation, the teacher and student models are the same network, meaning that the teacher and student are essentially the same model. There is no external teacher model.

**Training Process:**
- Knowledge from deeper sections of the network is distilled into shallower sections of the same network.
- Self-attention distillation methods and snapshot distillation are examples of self-distillation techniques.
- Self-distillation is used to optimize deep models with the same architecture one by one.

**Advantages:** Self-distillation simplifies the distillation process, and the network learns from its own predictions and feature representations.

## How is knowledge transferred from teacher to student?

Following is the setup of the training process of the student model:

<!-- ![know_dist](https://drive.google.com/uc?id=1QwDc9NFW8YRENWxwsDR7bvJhvM8kctjX) -->
![know_dist](https://i.postimg.cc/nLZHcfHv/image.png)

- Transfer of the generalization ability of teacher model to student model is done by the use of **soft targets.**
- There are two loss terms. The soft labels (from Teacher) and soft predictions (from Student) are used in the first loss term, and the hard prediction (from Student) and hard labels are used in the second loss term. These two terms can always be configured for their contribution.
- If teacher model is an ensemble of simpler models, we use arithmetic or geometric mean of individual predictive distributions as soft targets.
- When high entropy is present in the soft targets, much more information per training case is provided than in the case of hard targets, and much less variance in the gradient between training cases is observed. As a result, the small model can often be trained on much less data than the original cumbersome model while using a much higher learning rate.



### The Temperature Parameter:
The temperature parameter, often denoted as $T$, is a crucial element in the Knowledge Distillation process. It controls the softening of the teacher's predictions, making them more informative for the student. The temperature parameter is incorporated through a softmax function.

### Softmax Function:
The softmax function is typically used to convert logits (real-valued scores) into probabilities. In the context of Knowledge Distillation, it is modified to include the temperature parameter:

The standard softmax function is defined as:
$P_i=\frac{e^{z_i}}{∑_je^{z_j}}$
where,
- $P_i$ is the probability of class $i$.
- $z_i$ is the unnormalized logit of class $i$.
- The denominator sums over all logits.

Now, with the temperature parameter:

$p_i = \frac{e^{(z_i/T)}}{∑_je^{(z_j/T)}}$

The temperature parameter $T$ controls the balance between two key aspects in Knowledge Distillation:

1. **Hard Targets**: When $T=1$, the softmax function behaves like the standard softmax, resulting in hard targets. In this case, the student tries to match the exact class probabilities predicted by the teacher, aiming for a highly accurate but potentially overparameterized model.

2. **Soft Targets**: When $T>1$, the softened probabilities are used as "soft targets." This allows the student to focus on capturing the relative importance of different classes without necessarily matching the exact teacher's predictions. Soft targets help in reducing overfitting and generalizing better.

### Distillation

We use Logits(inputs to the final softmax) for distilling the learned knowledge. The student model can be trained using logits, which is achieved by minimizing the squared difference between the logits produced by the teacher model and the logits produced by the student model.

$$p_i = \frac{e^{(z_i/T)}}{∑_je^{(z_j/T)}}$$



**How Temperature Affects Knowledge Distillation:**
- Raising Temperature = Making logits smaller
- As we increase T, the logits are smoother (closer to each other)

For high temperatures (T -> inf), nearly the same probability is assigned to all actions, and at lower temperatures (T -> 0), the probability is influenced more by the expected rewards. For low temperatures, the probability of the action with the highest expected reward tends to approach 1.

In distillation, the temperature of the final softmax is raised until a suitably soft set of targets is generated by the cumbersome model. The same high temperature is then employed when training the small model to match these soft targets.

### Objective Function

- The initial objective function involves calculating the cross-entropy using soft targets, where the cross-entropy is determined using the high temperature employed in the softmax of the distilled model, which matches the temperature used for generating the soft targets from the larger model.
- The second objective function pertains to the cross-entropy with the correct labels, and it is computed by utilizing the identical logits in the softmax of the distilled model, but this time, the temperature is set to 1.

The Teacher model is trained using the categorical cross-entropy applied to the one-hot labels. When applying knowledge distillation the Student model is trained using a mix between the Kullback Leibler divergence and the MAE loss on the soft labels predicted by the Teacher model as target. The Kullback Leibler divergence measures the difference between two probability distributions, so our objective is to make the distribution predicted by the Student as close as possible to the distribution of the Teacher.

# Knowledge Distillation in NLP

- Conventional language models like BERT are complex and resource-intensive. Knowledge distillation has emerged as a powerful approach to create lightweight and efficient language models in the field of NLP.

- Numerous knowledge distillation (KD) methods have been proposed to address NLP tasks like neural machine translation (NMT), text generation, question-answering systems, event detection, document retrieval, text recognition, etc.

- Most KD methods for NLP fall into the category of Natural Language Understanding (NLU). They are designed as **task-specific distillation** or **multi-task distillation**.

- NMT stands out as a prominent application in NLP, and KD has been instrumental in creating lightweight NMT models. Various KD techniques have been employed to make NMT models more efficient and compact.

- For instance, non-autoregressive machine translation (NAT) models have benefited from KD-based approaches, where the capacity of the student model and the knowledge transfer play a vital role. KD methods in the context of NMT have also leveraged sequence-level distillation, data augmentation, and regularization techniques to achieve good performance.

- On the other hand, in NLU, BERT models have been a focal point. These deep models are effective but resource-intensive. To make them more lightweight and deployable, several lightweight BERT model variations have been created using knowledge distillation. These models, referred to as BERT model compression, aim to maintain the efficiency of the original BERT models while reducing their complexity.

- As an example, DistilBERT represents a more compact, quicker, cost-effective, and less resource-intensive variant of the BERT model, developed by Hugging Face. During the pre-training phase, knowledge distillation was utilized to produce a distilled BERT model that is 40% smaller (66 million parameters compared to 110 million parameters) and 60% faster, all while preserving approximately 97% of the original BERT model's accuracy.







# Distillation Algorithms

### 1. Adversarial Distillation

Adversarial distillation, inspired by generative adversarial networks (GANs), trains a generator to create synthetic data resembling the true distribution and a discriminator to differentiate genuine from synthetic data. This concept extends to knowledge distillation, improving both teacher and student models' understanding of the true data distribution. It can be implemented by training a generator for synthetic data, using a discriminator to help the student mimic the teacher, or by jointly optimizing the student and teacher models online.

<!-- ![GAN_KD](https://drive.google.com/uc?id=1MCdWB6F1VWu4DGetSauDQCECQl6qLqsa) -->
![GAN_KD](https://i.postimg.cc/kgy5SgYj/image.png)

### 2. Multi-Teacher Distillation

Multi-teacher distillation involves a student model learning from multiple teacher models, providing diverse knowledge. The knowledge from different teachers, typically based on logits and feature representations, can be combined by averaging their responses. This approach offers advantages over learning from a single teacher.

<!-- ![MulTeacher_KD](https://drive.google.com/uc?id=1wrebtycIoQTUa4elKUot_8KON4Bgrm1l) -->
![MulTeacher_KD](https://i.postimg.cc/fTnWg962/image.png)

### 3. Cross-Modal Distillation

Cross-modal distillation transfers knowledge between modalities. It's particularly useful in the visual domain, where a teacher trained on labeled image data imparts knowledge to a student with an unlabeled input domain like optical flow, text, or audio. The teacher's image features guide the supervised training of the student. This approach is applied in applications such as visual question answering and image captioning.

<!-- ![CrossModel_KD](https://drive.google.com/uc?id=17PTckHIxI3Y4xaOmiSjFEUI3_9CA_zsL) -->
![CrossModel_KD](https://i.postimg.cc/fyh45xJY/image.png)

### 4. Graph-Based Distillation

Graph-based distillation methods aim to use the graph as a carrier of teacher knowledge or to control the message passing of teacher knowledge. These methods often involve constructing graphs based on teacher knowledge and using them to facilitate knowledge transfer. Various techniques have been proposed, including using locality-preserving loss functions, multi-head graphs, mutual relations of data samples, similarity matrices, and instance features and relationships modeled as graph vertices and edges. Some methods also control knowledge transfer using graphs by considering modality discrepancies, privileged information, and diverse knowledge transfer patterns. Constructing graphs to model the structure knowledge of data remains a challenging research area in graph-based distillation.

<!-- ![GraphBased_KD](https://drive.google.com/uc?id=1dn4Ky0y-M0-Cr_ZAxwUbdvznlYDw-C39) -->
![GraphBased_KD](https://i.postimg.cc/L88cy46Q/image.png)

### 5. Attention-Based Distillation

Attention mechanisms are employed in knowledge distillation to improve the performance of student networks by reflecting neuron activations in convolutional neural networks. Different attention-based knowledge distillation methods define various attention transfer mechanisms to distill knowledge from the teacher network to the student network. These methods revolve around defining attention maps for feature embeddings in neural network layers, effectively transferring knowledge about feature embeddings using attention map functions. Additionally, one approach utilizes an attention mechanism to assign different confidence rules for knowledge distillation.

### 6. Data-Free Distillation

Data-free knowledge distillation methods address issues related to unavailability of data due to privacy, legal, security, or confidentiality concerns. These methods do not rely on existing training data but instead generate synthetic data. Several approaches have been proposed, including the use of Generative Adversarial Networks (GANs) to create transfer data, reconstructing transfer data from teacher network activations or spectral activations, and techniques that model the softmax space. Additionally, some methods explore knowledge distillation with few-shot learning, where the teacher uses limited labeled data. Data distillation, a related concept, employs new training annotations from unlabeled data generated by the teacher model to train the student.

<!-- ![DataFree_KD](https://drive.google.com/uc?id=1_WlZZCHUo-MTGpB4QvdzpKeRTytokwqq) -->
![DataFree_KD](https://i.postimg.cc/yxK50sbP/image.png)

### 7. Quantized Distillation

Quantized distillation combines network quantization, which reduces neural network complexity by converting high-precision networks into low-precision ones, with knowledge distillation. Several methods in this category aim to transfer knowledge from full-precision teacher networks to small, quantized student networks. This involves weight quantization, co-studying teacher and student networks, and the self-study of quantized student networks. Some recent approaches also incorporate self-distillation to improve the performance of quantized deep models, where teachers share model parameters with students.

<!-- ![Quantized_KD](https://drive.google.com/uc?id=1IZSTGqhK36YmkYQA2Zz0YqQFUqzcJJJ-) -->
![Quantized_KD](https://i.postimg.cc/d1PSKJwY/image.png)

### 8. Lifelong Distillation

Lifelong learning aims to accumulate and transfer knowledge over time. Knowledge distillation, a tool for preserving and transferring knowledge, has given rise to variants based on lifelong learning. These include meta-transfer networks, the Leap framework for meta-learning, and methods like semantic-aware knowledge preservation and global distillation. These approaches mitigate catastrophic forgetting in lifelong learning scenarios.

### 9. NAS-Based Distillation

Neural architecture search (NAS) automates the discovery of deep neural models. Knowledge distillation is integrated into NAS to address the challenge of transferring knowledge from large teacher models to smaller student models. This includes techniques like AdaNAS, NAS with distilled architecture knowledge, teacher-guided search for architectures (TGSA), and one-shot NAS. TGSA, for instance, guides architecture searches to mimic the teacher's intermediate feature representations for effective and efficient knowledge transfer.

# Code Example:

Let us see an example of knowledge distillation using the student-teacher architecture on the **MNIST dataset**. Let us start by importing the necessary libraries.

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm import tqdm

We now create a teacher and smaller student model both of which are convolutional neural networks. The student model takes grayscale images (3 channels) as input and outputs class predictions (10 classes). The teacher model is similar in architecture to the student model but has more complex and larger layers.

In [None]:
class Student(nn.Module):
	def __init__(self):
		super(Student, self).__init__()
		self.conv = nn.Sequential(
			nn.Conv2d(3, 32, 3, 2),
			nn.ReLU(),
			nn.Conv2d(32, 32, 3, 2),
			nn.ReLU(),
			nn.Conv2d(32, 32, 2, 1),
			nn.ReLU(),
			nn.AdaptiveAvgPool2d(2)
		)

		self.fc = nn.Linear(128, 10)

	def forward(self, x):
		out = self.conv(x)
		out = out.view(x.shape[0], -1)
		out = self.fc(out)
		return out


class Teacher(nn.Module):
	def __init__(self):
		super(Teacher, self).__init__()
		self.conv = nn.Sequential(
			nn.Conv2d(3, 128, 3, 2),
			nn.ReLU(),
			nn.Conv2d(128, 128, 3, 2),
			nn.ReLU(),
			nn.Conv2d(128, 128, 2, 1),
			nn.ReLU(),
			nn.AdaptiveAvgPool2d(2)
		)

		self.fc = nn.Linear(512, 10)

	def forward(self, x):
		out = self.conv(x)
		out = out.view(x.shape[0], -1)
		out = self.fc(out)
		return out

Define the cross-entropy loss function to calculate the loss between the model's output and the ground truth labels.

In [None]:
def cross_entropy_loss(output, target):
	return -torch.sum(output.log() * target) / output.shape[0]

We will now do the following:
- Define data transformations to preprocess the images (resize, convert to grayscale, and transform to tensors).
- Create training and testing datasets with these transformations.
- Set up data loaders for efficient batching and shuffling.

In [None]:
batch_size = 64
num_epochs = 5
temperature = 3
loss_lambda = 0.5

input_transform = transforms.Compose([
	transforms.Resize(28),
	transforms.Grayscale(3),
	transforms.ToTensor()
])

train_data = MNIST('./data', train=True, transform=input_transform, download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_data = MNIST('./data', train=False, transform=input_transform, download=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
len_train_data = len(train_data)
len_test_data = len(test_data)

Firstly, we train the **student model without knowledge distillation** using standard cross-entropy loss.

In [None]:
# ------------------------------------------------------------------------------------------------------#

print('\n\nTraining student without teacher...')
student_without_teacher = Student().cuda()
optimizer = torch.optim.Adam(student_without_teacher.parameters())
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
	print(f'Epoch {epoch + 1}/{num_epochs}:')

	train_loss = .0
	train_acc = .0
	student_without_teacher.train()
	for batch_x, batch_y in tqdm(train_loader):
		batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

		out = torch.softmax(student_without_teacher(batch_x), dim=1)
		loss = criterion(out, batch_y)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		train_loss += loss.data.item()
		train_acc += (torch.max(out, 1)[1] == batch_y).sum().data.item()
	print('Train loss: {:.6f}, acc: {:.6f}'.format(train_loss / len_train_data, train_acc / len_train_data))

	eval_loss = .0
	eval_acc = .0
	student_without_teacher.eval()
	with torch.no_grad():
		for batch_x, batch_y in tqdm(test_loader):
			batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

			out = torch.softmax(student_without_teacher(batch_x), dim=1)
			loss = criterion(out, batch_y)

			eval_loss += loss.data.item()
			eval_acc += (torch.max(out, 1)[1] == batch_y).sum().data.item()
		print('Eval loss: {:.6f}, acc: {:.6f}'.format(eval_loss / len_test_data, eval_acc / len_test_data))



Training student without teacher...
Epoch 1/5:


100%|██████████| 938/938 [00:20<00:00, 45.74it/s]


Train loss: 0.027950, acc: 0.682200


100%|██████████| 157/157 [00:02<00:00, 54.17it/s]


Eval loss: 0.026629, acc: 0.764400
Epoch 2/5:


100%|██████████| 938/938 [00:17<00:00, 53.71it/s]


Train loss: 0.026460, acc: 0.768717


100%|██████████| 157/157 [00:02<00:00, 62.99it/s]


Eval loss: 0.026377, acc: 0.780200
Epoch 3/5:


100%|██████████| 938/938 [00:18<00:00, 51.43it/s]


Train loss: 0.026337, acc: 0.775800


100%|██████████| 157/157 [00:02<00:00, 67.34it/s]


Eval loss: 0.026314, acc: 0.785000
Epoch 4/5:


100%|██████████| 938/938 [00:17<00:00, 52.35it/s]


Train loss: 0.026226, acc: 0.782950


100%|██████████| 157/157 [00:02<00:00, 57.88it/s]


Eval loss: 0.026228, acc: 0.789000
Epoch 5/5:


100%|██████████| 938/938 [00:17<00:00, 52.83it/s]


Train loss: 0.026167, acc: 0.786267


100%|██████████| 157/157 [00:02<00:00, 67.01it/s]

Eval loss: 0.026163, acc: 0.793300





Now, we **train the teacher model** using the same procedure as training the student.

In [None]:
# ------------------------------------------------------------------------------------------------------#

print('\n\nTraining teacher...')
teacher = Teacher().cuda()
optimizer = torch.optim.Adam(teacher.parameters())
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
	print(f'Epoch {epoch + 1}/{num_epochs}:')

	train_loss = .0
	train_acc = .0
	teacher.train()
	for batch_x, batch_y in tqdm(train_loader):
		batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

		out = torch.softmax(teacher(batch_x), dim=1)
		loss = criterion(out, batch_y)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		train_loss += loss.data.item()
		train_acc += (torch.max(out, 1)[1] == batch_y).sum().data.item()
	print('Train loss: {:.6f}, acc: {:.6f}'.format(train_loss / len_train_data, train_acc / len_train_data))

	eval_loss = .0
	eval_acc = .0
	teacher.eval()
	with torch.no_grad():
		for batch_x, batch_y in tqdm(test_loader):
			batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

			out = torch.softmax(teacher(batch_x), dim=1)
			loss = criterion(out, batch_y)

			eval_loss += loss.data.item()
			eval_acc += (torch.max(out, 1)[1] == batch_y).sum().data.item()
		print('Eval loss: {:.6f}, acc: {:.6f}'.format(eval_loss / len_test_data, eval_acc / len_test_data))



Training teacher...
Epoch 1/5:


100%|██████████| 938/938 [00:22<00:00, 42.24it/s]


Train loss: 0.028455, acc: 0.641000


100%|██████████| 157/157 [00:03<00:00, 51.36it/s]


Eval loss: 0.025109, acc: 0.863800
Epoch 2/5:


100%|██████████| 938/938 [00:22<00:00, 41.46it/s]


Train loss: 0.024835, acc: 0.873300


100%|██████████| 157/157 [00:02<00:00, 53.40it/s]


Eval loss: 0.024741, acc: 0.884900
Epoch 3/5:


100%|██████████| 938/938 [00:22<00:00, 42.44it/s]


Train loss: 0.024085, acc: 0.921300


100%|██████████| 157/157 [00:02<00:00, 52.86it/s]


Eval loss: 0.023296, acc: 0.978700
Epoch 4/5:


100%|██████████| 938/938 [00:22<00:00, 42.04it/s]


Train loss: 0.023220, acc: 0.977017


100%|██████████| 157/157 [00:02<00:00, 52.58it/s]


Eval loss: 0.023199, acc: 0.983700
Epoch 5/5:


100%|██████████| 938/938 [00:22<00:00, 42.54it/s]


Train loss: 0.023156, acc: 0.980700


100%|██████████| 157/157 [00:02<00:00, 53.25it/s]

Eval loss: 0.023225, acc: 0.982700





Finally, we **train the student model with knowledge distillation**.
We use a combination of standard cross-entropy loss and an additional loss function (cross_entropy_loss) to distill knowledge from the teacher.
Temperature scaling is applied during knowledge distillation, controlling the softness of the predicted probabilities.

In [None]:
# ------------------------------------------------------------------------------------------------------#

print('\n\nTraining student with teacher...')
student_with_teacher = Student().cuda()
optimizer = torch.optim.Adam(student_with_teacher.parameters())
criterion = [nn.CrossEntropyLoss(), cross_entropy_loss]
for epoch in range(num_epochs):
	print(f'Epoch {epoch + 1}/{num_epochs}:')

	train_loss = .0
	train_acc = .0
	student_with_teacher.train()
	for batch_x, batch_y in tqdm(train_loader):
		batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

		out = student_with_teacher(batch_x)
		batch_y_teacher = torch.softmax(teacher(batch_x) / temperature, dim=1).detach()
		out_result = torch.softmax(out, dim=1)
		loss = (1 - loss_lambda) * criterion[0](out_result, batch_y) + \
		       loss_lambda * temperature**2 * criterion[1](torch.softmax(out / temperature, dim=1), batch_y_teacher)
		assert not torch.isnan(loss)

		optimizer.zero_grad()
		loss.backward()
		torch.nn.utils.clip_grad_norm_(student_with_teacher.parameters(), 1)
		optimizer.step()

		train_loss += loss.data.item()
		train_acc += (torch.max(out_result, 1)[1] == batch_y).sum().data.item()
	print('Train loss: {:.6f}, acc: {:.6f}'.format(train_loss / len_train_data, train_acc / len_train_data))

	eval_loss = .0
	eval_acc = .0
	student_with_teacher.eval()
	with torch.no_grad():
		for batch_x, batch_y in tqdm(test_loader):
			batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

			out = student_with_teacher(batch_x)
			batch_y_teacher = torch.softmax(teacher(batch_x) / temperature, dim=1).detach()
			out_result = torch.softmax(out, dim=1)
			loss = (1 - loss_lambda) * criterion[0](out_result, batch_y) + \
			       loss_lambda * temperature ** 2 * criterion[1](torch.softmax(out / temperature, dim=1), batch_y_teacher)

			eval_loss += loss.data.item()
			eval_acc += (torch.max(out_result, 1)[1] == batch_y).sum().data.item()
		print('Eval loss: {:.6f}, acc: {:.6f}'.format(eval_loss / len_test_data, eval_acc / len_test_data))



Training student with teacher...
Epoch 1/5:


100%|██████████| 938/938 [00:21<00:00, 43.47it/s]


Train loss: 0.054643, acc: 0.820483


100%|██████████| 157/157 [00:03<00:00, 52.15it/s]


Eval loss: 0.029902, acc: 0.922200
Epoch 2/5:


100%|██████████| 938/938 [00:21<00:00, 44.52it/s]


Train loss: 0.026885, acc: 0.938950


100%|██████████| 157/157 [00:03<00:00, 51.00it/s]


Eval loss: 0.023749, acc: 0.954000
Epoch 3/5:


100%|██████████| 938/938 [00:21<00:00, 43.40it/s]


Train loss: 0.023196, acc: 0.955067


100%|██████████| 157/157 [00:02<00:00, 53.22it/s]


Eval loss: 0.022308, acc: 0.954500
Epoch 4/5:


100%|██████████| 938/938 [00:21<00:00, 43.36it/s]


Train loss: 0.021392, acc: 0.963200


100%|██████████| 157/157 [00:03<00:00, 52.04it/s]


Eval loss: 0.019958, acc: 0.971500
Epoch 5/5:


100%|██████████| 938/938 [00:21<00:00, 43.93it/s]


Train loss: 0.020283, acc: 0.967917


100%|██████████| 157/157 [00:03<00:00, 51.81it/s]

Eval loss: 0.019338, acc: 0.973400





If the teacher is trained for 5 full epochs and the student is distilled on this teacher for 5 full epochs, we can see in this example a performance boost compared to training the same student model from scratch. We can see the teacher to have accuracy around 98.3%, the student trained from scratch is around 80%, and the distilled student is around 97.3%. The student model with knowledge distillation from the teacher shows substantial improvements in accuracy during training and evaluation. It achieves approximately 97.34% accuracy on the evaluation dataset by the end of training. Knowledge distillation has effectively transferred knowledge from the teacher to the student, resulting in the student model approaching the accuracy of the teacher. The training loss is relatively higher than in the other cases because the distillation loss encourages softer predictions, leading to a more challenging optimization problem.

In summary, we can say that knowledge distillation is an effective technique for transferring knowledge from a more complex teacher model to a smaller student model. The student model, when trained with knowledge distillation, can achieve accuracy levels close to the teacher's accuracy, demonstrating the benefits of this approach for model compression and efficient deployment.