In [None]:
import os
import random
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

# Teacher network and student network
from nets.resnet import resnet32x4, resnet8x4
# Loss function of knowledge distillation KD
from loss.dkd import dkd_loss

#TensorBoard
Train_Info = "DKD : Res32x4 To Res8x4"
writer = SummaryWriter(comment=Train_Info)

In [None]:
# Set the random number seed so that it can be reproduced
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(42)

In [None]:
# GPU, replace the number in '0' with the gpu number on your device
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

## Define hyperparameters

In [None]:
T = 4               # temperature : 知识蒸馏中的温度
ALPHA = 1.0         # alpha : TCKD 部分的loss weight
BETA = 8.0          # beta : NCKD 部分的loss weight
LOSS_CE = 1.0       # loss_ce : 交叉熵的loss weight
N = 100             # num_classes : 类别数
EPOCH = 20          # epoch : 训练轮数
BATCH_SIZE = 128    # batch_size : 批处理大小 
LR = 0.05           # learning_rate : 初试学习率

# Other hyperparameters, such as momentum, weight_decay, milestones, gamma, etc. in the optimizer are generally rarely changed.
# When EPOCH changes, milestones should also change accordingly.

## Load the teacher model and define the student network

In [None]:
res32x4 = resnet32x4(num_classes=N)
ckpt = torch.load("checkpoints/teacher/ckpt_epoch_240.pth", map_location='cpu')
res32x4.load_state_dict(ckpt["model"])
res32x4 = nn.DataParallel(res32x4).cuda()
res32x4.eval()

res8x4 = resnet8x4(num_classes=N)
res8x4 = torch.nn.DataParallel(res8x4).cuda()

teacher_net = res32x4
student_net = res8x4

## Load the dataset
The first time you use it, it will be downloaded first. If the download is slow, you can manually download the dataset and drag it into the data folder

In [None]:
# Prepare the dataset and preprocess it
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 先四周填充0，在把图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),     # 图像一半的概率翻转，一半的概率不翻转
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), #R,G,B每层的归一化用到的均值和方差
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

DATA_PATH = "data"
# Training dataset
# num_workers generally depends on the performance of the cpu
trainset = torchvision.datasets.CIFAR100(root=DATA_PATH, train=True, download=True, transform=transform_train) 
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)   
# Test dataset
testset = torchvision.datasets.CIFAR100(root=DATA_PATH, train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)


## Optimizer

Generally, in various papers, the `epoch` on the cifar-100 dataset is set to `240`

The corresponding `milestones` value is `[150, 180, 210]`

Here, for demonstration only, all `epoch` are set to `40`, and the `milestones` value is `[15, 25, 35]`

In [None]:
optimizer = torch.optim.SGD(student_net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 25, 35], gamma=0.1)

## Train function and Test function

In [None]:
# Define the best Acc on the training set and test set respectively, use global to modify it as a global variable, and then update it during training
best_train_acc = 0
best_test_acc = 0

In [None]:
from tqdm import tqdm

def train(epoch):
    global best_train_acc

# Set the student model to training mode
    student_net.train()

    print('\nEpoch: %d' % epoch)

    train_loss = 0
    correct = 0
    total = 0

# Wrap trainloader with tqdm to display a progress bar
    with tqdm(trainloader, desc=f"Training Epoch {epoch}", total=len(trainloader)) as pbar:
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()

            logits_student, _ = student_net(inputs)
            with torch.no_grad():
                logits_teacher, _ = teacher_net(inputs)

# Hard loss
            ce_loss = nn.CrossEntropyLoss()(logits_student, targets)
# Soft loss
            kd_loss = loss(logits_student, logits_teacher, temperature=T)
            total_loss = ALPHA * ce_loss + BETA * kd_loss

            total_loss.backward()
            optimizer.step()

            train_loss += total_loss.item()
            _, predicted = logits_student.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

# Update TensorBoard
            writer.add_scalar('Train/Accuracy', 100. * correct / total, batch_idx + (epoch - 1) * 782)
            writer.add_scalar('Train/Loss', total_loss.item(), batch_idx + (epoch - 1) * 782)

# Use set_postfix to update the suffix of the progress bar
            pbar.set_postfix(loss=train_loss / (batch_idx + 1), acc=f"{100. * correct / total:.1f}%")

# If the accuracy on the current training set is higher than best_test_acc, update best_test_acc
    acc = 100 * correct / total
    if acc > best_train_acc:
        best_train_acc = acc


In [None]:
def test(net, epoch):
    global best_test_acc
    net.eval()

    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
# Use tqdm to wrap testloader to display a progress bar
        with tqdm(testloader, desc=f"Testing Epoch {epoch}", total=len(testloader)) as pbar:
            for batch_idx, (inputs, targets) in enumerate(pbar):

                inputs, targets = inputs.cuda(), targets.cuda()
                logits_student, _ = net(inputs)

                loss = nn.CrossEntropyLoss()(logits_student, targets)

                test_loss += loss.item()
                _, predicted = logits_student.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

# Display current loss and accuracy in the suffix of tqdm progress bar
                pbar.set_postfix(loss=test_loss / (batch_idx + 1), acc=f"{100. * correct / total:.1f}%")

# Update TensorBoard
                writer.add_scalar('Test/Accuracy', 100. * correct / total, batch_idx + (epoch - 1) * 157)
                writer.add_scalar('Test/Loss', loss.item(), batch_idx + (epoch - 1) * 157)

# Calculate the accuracy on the current test set
        acc = 100. * correct / total

# If the accuracy on the current test set is higher than best_test_acc, update best_test_acc
# And save the student model
        if acc > best_test_acc:
            print('Saving..')
            torch.save(student_net, 'checkpoints/student/kd_res8x4.pth')
            best_test_acc = acc

## train

In [None]:
for epoch in range(1, EPOCH + 1) :
    train(epoch)
    test(student_net, epoch)

# Update learning rate
    scheduler.step()
