# Performing Knowledge Distillation on CIFAR-100 using PyTorch

This is the third practical exercise of our course [Applied Edge AI](https://learn.ki-campus.org/courses/edgeai-hpi2022).
In the last exercise, we trained a neural network for image classification on CIFAR-100 using PyTorch.
In this exercise, we want to use the network we trained in the last exercise and distill the knowledge of that network into a smaller network.

Similarly to the previous exercise, we provide you with a notebook with missing code sections.
In the graded quiz at the end of the week, we might ask some questions that deal with this exercise, so make sure to do the exercise (and have your output handy) **before** taking the quiz!

# Reusing Code

In the last exercise, we wrote quite some code that can be reused here.
We already added all of this code in the following cells.
There is nothing you need to do, since you already wrote such code in the last exercise.

We start with the imports:

In [1]:
import math
import pickle
import statistics

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tt
import imgaug

from collections import defaultdict
from typing import Type, List, Union

from imgaug import augmenters as iaa
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm, trange
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler, OneCycleLR
from torch.utils.data import DataLoader, Dataset

import cifar100_resnets as models

In the following cell, we added the code for data loading:

In [2]:
class CIFAR100(Dataset):
    
    def __init__(self, dataset_path: Path, image_transforms: tt.Compose, image_augmentations: Union[None, Type[iaa.Augmenter]] = None):
        super().__init__()
        data = pickle.load(dataset_path.open("rb"), encoding="bytes")
        self.images = data[b"data"]
        self.labels = data[b"fine_labels"]
        
        self.image_transforms = image_transforms
        self.image_augmentations = image_augmentations
        
        assert len(self.images) == len(self.labels), "Number of images and labels is not equal!"
        
    def __len__(self) -> int:
        return len(self.images)
    
    def __getitem__(self, index: int) -> tuple:
        image = self.images[index]
        label = self.labels[index]
        
        image = np.reshape(image, (3, 32, 32))
        image = np.transpose(image, (1, 2, 0))
        
        if self.image_augmentations is not None:
            image = self.image_augmentations.augment_image(image)
        image = self.image_transforms(Image.fromarray(image))
        return image, label
    

image_transformations = tt.Compose([
    tt.ToTensor(),
    tt.Normalize(
        mean=(0.5074, 0.4867, 0.4411),
        std=(0.2011, 0.1987, 0.2025)
    )
])

train_augmentations = iaa.Sequential([
    iaa.Fliplr(0.5),
    iaa.CropAndPad(px=(-4, 4), pad_mode="reflect")
])


class CIFAR100Net(nn.Module):
    
    def __init__(self, model_type: str = "resnet18", temperature: int = 1):
        super().__init__()
        model_class = getattr(models, model_type)
        self.feature_extractor = model_class(num_classes=100)
        self.temperature = temperature
        
    def forward(self, images: torch.Tensor) -> torch.Tensor:
        activations = self.feature_extractor(images)
        return activations / self.temperature

    
def accuracy(predictions: torch.Tensor, labels: torch.Tensor, reduce_mean: bool = True) -> torch.Tensor:
    predicted_classes = torch.argmax(F.softmax(predictions, dim=1), dim=1)
    correct_predictions = torch.sum(predicted_classes == labels)
    if reduce_mean:
        return correct_predictions / len(labels)
    return correct_predictions


def test_model(network: Type[nn.Module], data_loader: DataLoader) -> float:
    num_correct_predictions = 0
    device = get_device()
    
    for images, labels in data_loader:
        images = to_device(images, device)
        labels = to_device(labels, device)
        predictions = network(images)
        num_correct_predictions += float(accuracy(predictions, labels, reduce_mean=False).item())
        
    return num_correct_predictions / len(data_loader.dataset)


def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def to_device(data: torch.Tensor, device: torch.device) -> torch.Tensor:
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device)


def plot_metrics(metrics: dict):
    # we prepare the plotting by creating a set of axes for plotting, we want to put each metric in its own plot in a separate row
    # furthermore, all plots should share the same x-axis values
    fig, axes = plt.subplots(math.ceil(len(metrics) / 2), 2, sharex=True, figsize=(20, 20))

    # we want to have a set of distinct colors for each logged metric
    colors = iter(plt.cm.rainbow(np.linspace(0, 1, len(metrics))))
    
    # create the actual plot
    for (metric_name, metric_values), axis in zip(metrics.items(), axes.flatten()):
        iterations = []
        values = []
        for logged_value in metric_values:
            iterations.append(logged_value["iteration"])
            values.append(logged_value["value"])
        axis.plot(iterations, values, label=metric_name, color=next(colors))
        axis.legend()
    plt.show()


BATCH_SIZE = 128
train_dataset = CIFAR100(Path("/kaggle/input/cifar100/train"), image_transformations, train_augmentations)
train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_dataset = CIFAR100(Path("/kaggle/input/cifar100/test"), image_transformations)
test_data_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Knowledge Distillation

In knowledge distillation, our aim is to distill the knowledge of a large model into a smaller model.
We can do this in two ways:
1. Train the large and the small model at the same time. Here, we train the large model only on the hard labels provided by the dataset. We train the small model using the soft labels provided by the large model.
1. Train the large model first, then train the small models based on the outputs of the large model.

We will try to train both models at the same time.
However, we highly encourage you to also try the other way, where we first train a larger model and then a smaller model!

To run the training, we need to perform the following steps:
1. build two networks (a large and a smaller one)
1. adapt the training code from last week to use two networks and also perform the correct loss calculations

# Task 1: Building the Networks

We will start with the creation of the networks, which should be fairly simple.
Have a look at the `CIFAR100Net` class above and figure out how you can use that class to build a `resnet56` and a `resnet20` model.
Note that the variable `models` was imported in our first code cell, from the included [utility scripts](https://www.kaggle.com/bartzi/cifar100-resnets).
We will use the `resnet56` model as the teacher model and the `resnet20` model as the student.

In [3]:
# TODO:
# define `teacher_model` as a ResNet with 56 layers based on CIFAR100Net
# define `student_model` as a ResNet with 20 layers based on CIFAR100Net
teacher_model = CIFAR100Net(model_type='resnet56', temperature=1)
student_model = CIFAR100Net(model_type='resnet20', temperature=1)

Building the networks was simple.
Now, we need to adapt the training loop from last week.

# Task 2: Adapt our Training Code

You can reuse most parts of the training loop but we have to make the following changes:

1. Since we are now handling two networks at the same time, we have to adopt our code to use two networks and also two optimizers (they should be given as parameters to the `train` function).
1. We have to adapt our `train_for_one_iteration` function. Here, we need to forward the batch through both networks, then we calculate the losses:
  1. the loss for the teacher network using the hard labels
  1. the loss for the student network using the soft labels (the kullback leibler divergence or cross entropy of the softmax outputs of both networks) + the hard labels ([HINT](https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html): make sure to disallow the flow of gradients to the teacher network when using the softmax outputs of the teacher network)
1. following the loss calculations, we need to run the backward passes for both networks and run the weight updates using the optimizers
1. we can then return the losses of both networks

## Task 2a: Initialize the Loss Functions

First we should initialize our two loss functions:

In [4]:
# TODO: initialize the correct loss functions
# 1. teacher_loss_function should contain a PyTorch function for the Cross Entropy loss
# 2. a) student_loss_function should contain a PyTorch implementation of the Kullback-Leibler divergence loss
#    b) make sure, the mean of the student loss is calculated over the batch dimension only - not over all dimensions
#    c) check out the documentation of the Kullback-Leibler divergence loss particularly about
#       whether the inputs expect probabilities or log-probabilities
teacher_loss_function = nn.CrossEntropyLoss()

class student_loss_function(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.kld_loss = nn.KLDivLoss(log_target=False, reduction='batchmean')
        
    def forward(self, student_output: Type[torch.Tensor], teacher_output: Type[torch.Tensor]) -> torch.Tensor:
        
        return self.kld_loss(student_output, teacher_output).sum(dim=0)
    
student_loss_function = student_loss_function()

In [5]:
## Testing KLDivLoss
kld_loss = nn.KLDivLoss(log_target=False, reduction='none')
#input = torch.randn(10, 2, 3)
#output = torch.randn(10, 2, 3)

input = torch.tensor([[1 , 1 ,1], [1 , 1 , 1]])
output = torch.tensor([[1 , 1 ,1], [1 , 1 , 1]])

print(input.shape)
print(kld_loss(input,output))


learning_rate = 0.01
num_epochs = 2

teacher_model = teacher_model.to(get_device())
student_model = student_model.to(get_device())

teacher_optimizer = torch.optim.Adam(teacher_model.parameters(), lr=learning_rate)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=learning_rate)

for idx, batch in enumerate(train_data_loader): 

    print('Target', batch[1])
    res_dict = train_for_one_iteration(
                networks=[teacher_model, student_model],
                batch = batch,
                optimizers = [teacher_optimizer, student_optimizer],
            )
    
    print(res_dict)
    
    raise '!'

## Task 2b: Adapt the Training Logic

Then we can adapt our training logic for a single batch:

In [6]:
def train_for_one_iteration(networks: List[Type[nn.Module]], batch: tuple, optimizers: List[Type[Optimizer]]) -> dict:
    images, labels = batch
    teacher_network, student_network = networks

    # TODO: do the forward pass and loss calculation for the *teacher* network:
    # 1. pass the images through the `teacher_network`, store the result (the predictions) in `teacher_predictions`
    # 2. calculate the `teacher_loss` with the `teacher_loss_function` based on the `teacher_predictions` and the labels
    teacher_predictions = teacher_network(images)
    teacher_loss = teacher_loss_function(teacher_predictions, labels)
    
    # TODO: do the forward pass and loss calculation for the *student* network:
    # 1. pass the images through the `student_network`, store the result (the predictions) in `student_predictions`
    # 2. calculate the cross entropy loss `student_ce_loss` with the `teacher_loss_function` based on the `student_predictions` and the labels
    # 3. calculate the knowledge distillaion loss `student_kd_loss` based on:
    #    1) the softmax of our `student_predictions` calculated on the label axis (dim 1)
    #    2) the softmax of our `teacher_predictions` calculated on the label axis (dim 1)
    #    HINT: check whether you need to include the regular or the logarithmic softmax for each one (refer to the documentation of the loss function)
    # 4. disable gradient calculation of the teacher in the previous step:
    #    add `.detach()` to `teacher_predictions`, which forwards the outputs but disables backpropagation
    # 5. add up both losses (`student_ce_loss` and `student_kd_loss`) as the `student_loss`
    student_predictions = student_network(images)
    student_ce_loss = teacher_loss_function(student_predictions, labels)
    
    softmax_function = nn.Softmax(dim=1)
    sfmx_student_predictions = softmax_function(student_predictions)
    sfmx_teacher_predictions = softmax_function(teacher_predictions).detach()
    
    student_kd_loss = student_loss_function(sfmx_student_predictions, sfmx_teacher_predictions)
    
    student_loss = student_ce_loss + student_kd_loss    
    
    # calculate the accuracy of both predictions
    teacher_accuracy = accuracy(teacher_predictions, labels)
    student_accuracy = accuracy(student_predictions, labels)
        
    # Here come the real weight adjustments, first zero gradients, then calculate derivatives, followed by the actual update of the optimizer
    for loss, optimizer in zip([teacher_loss, student_loss], optimizers):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return {
        "teacher_loss": float(teacher_loss.item()),
        "teacher_train_acc": float(teacher_accuracy.item()),
        "student_loss": float(student_loss.item()),
        "student_train_acc": float(student_accuracy.item()),
    }

We have already adapted the final `train` function for training both the student and the teacher:

In [7]:
def train(train_data: DataLoader, test_data: DataLoader, networks: List[Type[nn.Module]], optimizers: List[Type[Optimizer]], \
          lr_schedulers: List[Type[_LRScheduler]], num_epochs: int, update_lr_scheduler_each_iteration: bool = True) -> dict:
    device = get_device()
    # we save all metrics that we want to plot later on
    metrics = defaultdict(list)
    
    for epoch in trange(num_epochs, desc="Epoch: "):
        losses = defaultdict(list)
        
        with tqdm(total=len(train_data), desc="Iteration: ") as progress_bar:
            for iteration, batch in enumerate(train_data):
                current_iteration = epoch * len(train_data) + iteration
                
                batch = to_device(batch, device)
                calculated_losses = train_for_one_iteration(networks, batch, optimizers)
                
                for loss_name, loss_value in calculated_losses.items():
                    losses[loss_name].append(loss_value)
                    metrics[loss_name].append({"iteration": current_iteration, "value": loss_value})
                # postfix_data is used to display current metrics in the progress bar
                postfix_data = {name: f"{value:.2f}" for name, value in calculated_losses.items()}
                
                current_learning_rate = lr_schedulers[0].get_last_lr()[0]
                postfix_data["lr"] = f"{current_learning_rate:.6f}"
                metrics["lr"].append({"iteration": current_iteration, "value": current_learning_rate})
                
                progress_bar.set_postfix(postfix_data)
                progress_bar.update()
                
                if update_lr_scheduler_each_iteration:
                    for scheduler in lr_schedulers:
                        scheduler.step()

            progress_bar.set_description_str("Testing: ")
            accuracies = {}
            for metric_name,network in zip(["teacher_acc", "student_acc"], networks):
                accuracy = test_model(network, test_data)
                accuracies[f"{metric_name}"] = f"{accuracy:.2f}"
                metrics[metric_name].append({"iteration": (epoch + 1) * len(train_data), "value": accuracy})

            progress_bar.set_description_str(f"Epoch: {epoch}")
            postfix_data = {name: f"{statistics.mean(loss):.2f}" for name, loss in losses.items()}
            postfix_data.update()
            postfix_data.update(accuracies)
            progress_bar.set_postfix(postfix_data)
            progress_bar.update()
            
            if not update_lr_scheduler_each_iteration:
                    for scheduler in lr_schedulers:
                        scheduler.step()
    
    return metrics

Now, we just need to perform the last setup steps and then start the training. \O/

Before starting the training below, you should enable the GPU acclerator in the sidebar on the right (you can open the sidebar by clicking on the |< Symbol in the top right, then select *Settings*, *Accelerator*, *GPU*).

If you have not done so at the beginning of working on this exercise (which is fine), this means the other cells need to be run again.
To do so, you can select *Run All* in the top toolbar.
The notebook should run most of the previous cells very quickly until the training below is executed.

In [8]:
learning_rate = 0.01
num_epochs = 50

teacher_model = teacher_model.to(get_device())
student_model = student_model.to(get_device())

teacher_optimizer = torch.optim.Adam(teacher_model.parameters(), lr=learning_rate)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=learning_rate)

num_iterations = num_epochs * (len(train_dataset) / BATCH_SIZE)
teacher_scheduler = OneCycleLR(teacher_optimizer, learning_rate, epochs=num_epochs, steps_per_epoch=len(train_data_loader))
student_scheduler = OneCycleLR(student_optimizer, learning_rate, epochs=num_epochs, steps_per_epoch=len(train_data_loader))

# we are done with all setup and can start the training
logged_metrics = train(
    train_data_loader,
    test_data_loader,
    [teacher_model, student_model],
    [teacher_optimizer, student_optimizer],
    [teacher_scheduler, student_scheduler],
    num_epochs
)

## Plotting of Progress

As in the last exercise, we can now plot the train progress using the `plot_metrics` function:

In [9]:
plot_metrics(logged_metrics)

## What Now?

Similar to the last week, you should keep in mind what you just did in this exercise, as we will ask about the implementation in the graded test.
Since there was not too much coding required so far, we hope you are wondering, what else there is to do?
So we have prepared some suggestions:

You could also test different training optimizations and try to get the best performance out of your student model.
If you already developed a few improvements in the last week, you should try to use it in this week as well.
As you may have noticed we already include the learning rate scheduler [OneCycleLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html) in this week's code.
Maybe you can try to find another scheduling that performs better or integrate improvements from the previous week?

Another interesting experiment might be to compare the performance of the model trained this week (with Knowledge Distillation) and the model from last week.
To do this, go back to the previous exercise and compare the accuracies.
You can also adapt the code above and train the student model completely independent of the teacher model.
To do so, you should first train the teacher individually, and then use it to train the student network.

Now you can compare the accuracy to the ResNet-56 and the ResNet-20 that was achieved with simultaneous training.
Which models are performing better and why?

Another interesting question is, how much computation (during inference) you can save when you are using the ResNet-20 with distilled knowledge.
To see this, we can calculate the number of operations of each network using the [torchinfo](https://github.com/TylerYep/torchinfo) package and the following code:

In [10]:
# try to import the library we need for calculating the number of operations
# (if we can not import it, we need to install it)
try:
    import torchinfo
except ImportError:
    !pip install torchinfo
    import torchinfo
# if you get the warning "Failed to establish a new connection", go to the side bar on the right, then "Settings" and switch on "Internet"

## Teacher Model Summary

After installing torchinfo, we can now print the summary of our teacher model:

In [11]:
from torchinfo import summary

batch_size = 1
print(summary(teacher_model, input_size=(batch_size, 3, 32, 32)))

## Student Model Summary

And compare that to our student model (which should be much smaller):

In [12]:
print(summary(student_model, input_size=(batch_size, 3, 32, 32)))