https://arxiv.org/pdf/1503.02531.pdf

# Knowledge Distillation 

[TODO] write introduction for knowledge distillation
[TODO] Add relevant references at the end of the notebook

## What is Knowledge Distillation exactly?

## In this tutorial

- [Setup](#Setup)
- [Functions](#Functions)
- [Data](#load-data)
- [Experiments](#experiments)

## Setup

In [54]:
%pip install split-folders

Collecting split-folders
  Using cached split_folders-0.5.1-py3-none-any.whl (8.4 kB)
Installing collected packages: split-folders
Successfully installed split-folders-0.5.1
Note: you may need to restart the kernel to use updated packages.


In [2]:
import random
import tqdm
import numpy as np
import PIL
from sklearn.model_selection import train_test_split
import splitfolders

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
drop_rate = 0
train_val_test_split = (0.7,0.2,0.1)
batch_size = 96
num_workers=6
drop_rate = 0.23
train_dir = '../data/train'
epochs = 10
lr = 0.001

## Functions

In [4]:
def get_model_size(model):
    """function to calculate the model size in MB

    Args:
        model (nn.Module): pytorch model
    """
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_all_mb = (param_size + buffer_size) / 1024**2

    return size_all_mb

In [5]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for X, y in tqdm.tqdm(dataloader, desc = "Training", unit = " Iterations"):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        #if batch % 100 == 0:
        #    loss, current = loss.item(), (batch + 1) * len(X)
        #    print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in tqdm.tqdm(dataloader, desc = "Validating", unit="Iterations"):
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
        

In [6]:
class DistillationLoss:

    """Custom loss calculcation combining
    the loss of the student model with the distillation loss
    """

    def __init__(self, student_loss, temperature=1, alpha=0.25):
        self.student_loss = student_loss
        self.temperature = 1
        self.alpha = 0.25

    def __call__(self, student_logits, student_target_loss, teacher_logits):
        distillation_loss = nn.KLDivLoss(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits/self.temperature, dim=1))
        loss = (1 - self.alpha) * student_target_loss \
            + self.alpha * distillation_loss
        return loss


## Student/Teacher Models

In [7]:
class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.densenet201(weights='DEFAULT')
        for params in self.model.parameters():
            params.requires_grad_ = False

        num_ftrs = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 500),
            nn.Linear(500, 2)
            )
        
    def forward(self, x):
        x = self.model(x)
        return x

In [8]:
class Student(nn.Module):

    def __init__(self):
        super().__init__()

        # onvolutional layers (3,16,32)
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size=(5, 5), stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size=(5, 5), stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size=(3, 3), padding=1)

        # conected layers
        self.fc1 = nn.Linear(in_features= 64 * 3 * 3, out_features=500)
        self.fc2 = nn.Linear(in_features=500, out_features=50)
        self.fc3 = nn.Linear(in_features=50, out_features=2)


    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)

        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

# Load Data

In [9]:
# Download Cats&Dogs Dataset; Unzip the dataset
#!curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
#!unzip -q kagglecatsanddogs_5340.zip

In [10]:
data_dir = './PetImages/'
#files = [f for f in os.listdir(data_dir) if '.jpg' in f]

In [11]:
#splitfolders.ratio(data_dir, output="output", ratio=train_val_test_split)

In [12]:
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ColorJitter(),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

In [20]:
train_dir = './output/train/'
val_dir = './output/val/'
test_dir = './output/test/'

train_dataset = ImageFolder(train_dir, transform=train_transform)
val_dataset = ImageFolder(val_dir, transform=val_transform)
test_dataset = ImageFolder(test_dir, transform=val_transform)

train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle=True, num_workers=num_workers)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle=True, num_workers=num_workers)

In [8]:
# Creating dataset

class DogsVsCatsDataset(Dataset):
    def __init__(self, file_list, dir, mode='train', transform = val_transform):
        self.file_list = file_list
        self.dir = dir
        self.mode= mode
        self.transform = transform
            
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        img = PIL.Image.open(os.path.join(self.dir, self.file_list[idx]))
        img = self.transform(img)
        img = np.array(img)
        if 'dog' in self.file_list[idx]:
            self.label = 1
        else:
            self.label = 0
        return img.astype('float32'), self.label


train_files, test_files = train_test_split(train_files, 
                                    test_size=train_val_test_split[2], 
                                    random_state=42
                                    )
train_files, valid_files = train_test_split(train_files,
                                    test_size=train_val_test_split[1]/train_val_test_split[0], 
                                    random_state=42
                                    )

TrainDataSet = DogsVsCatsDataset(train_files, dir = train_dir, mode='train', transform = train_transform)
TrainDataLoader = DataLoader(TrainDataSet, batch_size = batch_size, shuffle=True, num_workers=num_workers)

ValidDataSet = DogsVsCatsDataset(valid_files, dir = train_dir, mode='valid')
ValidDataLoader = DataLoader(ValidDataSet, batch_size = batch_size, shuffle=False, num_workers=num_workers)

TestDataSet = DogsVsCatsDataset(test_files, dir = train_dir, mode='test')
TestDataLoader = DataLoader(TestDataSet, batch_size = batch_size, shuffle=False, num_workers=num_workers)

# Experiments

In [16]:
student_model = Student()
teacher_model = Teacher()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [17]:
model_size_student = get_model_size(student_model)
model_size_teacher = get_model_size(teacher_model)

print('model size teachermodel : {:.3f}MB'.format(model_size_teacher))
print('model size student model: {:.3f}MB'.format(model_size_student))

model size teachermodel : 73.562MB
model size student model: 1.321MB


def train(model):

    train_loss_list = []
    train_acc_list = []
    valid_loss_list = []
    valid_acc_list = []
    test_acc_list = []

    for epoch in range(epochs):
        print("Epoch",epoch+1,"/",epochs)
        model.train()
        train_loss = 0
        train_acc = 0
        itr = 1
        tot_itr = len(TrainDataLoader)
        for samples, labels in tqdm.tqdm(TrainDataLoader, desc = "Training", unit = " Iterations"):
            samples, labels = samples.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(samples)
            loss = criterion(output, labels)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()
            pred = torch.argmax(output, dim=1)
            correct = pred.eq(labels)
            train_acc+= torch.mean(correct.float())
            torch.cuda.empty_cache()
            itr += 1
            
        train_loss_list.append(train_loss/tot_itr)
        train_acc_list.append(train_acc.item()/tot_itr)
        print(' Total Loss: {:.4f}, Accuracy: {:.1f} %'.format(train_loss, train_acc/tot_itr*100))

        model.eval()
        valid_loss=0
        valid_acc=0
        itr=1
        tot_itr = len(ValidDataLoader)
        for samples, labels in tqdm.tqdm(ValidDataLoader, desc = "Validating", unit = " Iterations"):
            with torch.no_grad():
                samples, labels = samples.to(device), labels.to(device)
                output = model(samples)
                loss = criterion(output, labels)
                valid_loss += loss.item()
                pred = torch.argmax(output, dim=1)
                correct = pred.eq(labels)
                valid_acc += torch.mean(correct.float())
                torch.cuda.empty_cache()
                itr += 1
                
        valid_loss_list.append(valid_loss/tot_itr)
        valid_acc_list.append(valid_acc.item()/tot_itr)
        print('-----------------------------> Validation Loss: {:.4f}, Accuracy: {:.1f} %'.format(valid_loss, valid_acc/tot_itr*100))


## Teacher model

In [18]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(teacher_model.parameters(), lr=lr, amsgrad=True)
teacher_model = teacher_model.to(device)


In [21]:
# Fine-tune the final classification layers of the teacher model
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, teacher_model, loss_fn=criterion, optimizer=optimizer)
    test(val_dataloader, teacher_model, loss_fn=criterion)
print("Done!")



Epoch 1
-------------------------------


Training:   0%|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       | 0/183 [00:00<?, ? Iterations/s]B

UnidentifiedImageError: Caught UnidentifiedImageError in DataLoader worker process 4.
Original Traceback (most recent call last):
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 230, in __getitem__
    sample = self.loader(path)
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 269, in default_loader
    return pil_loader(path)
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 248, in pil_loader
    img = Image.open(f)
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/PIL/Image.py", line 2967, in open
    raise UnidentifiedImageError(
PIL.UnidentifiedImageError: cannot identify image file <_io.BufferedReader name='./output/train/Cat/666.jpg'>


## Train Student Model from Scratch

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(student_model.parameters(), lr=lr, amsgrad=True)
student_model = student_model.to(device)
train(student_model)

Epoch 1 / 10
 Total Loss: 90.5205, Accuracy: 56.7 %
-----------------------------> Validation Loss: 37.1146, Accuracy: 57.0 %
Epoch 2 / 10
 Total Loss: 85.4250, Accuracy: 63.6 %
-----------------------------> Validation Loss: 36.2891, Accuracy: 59.5 %
Epoch 3 / 10
 Total Loss: 81.3723, Accuracy: 67.3 %
-----------------------------> Validation Loss: 31.9461, Accuracy: 69.3 %
Epoch 4 / 10
 Total Loss: 78.0509, Accuracy: 69.2 %
-----------------------------> Validation Loss: 29.9043, Accuracy: 71.4 %
Epoch 5 / 10
 Total Loss: 71.9816, Accuracy: 72.7 %
-----------------------------> Validation Loss: 29.7169, Accuracy: 72.1 %
Epoch 6 / 10
 Total Loss: 68.4196, Accuracy: 75.0 %
-----------------------------> Validation Loss: 28.2762, Accuracy: 74.7 %
Epoch 7 / 10
 Total Loss: 65.8942, Accuracy: 76.2 %
-----------------------------> Validation Loss: 26.3024, Accuracy: 76.5 %
Epoch 8 / 10
 Total Loss: 64.1505, Accuracy: 76.8 %
-----------------------------> Validation Loss: 24.9781, Accuracy:

Training: 100%|██████████| 134/134 [01:36<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 54/54 [00:37<00:00,  1.44 Iterations/s]
Training: 100%|██████████| 134/134 [01:35<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.46 Iterations/s]
Training: 100%|██████████| 134/134 [01:36<00:00,  1.39 Iterations/s]
Validating: 100%|██████████| 54/54 [00:37<00:00,  1.45 Iterations/s]
Training: 100%|██████████| 134/134 [01:35<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 54/54 [00:37<00:00,  1.46 Iterations/s]
Training: 100%|██████████| 134/134 [01:35<00:00,  1.40 Iterations/s]
Validating: 100%|██████████| 54/54 [00:37<00:00,  1.43 Iterations/s]
Training: 100%|██████████| 134/134 [01:35<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 54/54 [00:37<00:00,  1.45 Iterations/s]
Training: 100%|██████████| 134/134 [01:35<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.47 Iterations/s]
Training: 100%|██████████| 134/134

## Train Student Model with Knowledge Distillation

In [12]:
student_model_distilled = Student()

In [14]:
criterion = nn.CrossEntropyLoss()
ds_loss = DistillationLoss()
optimizer = torch.optim.Adam(student_model_distilled.parameters(), lr=lr, amsgrad=True)
teacher_model = teacher_model.to(device)
student_model_distilled = student_model_distilled.to(device)

In [15]:
def train_with_distillation(student_model, teacher_model):

    train_loss_list = []
    train_acc_list = []
    valid_loss_list = []
    valid_acc_list = []
    test_acc_list = []

    for epoch in range(epochs):
        print("Epoch",epoch+1,"/",epochs)
        student_model.train()
        teacher_model.eval()

        train_loss = 0
        train_acc = 0
        
        itr = 1
        tot_itr = len(TrainDataLoader)
        for samples, labels in tqdm.tqdm(TrainDataLoader, desc = "Training", unit = " Iterations"):
            samples, labels = samples.to(device), labels.to(device)
            optimizer.zero_grad()
            output_student = student_model(samples)
            output_teacher = teacher_model(samples)

            student_target_loss = criterion(output_student, labels)
            loss = ds_loss(output_student, student_target_loss, output_teacher)

            train_loss += loss.item()
            loss.backward()
            optimizer.step()
            pred = torch.argmax(output_student, dim=1)
            correct = pred.eq(labels)
            train_acc+= torch.mean(correct.float())
            torch.cuda.empty_cache()
            itr += 1
            
        train_loss_list.append(train_loss/tot_itr)
        train_acc_list.append(train_acc.item()/tot_itr)
        print(' Total Loss: {:.4f}, Accuracy: {:.1f} %'.format(train_loss, train_acc/tot_itr*100))

        student_model.eval()
        valid_loss=0
        valid_acc=0
        itr=1
        tot_itr = len(ValidDataLoader)
        for samples, labels in tqdm.tqdm(ValidDataLoader, desc = "Validating", unit = " Iterations"):
            with torch.no_grad():
                samples, labels = samples.to(device), labels.to(device)
                output = student_model(samples)
                loss = criterion(output, labels)
                valid_loss += loss.item()
                pred = torch.argmax(output, dim=1)
                correct = pred.eq(labels)
                valid_acc += torch.mean(correct.float())
                torch.cuda.empty_cache()
                itr += 1
                
        valid_loss_list.append(valid_loss/tot_itr)
        valid_acc_list.append(valid_acc.item()/tot_itr)
        print('-----------------------------> Validation Loss: {:.4f}, Accuracy: {:.1f} %'.format(valid_loss, valid_acc/tot_itr*100))


In [16]:
train_with_distillation(student_model=student_model_distilled, teacher_model=teacher_model)

Epoch 1 / 10
 Total Loss: 78.9578, Accuracy: 55.2 %
-----------------------------> Validation Loss: 37.6149, Accuracy: 57.5 %
Epoch 2 / 10
 Total Loss: 71.9161, Accuracy: 65.7 %
-----------------------------> Validation Loss: 34.4853, Accuracy: 63.6 %
Epoch 3 / 10
 Total Loss: 67.2270, Accuracy: 69.4 %
-----------------------------> Validation Loss: 30.0797, Accuracy: 71.4 %
Epoch 4 / 10
 Total Loss: 62.5051, Accuracy: 72.8 %
-----------------------------> Validation Loss: 30.3225, Accuracy: 71.5 %
Epoch 5 / 10
 Total Loss: 60.0720, Accuracy: 74.5 %
-----------------------------> Validation Loss: 27.9254, Accuracy: 74.9 %
Epoch 6 / 10
 Total Loss: 56.2233, Accuracy: 76.3 %
-----------------------------> Validation Loss: 27.6159, Accuracy: 75.3 %
Epoch 7 / 10
 Total Loss: 55.2416, Accuracy: 77.1 %
-----------------------------> Validation Loss: 25.5862, Accuracy: 77.9 %
Epoch 8 / 10
 Total Loss: 52.4742, Accuracy: 78.8 %
-----------------------------> Validation Loss: 24.7130, Accuracy:

Training: 100%|██████████| 134/134 [01:34<00:00,  1.42 Iterations/s]
Validating: 100%|██████████| 54/54 [00:37<00:00,  1.45 Iterations/s]
Training: 100%|██████████| 134/134 [01:35<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.49 Iterations/s]
Training: 100%|██████████| 134/134 [01:33<00:00,  1.44 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.48 Iterations/s]
Training: 100%|██████████| 134/134 [01:33<00:00,  1.44 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.49 Iterations/s]
Training: 100%|██████████| 134/134 [01:32<00:00,  1.45 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.46 Iterations/s]
Training: 100%|██████████| 134/134 [01:34<00:00,  1.42 Iterations/s]
Validating: 100%|██████████| 54/54 [00:37<00:00,  1.46 Iterations/s]
Training: 100%|██████████| 134/134 [01:35<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.47 Iterations/s]
Training: 100%|██████████| 134/134

In [17]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [18]:
test(TestDataLoader, student_model_distilled, criterion)

Test Error: 
 Accuracy: 79.7%, Avg loss: 0.443518 



In [19]:
test(TestDataLoader, student_model, criterion)

Test Error: 
 Accuracy: 74.7%, Avg loss: 0.524062 



In [20]:
test(TestDataLoader, teacher_model, criterion)

Test Error: 
 Accuracy: 95.9%, Avg loss: 0.132754 



# References 