# Knowledge Distillation

This notebook demonstrates the implementation of this paper [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)

## Steps to transfer from a teacher to student model

- Load dataset and create dataloaders
- Create teacher and student models and load pretrained weights of teacher model
- Load the configuration YAML file
- Create `KDTransfer` object and pass the dataloaders, teacher model, student model and configuration
- Transfer the knowledge to student model by using `compress_model` method

In [3]:
import sys
sys.path.append("../../")

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from sklearn.metrics import accuracy_score
import numpy as np
import yaml

from trailmet.algorithms.distill.response_kd import KDTransfer
from trailmet.models.resnet import get_resnet_model

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Dataset

### Augmentations

In [5]:
transform_train = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Pad(4, padding_mode='reflect'),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomCrop(32),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                ]
            )
transform_test = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                ]
            )

transforms1 = {
    'train': transform_train, 
    'val': transform_test, 
    'test': transform_test}

target_transforms = {
    'train': None, 
    'val': None, 
    'test': None}

### Load Dataset

In [6]:
from trailmet.datasets.classification import DatasetFactory

cifar_dataset = DatasetFactory.create_dataset(name = 'CIFAR100', 
                                        root = "./data",
                                        split_types = ['train', 'val', 'test'],
                                        val_fraction = 0.1,
                                        transform = transforms1,
                                        target_transform = target_transforms
                                        )

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


### Define data loaders

In [7]:
train_loader = torch.utils.data.DataLoader(
        cifar_dataset['train_dataset'], batch_size=128, 
        sampler=cifar_dataset['train_sampler'],
        num_workers=2
    )

val_loader = torch.utils.data.DataLoader(
        cifar_dataset['val_dataset'], batch_size=128, 
        sampler=cifar_dataset['val_sampler'],
        num_workers=2
    )

test_loader = torch.utils.data.DataLoader(
        cifar_dataset['test_dataset'], batch_size=128, 
        sampler=cifar_dataset['test_sampler'],
        num_workers=2
    )

dataloaders = {'train':train_loader, 'val':val_loader, 'test':test_loader}

# Model

### Create the teacher and student models

In [8]:
teacher_model = get_resnet_model('resnet50', 100, 32, False)
student_model = get_resnet_model('resnet18', 100, 32, False)

teacher_model.to(device)
student_model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activ): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activ): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activ): ReLU(inplace=True)

### Load Pretrained Teacher Model

In [9]:
weights = torch.load("./resnet50_cifar100_pretrained.pth")['state_dict']
teacher_model.load_state_dict(weights)

<All keys matched successfully>

# Training

### Load configurations for training the student model

The configuration should contain distillation arguments including the training parameters such as total epochs, learning rates, milestones, etc.

__Note:__ Running on 5 epochs for demonstration purpose

In [18]:
with open("resnet50-resnet18.yaml", 'r') as stream:
    data_loaded = yaml.safe_load(stream)
data_loaded

{'DISTILL_ARGS': {'LAMBDA': 0.5,
  'TEMPERATURE': 20,
  'seed': 43,
  'EPOCHS': 5,
  'LR': 0.1,
  'WEIGHT_DECAY': 0.0005},
 'VERBOSE': True,
 'log_dir': 'r50_to_r18',
 'cuda_id': 0}

### Training student model

In [19]:
distillation_box = KDTransfer(teacher_model, student_model, dataloaders, **data_loaded)

In [20]:
distillation_box.compress_model()

=====TRAINING STUDENT NETWORK=====
Epoch: 1


100%|██████████| 352/352 [01:25<00:00,  4.10it/s, loss=2.32]
100%|██████████| 40/40 [00:02<00:00, 14.24it/s, acc=0.232, loss=2.11]


**Saving checkpoint**
Epoch: 2


100%|██████████| 352/352 [01:17<00:00,  4.52it/s, loss=1.88]
100%|██████████| 40/40 [00:03<00:00, 11.56it/s, acc=0.323, loss=1.83]


**Saving checkpoint**
Epoch: 3


100%|██████████| 352/352 [01:17<00:00,  4.52it/s, loss=1.68]
100%|██████████| 40/40 [00:02<00:00, 13.92it/s, acc=0.36, loss=1.66]


**Saving checkpoint**
Epoch: 4


100%|██████████| 352/352 [01:17<00:00,  4.53it/s, loss=1.54]
100%|██████████| 40/40 [00:02<00:00, 15.46it/s, acc=0.387, loss=1.61]


**Saving checkpoint**
Epoch: 5


100%|██████████| 352/352 [01:17<00:00,  4.52it/s, loss=1.44]
100%|██████████| 40/40 [00:02<00:00, 15.55it/s, acc=0.373, loss=1.69]


# Evaluate student model on test set

In [21]:
student_model.load_state_dict(torch.load(f"./checkpoints/{data_loaded['log_dir']}.pth")['state_dict'])

preds = []
valid_labels = []
student_model.eval()
# Run the best model on test set
for step, (images, labels) in tqdm(enumerate(test_loader), total = len(test_loader)):

    images = images.to(device, dtype=torch.float)
    labels = labels.to(device)
    batch_size = labels.size(0)

    with torch.no_grad():
        y_preds = student_model(images)
        
    preds.append(y_preds.softmax(1).to('cpu').numpy())
    valid_labels.append(labels.to('cpu').numpy())

predictions = np.concatenate(preds)
valid_labels = np.concatenate(valid_labels)

# Get the score
score = accuracy_score(valid_labels, predictions.argmax(1))

100%|██████████| 79/79 [00:03<00:00, 25.96it/s]
