# Knowledge Distillation
- The concept of **knowledge distillation** is to utilize class probabilities of a higher-capacity model (teacher) as soft targets of a smaller model (student)
- The implement processes can be divided into several stages:
  1. Finish the `ResNet()` classes
  2. Train the teacher model (ResNet50) and the student model (ResNet18) from scratch, i.e. **without KD**
  3. Define the `Distiller()` class and `loss_re()`, `loss_fe()` functions
  4. Train the student model **with KD** from the teacher model in two different ways, response-based and feature based distillation
  5. Comparison of student models w/ & w/o KD

## Setup

In [1]:
! pip install torchinfo



In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset, random_split
from torchinfo import summary
from tqdm import tqdm
import sys
import numpy as np
import math
import matplotlib.pyplot as plt
import os
from PIL import Image

In [3]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

## Download dataset

In [4]:
validation_split = 0.2
batch_size = 128

# data augmentation and normalization
transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# download dataset
train_and_val_dataset = torchvision.datasets.CIFAR10(
    root='dataset/',
    train=True,
    transform=transform_train,
    download=True
)

test_dataset = torchvision.datasets.CIFAR10(
    root='dataset/',
    train=False,
    transform=transform_test,
    download=True
)

# split train and validation dataset
train_size = int((1 - validation_split) * len(train_and_val_dataset))
val_size = len(train_and_val_dataset) - train_size
train_dataset, val_dataset = random_split(train_and_val_dataset, [train_size, val_size])

# create dataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

test_num = len(test_dataset)
test_steps = len(test_loader)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to dataset/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:14<00:00, 11.9MB/s] 


Extracting dataset/cifar-10-python.tar.gz to dataset/
Files already downloaded and verified


## Create teacher and student models
### Define BottleNeck for ResNet50

In [5]:
class BottleNeck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BottleNeck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out

### Define Resifual Block

In [6]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

### Define ResNet Model

In [7]:
class ResNet(nn.Module):

    def __init__(self, block, blocks_num, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_channel = 64

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        # 1. Finish the forward pass and return the output layer as well as hidden features.
        # 2. The output layer and hidden features will be used later for distilling.
        # 3. You can refer to the ResNet structure illustration to finish it.
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        feature1 = x

        x = self.layer2(x)
        feature2 = x

        x = self.layer3(x)
        feature3 = x

        x = self.layer4(x)
        feature4 = x

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x, [feature1, feature2, feature3, feature4]

### Define ResNet50 and Resnet18

In [8]:
def resnet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

def resnet50(num_classes=10):
    return ResNet(BottleNeck, [3, 4, 6, 3], num_classes=num_classes)

## Teacher Model (ResNet50)

In [9]:
Teacher = resnet50(num_classes=10)  # commment out this line if loading trained teacher model
# Teacher = torch.load('Teacher.pt', weights_only=False)  # loading trained teacher model
Teacher = Teacher.to(device)

In [10]:
summary(Teacher)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            1,728
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BottleNeck: 2-1                   --
│    │    └─Conv2d: 3-1                  4,096
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Conv2d: 3-3                  36,864
│    │    └─BatchNorm2d: 3-4             128
│    │    └─Conv2d: 3-5                  16,384
│    │    └─BatchNorm2d: 3-6             512
│    │    └─ReLU: 3-7                    --
│    │    └─Sequential: 3-8              16,896
│    └─BottleNeck: 2-2                   --
│    │    └─Conv2d: 3-9                  16,384
│    │    └─BatchNorm2d: 3-10            128
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            128
│    │    └─Conv2d: 3-13               

## Student Model (ResNet18)

In [11]:
Student = resnet18(num_classes=10)  # commment out this line if loading trained student model
# Student = torch.load('Student.pt', weights_only=False)  # loading trained student model
Student = Student.to(device)

In [12]:
summary(Student)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            1,728
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BasicBlock: 2-1                   --
│    │    └─Conv2d: 3-1                  36,864
│    │    └─BatchNorm2d: 3-2             128
│    │    └─ReLU: 3-3                    --
│    │    └─Conv2d: 3-4                  36,864
│    │    └─BatchNorm2d: 3-5             128
│    └─BasicBlock: 2-2                   --
│    │    └─Conv2d: 3-6                  36,864
│    │    └─BatchNorm2d: 3-7             128
│    │    └─ReLU: 3-8                    --
│    │    └─Conv2d: 3-9                  36,864
│    │    └─BatchNorm2d: 3-10            128
├─Sequential: 1-6                        --
│    └─BasicBlock: 2-3                   --
│    │    └─Conv2d: 3-11                 73,728

## Define training function

In [13]:
def train_from_scratch(model, train_loader, val_loader, epochs, learning_rate, device, model_name):
    criterion = nn.CrossEntropyLoss()
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=learning_rate, weight_decay=5e-4)

    loss = []
    train_error=[]
    val_error = []
    valdation_error = []
    train_loss = []
    valdation_loss = []
    train_accuraacy = []
    valdation_accuracy= []

    for epoch in range(epochs):
        train_loss = 0.0
        valid_loss = 0.0
        train_acc = 0.0
        valid_acc = 0.0
        correct = 0.
        total = 0.
        V_correct = 0.
        V_total = 0.

        model.train()
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            logits, hidden = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
            pred = logits.data.max(1, keepdim=True)[1]
            correct += np.sum(np.squeeze(pred.eq(labels.data.view_as(pred))).cpu().numpy())
            total += images.size(0)
            train_acc =  correct/total
            train_bar.desc = "train epoch[{}/{}]".format(epoch + 1, epochs)

        model.eval()
        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                val_images, val_labels = val_images.to(device), val_labels.to(device)
                outputs, hidden_outputs = model(val_images)
                loss = criterion(outputs, val_labels)
                valid_loss += loss.item() * val_images.size(0)
                pred = outputs.data.max(1, keepdim=True)[1]
                V_correct += np.sum(np.squeeze(pred.eq(val_labels.data.view_as(pred))).cpu().numpy())
                V_total += val_images.size(0)
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)

        train_loss = train_loss / len(train_loader.dataset)
        train_error.append(train_loss)
        valid_loss = valid_loss / len(val_loader.dataset)
        val_error.append(valid_loss)
        train_accuraacy.append( correct / total)
        valdation_accuracy.append(V_correct / V_total)

        print('\tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(train_loss, valid_loss))
        print('\tTrain Accuracy: %.3fd%% (%2d/%2d)\tValdation Accuracy: %.3fd%% (%2d/%2d) '% (100. * correct / total, correct, total, 100. * V_correct / V_total, V_correct, V_total))

    torch.save(model, f'{model_name}.pt')
    print(f'{model_name}.pt is saved')

    print('Finished Training')

## Define testing function

In [14]:
def test(model, test_loader ,device, type=None):
    criterion = nn.CrossEntropyLoss()
    acc = 0.0
    test_loss = 0.0

    if type == None:
        model.eval()
    elif type == 'distiller':
        model.eval()
        model.teacher.eval()
        model.student.eval()
    else:
       raise ValueError(f'Error: only support response-based and feature-based distillation')

    with torch.no_grad():
        test_bar = tqdm(test_loader, file=sys.stdout)
        for test_data in test_bar:
            test_images, test_labels = test_data
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            if type == None:
                outputs, features = model(test_images)
                loss = criterion(outputs, test_labels)
            elif type == 'distiller':
                outputs, loss = model(test_images, test_labels)
            else:
                raise ValueError(f'Error: only support response-based and feature-based distillation')

            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, test_labels.to(device)).sum().item()
            test_loss += loss.item()
            test_bar.desc = "test"

    test_accurate = acc / test_num
    print('test_loss: %.3f  test_accuracy: %.3f' %(test_loss / test_steps, test_accurate * 100))
    return test_loss / test_steps, test_accurate * 100.

## Train Teacher and Student model from scratch

In [17]:
# Decide the epochs and learning rate
train_from_scratch(Teacher, train_loader, val_loader, epochs= 50 , learning_rate= 1e-3 , device=device, model_name="Teacher")

train epoch[1/50]: 100%|██████████| 313/313 [00:17<00:00, 18.34it/s]
valid epoch[1/50]: 100%|██████████| 79/79 [00:02<00:00, 34.16it/s]
	Training Loss: 0.338919 	Validation Loss: 0.442719
	Train Accuracy: 88.470d% (35388/40000)	Valdation Accuracy: 85.480d% (8548/10000) 
train epoch[2/50]: 100%|██████████| 313/313 [00:16<00:00, 18.88it/s]
valid epoch[2/50]: 100%|██████████| 79/79 [00:02<00:00, 34.06it/s]
	Training Loss: 0.240642 	Validation Loss: 0.437250
	Train Accuracy: 91.593d% (36637/40000)	Valdation Accuracy: 86.060d% (8606/10000) 
train epoch[3/50]: 100%|██████████| 313/313 [00:16<00:00, 18.87it/s]
valid epoch[3/50]: 100%|██████████| 79/79 [00:02<00:00, 34.29it/s]
	Training Loss: 0.210331 	Validation Loss: 0.443833
	Train Accuracy: 92.657d% (37063/40000)	Valdation Accuracy: 86.330d% (8633/10000) 
train epoch[4/50]: 100%|██████████| 313/313 [00:16<00:00, 18.92it/s]
valid epoch[4/50]: 100%|██████████| 79/79 [00:02<00:00, 34.06it/s]
	Training Loss: 0.201623 	Validation Loss: 0.449541

In [18]:
T_loss, T_accuracy = test(Teacher, test_loader, device=device)

test: 100%|██████████| 79/79 [00:01<00:00, 40.15it/s]
test_loss: 0.459  test_accuracy: 89.840


In [19]:
# Decide the epochs and learning rate
train_from_scratch(Student, train_loader, val_loader, epochs= 50 , learning_rate= 1e-3 , device=device, model_name="Student")

train epoch[1/50]: 100%|██████████| 313/313 [00:10<00:00, 30.03it/s]
valid epoch[1/50]: 100%|██████████| 79/79 [00:01<00:00, 43.86it/s]
	Training Loss: 1.409035 	Validation Loss: 1.123096
	Train Accuracy: 48.617d% (19447/40000)	Valdation Accuracy: 59.640d% (5964/10000) 
train epoch[2/50]: 100%|██████████| 313/313 [00:09<00:00, 31.70it/s]
valid epoch[2/50]: 100%|██████████| 79/79 [00:01<00:00, 43.13it/s]
	Training Loss: 1.011369 	Validation Loss: 1.268687
	Train Accuracy: 64.097d% (25639/40000)	Valdation Accuracy: 57.980d% (5798/10000) 
train epoch[3/50]: 100%|██████████| 313/313 [00:09<00:00, 31.68it/s]
valid epoch[3/50]: 100%|██████████| 79/79 [00:01<00:00, 43.52it/s]
	Training Loss: 0.823543 	Validation Loss: 0.842080
	Train Accuracy: 71.118d% (28447/40000)	Valdation Accuracy: 70.590d% (7059/10000) 
train epoch[4/50]: 100%|██████████| 313/313 [00:09<00:00, 31.73it/s]
valid epoch[4/50]: 100%|██████████| 79/79 [00:01<00:00, 43.47it/s]
	Training Loss: 0.715940 	Validation Loss: 0.752887

In [20]:
S_loss, S_accuracy = test(Student, test_loader, device=device)

test: 100%|██████████| 79/79 [00:01<00:00, 50.80it/s]
test_loss: 0.514  test_accuracy: 87.500


## Define distillation

### Define the loss functions

In [21]:
# Finish the loss function for response-based distillation.
def loss_re(student_logits, teacher_logits, target, temperature=5.0, alpha=0.9):
    # ---------------------------------------------------------
    # T (Temperature): 控制機率分佈的平滑程度。T 越高，分佈越平滑 (Dark Knowledge 越明顯)。
    # alpha: 平衡 Distillation Loss (學老師) 和 CE Loss (學正確答案) 的權重。
    # ---------------------------------------------------------

    T = temperature
    loss_student = nn.functional.cross_entropy(student_logits, target)

    distillation_loss = nn.functional.kl_div(
        nn.functional.log_softmax(student_logits / T, dim=1),
        nn.functional.softmax(teacher_logits / T, dim=1),
        reduction='batchmean'
    )


    total_loss = (1. - alpha) * loss_student + alpha * (T ** 2) * distillation_loss

    return total_loss

In [22]:
def loss_fe(student_features, teacher_features, student_logits, target, alpha=0.5):

    # 1. 計算原本的分類損失 (Cross Entropy)
    # 這是為了讓學生不要為了模仿特徵而忘記怎麼做分類
    loss_ce = nn.functional.cross_entropy(student_logits, target)

    # 2. 計算特徵蒸餾損失 (Feature Distillation Loss)
    # 使用 MSE (均方誤差) 來計算兩組特徵圖的差異
    loss_feat = 0

    # 我們遍歷每一層 (feature1 ~ feature4) 加總誤差
    # 注意：這裡假設傳進來的 student_features 已經經過「維度對齊」了！
    for s_feat, t_feat in zip(student_features, teacher_features):
        loss_feat += nn.functional.mse_loss(s_feat, t_feat)

    # 3. 結合兩者
    total_loss = (1. - alpha) * loss_ce + alpha * loss_feat

    return total_loss

### Define Distillation Framework

In [23]:
class Distiller(nn.Module):
    def __init__(self, teacher, student, type):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student
        self.type = type

        # ----------------------------------------------------------------
        # 1. Feature-Based KD 的維度對齊 (Connectors)
        # ----------------------------------------------------------------
        # 如果是 feature-based，我們需要建立 connector 來轉換維度
        # ResNet18 features channels: [64, 128, 256, 512]
        # ResNet50 features channels: [256, 512, 1024, 2048]
        # 我們用 1x1 Conv (kernel_size=1) 進行升維

        if self.type == 'feature':
            self.connectors = nn.ModuleList([
                nn.Conv2d(64, 256, kernel_size=1),   # feat1: 64 -> 256
                nn.Conv2d(128, 512, kernel_size=1),  # feat2: 128 -> 512
                nn.Conv2d(256, 1024, kernel_size=1), # feat3: 256 -> 1024
                nn.Conv2d(512, 2048, kernel_size=1)  # feat4: 512 -> 2048
            ])

        for param in self.teacher.parameters():
            param.requires_grad = False

    def forward(self, x, target):
        # ----------------------------------------------------------------
        # 2. Forward Pass
        # ----------------------------------------------------------------

        # 取得 Student 的 logits 和 features
        student_logits, student_features = self.student(x)

        # 取得 Teacher 的 logits 和 features (不需要計算梯度)
        with torch.no_grad():
            teacher_logits, teacher_features = self.teacher(x)

        # ----------------------------------------------------------------
        # 3. 計算 Loss
        # ----------------------------------------------------------------
        if self.type == 'response':
            # 呼叫之前寫好的 loss_re
            # 參數順序要對: (student_logits, teacher_logits, target)
            loss_distill = loss_re(student_logits, teacher_logits, target)

        elif self.type == 'feature':
            # 先將 Student 的 features 通過 Connector 升維
            s_features_transformed = []
            for i, connector in enumerate(self.connectors):
                s_features_transformed.append(connector(student_features[i]))

            loss_distill = loss_fe(s_features_transformed, teacher_features, student_logits, target)

        else:
            raise ValueError(f'Error: only support response-based and feature-based distillation')

        return student_logits, loss_distill

### Training function

In [24]:
def train_distillation(distiller, student, train_loader, val_loader, epochs, learning_rate, device):
    ce_loss = nn.CrossEntropyLoss()
    # define the parameter the optimizer used
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, distiller.parameters()), lr=learning_rate, weight_decay=1e-4)

    loss = []
    train_error=[]
    val_error = []
    valdation_error = []
    train_loss = []
    valdation_loss = []
    train_accuraacy = []
    valdation_accuracy= []

    for epoch in range(epochs):
        distiller.train()
        distiller.teacher.train()
        distiller.student.train()

        train_loss = 0.0
        valid_loss = 0.0
        train_acc = 0.0
        valid_acc  = 0.0
        correct = 0.
        total = 0.
        V_correct = 0.
        V_total = 0.
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs, loss = distiller(images, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            pred = outputs.data.max(1, keepdim=True)[1]
            result = pred.eq(labels.data.view_as(pred))
            result = np.squeeze(result.cpu().numpy())
            correct += np.sum(result)
            total += images.size(0)
            train_bar.desc = "train epoch[{}/{}]".format(epoch + 1, epochs)

        distiller.eval()
        distiller.teacher.eval()
        distiller.student.eval()

        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:

                val_images, val_labels = val_data
                val_images, val_labels = val_images.to(device), val_labels.to(device)

                outputs, loss = distiller(val_images, val_labels)

                valid_loss += loss.item() * val_images.size(0)
                pred = outputs.max(1, keepdim=True)[1]
                V_correct += np.sum(np.squeeze(pred.eq(val_labels.data.view_as(pred))).cpu().numpy())
                V_total += val_images.size(0)
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)

        train_loss = train_loss / len(train_loader.dataset)
        train_error.append(train_loss)
        valid_loss = valid_loss / len(val_loader.dataset)
        val_error.append(valid_loss)
        train_accuraacy.append( correct / total)
        valdation_accuracy.append(V_correct / V_total)

        print('\tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(train_loss, valid_loss))
        print('\tTrain Accuracy: %.3fd%% (%2d/%2d)\tValdation Accuracy: %.3fd%% (%2d/%2d) '% (100. * correct / total, correct, total, 100. * V_correct / V_total, V_correct, V_total))

    print('Finished Distilling')

## Response-based distillation

In [25]:
# Decide the epochs and learning rate
Student_re = resnet18(num_classes=10)
Student_re = Student_re.to(device)
distiller_re = Distiller(Teacher, Student_re, type='response')
train_distillation(distiller_re, Student_re, train_loader, val_loader, epochs= 50  , learning_rate= 1e-3 , device=device)

train epoch[1/50]: 100%|██████████| 313/313 [00:12<00:00, 24.12it/s]
valid epoch[1/50]: 100%|██████████| 79/79 [00:02<00:00, 30.45it/s]
	Training Loss: 13.766044 	Validation Loss: 11.934295
	Train Accuracy: 51.030d% (20412/40000)	Valdation Accuracy: 55.710d% (5571/10000) 
train epoch[2/50]: 100%|██████████| 313/313 [00:12<00:00, 24.88it/s]
valid epoch[2/50]: 100%|██████████| 79/79 [00:02<00:00, 30.74it/s]
	Training Loss: 8.084339 	Validation Loss: 6.595953
	Train Accuracy: 67.528d% (27011/40000)	Valdation Accuracy: 70.180d% (7018/10000) 
train epoch[3/50]: 100%|██████████| 313/313 [00:12<00:00, 25.20it/s]
valid epoch[3/50]: 100%|██████████| 79/79 [00:02<00:00, 30.00it/s]
	Training Loss: 5.815515 	Validation Loss: 6.399501
	Train Accuracy: 74.483d% (29793/40000)	Valdation Accuracy: 70.750d% (7075/10000) 
train epoch[4/50]: 100%|██████████| 313/313 [00:12<00:00, 24.97it/s]
valid epoch[4/50]: 100%|██████████| 79/79 [00:02<00:00, 31.05it/s]
	Training Loss: 4.726374 	Validation Loss: 3.9815

In [26]:
reS_loss, reS_accuracy = test(distiller_re, test_loader, type='distiller', device=device)

test: 100%|██████████| 79/79 [00:02<00:00, 32.44it/s]
test_loss: 1.360  test_accuracy: 89.460


## Feature-based distillation

In [28]:
# Decide the epochs and learning rate
Student_fe = resnet18(num_classes=10)
Student_fe = Student_fe.to(device)
distiller_fe = Distiller(Teacher, Student_fe, type='feature')
distiller_fe = distiller_fe.to(device)
train_distillation(distiller_fe, Student_fe, train_loader, val_loader, epochs= 50 , learning_rate= 1e-3 , device=device)

train epoch[1/50]: 100%|██████████| 313/313 [00:14<00:00, 22.24it/s]
valid epoch[1/50]: 100%|██████████| 79/79 [00:02<00:00, 28.79it/s]
	Training Loss: 3.196629 	Validation Loss: 2.681766
	Train Accuracy: 46.895d% (18758/40000)	Valdation Accuracy: 56.950d% (5695/10000) 
train epoch[2/50]: 100%|██████████| 313/313 [00:13<00:00, 22.39it/s]
valid epoch[2/50]: 100%|██████████| 79/79 [00:02<00:00, 28.99it/s]
	Training Loss: 2.203169 	Validation Loss: 2.061657
	Train Accuracy: 65.540d% (26216/40000)	Valdation Accuracy: 69.820d% (6982/10000) 
train epoch[3/50]: 100%|██████████| 313/313 [00:13<00:00, 22.47it/s]
valid epoch[3/50]: 100%|██████████| 79/79 [00:02<00:00, 28.80it/s]
	Training Loss: 1.853790 	Validation Loss: 2.004774
	Train Accuracy: 73.790d% (29516/40000)	Valdation Accuracy: 74.180d% (7418/10000) 
train epoch[4/50]: 100%|██████████| 313/313 [00:13<00:00, 22.37it/s]
valid epoch[4/50]: 100%|██████████| 79/79 [00:02<00:00, 28.54it/s]
	Training Loss: 1.680830 	Validation Loss: 1.802797

In [29]:
ftS_loss, ftS_accuracy = test(distiller_fe, test_loader, type='distiller', device=device)

test: 100%|██████████| 79/79 [00:02<00:00, 33.67it/s]
test_loss: 1.244  test_accuracy: 88.650


## Result and Comparison

In [30]:
print(f'Teacher from scratch: loss = {T_loss:.2f}, accuracy = {T_accuracy:.2f}')
print(f'Student from scratch: loss = {S_loss:.2f}, accuracy = {S_accuracy:.2f}')
print(f'Response-based student: loss = {reS_loss:.2f}, accuracy = {reS_accuracy:.2f}')
print(f'Featured-based student: loss = {ftS_loss:.2f}, accuracy = {ftS_accuracy:.2f}')

Teacher from scratch: loss = 0.46, accuracy = 89.84
Student from scratch: loss = 0.51, accuracy = 87.50
Response-based student: loss = 1.36, accuracy = 89.46
Featured-based student: loss = 1.24, accuracy = 88.65
