In [2]:
import os
os.environ["http_proxy"] = "http://127.0.0.1:7897"
os.environ["https_proxy"] = "http://127.0.0.1:7897"

import torch
from transformers import BertModel, DistilBertModel, DistilBertForQuestionAnswering
from transformers import DistilBertTokenizer, BertTokenizer
import torch.nn as nn

# 加载BERT教师模型与DistilBERT学生模型
teacher_model = BertModel.from_pretrained('bert-base-uncased')
student_model = DistilBertModel.from_pretrained('distilbert-base-uncased')

# 使用相同的Tokenizer进行词汇预处理
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
distil_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# 定义输入文本并进行编码
text = "Machine reading comprehension is essential for question-answering."
inputs = tokenizer(text, return_tensors="pt")
distil_inputs = distil_tokenizer(text, return_tensors="pt")

# 获取教师模型输出
with torch.no_grad():
    teacher_outputs = teacher_model(**inputs).last_hidden_state

# 学生模型的前向传播
student_outputs = student_model(**distil_inputs).last_hidden_state

# 定义蒸馏损失函数：使用均方误差（MSE）对齐学生与教师模型的输出
distillation_loss = nn.MSELoss()(student_outputs, teacher_outputs)

# 打印蒸馏损失
print("Distillation Loss:", distillation_loss.item())


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Distillation Loss: 0.06905972212553024


KL散度评估

In [4]:
import torch.nn.functional as F
def kl_distillation_loss(student_logits, teacher_logits, T=2.0):

    # 计算软概率分布
    # 注意：F.kl_div 的输入期望是 log_softmax，目标是 softmax
    p_teacher = F.softmax(teacher_logits / T, dim=-1)
    p_student = F.log_softmax(student_logits / T, dim=-1)
    
    # 计算 KL 散度
    # reduction='batchmean' 是数学上标准的 KL 散度计算方式
    loss = F.kl_div(p_student, p_teacher, reduction='batchmean') * (T ** 2)
    return loss

# 4. 计算并打印损失
loss = kl_distillation_loss(student_outputs, teacher_outputs, T=2.0)
print("KL Distillation Loss:", loss.item())

KL Distillation Loss: 0.3852507472038269


循环蒸馏效果

In [5]:
from torch.optim import AdamW
from tqdm import tqdm
# 这里也可以用其他优化器
optimizer = AdamW(student_model.parameters(), lr=1e-5)
texts = ["Machine learning is the study of algorithms.",
         "Natural Language Processing involves understanding human languages."]
labels = ["It is a subset of AI.", "A field in AI focusing on language."]

# 蒸馏训练循环
for epoch in range(3):
    print(f"Epoch {epoch + 1}")
    total_loss = 0
    for text, label in zip(texts, labels):
        # 准备输入
        inputs = tokenizer(text, return_tensors="pt")
        distil_inputs = distil_tokenizer(text, return_tensors="pt")
        
        # 获取教师模型输出
        with torch.no_grad():
            teacher_outputs = teacher_model(**inputs).last_hidden_state
        
        # 获取学生模型输出
        student_outputs = student_model(**distil_inputs).last_hidden_state
        
        # 计算蒸馏损失
        loss = nn.MSELoss()(student_outputs, teacher_outputs)
        
        # 反向传播与优化
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # 记录损失
        total_loss += loss.item()
    
    avg_loss = total_loss / len(texts)
    print(f"Average Distillation Loss: {avg_loss:.4f}")

Epoch 1
Average Distillation Loss: 0.0708
Epoch 2
Average Distillation Loss: 0.0597
Epoch 3
Average Distillation Loss: 0.0515
