<a href="https://colab.research.google.com/github/xuwangfmc/dlbook/blob/main/modelcompression/KnowledgeDistillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 知识蒸馏

知识蒸馏（Knowledge Distillation）的思想是用原始高精度的大模型指导小模型进行训练，最后只采用小模型，从而使得保持一定精度的同时减少了参数。知识蒸馏主要分为logits学习和features学习两种。Logits方法的上限较低，而feature的方法与模型结构和任务本身强相关。

![KnowledgeDistillation.png](https://s2.loli.net/2022/01/22/LoZBtdq1AnzbhwX.png)

该教程主要介绍知识蒸馏有效的原因，以及在实际运用中如何用大模型指导小模型进行训练。


## 知识蒸馏有效的原因  
- 例如当data不是很干净的时候，对一般的model来说就是noise,只会干扰学习。透过去学习其它大model预测的logits会比较好。  
-label和label之间可能有关联，这可以引导小model去学习。例如数字5可能与4，6有关系。  
-弱化已经学习不错的target，避免让其gradient干扰其它还没学好的task。  


## 知识蒸馏后的模型有时会比原来的模型效果还要好的原因
因为常见的标签是one-hot的，而TeacherNet给到的是含有其它可能性的标签。即如果mnist中’1’的图片，我们人为打标签，会标记为{‘1’:1.0}, 而TeacherNet的输出会是{‘1’:0.7,’7’:0.2,’9’:0.1}等形式，从而让StudentNet学到一些暗知识。
在实际操作中，对于TeacherNet最后的softmax，通常会进行Temperature操作，原来的softmax通过自然指数后输出接近one-hot的向量，而Teaperature则会软化输出接近one-hot的特性，从而让StudentNet学习到暗知识，如下图所示，Temperature通常取大于1的值，数值T越接近0，概率最大类别输出越接近1，其它越接近零。
知识蒸馏中“蒸馏”的意思就是在训练时用较大Temperature，让StudentNet学习到TeacherNet的分布。训练结束后，StudentNet又用T=1，即经典的softmax输出。Temperature的由高变低就是抽象的“蒸馏”过程。


## 实战实例
Loss构建为Teacher与Student之间的KL散度，加上Student与Ground Truth的损失，用$\alpha$调整权重。
$Loss = \alpha T^2 \times KL(\frac{\text{Teacher's Logits}}{T} || \frac{\text{Student's Logits}}{T}) + (1-\alpha)(\text{Original Loss})$ 

步骤1：加载数据集

In [1]:
# Download dataset
!gdown --id '1O6pFYd9aw1cZbry-NXk3k3tTXLVgssIg' --output food-11.zip
# Unzip the files
!unzip -q food-11.zip

Downloading...
From: https://drive.google.com/uc?id=1O6pFYd9aw1cZbry-NXk3k3tTXLVgssIg
To: /content/food-11.zip
100% 277M/277M [00:01<00:00, 148MB/s]


步骤2：加载StudentNet

In [2]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
# 运行之前的Architecture_Design文件
!gdown --id '1-sSaAOk3vnmfZv8F4Vo_YhdrHkTNn7bh' --output "Architecture_Design.ipynb"
%run "Architecture_Design.ipynb"

Downloading...
From: https://drive.google.com/uc?id=1-sSaAOk3vnmfZv8F4Vo_YhdrHkTNn7bh
To: /content/Architecture_Design.ipynb
  0% 0.00/7.62k [00:00<?, ?B/s]100% 7.62k/7.62k [00:00<00:00, 12.4MB/s]


步骤3：执行代码

In [3]:
import torchvision.models as models
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
# 下载TeacherNet的参数
!gdown --id '1B8ljdrxYXJsZv2vmTequdPOofp3VF3NN' --output teacher_resnet18.bin
class FoodDataset(torch.utils.data.Dataset):

    def __init__(self, dir_path, transform, cuda=False):
        self.cuda = cuda
        self.transform = transform
        self.x = []
        self.y = []
        img_names = sorted(os.listdir(dir_path))
        for img_name in img_names:  # glob返回匹配到的所有文件的路径
            img_path = os.path.join(dir_path, img_name)
            label = int(img_name.split("_")[0])

            image = Image.open(img_path)
            # Get File Descriptor
            image_fp = image.fp
            image.load()
            # Close File Descriptor (or it'll reach OPEN_MAX)
            image_fp.close()

            self.x.append(image)
            self.y.append(label)

    def __getitem__(self, idx):
        image = self.transform(self.x[idx])
        label = torch.torch.tensor(self.y[idx], dtype=torch.int64)
        if self.cuda:
            image = image.cuda()
            label = label.cuda()
        return image, label
    
    def __len__(self):
        return len(self.x)


trainTransform = transforms.Compose([
    transforms.RandomCrop(256, pad_if_needed=True, padding_mode='symmetric'),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])
testTransform = transforms.Compose([
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])


def get_dataloader(dir_path='/data',mode='training', batch_size=32, cuda=False):

    assert mode in ['training', 'testing', 'validation']

    dataset = FoodDataset(
        f'{dir_path}',
        transform=trainTransform if mode == 'training' else testTransform, cuda=cuda)

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(mode == 'training'))

    return dataloader

# config
batch_size = 4
cuda = True
epochs = 15



# 定义Knowledge Distillation中的损失函数
def loss_fn_kd(outputs, labels, teacher_outputs, T=20, alpha=0.5):
    # 一般的Cross Entropy
    hard_loss = F.cross_entropy(outputs, labels) * (1. - alpha)
    # 让student的logits做log_softmax后对目标概率(teacher的logits/T后softmax)做KL Divergence。
    soft_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T)
    return hard_loss + soft_loss

# 运行一个epoch
def run_epoch(dataloader, update=True, alpha=0.5):
    total_num, total_hit, total_loss = 0, 0, 0
    for now_step, batch_data in enumerate(dataloader):
        # 清空梯度
        optimizer.zero_grad()
        # 获取数据
        inputs, hard_labels = batch_data
        # Teacher不用反向传播，所以使用torch.no_grad()
        with torch.no_grad():
            soft_labels = teacher_net(inputs)

        if update:
            logits = student_net(inputs)
            # 使用前面定义的融合soft label&hard label的损失函数：loss_fn_kd，T=20是原论文设定的参数值
            loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)
            loss.backward()
            optimizer.step()
        else:
            # 只是做validation的话，就不用计算梯度
            with torch.no_grad():
                logits = student_net(inputs)
                loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)

        total_hit += torch.sum(torch.argmax(logits, dim=1) == hard_labels).item()
        total_num += len(inputs)

        total_loss += loss.item() * len(inputs)
    return total_loss / total_num, total_hit / total_num


if __name__ == '__main__':
    # 加载数据
    train_dataloader = get_dataloader('training','training', batch_size, cuda)
    valid_dataloader = get_dataloader('validation','validation', batch_size, cuda)
    print('Data Loaded')

    # 加载网络
    teacher_net = models.resnet18(pretrained=False, num_classes=11)
    teacher_net.load_state_dict(torch.load(f'./teacher_resnet18.bin'))
    student_net = StudentNet(base=16)
    if cuda:
        teacher_net = teacher_net.cuda()
        student_net = student_net.cuda()
    print('Model Loaded')

    # 开始训练(Knowledge Distillation)
    print('Training Started')
    optimizer = optim.AdamW(student_net.parameters(), lr=1e-3)
    teacher_net.eval()
    now_best_acc = 0
    for epoch in range(epochs):
        student_net.train()
        train_loss, train_acc = run_epoch(train_dataloader, update=True)
        student_net.eval()
        valid_loss, valid_acc = run_epoch(valid_dataloader, update=False)

        # 存下最好的model
        if valid_acc > now_best_acc:
            now_best_acc = valid_acc
            torch.save(student_net.state_dict(), './student_model.bin')
        print('epoch {:>3d}: train loss: {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}'.format(
            epoch, train_loss, train_acc, valid_loss, valid_acc))


Downloading...
From: https://drive.google.com/uc?id=1B8ljdrxYXJsZv2vmTequdPOofp3VF3NN
To: /content/teacher_resnet18.bin
100% 44.8M/44.8M [00:00<00:00, 123MB/s] 
Data Loaded
Model Loaded
Training Started
epoch   0: train loss: 20.8721, acc 0.2838 valid loss: 20.6007, acc 0.2108
epoch   1: train loss: 20.2215, acc 0.3128 valid loss: 19.5978, acc 0.2719
epoch   2: train loss: 19.7559, acc 0.3347 valid loss: 20.3388, acc 0.2911
epoch   3: train loss: 18.8262, acc 0.3616 valid loss: 17.9361, acc 0.3276
epoch   4: train loss: 18.4149, acc 0.3745 valid loss: 19.0117, acc 0.3303
epoch   5: train loss: 18.0903, acc 0.3931 valid loss: 19.5524, acc 0.3312
epoch   6: train loss: 18.1428, acc 0.3969 valid loss: 18.5035, acc 0.3239
epoch   7: train loss: 17.6772, acc 0.4155 valid loss: 17.8237, acc 0.3485
epoch   8: train loss: 17.0190, acc 0.4275 valid loss: 16.4812, acc 0.3932
epoch   9: train loss: 16.0832, acc 0.4449 valid loss: 18.4974, acc 0.3859
epoch  10: train loss: 15.9349, acc 0.4602 vali