# Knowledge Distillation 

Current state-of-the-art performance in AI and ML is mainly driven by large and complex deep neural network models that often consist of multiple billions of model parameters. Fortunately, for many applications, pre-trained models can be leveraged through transfer learning, avoiding the burden of training large models from scratch.

However, as transfer learning from large pre-trained models becomes more prevalent, deploying these large models to run on devices with limited processing power, such as edge devices (i.e. IoT devices), is challenging. While deep learning models often achieve excellent accuracy, they often fail to meet other requirements such as latency and memory footprint.

In this notebook, we demonstrate how a compression technique called knowledge distillation (KD) helps to transfer knowledge from a larger into a smaller, more compact neural network model. In this way, we can benefit (partially) from the knowledge of the larger model and still retain the small memory footprint and inference latency of the smaller model.

## What is Knowledge Distillation exactly?

<center>
<img src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/Knowledge-Distillation_1.png?ssl=1" width=30%></br>
<a href="https://arxiv.org/pdf/2006.05525.pdf">Source</a>
</center>

This notebook focuses on *response-based knowledge distillation*. Response-based knowledge distillation is a compression technique where a student model is optimised to reproduce the outputs of a larger 'teacher' model. The technique is described in the paper by Hinton et al. (2015) ([paper](https://arxiv.org/abs/1503.02531)). 

The idea behind response based KD is intuitive, we first train a large 'teacher' model and show the predictions made by the teacher to the 'student model'. When training the student model, we calculate the loss based on both the predictions of the student (logits),  and on the logits of the Teacher. Hence, the student learns from both its own predictions as well as the predictions made by the teacher. 

<center>
<img src="https://miro.medium.com/v2/resize:fit:1400/0*B8vlOvK1N_CSgZMo" width=30%></br>
<a href="https://arxiv.org/pdf/2006.05525.pdf">Source</a>
</center>

For more details on other KD approaches see the survey paper by Gou et al. (See [paper](https://arxiv.org/pdf/2006.05525.pdf)).

## Where is Knowledge Distillation used?

A famous example of knowledge distillation is the DistilBERT model. DistilBERT ([link](https://arxiv.org/pdf/1910.01108.pdf)) is a faster and lighter version of the BERT model ([link](https://arxiv.org/abs/1810.04805)). Thanks to knowledge distillation, DistilBERT is 40% smaller, 60% faster, while retaining 97% of the language understanding capabilities.

## In this tutorial

In this notebook we will use the Dogs vs. Cats dataset from Kaggle.  The dataset contains approx. 25,000 images of cats and dogs. The goal is to train a computer vision model that can predict whether a cat or dog is in an image. The aim is to illustrate how the accuracy of a very simple CNN can be boosted through knowledge distillation from a larger and more complex model (DenseNet 121).

- [Setup](#Setup)
- [Functions](#Functions)
- [Data](#Load-data)
- [Experiments](#Experiments)
    - Fine tune the teacher model on the Dogs vs. Cats prediction task
    - Train the student model without knowledge distillation
    - Train the student model with Knowledge distillation
- [Conclusion](#Conclusion)

## Setup

In [1]:
import tqdm
import numpy as np
import PIL
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# path to the directory containing the training data
data_dir = '../data/train'

# Train, validation and test set percentage
train_val_test_split = (0.7,0.2,0.1)

# Hyperparameters for training our models
num_workers=6
batch_size = 100
epochs = 10
lr = 0.001

# Knowledge distillation parameters
alpha = 0.5
temperature = 5

## Functions

Here we define a number of python functions and classes that we use in our notebook:

- `get_model_size`: to calculate how 'large' a model is in terms of memory footprint. 
- `DistillationLoss`: a class defining the distillation loss function 
- `train`: a function to train a pytorch model
- `test`: a function to test our trained pytorch models
- `train_with_distillation`: similar to the `train` function, except for the fact that the distillation loss is used instead of the regular loss function

In [3]:
def get_model_size(model):
    """function to calculate the model size in MB

    Args:
        model (nn.Module): pytorch model
    """
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_all_mb = (param_size + buffer_size) / 1024**2

    return size_all_mb

In [4]:
class DistillationLoss:

    """Custom loss calculcation combining
    the loss of the student model with the distillation loss
    """

    def __init__(self, student_loss, temperature=1, alpha=0.25):
        self.student_loss = student_loss
        self.distillation_loss = nn.KLDivLoss()
        self.temperature = temperature
        self.alpha = alpha

    def __call__(self, student_logits, student_target_loss, teacher_logits):
        distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1),
                                                   F.softmax(teacher_logits / self.temperature, dim=1))

        loss = (1 - self.alpha) * student_target_loss + (self.alpha * distillation_loss * (self.temperature**2))
        return loss, distillation_loss


In [5]:
def train(dataloader, model, loss_fn, optimizer):

    """Simple training function looping over a dataloader to optimize a model with given optimizer and loss function.
    """

    size = len(dataloader.dataset)
    model.train()
    for X, y in tqdm.tqdm(dataloader, desc = "Training", unit = " Iterations"):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()



def test(dataloader, model, loss_fn):

    """test function evaluating a trained model on test data provided through the dataloader argument.
    """

    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in tqdm.tqdm(dataloader, desc = "Validating", unit="Iterations"):
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

def train_with_distillation(dataloader, student_model, teacher_model, loss_fn, optimizer, alpha=0.25, temperature=1):

    """training function to train a student_model with knowledge distillation from a teacher_model. 
    """

    distillation_loss = DistillationLoss(student_loss=loss_fn, alpha=alpha, temperature=temperature)
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    student_model.train()
    teacher_model.eval()

    train_loss, train_dist_loss, correct = 0, 0, 0

    for X, y in tqdm.tqdm(dataloader, desc = "Training with Distillation", unit = " Iterations"):
        X, y = X.to(device), y.to(device)

        # Let student and teacher both make predictions
        pred_student = student_model(X)

        with torch.no_grad():
          pred_teacher = teacher_model(X)

        # Compute the regular student loss
        student_target_loss = loss_fn(pred_student, y)
        # Combine student loss with the loss resulting from difference between student and teacher predictions
        loss, dist_loss = distillation_loss(pred_student, student_target_loss, pred_teacher)

        train_loss += loss.item()
        train_dist_loss += dist_loss.item()
        correct += (pred_student.argmax(1) == y).type(torch.float).sum().item()

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        #print("student loss {}".format(student_target_loss))
        #print("distillation loss {}".format(dist_loss))
        #print("loss {}".format(loss))

    train_loss /= num_batches
    train_dist_loss /= num_batches
    correct /= size
    print(f"Train Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {train_loss:>8f}, Avg dist loss: {train_dist_loss:>8f}\n")
        

## Student/Teacher Models

In this section we define our Student and Teacher model architectures. For the Teacher we use a pretrained Densenet (densenet 121) with a modified classifier head. For the Student we implement a very simple and shallow CNN network with only three convolutional layers. 

In [6]:
class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.densenet121(weights='DEFAULT')
        for params in self.model.parameters():
            params.requires_grad_ = False

        num_ftrs = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 250),
            nn.Linear(250, 50),
            nn.Linear(50, 10)
            )
        
    def forward(self, x):
        x = self.model(x)
        return x

In [7]:
class Student(nn.Module):

    def __init__(self):
        super().__init__()

        # onvolutional layers (3,16,32)
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size=(5, 5), stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size=(5, 5), stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size=(3, 3), padding=1)

        # conected layers
        self.fc1 = nn.Linear(in_features= 64 * 3 * 3, out_features=250)
        self.fc2 = nn.Linear(in_features=250, out_features=50)
        self.fc3 = nn.Linear(in_features=50, out_features=10)


    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)

        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

# Load Data

1. Download the dataset from kaggle: [link](https://www.kaggle.com/c/dogs-vs-cats).
2. Unzip the dogs-vs-cats.zip and find the train.zip file. Unzip the train.zip file and put the files in an easy to reach directory. 

If you are using Google Colab follow the steps below: 



**Step 1**: 
Use below code to upload your kaggle.json to colab environment (you can download kaggle.json from your Profile->Account->API Token)

```
from google.colab import files
files.upload()
```

**Step 2**:
Below code will remove any existing ~/.kaggle directory and create a new one. It will also move your kaggle.json to ~/.kaggle

```
!rm -r ~/.kaggle
!mkdir ~/.kaggle
!mv ./kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
```

**Step 3**:
Download Dataset. 

```
!kaggle competitions download -c dogs-vs-cats
```

**Step 4**:

```
!mkdir data
!unzip -o -q dogs-vs-cats.zip -d ./data/ 
!unzip -o -q ./data/train.zip -d ./data/ 
```

In [8]:
# Define data transformations for both training and testing phases

train_transform = transforms.Compose([
    #transforms.Resize(256),
    #transforms.ColorJitter(),
    #transforms.RandomCrop(224),
    #transforms.RandomHorizontalFlip(),
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

In [9]:
trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                        download=True, transform=train_transform)
train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=val_transform)
test_dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [10]:
# path to the directory containing the training data
data_dir = '../data/train'


# Experiments

1. First, we fine-tune the teacher model (DenseNet) on our Dogs vs. Cats prediction task. 
2. The student model is trained without knowledge distillation
3. The student model is trained again but with knowledge distillation leveraging the predictions of the fine-tuned teacher model (see 1.)

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: {}".format(device))

Device: cuda


## Fine-tune Teacher model on Dogs vs. Cats Prediction Task

In [12]:
teacher_model = Teacher().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(teacher_model.parameters(), lr=lr, amsgrad=True)

In [13]:
# Fine-tune the final classification layers of the teacher model
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, teacher_model, loss_fn=criterion, optimizer=optimizer)
    test(test_dataloader, teacher_model, loss_fn=criterion)
print("Done!")



Epoch 1
-------------------------------
Test Error: 
 Accuracy: 86.6%, Avg loss: 0.390192 

Epoch 2
-------------------------------
Test Error: 
 Accuracy: 88.9%, Avg loss: 0.336189 

Epoch 3
-------------------------------
Test Error: 
 Accuracy: 90.9%, Avg loss: 0.283463 

Epoch 4
-------------------------------
Test Error: 
 Accuracy: 91.5%, Avg loss: 0.267153 

Epoch 5
-------------------------------
Test Error: 
 Accuracy: 92.0%, Avg loss: 0.264558 

Epoch 6
-------------------------------
Test Error: 
 Accuracy: 91.0%, Avg loss: 0.292799 

Epoch 7
-------------------------------
Test Error: 
 Accuracy: 92.2%, Avg loss: 0.256562 

Epoch 8
-------------------------------
Test Error: 
 Accuracy: 92.6%, Avg loss: 0.269953 

Epoch 9
-------------------------------
Test Error: 
 Accuracy: 92.7%, Avg loss: 0.273649 

Epoch 10
-------------------------------
Test Error: 
 Accuracy: 93.7%, Avg loss: 0.238995 

Done!


Training: 100%|██████████| 500/500 [01:24<00:00,  5.94 Iterations/s]
Validating: 100%|██████████| 100/100 [00:05<00:00, 19.61Iterations/s]
Training: 100%|██████████| 500/500 [01:22<00:00,  6.09 Iterations/s]
Validating: 100%|██████████| 100/100 [00:05<00:00, 19.76Iterations/s]
Training: 100%|██████████| 500/500 [01:22<00:00,  6.05 Iterations/s]
Validating: 100%|██████████| 100/100 [00:05<00:00, 19.70Iterations/s]
Training: 100%|██████████| 500/500 [01:22<00:00,  6.07 Iterations/s]
Validating: 100%|██████████| 100/100 [00:05<00:00, 19.63Iterations/s]
Training: 100%|██████████| 500/500 [01:22<00:00,  6.05 Iterations/s]
Validating: 100%|██████████| 100/100 [00:05<00:00, 19.64Iterations/s]
Training: 100%|██████████| 500/500 [01:22<00:00,  6.04 Iterations/s]
Validating: 100%|██████████| 100/100 [00:05<00:00, 19.69Iterations/s]
Training: 100%|██████████| 500/500 [01:22<00:00,  6.09 Iterations/s]
Validating: 100%|██████████| 100/100 [00:05<00:00, 19.56Iterations/s]
Training: 100%|██████████| 

## Train Student Model without Knowledge Distillation

In [14]:
student_model = Student().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(student_model.parameters(), lr=lr, amsgrad=True)

In [15]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, student_model, loss_fn=criterion, optimizer=optimizer)
    test(test_dataloader, student_model, loss_fn=criterion)
print("Done!")

Epoch 1
-------------------------------
Test Error: 
 Accuracy: 45.1%, Avg loss: 1.478924 

Epoch 2
-------------------------------
Test Error: 
 Accuracy: 50.5%, Avg loss: 1.347503 

Epoch 3
-------------------------------
Test Error: 
 Accuracy: 57.8%, Avg loss: 1.179652 

Epoch 4
-------------------------------
Test Error: 
 Accuracy: 56.3%, Avg loss: 1.207788 

Epoch 5
-------------------------------
Test Error: 
 Accuracy: 60.1%, Avg loss: 1.129663 

Epoch 6
-------------------------------
Test Error: 
 Accuracy: 64.9%, Avg loss: 0.995375 

Epoch 7
-------------------------------
Test Error: 
 Accuracy: 66.4%, Avg loss: 0.961396 

Epoch 8
-------------------------------
Test Error: 
 Accuracy: 66.7%, Avg loss: 0.945566 

Epoch 9
-------------------------------
Test Error: 
 Accuracy: 68.0%, Avg loss: 0.923723 

Epoch 10
-------------------------------
Test Error: 
 Accuracy: 68.9%, Avg loss: 0.898444 

Done!


Training: 100%|██████████| 500/500 [00:14<00:00, 33.70 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.70Iterations/s]
Training: 100%|██████████| 500/500 [00:14<00:00, 33.47 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 31.66Iterations/s]
Training: 100%|██████████| 500/500 [00:14<00:00, 33.49 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.53Iterations/s]
Training: 100%|██████████| 500/500 [00:14<00:00, 33.60 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.43Iterations/s]
Training: 100%|██████████| 500/500 [00:14<00:00, 33.61 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.36Iterations/s]
Training: 100%|██████████| 500/500 [00:15<00:00, 33.16 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.04Iterations/s]
Training: 100%|██████████| 500/500 [00:14<00:00, 33.59 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.30Iterations/s]
Training: 100%|██████████| 

## Train Student Model with Knowledge Distillation

In [16]:
student_model_distilled = Student()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(student_model_distilled.parameters(), lr=lr, amsgrad=True)
teacher_model = teacher_model.to(device)
student_model_distilled = student_model_distilled.to(device)

In [17]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_with_distillation(train_dataloader, student_model_distilled, teacher_model, loss_fn=criterion, optimizer=optimizer, alpha=alpha, temperature=temperature)
    test(test_dataloader, student_model_distilled, loss_fn=criterion)
print("Done!")


Epoch 1
-------------------------------
Train Error: 
 Accuracy: 32.7%, Avg loss: 2.334482, Avg dist loss: 0.111679

Test Error: 
 Accuracy: 45.0%, Avg loss: 1.557855 

Epoch 2
-------------------------------
Train Error: 
 Accuracy: 47.6%, Avg loss: 1.897204, Avg dist loss: 0.089274

Test Error: 
 Accuracy: 50.6%, Avg loss: 1.471544 

Epoch 3
-------------------------------
Train Error: 
 Accuracy: 53.1%, Avg loss: 1.720639, Avg dist loss: 0.080178

Test Error: 
 Accuracy: 55.7%, Avg loss: 1.389279 

Epoch 4
-------------------------------
Train Error: 
 Accuracy: 57.7%, Avg loss: 1.571372, Avg dist loss: 0.072409

Test Error: 
 Accuracy: 57.5%, Avg loss: 1.355817 

Epoch 5
-------------------------------
Train Error: 
 Accuracy: 61.6%, Avg loss: 1.440392, Avg dist loss: 0.065861

Test Error: 
 Accuracy: 60.4%, Avg loss: 1.291276 

Epoch 6
-------------------------------
Train Error: 
 Accuracy: 64.4%, Avg loss: 1.337352, Avg dist loss: 0.060624

Test Error: 
 Accuracy: 62.5%, Avg los

Training with Distillation: 100%|██████████| 500/500 [00:27<00:00, 18.23 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.26Iterations/s]
Training with Distillation: 100%|██████████| 500/500 [00:27<00:00, 18.24 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.30Iterations/s]
Training with Distillation: 100%|██████████| 500/500 [00:27<00:00, 18.31 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.14Iterations/s]
Training with Distillation: 100%|██████████| 500/500 [00:27<00:00, 18.29 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.10Iterations/s]
Training with Distillation: 100%|██████████| 500/500 [00:27<00:00, 18.22 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.33Iterations/s]
Training with Distillation: 100%|██████████| 500/500 [00:27<00:00, 18.26 Iterations/s]
Validating: 100%|██████████| 100/100 [00:03<00:00, 32.16Iterations/s]
Training with Distillation: 100%|██████████| 500/500 [00:2

# Conclusion

Finally, we compare the performance of our two student models, one with knowledge distillation from a teacher model and a second one without. 

The accuracy is measured on part of our dataset (test set) that none of the models has seen during training. 

How good are the students in telling cats and dogs apart?

In [18]:
# Teacher model 
test(test_dataloader, teacher_model, criterion)

Validating: 100%|██████████| 100/100 [00:05<00:00, 19.49Iterations/s]


Test Error: 
 Accuracy: 93.7%, Avg loss: 0.238995 



In [19]:
# Student model with knowledge distillation.
test(test_dataloader, student_model_distilled, criterion)

Validating: 100%|██████████| 100/100 [00:03<00:00, 32.28Iterations/s]


Test Error: 
 Accuracy: 68.8%, Avg loss: 1.170198 



In [20]:
# Student model without knowledge distillation.
test(test_dataloader, student_model, criterion)

Validating: 100%|██████████| 100/100 [00:03<00:00, 31.98Iterations/s]


Test Error: 
 Accuracy: 68.9%, Avg loss: 0.898444 



The result can vary between individual runs due to the stochastic nature of these models. However, the knowledge distillation should result in an increase in model accuracy in the range of 3-6 percentage points. 

Besides accuracy, model size also plays an important role. The aim of this experiment was to show how smaller models can learn from larger, more complex models through KD. Let's inspect how much smaller our student models are in comparison to the teacher model:

In [21]:
# Calculate the model size 
model_size_student = get_model_size(student_model)
model_size_teacher = get_model_size(teacher_model)

print('model size teachermodel : {:.3f}MB'.format(model_size_teacher))
print('model size student model: {:.3f}MB'.format(model_size_student))

model size teachermodel : 27.874MB
model size student model: 0.724MB


In [22]:
percentage_reduction = ((model_size_student - model_size_teacher)/model_size_teacher)*100
print('reduction in model size: {:.2f}%'.format(percentage_reduction))

reduction in model size: -97.40%


# References 
- [Maximizing Model Performance with Knowlegde Distillation](https://medium.com/artificialis/maximizing-model-performance-with-knowledge-distillation-in-pytorch-12b3960a486a)
- [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)
- [Knowledge Distillation a Survey](https://arxiv.org/abs/2006.05525)


## Dataroots blog posts on model compression
- https://dataroots.io/research/contributions/deep-learning-model-compression/?ref=dataroots.ghost.io
- https://dataroots.io/research/contributions/model_compression/