https://arxiv.org/pdf/1503.02531.pdf

# Knowledge Distillation 

Current state-of-the-art performance in AI and ML is mainly driven by large and complex deep neural network models that often consist of multiple billions of model parameters. Fortunately, for many applications, pre-trained models can be leveraged through transfer learning, avoiding the burden of training large models from scratch.

However, as transfer learning from large pre-trained models becomes more prevalent, deploying these large models to run on devices with limited processing power, such as edge devices (i.e. IoT devices), is challenging. While deep learning models often achieve excellent accuracy, they often fail to meet other requirements such as latency and memory footprint.

In this notebook, we demonstrate how a compression technique called knowledge distillation (KD) helps to transfer knowledge from a larger into a smaller, more compact neural network model. In this way, we can benefit (partially) from the knowledge of the larger model and still retain the small memory footprint and inference latency of the smaller model.

## What is Knowledge Distillation exactly?

<center>
<img src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/Knowledge-Distillation_1.png?ssl=1" width=30%></br>
<a href="https://arxiv.org/pdf/2006.05525.pdf">Source</a>
</center>

This notebook focuses on *response-based knowledge distillation*. Response-based knowledge distillation is a compression technique where a student model is optimised to reproduce the outputs of a larger 'teacher' model. The technique is described in the paper by Hinton et al. (2015) ([paper](https://arxiv.org/abs/1503.02531)). 

The idea behind response based KD is intuitive, we first train a large 'teacher' model and show the predictions made by the teacher to the 'student model'. When training the student model, we calculate the loss based on both the predictions of the student (logits),  and on the logits of the Teacher. Hence, the student learns from both its own predictions as well as the predictions made by the teacher. 

<center>
<img src="https://miro.medium.com/v2/resize:fit:1400/0*B8vlOvK1N_CSgZMo" width=30%></br>
<a href="https://arxiv.org/pdf/2006.05525.pdf">Source</a>
</center>

For more details on other KD approaches see the survey paper by Gou et al. (See [paper](https://arxiv.org/pdf/2006.05525.pdf)).

## Where is Knowledge Distillation used?

A famous example of knowledge distillation is the DistilBERT model. DistilBERT ([link](https://arxiv.org/pdf/1910.01108.pdf)) is a faster and lighter version of the BERT model ([link](https://arxiv.org/abs/1810.04805)). Thanks to knowledge distillation, DistilBERT is 40% smaller, 60% faster, while retaining 97% of the language understanding capabilities.

## In this tutorial

In this notebook we will use the Dogs vs. Cats dataset from Kaggle.  The dataset contains approx. 25,000 images of cats and dogs. The goal is to train a computer vision model that can predict whether a cat or dog is in an image. The aim is to illustrate how the accuracy of a very simple CNN can be boosted through knowledge distillation from a larger and more complex model (DenseNet 121).

- [Setup](#Setup)
- [Functions](#Functions)
- [Data](#Load-data)
- [Experiments](#Experiments)
    - Fine tune the teacher model on the Dogs vs. Cats prediction task
    - Train the student model without knowledge distillation
    - Train the student model with Knowledge distillation
- [Conclusion](#Conclusion)

## Setup

In [1]:
import tqdm
import numpy as np
import PIL
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# path to the directory containing the training data
data_dir = '../data/train'

# Train, validation and test set percentage
train_val_test_split = (0.7,0.2,0.1)

# Hyperparameters for training our models
num_workers=6
batch_size = 100
epochs = 10
lr = 0.001

## Functions

Here we define a number of python functions and classes that we use in our notebook:

- `get_model_size`: to calculate how 'large' a model is in terms of memory footprint. 
- `DistillationLoss`: a class defining the distillation loss function 
- `train`: a function to train a pytorch model
- `test`: a function to test our trained pytorch models
- `train_with_distillation`: similar to the `train` function, except for the fact that the distillation loss is used instead of the regular loss function

In [3]:
def get_model_size(model):
    """function to calculate the model size in MB

    Args:
        model (nn.Module): pytorch model
    """
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_all_mb = (param_size + buffer_size) / 1024**2

    return size_all_mb

In [4]:
class DistillationLoss:

    """Custom loss calculcation combining
    the loss of the student model with the distillation loss
    """

    def __init__(self, student_loss, temperature=1, alpha=0.25):
        self.student_loss = student_loss
        self.distillation_loss = nn.KLDivLoss()
        self.temperature = 1
        self.alpha = 0.25

    def __call__(self, student_logits, student_target_loss, teacher_logits):
        distillation_loss = self.distillation_loss(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits/self.temperature, dim=1))
        loss = (1 - self.alpha) * student_target_loss \
            + self.alpha * distillation_loss
        return loss


In [5]:
def train(dataloader, model, loss_fn, optimizer):

    """Simple training function looping over a dataloader to optimize a model with given optimizer and loss function.
    """

    size = len(dataloader.dataset)
    model.train()
    for X, y in tqdm.tqdm(dataloader, desc = "Training", unit = " Iterations"):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()



def test(dataloader, model, loss_fn):

    """test function evaluating a trained model on test data provided through the dataloader argument.
    """

    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in tqdm.tqdm(dataloader, desc = "Validating", unit="Iterations"):
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

def train_with_distillation(dataloader, student_model, teacher_model, loss_fn, optimizer):

    """training function to train a student_model with knowledge distillation from a teacher_model. 
    """

    distillation_loss = DistillationLoss(student_loss=loss_fn)
    size = len(dataloader.dataset)
    student_model.train()
    teacher_model.eval()

    for X, y in tqdm.tqdm(dataloader, desc = "Training with Distillation", unit = " Iterations"):
        X, y = X.to(device), y.to(device)

        # Let student and teacher both make predictions
        pred_student = student_model(X)
        pred_teacher = teacher_model(X)

        # Compute the regular student loss
        student_target_loss = loss_fn(pred_student, y)
        # Combine student loss with the loss resulting from difference between student and teacher predictions
        loss = distillation_loss(pred_student, student_target_loss, pred_teacher)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        

## Student/Teacher Models

In this section we define our Student and Teacher model architectures. For the Teacher we use a pretrained Densenet (densenet 121) with a modified classifier head. For the Student we implement a very simple and shallow CNN network with only three convolutional layers. 

In [6]:
class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.densenet121(weights='DEFAULT')
        for params in self.model.parameters():
            params.requires_grad_ = False

        num_ftrs = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 500),
            nn.Linear(500, 2)
            )
        
    def forward(self, x):
        x = self.model(x)
        return x

In [7]:
class Student(nn.Module):

    def __init__(self):
        super().__init__()

        # onvolutional layers (3,16,32)
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size=(5, 5), stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size=(5, 5), stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size=(3, 3), padding=1)

        # conected layers
        self.fc1 = nn.Linear(in_features= 64 * 3 * 3, out_features=500)
        self.fc2 = nn.Linear(in_features=500, out_features=50)
        self.fc3 = nn.Linear(in_features=50, out_features=2)


    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)

        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

# Load Data

1. Download the dataset from kaggle: [link](https://www.kaggle.com/c/dogs-vs-cats).
2. Unzip the dogs-vs-cats.zip and find the train.zip file. Unzip the train.zip file and put the files in an easy to reach directory. 

If you are using Google Colab follow these steps: 
1. Upload your Kaggle API key (Json file)
2. Move the API key in the right location
3. 



**Step 1**: 
Use below code to upload your kaggle.json to colab environment (you can download kaggle.json from your Profile->Account->API Token)

```
from google.colab import files
files.upload()
```

**Step 2**:
Below code will remove any existing ~/.kaggle directory and create a new one. It will also move your kaggle.json to ~/.kaggle

```
!rm -r ~/.kaggle
!mkdir ~/.kaggle
!mv ./kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
```

**Step 3**:
Download Dataset. 

```
!kaggle competitions download -c dogs-vs-cats
```

**Step 4**:

```
!mkdir data
!unzip -o -q dogs-vs-cats.zip -d ./data/ 
!unzip -o -q ./data/train.zip -d ./data/ 
```

In [8]:
# path to the directory containing the training data
data_dir = '../data/train'

# Select only the image files from your data directory
files = os.listdir(data_dir)
files = [f for f in files if '.jpg' in f]


In [9]:
# Define data transformations for both training and testing phases

train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ColorJitter(),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

In [10]:
class DogsVsCatsDataset(Dataset):
    def __init__(self, file_list, dir, mode='train', transform = val_transform):
        self.file_list = file_list
        self.dir = dir
        #self.mode= mode
        self.transform = transform
            
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        img = PIL.Image.open(os.path.join(self.dir, self.file_list[idx]))
        img = self.transform(img)
        img = np.array(img)
        if 'dog' in self.file_list[idx]:
            self.label = 1
        else:
            self.label = 0
        return img.astype('float32'), self.label


train_files, test_files = train_test_split(files, 
                                    test_size=train_val_test_split[2], 
                                    random_state=42
                                    )
train_files, val_files = train_test_split(train_files,
                                    test_size=train_val_test_split[1]/train_val_test_split[0], 
                                    random_state=42
                                    )

train_dataset = DogsVsCatsDataset(train_files, dir = data_dir, transform = train_transform)
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers=num_workers)

val_dataset = DogsVsCatsDataset(val_files, dir = data_dir, transform = val_transform)
val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle=False, num_workers=num_workers)

test_dataset = DogsVsCatsDataset(test_files, dir = data_dir, transform = val_transform)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False, num_workers=num_workers)

# Experiments

1. First, we fine-tune the teacher model (DenseNet) on our Dogs vs. Cats prediction task. 
2. The student model is trained without knowledge distillation
3. The student model is trained again but now with knowledge distillation leveraging the predictions of the fine-tuned teacher model (see 1.)

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: {}".format(device))

Device: cuda


## Fine-tune Teacher model on Dogs vs. Cats Prediction Task

In [12]:
teacher_model = Teacher().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(teacher_model.parameters(), lr=lr, amsgrad=True)

In [13]:
# Fine-tune the final classification layers of the teacher model
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, teacher_model, loss_fn=criterion, optimizer=optimizer)
    test(val_dataloader, teacher_model, loss_fn=criterion)
print("Done!")



Epoch 1
-------------------------------


Training: 100%|██████████| 161/161 [01:51<00:00,  1.45 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 93.5%, Avg loss: 0.170367 

Epoch 2
-------------------------------


Training: 100%|██████████| 161/161 [01:54<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 94.7%, Avg loss: 0.135645 

Epoch 3
-------------------------------


Training: 100%|██████████| 161/161 [01:55<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 90.9%, Avg loss: 0.296362 

Epoch 4
-------------------------------


Training: 100%|██████████| 161/161 [01:54<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.41Iterations/s]


Test Error: 
 Accuracy: 95.6%, Avg loss: 0.108159 

Epoch 5
-------------------------------


Training: 100%|██████████| 161/161 [01:56<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 65/65 [00:46<00:00,  1.40Iterations/s]


Test Error: 
 Accuracy: 93.6%, Avg loss: 0.166912 

Epoch 6
-------------------------------


Training: 100%|██████████| 161/161 [01:54<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.41Iterations/s]


Test Error: 
 Accuracy: 94.7%, Avg loss: 0.137621 

Epoch 7
-------------------------------


Training: 100%|██████████| 161/161 [01:56<00:00,  1.38 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 95.8%, Avg loss: 0.108259 

Epoch 8
-------------------------------


Training: 100%|██████████| 161/161 [01:54<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 96.1%, Avg loss: 0.100093 

Epoch 9
-------------------------------


Training: 100%|██████████| 161/161 [01:54<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 94.6%, Avg loss: 0.148131 

Epoch 10
-------------------------------


Training: 100%|██████████| 161/161 [01:56<00:00,  1.38 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.41Iterations/s]

Test Error: 
 Accuracy: 94.8%, Avg loss: 0.133100 

Done!





## Train Student Model without Knowledge Distillation

In [15]:
student_model = Student().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(student_model.parameters(), lr=lr, amsgrad=True)

In [16]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, student_model, loss_fn=criterion, optimizer=optimizer)
    test(val_dataloader, student_model, loss_fn=criterion)
print("Done!")

Epoch 1
-------------------------------


Training:  93%|█████████▎| 149/161 [01:49<00:07,  1.63 Iterations/s]

## Train Student Model with Knowledge Distillation

In [None]:
student_model_distilled = Student()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(student_model_distilled.parameters(), lr=lr, amsgrad=True)
teacher_model = teacher_model.to(device)
student_model_distilled = student_model_distilled.to(device)

In [None]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_with_distillation(train_dataloader, student_model_distilled, teacher_model, loss_fn=criterion, optimizer=optimizer)
    test(val_dataloader, student_model_distilled, loss_fn=criterion)
print("Done!")


Epoch 1
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:57<00:00,  1.37 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.41Iterations/s]


Test Error: 
 Accuracy: 57.7%, Avg loss: 0.680282 

Epoch 2
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 65/65 [00:47<00:00,  1.38Iterations/s]


Test Error: 
 Accuracy: 62.0%, Avg loss: 0.656579 

Epoch 3
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 69.8%, Avg loss: 0.577083 

Epoch 4
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:56<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 65/65 [00:46<00:00,  1.41Iterations/s]


Test Error: 
 Accuracy: 73.4%, Avg loss: 0.543792 

Epoch 5
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 72.7%, Avg loss: 0.543518 

Epoch 6
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:56<00:00,  1.38 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 75.7%, Avg loss: 0.516575 

Epoch 7
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 77.9%, Avg loss: 0.473942 

Epoch 8
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:54<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 75.5%, Avg loss: 0.522050 

Epoch 9
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 79.6%, Avg loss: 0.447792 

Epoch 10
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 65/65 [00:46<00:00,  1.39Iterations/s]


Test Error: 
 Accuracy: 77.9%, Avg loss: 0.493075 

Epoch 11
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:52<00:00,  1.43 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 80.0%, Avg loss: 0.436843 

Epoch 12
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 65/65 [00:46<00:00,  1.40Iterations/s]


Test Error: 
 Accuracy: 77.2%, Avg loss: 0.494972 

Epoch 13
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:53<00:00,  1.42 Iterations/s]
Validating: 100%|██████████| 65/65 [00:46<00:00,  1.41Iterations/s]


Test Error: 
 Accuracy: 79.6%, Avg loss: 0.455531 

Epoch 14
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 80.3%, Avg loss: 0.446204 

Epoch 15
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:54<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 81.0%, Avg loss: 0.420882 

Epoch 16
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 78.3%, Avg loss: 0.466342 

Epoch 17
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:54<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 81.2%, Avg loss: 0.420369 

Epoch 18
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:47<00:00,  1.37Iterations/s]


Test Error: 
 Accuracy: 79.8%, Avg loss: 0.443457 

Epoch 19
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 78.9%, Avg loss: 0.442242 

Epoch 20
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:46<00:00,  1.41Iterations/s]


Test Error: 
 Accuracy: 80.1%, Avg loss: 0.439714 

Epoch 21
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 83.2%, Avg loss: 0.377458 

Epoch 22
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:54<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 82.6%, Avg loss: 0.392714 

Epoch 23
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:54<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 83.1%, Avg loss: 0.386031 

Epoch 24
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:54<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 83.9%, Avg loss: 0.375174 

Epoch 25
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 65/65 [00:44<00:00,  1.45Iterations/s]


Test Error: 
 Accuracy: 83.7%, Avg loss: 0.383439 

Epoch 26
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:54<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 82.1%, Avg loss: 0.417277 

Epoch 27
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:55<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.44Iterations/s]


Test Error: 
 Accuracy: 84.3%, Avg loss: 0.363048 

Epoch 28
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:53<00:00,  1.42 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.42Iterations/s]


Test Error: 
 Accuracy: 83.4%, Avg loss: 0.395298 

Epoch 29
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:54<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.43Iterations/s]


Test Error: 
 Accuracy: 84.7%, Avg loss: 0.358481 

Epoch 30
-------------------------------


Training with Distillation: 100%|██████████| 161/161 [01:54<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 65/65 [00:45<00:00,  1.41Iterations/s]

Test Error: 
 Accuracy: 84.6%, Avg loss: 0.353176 

Done!





# Conclusion

In [None]:
test(test_dataloader, student_model_distilled, criterion)

Validating: 100%|██████████| 25/25 [00:19<00:00,  1.30Iterations/s]

Test Error: 
 Accuracy: 84.4%, Avg loss: 0.360117 






In [None]:
test(test_dataloader, student_model, criterion)

Validating: 100%|██████████| 25/25 [00:19<00:00,  1.30Iterations/s]

Test Error: 
 Accuracy: 83.9%, Avg loss: 0.352519 






In [None]:
test(test_dataloader, teacher_model, criterion)

Validating: 100%|██████████| 25/25 [00:18<00:00,  1.32Iterations/s]

Test Error: 
 Accuracy: 96.2%, Avg loss: 0.135738 






In [None]:
# Calculate the model size 
model_size_student = get_model_size(student_model)
model_size_teacher = get_model_size(teacher_model)

print('model size teachermodel : {:.3f}MB'.format(model_size_teacher))
print('model size student model: {:.3f}MB'.format(model_size_student))

model size teachermodel : 28.806MB
model size student model: 1.321MB


In [None]:
#torch.save(teacher_model, '../models/teacher_model_densenet121_10epochs.pt')
#torch.save(student_model, '../models/student_model_10epochs.pt')
#torch.save(student_model_distilled, '../models/student_model_distilled_10epochs.pt')

# References 



## Dataroots blog posts on model compression
- https://dataroots.io/research/contributions/deep-learning-model-compression/?ref=dataroots.ghost.io
- https://dataroots.io/research/contributions/model_compression/