# Knowledge Distillation

This notebook demonstrates the implementation of this paper [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)

## Steps to transfer from a teacher to student model

- Load dataset and create dataloaders
- Create teacher and student models and load pretrained weights of teacher model
- Load the configuration YAML file
- Create `KDTransfer` object and pass the dataloaders, teacher model, student model and configuration
- Transfer the knowledge to student model by using `compress_model` method

In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import sys

sys.path.append("../../../")

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from sklearn.metrics import accuracy_score
import numpy as np
import yaml

from trailmet.algorithms.distill.response_kd import KDTransfer
from trailmet.models import ModelsFactory

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Dataset

### Augmentations

In [3]:
transform_train = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Pad(4, padding_mode="reflect"),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomCrop(32),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ]
)
transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ]
)

transforms1 = {"train": transform_train, "val": transform_test, "test": transform_test}

target_transforms = {"train": None, "val": None, "test": None}

### Load Dataset

In [4]:
from trailmet.datasets.classification import DatasetFactory

cifar_dataset = DatasetFactory.create_dataset(
    name="CIFAR100",
    root="./data",
    split_types=["train", "val", "test"],
    val_fraction=0.1,
    transform=transforms1,
    target_transform=target_transforms,
)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


### Define data loaders

In [5]:
train_loader = torch.utils.data.DataLoader(
    cifar_dataset["train"],
    batch_size=128,
    sampler=cifar_dataset["train_sampler"],
    num_workers=2,
)

val_loader = torch.utils.data.DataLoader(
    cifar_dataset["val"],
    batch_size=128,
    sampler=cifar_dataset["val_sampler"],
    num_workers=2,
)

test_loader = torch.utils.data.DataLoader(
    cifar_dataset["test"],
    batch_size=128,
    sampler=cifar_dataset["test_sampler"],
    num_workers=2,
)

dataloaders = {"train": train_loader, "val": val_loader, "test": test_loader}

### Load configurations for training the student model

The configuration should contain distillation arguments including the training parameters such as total epochs, learning rates, milestones, etc. as well as the name of layer involved in Factor Transfer

__Note:__ Running on 5 epochs for demonstration purpose

In [6]:
with open("./resnet50-resnet18.yaml", "r") as stream:
    data_loaded = yaml.safe_load(stream)
data_loaded

{'DISTILL_ARGS': {'LAMBDA': 0.5,
  'TEMPERATURE': 20,
  'seed': 43,
  'EPOCHS': 5,
  'LR': 0.1,
  'WEIGHT_DECAY': 0.0005},
 'VERBOSE': True,
 'log_dir': 'r50_to_r18',
 'DEVICE': 'cuda',
 'insize': 32,
 'wandb': True}

# Model

### Create the teacher and student models

In [7]:
teacher_model = ModelsFactory.create_model(
    name="resnet50",
    num_classes=100,
    version="original",
    pretrained=False,
    **data_loaded
)
student_model = ModelsFactory.create_model(
    name="resnet18",
    num_classes=100,
    version="original",
    pretrained=False,
    **data_loaded
)

teacher_model.to(device)
student_model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activ): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activ): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activ): ReLU(inplace=True)

# Training

### Training student model

In [8]:
distillation_box = KDTransfer(teacher_model, student_model, dataloaders, **data_loaded)

[34m[1mwandb[0m: Currently logged in as: [33manimesh-007[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
distillation_box.compress_model()

=====> TRAINING STUDENT NETWORK <=====


Training student network Epoch [0] (352 / 352 Steps) (batch time=0.04378s) (data time=0.00259s) (loss=3.09938): 100%|| 352/352 [00:18<00:00, 19.08it/s]
Validating student network Epoch [0] (40 / 40 Steps) (batch time=0.01604s) (loss=2.58542) (top1=12.50000) (top5=37.50000): 100%|| 40/40 [00:00<00:00, 40.20it/s]


 * acc@1 14.240 acc@5 30.100


Training student network Epoch [1] (352 / 352 Steps) (batch time=0.04645s) (data time=0.00328s) (loss=2.51164): 100%|| 352/352 [00:16<00:00, 20.76it/s]
Validating student network Epoch [1] (40 / 40 Steps) (batch time=0.01612s) (loss=3.04244) (top1=0.00000) (top5=37.50000): 100%|| 40/40 [00:01<00:00, 38.40it/s] 


 * acc@1 23.440 acc@5 40.940


Training student network Epoch [2] (352 / 352 Steps) (batch time=0.04465s) (data time=0.00256s) (loss=2.43971): 100%|| 352/352 [00:17<00:00, 20.51it/s]
Validating student network Epoch [2] (40 / 40 Steps) (batch time=0.01649s) (loss=2.22127) (top1=37.50000) (top5=50.00000): 100%|| 40/40 [00:01<00:00, 38.51it/s]


 * acc@1 27.740 acc@5 47.700


Training student network Epoch [3] (352 / 352 Steps) (batch time=0.04462s) (data time=0.00202s) (loss=2.42793): 100%|| 352/352 [00:17<00:00, 20.37it/s]
Validating student network Epoch [3] (40 / 40 Steps) (batch time=0.01679s) (loss=1.87125) (top1=62.50000) (top5=62.50000): 100%|| 40/40 [00:01<00:00, 39.11it/s]


 * acc@1 30.620 acc@5 50.840


Training student network Epoch [4] (352 / 352 Steps) (batch time=0.04411s) (data time=0.00256s) (loss=2.12732): 100%|| 352/352 [00:16<00:00, 20.75it/s]
Validating student network Epoch [4] (40 / 40 Steps) (batch time=0.01600s) (loss=2.25056) (top1=12.50000) (top5=50.00000): 100%|| 40/40 [00:00<00:00, 40.76it/s]


 * acc@1 35.440 acc@5 54.600
