In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import (
    BertForSequenceClassification,
    DistilBertForSequenceClassification,
    BertTokenizer,
    Trainer,
    TrainingArguments,
    glue_convert_examples_to_features,
    glue_processors,
    glue_output_modes,
    glue_tasks_num_labels,
)
from datasets import load_dataset
from sklearn.metrics import accuracy_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [19]:
# 1. 加载 SST-2 数据集
dataset = load_dataset("glue", "sst2")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize(batch):
    return tokenizer(batch["sentence"], padding="max_length", truncation=True, max_length=128)

dataset = dataset.map(tokenize, batched=True)
dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])

train_dataloader = DataLoader(dataset["train"], batch_size=192, shuffle=True)
valid_dataloader = DataLoader(dataset["validation"], batch_size=192)

In [21]:
# 2. 加载 Teacher 模型 (已微调的BERT)
teacher = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2", num_labels=2, output_hidden_states=True).to(device)

Error during conversion: ChunkedEncodingError(ProtocolError('Response ended prematurely'))


In [23]:
# 3. 加载 Student 模型 (DistilBERT)
student = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2, output_hidden_states=True).to(device)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [127]:
# 4. 知识蒸馏训练设置
def cosine_sim(hidden1, hidden2):
    # shape: [batch_size, seq_len, hidden_dim]
    sim = F.cosine_similarity(hidden1, hidden2, dim=-1)  # 每个 token 的相似度
    return sim.mean().item()

optimizer = torch.optim.Adam(student.parameters(), lr=2e-5)
# 损失函数
kl_loss_fn = nn.KLDivLoss(reduction="batchmean")
ce_loss_fn = nn.CrossEntropyLoss()
mse_loss_fn = nn.MSELoss()

# 权重系数（可以调）
alpha = 0.2   # CE loss 权重
beta = 0.3    # logit 蒸馏权重
gamma = 0.5   # hidden 蒸馏权重
temperature = 5.0

best_val_acc = 0
no_improve_count = 0
early_stop_patience = 2

In [131]:
from tqdm import tqdm
import torch.nn.functional as F

for epoch in range(15):
    student.train()
    total_loss = 0
    running_loss = 0
    
    loop = tqdm(train_dataloader, leave=False, desc=f"Epoch {epoch+1}")

    for step, batch in enumerate(loop):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        with torch.no_grad():
            teacher_outputs = teacher(input_ids=input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits
            teacher_hidden = teacher_outputs.hidden_states[-1]

        student_outputs = student(input_ids=input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits
        student_hidden = student_outputs.hidden_states[-1]

        # 三重损失
        ce_loss = ce_loss_fn(student_logits, labels)
        student_soft = nn.functional.log_softmax(student_logits / temperature, dim=1)
        teacher_soft = nn.functional.softmax(teacher_logits / temperature, dim=1)
        logit_distill = kl_loss_fn(student_soft, teacher_soft) * (temperature ** 2)
        hidden_distill = 1 - F.cosine_similarity(student_hidden, teacher_hidden, dim=-1).mean()

        loss = alpha * ce_loss + beta * logit_distill + gamma * hidden_distill
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()
        running_loss += loss.item()

        # 显示在进度条上
        loop.set_postfix(loss=loss.item())

    # ========== 添加验证集评估 ==========
    student.eval()
    correct = total = 0
    with torch.no_grad():
        for batch in valid_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = student(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_acc = correct / total
    print(f"Validation Accuracy: {val_acc:.4f}")

    # 在训练或验证时添加
    with torch.no_grad():
        teacher_outputs = teacher(input_ids=input_ids, attention_mask=attention_mask)
        student_outputs = student(input_ids=input_ids, attention_mask=attention_mask)
    
        teacher_hidden = teacher_outputs.hidden_states[-1]
        student_hidden = student_outputs.hidden_states[-1]

        sim_score = cosine_sim(student_hidden, teacher_hidden)
        print(f"Hidden State Cosine Similarity: {sim_score:.4f}")

    # ========== Early Stopping 逻辑 ==========
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        no_improve_count = 0
        print("Accuracy improved. Continuing training.")
    else:
        no_improve_count += 1
        print(f"No improvement. Patience counter: {no_improve_count}/{early_stop_patience}")

        if no_improve_count >= early_stop_patience:
            print("Early stopping triggered.")
            break

    # 清理变量和显存
    del teacher_outputs
    del student_outputs
    torch.cuda.empty_cache()
    
    print(f"Epoch {epoch+1} Completed | Avg Loss: {total_loss / len(train_dataloader):.4f}")

                                                                                                                       

Validation Accuracy: 0.9174
Hidden State Cosine Similarity: 0.9053
Accuracy improved. Continuing training.
Epoch 1 Completed | Avg Loss: 0.0461


                                                                                                                       

Validation Accuracy: 0.9174
Hidden State Cosine Similarity: 0.9047
No improvement. Patience counter: 1/2
Epoch 2 Completed | Avg Loss: 0.0454


                                                                                                                       

Validation Accuracy: 0.9163
Hidden State Cosine Similarity: 0.9041
No improvement. Patience counter: 2/2
Early stopping triggered.


In [129]:
torch.cuda.empty_cache()

In [35]:
# Evaluate teacher
teacher.eval()
correct = total = 0
with torch.no_grad():
    for batch in valid_dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = teacher(input_ids=input_ids, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Teacher Accuracy: {correct / total:.4f}")

Teacher Accuracy: 0.9243


In [37]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Teacher Params: {count_parameters(teacher) / 1e6:.2f}M")
print(f"Student Params: {count_parameters(student) / 1e6:.2f}M")

Teacher Params: 109.48M
Student Params: 66.96M


In [133]:
import time

def measure_inference_time(model, dataloader, name="Model"):
    model.eval()
    total_time = 0
    n = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            start = time.perf_counter()
            _ = model(input_ids=input_ids, attention_mask=attention_mask)
            end = time.perf_counter()

            total_time += (end - start)
            n += 1
            if n >= 100: break  # 只测100条，够用了

    print(f"{name} Avg Inference Time: {total_time / n * 1000:.2f} ms/sample")

measure_inference_time(teacher, valid_dataloader, "Teacher")
measure_inference_time(student, valid_dataloader, "Student")

Teacher Avg Inference Time: 30.95 ms/sample
Student Avg Inference Time: 7.27 ms/sample
