In [1]:
# # SST-2 数据集上的知识蒸馏：DeBERTa-v3-base (Teacher) -> TinyBERT-6L (Student)

In [2]:
# 确保安装了必要的库
!pip install evaluate


Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting fsspec>=2021.05.0 (from fsspec[http]>=2021.05.0->evaluate)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.whl (183 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, evaluate
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2024.10.0 requi

In [3]:
# 导入核心库
import logging
import os
import sys
from typing import Dict, Any, List # 用于类型提示

import evaluate
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,  # 用于评估阶段
    PreTrainedTokenizerBase,
    BatchEncoding # 确保导入
)


2025-04-26 07:28:03.133653: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745652483.321847      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745652483.376005      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
# 配置日志记录器
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.StreamHandler(sys.stdout) # 直接输出到标准输出
    ],
    force=True # 强制覆盖现有配置
)
logger = logging.getLogger(__name__)


In [5]:
# 检查并设置计算设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"使用设备: {device}")


2025-04-26 07:28:20,028 [INFO] 使用设备: cuda


In [6]:
# 创建必要的输出目录 (使用 sst2 后缀区分)
os.makedirs("./result_sst2", exist_ok=True)
os.makedirs("./logs_sst2", exist_ok=True)
os.makedirs("./teacher_checkpoints_sst2", exist_ok=True) # 假设已有 SST-2 微调教师模型
os.makedirs("./distill_checkpoints_tinybert_sst2_pretokenized", exist_ok=True) # 蒸馏检查点
os.makedirs("./eval_output_tinybert_sst2_pretokenized", exist_ok=True) # 评估输出
os.makedirs("./distill_logs_tinybert_sst2_pretokenized", exist_ok=True) # TensorBoard 日志


In [7]:
# ==============================================================================
# ## 第零步：定义模型和分词器标识符
# ==============================================================================


In [8]:
# 定义教师和学生模型的 Hugging Face Hub ID
teacher_model_id: str = 'microsoft/deberta-v3-base'
student_model_id: str = "huawei-noah/TinyBERT_General_6L_768D"

# *** 关键：修改为在 SST-2 上微调好的教师模型路径 ***
# 你需要提供一个实际存在的、在 SST-2 上微调过的 DeBERTa-v3-base 模型路径
# 例如: teacher_model_finetuned_sst2_path: str = 'path/to/your/deberta-v3-base-finetuned-sst2'
# 如果没有，你需要先微调一个教师模型，或者使用 Hugging Face Hub 上已有的 SST-2 微调模型
teacher_model_finetuned_sst2_path: str = '/kaggle/input/deberta-v3-base-finetuned-sst2/deberta-v3-base-finetuned-sst2' # <--- 修改这里！
logger.warning(f"请确保教师模型路径 '{teacher_model_finetuned_sst2_path}' 指向一个在 SST-2 上微调好的模型！")

# 定义最终学生模型蒸馏后的保存路径
final_student_model_path: str = "tinybert-student-distilled-sst2-pretokenized" # 使用 sst2 后缀

# 定义分词器的最大序列长度 (SST-2 句子较短，可以适当减小，但 512 也能工作)
MAX_LENGTH: int = 128 # 调整为更适合 SST-2 的长度




In [9]:
# ==============================================================================
# ## 第一步：加载数据集、分词器，并进行 *预分词*
# ==============================================================================


In [10]:
# 1. 加载 SST-2 数据集 (来自 GLUE)
logger.info("加载 SST-2 (GLUE) 数据集...")
try:
    # SST-2 是 GLUE 基准的一部分
    sst2_dataset: DatasetDict = load_dataset("glue", "sst2")
except Exception as e:
    logger.error(f"加载 SST-2 数据集失败: {e}")
    raise

# 查看数据集结构
print("\nSST-2 数据集信息:")
print(sst2_dataset)


2025-04-26 07:28:20,119 [INFO] 加载 SST-2 (GLUE) 数据集...


README.md:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]


SST-2 数据集信息:
DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})


In [11]:
# 2. 加载教师模型 (DeBERTa) 的分词器
logger.info(f"加载教师模型的分词器: {teacher_model_id}")
try:
    teacher_tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(teacher_model_id, use_fast=True)
except Exception as e:
    logger.error(f"加载教师分词器 '{teacher_model_id}' 失败: {e}")
    raise


2025-04-26 07:28:30,171 [INFO] 加载教师模型的分词器: microsoft/deberta-v3-base


tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]



In [12]:
# 3. 加载学生模型 (TinyBERT) 的分词器
logger.info(f"加载学生模型的分词器: {student_model_id}")
try:
    student_tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(student_model_id, use_fast=True)
    # 设置学生分词器的最大长度
    student_tokenizer.model_max_length = MAX_LENGTH
except Exception as e:
    logger.error(f"加载学生分词器 '{student_model_id}' 失败: {e}")
    raise


2025-04-26 07:28:34,401 [INFO] 加载学生模型的分词器: huawei-noah/TinyBERT_General_6L_768D


config.json:   0%|          | 0.00/390 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [13]:
# 4. 定义预分词函数 (适配 SST-2 的 "sentence" 列)
logger.info("定义预分词函数 (适配 SST-2)...")
def tokenize_for_distillation_sst2(examples: Dict[str, List]) -> Dict[str, List]:
    """
    使用教师和学生分词器同时处理 SST-2 文本 ("sentence" 列)，并添加前缀区分。
    """
    # *** 修改：使用 "sentence" 列 ***
    text_input = examples["sentence"]
    # ------------------------------

    # 使用教师分词器处理
    teacher_encoding = teacher_tokenizer(
        text_input, truncation=True, max_length=MAX_LENGTH, padding=False
    )
    # 使用学生分词器处理
    student_encoding = student_tokenizer(
        text_input, truncation=True, max_length=MAX_LENGTH, padding=False
    )

    # 构建新的样本字典
    processed: Dict[str, List] = {}
    # SST-2 的标签列名也是 "label"
    processed["labels"] = examples["label"]
    for k, v in teacher_encoding.items():
        processed[f"teacher_{k}"] = v
    for k, v in student_encoding.items():
        processed[f"student_{k}"] = v
    return processed


2025-04-26 07:28:37,402 [INFO] 定义预分词函数 (适配 SST-2)...


In [14]:
# 5. 对整个数据集应用预分词函数
logger.info("对 SST-2 数据集进行预分词 (可能需要一些时间)...")
# 获取原始列名，以便移除它们
original_columns = sst2_dataset["train"].column_names
logger.info(f"SST-2 原始列名: {original_columns}") # 通常是 ['sentence', 'label', 'idx']

tokenized_datasets: DatasetDict = sst2_dataset.map(
    tokenize_for_distillation_sst2, # 使用适配 SST-2 的函数
    batched=True,
    remove_columns=original_columns, # 移除所有原始列
    num_proc=os.cpu_count() // 2 if os.cpu_count() else 1 # 使用多进程加速
)
logger.info("SST-2 数据集预分词完成。")

# 查看一个处理后的样本结构
print("\n预分词后的 SST-2 训练集样本示例:")
print(tokenized_datasets["train"][0])


2025-04-26 07:28:37,424 [INFO] 对 SST-2 数据集进行预分词 (可能需要一些时间)...
2025-04-26 07:28:37,425 [INFO] SST-2 原始列名: ['sentence', 'label', 'idx']


Map (num_proc=2):   0%|          | 0/67349 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/872 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/1821 [00:00<?, ? examples/s]

2025-04-26 07:28:43,850 [INFO] SST-2 数据集预分词完成。

预分词后的 SST-2 训练集样本示例:
{'labels': 0, 'teacher_input_ids': [1, 5021, 353, 69422, 292, 262, 14019, 2339, 2], 'teacher_token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0], 'teacher_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1], 'student_input_ids': [101, 5342, 2047, 3595, 8496, 2013, 1996, 18643, 3197, 102], 'student_token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'student_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [15]:
# *** 自定义 Data Collator 类定义 (无需修改) ***
import torch
from transformers import PreTrainedTokenizerBase, BatchEncoding
from typing import List, Dict, Any

class DistillationDataCollator:
    """
    自定义数据整理器，用于预分词蒸馏。
    分别使用教师和学生分词器填充对应的输入。
    """
    def __init__(self, teacher_tokenizer: PreTrainedTokenizerBase, student_tokenizer: PreTrainedTokenizerBase):
        self.teacher_tokenizer = teacher_tokenizer
        self.student_tokenizer = student_tokenizer

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # 1. 分离教师、学生输入和标签
        teacher_features = []
        student_features = []
        labels = []

        # 确定教师和学生模型期望的标准输入键名
        teacher_keys_to_extract = ["input_ids", "attention_mask"]
        if "token_type_ids" in self.teacher_tokenizer.model_input_names:
            teacher_keys_to_extract.append("token_type_ids")
        student_keys_to_extract = ["input_ids", "attention_mask"]
        if "token_type_ids" in self.student_tokenizer.model_input_names:
            student_keys_to_extract.append("token_type_ids")

        for feature in features:
            # 提取标签 (确保键名是 'labels')
            if "labels" in feature:
                labels.append(feature["labels"])
            else:
                logger.warning("在 collate 过程中未找到 'labels' 键。")
                pass

            # 提取教师输入
            t_feat = {}
            for standard_key in teacher_keys_to_extract:
                prefixed_key = f"teacher_{standard_key}"
                if prefixed_key in feature:
                    t_feat[standard_key] = feature[prefixed_key]
            if t_feat:
                teacher_features.append(t_feat)

            # 提取学生输入
            s_feat = {}
            for standard_key in student_keys_to_extract:
                prefixed_key = f"student_{standard_key}"
                if prefixed_key in feature:
                    s_feat[standard_key] = feature[prefixed_key]
            if s_feat:
                student_features.append(s_feat)

        if not teacher_features or not student_features:
            if not features: return {}
            raise ValueError("未能从特征中提取教师或学生输入。Features sample keys: " + str(features[0].keys()))

        # 2. 使用各自的分词器进行填充
        try:
            padded_teacher_batch = self.teacher_tokenizer.pad(
                teacher_features, padding=True, return_tensors="pt",
            )
        except Exception as e:
            logger.error(f"教师分词器填充失败: {e}. Teacher features keys: {[f.keys() for f in teacher_features[:2]]}")
            raise e
        try:
            padded_student_batch = self.student_tokenizer.pad(
                student_features, padding=True, return_tensors="pt",
            )
        except Exception as e:
            logger.error(f"学生分词器填充失败: {e}. Student features keys: {[f.keys() for f in student_features[:2]]}")
            raise e

        # 3. 组合最终的批次字典
        batch = {}
        if labels:
            batch["labels"] = torch.tensor(labels, dtype=torch.long)
        for k, v in padded_teacher_batch.items():
            batch[f"teacher_{k}"] = v
        for k, v in padded_student_batch.items():
            batch[f"student_{k}"] = v
        return batch


In [16]:
# 6. 划分训练集、验证集 (SST-2 自带验证集)
# *** 修改：直接使用 SST-2 的 train 和 validation 集 ***
train_dataset = tokenized_datasets["train"]
val_dataset = tokenized_datasets["validation"] # 使用 SST-2 自带的验证集
# test_dataset_pretokenized = tokenized_datasets["test"] # SST-2 的 test 集通常无标签，不用于此阶段

logger.info(f"训练集大小: {len(train_dataset)}")
logger.info(f"验证集大小: {len(val_dataset)}")
# logger.info(f"用于训练中评估的测试集大小: {len(test_dataset_pretokenized)}") # 不再需要


2025-04-26 07:28:43,902 [INFO] 训练集大小: 67349
2025-04-26 07:28:43,903 [INFO] 验证集大小: 872


In [17]:
# 7. 数据整理器 (Data Collator)
#    使用自定义的 DistillationDataCollator 处理预分词数据。
logger.info("使用自定义的 DistillationDataCollator 处理预分词数据。")
data_collator = DistillationDataCollator(
    teacher_tokenizer=teacher_tokenizer,
    student_tokenizer=student_tokenizer
)


2025-04-26 07:28:43,924 [INFO] 使用自定义的 DistillationDataCollator 处理预分词数据。


In [18]:
# 8. 准备评估指标计算函数 (Accuracy 适用于 SST-2)
logger.info("加载 'accuracy' 评估指标...")
try:
    # 可以同时加载多个指标
    accuracy_metric = evaluate.load("accuracy")
    # glue_metric = evaluate.load("glue", "sst2") # 或者直接加载 GLUE 的 SST-2 指标
except Exception as e:
    logger.error(f"加载评估指标失败: {e}")
    raise

def compute_metrics(eval_pred):
    """计算准确率指标"""
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.detach().cpu().numpy()

    predictions = np.argmax(predictions, axis=1)
    # 计算 accuracy
    acc_result = accuracy_metric.compute(predictions=predictions, references=labels)
    # 如果加载了 GLUE 指标，可以这样计算:
    # glue_result = glue_metric.compute(predictions=predictions, references=labels)
    # return glue_result
    return acc_result # 返回 accuracy 结果


2025-04-26 07:28:43,947 [INFO] 加载 'accuracy' 评估指标...


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [19]:
# ==============================================================================
# ## 第二步：加载微调好的教师模型 (DeBERTa-v3-base on SST-2)
# ==============================================================================


In [20]:
# *** 修改：使用 SST-2 的教师模型路径 ***
logger.info(f"检查微调好的 SST-2 教师模型路径: {teacher_model_finetuned_sst2_path}")
if not os.path.exists(teacher_model_finetuned_sst2_path):
    logger.error(f"错误：未找到微调好的 SST-2 教师模型路径 '{teacher_model_finetuned_sst2_path}'。")
    logger.error("请确保教师模型已在 SST-2 上微调并保存，或修改路径。")
    raise FileNotFoundError(f"SST-2 教师模型未在路径 {teacher_model_finetuned_sst2_path} 找到")
else:
    logger.info(f"加载微调好的 SST-2 教师模型: {teacher_model_finetuned_sst2_path}")
    try:
        teacher_model_for_distill = AutoModelForSequenceClassification.from_pretrained(
            teacher_model_finetuned_sst2_path, num_labels=2 # SST-2 是二分类
        )
        logger.info("成功加载微调教师模型。")
    except Exception as e:
        logger.error(f"加载微调教师模型 '{teacher_model_finetuned_sst2_path}' 失败: {e}")
        raise


2025-04-26 07:28:46,041 [INFO] 检查微调好的 SST-2 教师模型路径: /kaggle/input/deberta-v3-base-finetuned-sst2/deberta-v3-base-finetuned-sst2
2025-04-26 07:28:46,046 [INFO] 加载微调好的 SST-2 教师模型: /kaggle/input/deberta-v3-base-finetuned-sst2/deberta-v3-base-finetuned-sst2
2025-04-26 07:28:46,787 [INFO] 成功加载微调教师模型。


In [21]:
# ==============================================================================
# ## 第三步：加载预训练的学生模型 (TinyBERT)
# ==============================================================================


In [22]:
logger.info(f"加载预训练学生模型 ({student_model_id}) 用于蒸馏...")
try:
    student_model = AutoModelForSequenceClassification.from_pretrained(
        student_model_id,
        num_labels=2, # 匹配 SST-2 任务
        ignore_mismatched_sizes=True
    )
    logger.info("成功加载 TinyBERT 学生模型。")
except Exception as e:
    logger.error(f"加载学生模型 {student_model_id} 失败: {e}")
    raise


2025-04-26 07:28:46,893 [INFO] 加载预训练学生模型 (huawei-noah/TinyBERT_General_6L_768D) 用于蒸馏...


pytorch_model.bin:   0%|          | 0.00/287M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_6L_768D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


2025-04-26 07:28:49,504 [INFO] 成功加载 TinyBERT 学生模型。


In [23]:
# ==============================================================================
# ## 第四步：设置并运行知识蒸馏训练 (使用预分词数据)
# ==============================================================================


In [24]:
# 1. 定义蒸馏训练参数类 (无需修改)
class DistillationTrainingArguments(TrainingArguments):
    """
    自定义训练参数，添加蒸馏所需的 alpha 和 temperature。
    """
    def __init__(self, *args, alpha: float = 0.5, temperature: float = 2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature


In [25]:
# 2. 定义蒸馏训练器 (无需修改 compute_loss)
class PreTokenizedDistillationTrainer(Trainer):
    """
    处理预分词数据的蒸馏训练器。
    """
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        if teacher_model is None:
            raise ValueError("必须提供教师模型 (teacher_model)。")
        self.teacher = teacher_model
        self._move_model_to_device(self.teacher, self.args.device)
        self.teacher.eval()

    def compute_loss(self, model, inputs: Dict[str, torch.Tensor], return_outputs=False,**kwargs):
        """
        重写损失计算函数以处理预分词数据和蒸馏逻辑。
        'inputs' 字典包含由 DistillationDataCollator 处理好的张量。
        """
        labels = inputs.get("labels")
        if labels is not None:
            labels = labels.to(self.args.device)

        # --- 学生模型前向传播 ---
        student_input_keys = [k for k in inputs if k.startswith("student_")]
        student_inputs = {k.replace("student_", ""): v.to(self.args.device) for k, v in inputs.items() if k in student_input_keys}
        if labels is not None:
            student_inputs["labels"] = labels
        outputs_student = model(**student_inputs)
        student_loss = outputs_student.loss if hasattr(outputs_student, "loss") else None
        logits_student = outputs_student.logits

        # --- 教师模型前向传播 ---
        teacher_input_keys = [k for k in inputs if k.startswith("teacher_")]
        teacher_inputs = {k.replace("teacher_", ""): v.to(self.args.device) for k, v in inputs.items() if k in teacher_input_keys}
        with torch.no_grad():
            outputs_teacher = self.teacher(**teacher_inputs)
        logits_teacher = outputs_teacher.logits.to(self.args.device)

        # --- 计算蒸馏损失 (KL 散度) ---
        temperature = self.args.temperature
        alpha = self.args.alpha
        loss_fct_kl = nn.KLDivLoss(reduction="batchmean", log_target=False)
        log_softmax_student = F.log_softmax(logits_student.float() / temperature, dim=-1)
        softmax_teacher = F.softmax(logits_teacher.float() / temperature, dim=-1)
        distillation_loss = loss_fct_kl(
            log_softmax_student,
            softmax_teacher
        ) * (temperature ** 2)

        # --- 组合损失 ---
        if student_loss is not None:
            loss = alpha * student_loss + (1.0 - alpha) * distillation_loss.to(student_loss.dtype)
        else:
            loss = distillation_loss

        return (loss, outputs_student) if return_outputs else loss


In [26]:
# 3. 配置蒸馏训练参数 (适配 SST-2)
# *** 修改：更新输出目录名 ***
distill_output_dir: str = "./distill_checkpoints_tinybert_sst2_pretokenized"
distill_logging_dir: str = './distill_logs_tinybert_sst2_pretokenized'

# 使用自定义的 DistillationTrainingArguments
distillation_args = DistillationTrainingArguments(
    output_dir=distill_output_dir,
    # SST-2 较小，可以减少训练轮数或步数
    num_train_epochs=3,             # 保持 3 轮，或根据需要调整
    # max_steps=1,               # 或者设置最大步数
    per_device_train_batch_size=32, # 可以适当增大批次大小
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=1,
    learning_rate=5e-5,             # 可能需要微调
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_dir=distill_logging_dir,
    logging_strategy="steps",
    logging_steps=50,              # SST-2 步数少，日志频率可以提高
    eval_strategy="epoch",         # SST-2 较小，可以按 epoch 评估
    # eval_steps=100,                # 如果按步数评估，减少步数
    save_strategy="epoch",         # 按 epoch 保存
    # save_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy", # 使用 accuracy
    greater_is_better=True,
    save_total_limit=2,
    fp16=torch.cuda.is_available(),
    report_to="tensorboard",
    # --- 蒸馏特定参数 (保持不变或微调) ---
    alpha=0.5,
    temperature=4.0,
    # --- 关键参数 (保持不变) ---
    remove_unused_columns=False
)


In [27]:
# 4. 创建预分词蒸馏训练器实例 (传入 SST-2 数据集)
distill_trainer = PreTokenizedDistillationTrainer(
    model=student_model,
    teacher_model=teacher_model_for_distill,
    args=distillation_args,
    train_dataset=train_dataset,    # SST-2 训练集
    eval_dataset=val_dataset,       # SST-2 验证集
    data_collator=data_collator,    # 自定义 collator
    compute_metrics=compute_metrics,
    tokenizer=student_tokenizer
)


  super().__init__(*args, **kwargs)


model.safetensors:   0%|          | 0.00/287M [00:00<?, ?B/s]

In [28]:
# 5. 开始蒸馏训练
# *** 修改：更新日志信息 ***
logger.info("开始在 SST-2 数据集上进行预分词知识蒸馏...")
try:
    train_result = distill_trainer.train()
    logger.info("SST-2 知识蒸馏训练完成。")
    metrics = train_result.metrics
    distill_trainer.log_metrics("train", metrics)
    distill_trainer.save_metrics("train", metrics)
    distill_trainer.save_state()
except Exception as e:
    logger.error(f"SST-2 蒸馏训练失败: {e}", exc_info=True)
    raise
finally:
    if hasattr(distill_trainer, 'teacher') and hasattr(distill_trainer.teacher, 'to'):
        try:
            distill_trainer.teacher.to('cpu')
            logger.info("教师模型已移至 CPU。")
        except Exception as e_cpu:
            logger.warning(f"将教师模型移至 CPU 时出错: {e_cpu}")



2025-04-26 07:28:55,342 [INFO] 开始在 SST-2 数据集上进行预分词知识蒸馏...


You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy
1,0.5023,0.64833,0.91055
2,0.2428,0.762463,0.908257
3,0.1421,0.637158,0.912844


2025-04-26 07:42:48,662 [INFO] SST-2 知识蒸馏训练完成。
***** train metrics *****
  epoch                    =        3.0
  total_flos               =        0GF
  train_loss               =     0.4446
  train_runtime            = 0:13:52.88
  train_samples_per_second =    242.586
  train_steps_per_second   =      7.582
2025-04-26 07:42:49,351 [INFO] 教师模型已移至 CPU。


In [29]:
# 6. 保存最终训练好的 (最佳) 学生模型和 *其对应的分词器*
# *** 修改：使用 SST-2 的路径 ***
logger.info(f"保存最终蒸馏学生模型到: {final_student_model_path}")
distill_trainer.save_model(final_student_model_path)
logger.info(f"同时保存学生分词器到: {final_student_model_path}")
if student_tokenizer:
    student_tokenizer.save_pretrained(final_student_model_path)


2025-04-26 07:42:49,378 [INFO] 保存最终蒸馏学生模型到: tinybert-student-distilled-sst2-pretokenized
2025-04-26 07:42:50,040 [INFO] 同时保存学生分词器到: tinybert-student-distilled-sst2-pretokenized


In [30]:
# 7. 清理 GPU 显存 (无需修改)
logger.info("清理模型和训练器...")
if 'teacher_model_for_distill' in locals() or 'teacher_model_for_distill' in globals(): del teacher_model_for_distill
if 'student_model' in locals() or 'student_model' in globals(): del student_model
if 'distill_trainer' in locals() or 'distill_trainer' in globals(): del distill_trainer
if torch.cuda.is_available(): torch.cuda.empty_cache(); logger.info("GPU 显存已清理。")


2025-04-26 07:42:50,083 [INFO] 清理模型和训练器...
2025-04-26 07:42:50,186 [INFO] GPU 显存已清理。


In [31]:
# ==============================================================================
# ## 第五步：评估蒸馏后的学生模型 (TinyBERT) on SST-2 Validation Set
# ==============================================================================
# **重要**: 最终评估使用 *原始* SST-2 验证集文本，并 *只用学生分词器* 处理。


In [32]:
# 1. 加载最终蒸馏好的学生模型
# *** 修改：使用 SST-2 的路径 ***
logger.info(f"加载最终蒸馏学生模型进行评估: {final_student_model_path}")
try:
    if not os.path.exists(final_student_model_path):
        raise FileNotFoundError(f"最终学生模型路径不存在: {final_student_model_path}")
    final_student_model = AutoModelForSequenceClassification.from_pretrained(
        final_student_model_path, num_labels=2
    ).to(device)
    final_student_model.eval()
except Exception as e:
    logger.error(f"加载最终蒸馏学生模型 '{final_student_model_path}' 失败: {e}")
    raise


2025-04-26 07:42:50,236 [INFO] 加载最终蒸馏学生模型进行评估: tinybert-student-distilled-sst2-pretokenized


In [33]:
# 2. 加载用于评估的学生分词器
# *** 修改：使用 SST-2 的路径 ***
logger.info(f"加载学生分词器进行评估: {final_student_model_path}")
try:
    if not os.path.exists(final_student_model_path):
        raise FileNotFoundError(f"学生分词器路径不存在: {final_student_model_path}")
    tokenizer_for_eval: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(final_student_model_path)
except Exception as e:
    logger.error(f"加载学生分词器 '{final_student_model_path}' 失败: {e}")
    raise


2025-04-26 07:42:50,398 [INFO] 加载学生分词器进行评估: tinybert-student-distilled-sst2-pretokenized


In [34]:
# 3. 准备用于最终评估的测试集 (使用原始 SST-2 验证集)
# *** 修改：加载 SST-2 验证集 ***
logger.info("加载原始 SST-2 验证集用于最终评估...")
try:
    # 直接加载验证集
    original_eval_dataset = load_dataset("glue", "sst2", split="validation")
except Exception as e:
    logger.error(f"加载原始 SST-2 验证集失败: {e}")
    raise

logger.info(f"最终评估验证集大小: {len(original_eval_dataset)}")

# *** 修改：适配 SST-2 的 "sentence" 列 ***
logger.info("使用学生分词器对原始验证集进行分词以供评估...")
def tokenize_for_student_eval_sst2(examples: Dict[str, List]) -> Any:
    """仅使用学生分词器处理 SST-2 文本 ("sentence")，用于最终评估。"""
    return tokenizer_for_eval(
        examples["sentence"], # 使用 "sentence" 列
        truncation=True,
        max_length=MAX_LENGTH,
        padding=False
    )

# 对原始验证集应用分词函数
# 获取原始列名以便移除
original_eval_columns = original_eval_dataset.column_names
tokenized_eval_for_final = original_eval_dataset.map(
    tokenize_for_student_eval_sst2, # 使用适配 SST-2 的函数
    batched=True,
    remove_columns=[col for col in original_eval_columns if col != 'label'], # 移除除 'label' 外所有原始列
    num_proc=os.cpu_count() // 2 if os.cpu_count() else 1
)
logger.info("最终评估验证集已使用学生分词器处理完毕。")
print("\n评估用验证集样本示例 (已用学生分词器处理):")
if len(tokenized_eval_for_final) > 0:
    print(tokenized_eval_for_final[0])
else:
    logger.warning("评估用的验证集为空。")



2025-04-26 07:42:50,451 [INFO] 加载原始 SST-2 验证集用于最终评估...
2025-04-26 07:42:55,055 [INFO] 最终评估验证集大小: 872
2025-04-26 07:42:55,056 [INFO] 使用学生分词器对原始验证集进行分词以供评估...


Map (num_proc=2):   0%|          | 0/872 [00:00<?, ? examples/s]

2025-04-26 07:42:55,395 [INFO] 最终评估验证集已使用学生分词器处理完毕。

评估用验证集样本示例 (已用学生分词器处理):
{'label': 1, 'input_ids': [101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [35]:
# 4. 创建用于评估的数据整理器 (无需修改)
logger.info("创建用于评估的数据整理器...")
data_collator_for_eval = DataCollatorWithPadding(tokenizer=tokenizer_for_eval)


2025-04-26 07:42:55,422 [INFO] 创建用于评估的数据整理器...


In [36]:
# 5. 创建一个新的标准 Trainer 实例用于评估
# *** 修改：更新输出目录名 ***
logger.info("创建评估用 Trainer...")
eval_output_dir: str = './eval_output_tinybert_sst2_pretokenized'
eval_args = TrainingArguments(
    output_dir=eval_output_dir,
    per_device_eval_batch_size=64,
    do_train=False,
    do_eval=True,
    report_to="none",
    remove_unused_columns=False
)

eval_trainer = Trainer(
    model=final_student_model,
    args=eval_args,
    eval_dataset=tokenized_eval_for_final, # 使用处理好的 SST-2 验证集
    data_collator=data_collator_for_eval,
    tokenizer=tokenizer_for_eval,
    compute_metrics=compute_metrics,
)


2025-04-26 07:42:55,446 [INFO] 创建评估用 Trainer...


  eval_trainer = Trainer(


In [37]:
# 6. 在验证集上执行评估
# *** 修改：更新日志信息 ***
logger.info("在最终 SST-2 验证集上评估蒸馏后的学生模型 (TinyBERT)...")
try:
    if len(tokenized_eval_for_final) > 0:
        evaluation_results = eval_trainer.evaluate(eval_dataset=tokenized_eval_for_final)
    else:
        logger.warning("评估验证集为空，跳过评估。")
        evaluation_results = {}
except Exception as e:
    logger.error(f"评估失败: {e}", exc_info=True)
    raise


2025-04-26 07:42:55,505 [INFO] 在最终 SST-2 验证集上评估蒸馏后的学生模型 (TinyBERT)...


In [38]:
# 7. 打印并保存评估结果
if evaluation_results:
    # *** 修改：更新日志信息 ***
    logger.info("最终学生模型 (预分词蒸馏) 在 SST-2 验证集上的评估结果:")
    for key, value in evaluation_results.items():
        if key.startswith("eval_"):
            logger.info(f"  {key}: {value:.4f}")

    eval_trainer.log_metrics("eval", evaluation_results)
    eval_trainer.save_metrics("eval", evaluation_results)
else:
    logger.info("没有评估结果可打印或保存。")

print("\n脚本执行完毕。")

2025-04-26 07:42:56,291 [INFO] 最终学生模型 (预分词蒸馏) 在 SST-2 验证集上的评估结果:
2025-04-26 07:42:56,292 [INFO]   eval_loss: 0.4698
2025-04-26 07:42:56,292 [INFO]   eval_model_preparation_time: 0.0018
2025-04-26 07:42:56,293 [INFO]   eval_accuracy: 0.9128
2025-04-26 07:42:56,294 [INFO]   eval_runtime: 0.7565
2025-04-26 07:42:56,294 [INFO]   eval_samples_per_second: 1152.6910
2025-04-26 07:42:56,295 [INFO]   eval_steps_per_second: 18.5070
***** eval metrics *****
  eval_accuracy               =     0.9128
  eval_loss                   =     0.4698
  eval_model_preparation_time =     0.0018
  eval_runtime                = 0:00:00.75
  eval_samples_per_second     =   1152.691
  eval_steps_per_second       =     18.507

脚本执行完毕。
