https://arxiv.org/pdf/1503.02531.pdf

# Knowledge Distillation 

- Current SOTA performance in AI and ML is mainly driven by large and complex deep neural network models that consist of billions of model parameters. 

- deploying large complex models on constrained devices (e.g. edge devices) is not straightforward

- while deep learning models often achieve excellent accuracy, they often fail to meet other requirements such as  latency and memory footprint. 

- Knowledge distillation distills the knowlegde of a larger, complex model into a smaller and easier to deploy model. 

- The complex model is called the 'teacher' and the smaller model is referred to as the 'student'. 

### different kinds of knowledge
- response based knowledge
- Feature based knowledge
- Relation-based knowledge

### Different kinds of training
- Offline distillation
- Online distillation
- Self distillation

### Real world examples
- DistilBERT
-> add short summary (e.g. smaller by 40%, whilst retaining xx % of performance). 


[TODO] write introduction for knowledge distillation
[TODO] Add relevant references at the end of the notebook

## What is Knowledge Distillation exactly?

## In this tutorial

- [Setup](#Setup)
- [Functions](#Functions)
- [Data](#load-data)
- [Experiments](#experiments)

## Setup

In [1]:
import tqdm
import numpy as np
import PIL
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

train_val_test_split = (0.7,0.2,0.1)
batch_size = 96
num_workers=6

data_dir = '../data/train'
epochs = 10
lr = 0.001

## Functions

In [3]:
def get_model_size(model):
    """function to calculate the model size in MB

    Args:
        model (nn.Module): pytorch model
    """
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_all_mb = (param_size + buffer_size) / 1024**2

    return size_all_mb

In [4]:
class DistillationLoss:

    """Custom loss calculcation combining
    the loss of the student model with the distillation loss
    """

    def __init__(self, student_loss, temperature=1, alpha=0.25):
        self.student_loss = student_loss
        self.distillation_loss = nn.KLDivLoss()
        self.temperature = 1
        self.alpha = 0.25

    def __call__(self, student_logits, student_target_loss, teacher_logits):
        distillation_loss = self.distillation_loss(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits/self.temperature, dim=1))
        loss = (1 - self.alpha) * student_target_loss \
            + self.alpha * distillation_loss
        return loss


In [5]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for X, y in tqdm.tqdm(dataloader, desc = "Training", unit = " Iterations"):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        #if batch % 100 == 0:
        #    loss, current = loss.item(), (batch + 1) * len(X)
        #    print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in tqdm.tqdm(dataloader, desc = "Validating", unit="Iterations"):
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

def train_with_distillation(dataloader, student_model, teacher_model, loss_fn, optimizer):
    distillation_loss = DistillationLoss(student_loss=loss_fn)
    size = len(dataloader.dataset)
    student_model.train()
    teacher_model.eval()

    for X, y in tqdm.tqdm(dataloader, desc = "Training with Distillation", unit = " Iterations"):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred_student = student_model(X)
        pred_teacher = teacher_model(X)

        student_target_loss = loss_fn(pred_student, y)
        loss = distillation_loss(pred_student, student_target_loss, pred_teacher)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        

## Student/Teacher Models

In [6]:
class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.densenet121(weights='DEFAULT')
        for params in self.model.parameters():
            params.requires_grad_ = False

        num_ftrs = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 500),
            nn.Linear(500, 2)
            )
        
    def forward(self, x):
        x = self.model(x)
        return x

In [7]:
class Student(nn.Module):

    def __init__(self):
        super().__init__()

        # onvolutional layers (3,16,32)
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size=(5, 5), stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size=(5, 5), stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size=(3, 3), padding=1)

        # conected layers
        self.fc1 = nn.Linear(in_features= 64 * 3 * 3, out_features=500)
        self.fc2 = nn.Linear(in_features=500, out_features=50)
        self.fc3 = nn.Linear(in_features=50, out_features=2)


    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)

        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

# Load Data

In [8]:
#from google.colab import files
#files.upload()

In [9]:
#!rm -r ~/.kaggle
#!mkdir ~/.kaggle
#!mv ./kaggle.json ~/.kaggle/
#!chmod 600 ~/.kaggle/kaggle.json

In [10]:
#!kaggle competitions download -c dogs-vs-cats

In [11]:
#!unzip -o -q dogs-vs-cats.zip -d ./data/ 
#!unzip -o -q ./data/train.zip -d ./data/ 

In [12]:
import random
files = os.listdir(data_dir)
files = [f for f in files if '.jpg' in f]
#files = random.sample(files, 10000)

In [13]:
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ColorJitter(),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

In [14]:
class DogsVsCatsDataset(Dataset):
    def __init__(self, file_list, dir, mode='train', transform = val_transform):
        self.file_list = file_list
        self.dir = dir
        #self.mode= mode
        self.transform = transform
            
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        img = PIL.Image.open(os.path.join(self.dir, self.file_list[idx]))
        img = self.transform(img)
        img = np.array(img)
        if 'dog' in self.file_list[idx]:
            self.label = 1
        else:
            self.label = 0
        return img.astype('float32'), self.label


train_files, test_files = train_test_split(files, 
                                    test_size=train_val_test_split[2], 
                                    random_state=42
                                    )
train_files, val_files = train_test_split(train_files,
                                    test_size=train_val_test_split[1]/train_val_test_split[0], 
                                    random_state=42
                                    )

train_dataset = DogsVsCatsDataset(train_files, dir = data_dir, transform = train_transform)
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers=num_workers)

val_dataset = DogsVsCatsDataset(val_files, dir = data_dir, transform = val_transform)
val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle=False, num_workers=num_workers)

test_dataset = DogsVsCatsDataset(test_files, dir = data_dir, transform = val_transform)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False, num_workers=num_workers)

# Experiments

In [15]:
student_model = Student()
teacher_model = Teacher()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: {}".format(device))

Device: cuda


In [16]:
model_size_student = get_model_size(student_model)
model_size_teacher = get_model_size(teacher_model)

print('model size teachermodel : {:.3f}MB'.format(model_size_teacher))
print('model size student model: {:.3f}MB'.format(model_size_student))

model size teachermodel : 28.806MB
model size student model: 1.321MB


## Teacher model

In [17]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(teacher_model.parameters(), lr=lr, amsgrad=True)
teacher_model = teacher_model.to(device)


In [18]:
# Fine-tune the final classification layers of the teacher model
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, teacher_model, loss_fn=criterion, optimizer=optimizer)
    test(val_dataloader, teacher_model, loss_fn=criterion)
print("Done!")



Epoch 1
-------------------------------


Training: 100%|██████████| 168/168 [02:00<00:00,  1.39 Iterations/s]
Validating:  57%|█████▋    | 38/67 [00:30<00:24,  1.19Iterations/s]

## Train Student Model from Scratch

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(student_model.parameters(), lr=lr, amsgrad=True)
student_model = student_model.to(device)

In [None]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, student_model, loss_fn=criterion, optimizer=optimizer)
    test(val_dataloader, student_model, loss_fn=criterion)
print("Done!")

## Train Student Model with Knowledge Distillation

In [None]:
student_model_distilled = Student()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(student_model_distilled.parameters(), lr=lr, amsgrad=True)
teacher_model = teacher_model.to(device)
student_model_distilled = student_model_distilled.to(device)

In [None]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_with_distillation(train_dataloader, student_model_distilled, teacher_model, loss_fn=criterion, optimizer=optimizer)
    test(val_dataloader, student_model_distilled, loss_fn=criterion)
print("Done!")


In [None]:
train_with_distillation(student_model=student_model_distilled, teacher_model=teacher_model)

In [None]:
test(test_dataloader, student_model_distilled, criterion)

In [None]:
test(test_dataloader, student_model, criterion)

In [None]:
test(test_dataloader, teacher_model, criterion)

# References 