<a href="https://colab.research.google.com/github/ruheyun/python_pytorch/blob/main/Knowledge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import time

In [None]:
# 设置超参数
BATCH_SIZE = 128
EPOCHS = 10
EPOCHS_T = 25
TEMPERATURE = 4.0  # 温度参数
ALPHA = 0.3  # 交叉熵损失和知识蒸馏损失的权重
LEARNING_RATE = 0.001

In [None]:
# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
    # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 运行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# 定义教师模型
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义学生模型
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(2, 4, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(4 * 7 * 7, 8)
        self.fc2 = nn.Linear(8, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# 计算蒸馏损失
def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):
    # 计算教师模型和学生模型的 softmax 预测（使用温度参数）
    soft_targets = F.softmax(teacher_logits / temperature, dim=1)
    soft_predictions = F.log_softmax(student_logits / temperature, dim=1)

    # 计算 KL 散度损失
    kl_loss = F.kl_div(soft_predictions, soft_targets, reduction='batchmean') * (temperature ** 2)

    # 计算标准交叉熵损失
    ce_loss = F.cross_entropy(student_logits, true_labels)

    # 组合损失
    return alpha * ce_loss + (1 - alpha) * kl_loss

In [None]:
# 训练教师模型
def train_teacher(teacher):
    optimizer = optim.Adam(teacher.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS_T):
        teacher.train()
        losses = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = teacher(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            losses += loss.item()

        print(f"Epoch [{epoch+1}/{EPOCHS_T}], Loss: {losses / len(train_loader):.4f}")

    torch.save(teacher.state_dict(), "teacher_model.pth")
    print("教师模型训练完成并已保存！")
    return teacher

In [None]:
# 训练学生模型
def train_student(student, teacher=None):
    optimizer = optim.Adam(student.parameters(), lr=LEARNING_RATE)
    if teacher:
      teacher.eval()
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        student.train()
        losses = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = student(images)
            if teacher:
              teacher_outputs = teacher(images).detach()  # 关闭教师模型的梯度计算
              loss = distillation_loss(outputs, teacher_outputs, labels, TEMPERATURE, ALPHA)
            else:
              loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            losses += loss.item()

        print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {losses / len(train_loader):.4f}")

    torch.save(student.state_dict(), "student_model_Distillation.pth" if teacher else 'student_model.pth')
    print("学生模型训练完成并已保存！")
    return student

In [None]:
# 评估模型
def evaluate_model(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'模型准确率: {100 * correct / total:.2f}%')

In [None]:
# 训练教师模型
print('训练教师模型')
start_time = time.time()
teacher = TeacherModel().to(device)
teacher_model = train_teacher(teacher)
end_time = time.time()
total_time = end_time - start_time
print(f"教师模型训练耗时: {total_time:.2f} 秒 ({total_time / 60:.2f} 分钟)")

# 评估教师模型
print("\n教师模型测试集准确率：")
evaluate_model(teacher_model)

训练教师模型
Epoch [1/25], Loss: 0.2101
Epoch [2/25], Loss: 0.0545
Epoch [3/25], Loss: 0.0382
Epoch [4/25], Loss: 0.0290
Epoch [5/25], Loss: 0.0218
Epoch [6/25], Loss: 0.0176
Epoch [7/25], Loss: 0.0143
Epoch [8/25], Loss: 0.0127
Epoch [9/25], Loss: 0.0107
Epoch [10/25], Loss: 0.0088
Epoch [11/25], Loss: 0.0065
Epoch [12/25], Loss: 0.0085
Epoch [13/25], Loss: 0.0059
Epoch [14/25], Loss: 0.0058
Epoch [15/25], Loss: 0.0056
Epoch [16/25], Loss: 0.0048
Epoch [17/25], Loss: 0.0038
Epoch [18/25], Loss: 0.0054
Epoch [19/25], Loss: 0.0024
Epoch [20/25], Loss: 0.0049
Epoch [21/25], Loss: 0.0035
Epoch [22/25], Loss: 0.0019
Epoch [23/25], Loss: 0.0040
Epoch [24/25], Loss: 0.0025
Epoch [25/25], Loss: 0.0028
教师模型训练完成并已保存！
教师模型训练耗时: 318.47 秒 (5.31 分钟)

教师模型测试集准确率：
模型准确率: 99.16%


In [None]:
# 训练学生模型
print('训练学生模型')
start_time = time.time()
student = StudentModel().to(device)
student_model = train_student(student)
end_time = time.time()
total_time = end_time - start_time
print(f"学生模型训练耗时: {total_time:.2f} 秒 ({total_time / 60:.2f} 分钟)")

print("\n学生模型测试集准确率：")
evaluate_model(student_model)

训练学生模型
Epoch [1/10], Loss: 1.1992
Epoch [2/10], Loss: 0.4469
Epoch [3/10], Loss: 0.3685
Epoch [4/10], Loss: 0.3226
Epoch [5/10], Loss: 0.2928
Epoch [6/10], Loss: 0.2698
Epoch [7/10], Loss: 0.2501
Epoch [8/10], Loss: 0.2340
Epoch [9/10], Loss: 0.2206
Epoch [10/10], Loss: 0.2123
学生模型训练完成并已保存！
学生模型训练耗时: 118.36 秒 (1.97 分钟)

学生模型测试集准确率：
模型准确率: 93.97%


In [None]:
# 训练学生模型（使用知识蒸馏）
print('训练学生模型（使用知识蒸馏）')
start_time = time.time()
student = StudentModel().to(device)
student_model_distillation = train_student(student, teacher=teacher_model)
end_time = time.time()
total_time = end_time - start_time
print(f"学生模型(蒸馏)训练耗时: {total_time:.2f} 秒 ({total_time / 60:.2f} 分钟)")

print("\n学生模型测试集准确率：")
evaluate_model(student_model_distillation)

训练学生模型（使用知识蒸馏）
Epoch [1/10], Loss: 14.0457
Epoch [2/10], Loss: 4.7734
Epoch [3/10], Loss: 3.4057
Epoch [4/10], Loss: 2.7472
Epoch [5/10], Loss: 2.3705
Epoch [6/10], Loss: 2.1313
Epoch [7/10], Loss: 1.9584
Epoch [8/10], Loss: 1.8319
Epoch [9/10], Loss: 1.7312
Epoch [10/10], Loss: 1.6516
学生模型训练完成并已保存！
学生模型(蒸馏)训练耗时: 122.30 秒 (2.04 分钟)

学生模型测试集准确率：
模型准确率: 96.00%
