In [1]:
'''
textbrewer implemented knowledge distillation
textbrewer package: 一个用于nlp模型的知识蒸馏工具包，旨在简化和加速模型蒸馏过程。
- 知识蒸馏是一种技术，通过将pre-trained模型(teacher model)的知识转移到较小的模型(student model)，从而在保留性能的同时减少计算资源消耗。
import textbrewer
from textbrewer import Generalstiller  # 用于执行通用的知识蒸馏过程。它支持多种蒸馏配置和训练配置，适用于大多数蒸馏场景。
# TrainingConfig用于配置训练过程中的一些参数，比如lr、batch_size、epochs等; DistillationConfig用于配置蒸馏过程参数，比如Temperature、alpha等。
from textbrewer import TrainingConfig, DistillationConfig

import torch
from transformers import BertForSequenceClassification, AdamW

# 定义teacher model和student model
teacher_model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
student_model = BertForSequenceClassification.from_pretrained('distillbert-base-uncased')

# 定义优化器
optimizer = AdamW(student_model.parameters(), lr=5e-5)

# 定义数据加载器
train_dataloader = ...
test_dataloader = ...

# 定义训练配置
training_config = TrainingConfig(
    gradient_accumulation_steps=1,
    ckpt_frequency=1,
)

# 定义蒸馏配置
distillation_config = DistillationConfig(
    temperature=4.0,
    intermediate_matches=[]  # 可以定义中间层匹配
)

# 创建蒸馏器
distiller = GeneralDistiller(
    train_config=training_config,
    distill_config=distillation_config,
    model_T=teacher_model,
    model_S=student_model,
)

# 开始蒸馏训练
with distiller:
    distiller.train(
        optimizer=optimizer,
        dataloader=train_dataloader,
        num_epochs=3,
        callback=None,
    )
'''


"\ntextbrewer implemented knowledge distillation\ntextbrewer package: 一个用于nlp模型的知识蒸馏工具包，旨在简化和加速模型蒸馏过程。\n- 知识蒸馏是一种技术，通过将pre-trained模型(teacher model)的知识转移到较小的模型(student model)，从而在保留性能的同时减少计算资源消耗。\nimport textbrewer\nfrom textbrewer import Generalstiller  # 用于执行通用的知识蒸馏过程。它支持多种蒸馏配置和训练配置，适用于大多数蒸馏场景。\n# TrainingConfig用于配置训练过程中的一些参数，比如lr、batch_size、epochs等; DistillationConfig用于配置蒸馏过程参数，比如Temperature、alpha等。\nfrom textbrewer import TrainingConfig, DistillationConfig\n\nimport torch\nfrom transformers import BertForSequenceClassification, AdamW\n\n# 定义teacher model和student model\nteacher_model = BertForSequenceClassification.from_pretrained('bert-base-uncased')\nstudent_model = BertForSequenceClassification.from_pretrained('distillbert-base-uncased')\n\n# 定义优化器\noptimizer = AdamW(student_model.parameters(), lr=5e-5)\n\n# 定义数据加载器\ntrain_dataloader = ...\ntest_dataloader = ...\n\n# 定义训练配置\ntraining_config = TrainingConfig(\n    gradient_accumulation_steps=1,\n    ckpt_frequency=1,\n)\n\n# 定义蒸馏配置\

## KD_msra_example

In [3]:
# 序列标注任务-命名实体识别NER: https://github.com/airaria/TextBrewer/blob/master/examples/notebook_examples/msra_ner.ipynb

In [4]:
import os
import numpy as np

import torch

from transformers import BertForSequenceClassification, BertTokenizer,BertConfig,BertForTokenClassification
from transformers import Trainer, TrainingArguments
from transformers import pipeline, AutoTokenizer

from datasets import load_dataset,load_metric
from sklearn.metrics import accuracy_score, precision_recall_fscore_support




In [5]:
device='cuda' if torch.cuda.is_available() else 'cpu'

### prepare dataset to train

In [6]:
task = "ner" #  "ner", "pos" or "chunk"
model_checkpoint = "bert-base-chinese"
batch_size = 8

datasets = load_dataset("msra_ner")
# 从datasets中获取train datset -> 获取features特征，这是一个dict，f"{task}_tags"是f-straing语法，将task变量值嵌入到string中，
# 以动态生成feature名称。-> feature获取特征的具体描述 -> names 获取标签名称列表list。
label_list = datasets["train"].features[f"{task}_tags"].feature.names

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [7]:
# map方法对train_dataset进行预处理操作，尤其是使用tokenizer对数据集中的文本进行分词、截断和填充操作。
# tokenizer分词对象，来自transformers.BertTokenizer或RobertaTokenizer; truncation=True对长度超过max_length的文本进行截断；padding填充到最大长度。
# batched=True，表示map方法每次处理一个batch的数据，而不是一次处理一个数据。

In [9]:
datasets.shape

{'train': (45001, 3), 'test': (3443, 3)}

In [10]:
label_list

['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']

In [11]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

### utils

In [12]:
'''
命名实体识别NER任务中:
problem: tokenizer subword, 即一个word被过度细分成几个subwords，e.g. 单词sheepmeat，被分成3个subtokens: 'sheep','##me','##at'.
        由于label通常是在word级别进行标注的，既然word还会被切分成subtokens，那么意味着我们还需要对label进行subtokens的对齐。
        由于pre-trained model输入格式的要求，往往还需要加入一些特殊符号: [CLS], [SEP].
        len(example[f"{task}_tags"]), len(tokenized_input["input_ids"]) -> (31, 39)
solution: tokenizer.word_ids()方法可以帮助我们解决这个对齐问题。
        word_ids() return: [None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 10, 11, 11, 11, 12, 13, 14, 15, 16, 17, 18, 18, 18, None]
        可以看到，word_ids将每一个subtokens位置都对应了一个word的下标。比如第一个位置对应第0个word，第2、3个位置对应第1个word，特殊字符对应none。
        
我们通常将特殊字符的label设置为-100，在模型中-100通常会被忽略掉不计算loss。
两种对齐label的方式:
- 多个subtokens对齐一个token，对齐一个label。
- 多个subtokens的第一个subtoken对齐word，对齐一个label，其他subtokens直接赋予-100.
提供了这两种方式，通过label_all_tokens = True 切换
'''

'\n命名实体识别NER任务中:\nproblem: tokenizer subword, 即一个word被过度细分成几个subwords，e.g. 单词sheepmeat，被分成3个subtokens: \'sheep\',\'##me\',\'##at\'.\n        由于label通常是在word级别进行标注的，既然word还会被切分成subtokens，那么意味着我们还需要对label进行subtokens的对齐。\n        由于pre-trained model输入格式的要求，往往还需要加入一些特殊符号: [CLS], [SEP].\n        len(example[f"{task}_tags"]), len(tokenized_input["input_ids"]) -> (31, 39)\nsolution: tokenizer.word_ids()方法可以帮助我们解决这个对齐问题。\n        word_ids() return: [None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 10, 11, 11, 11, 12, 13, 14, 15, 16, 17, 18, 18, 18, None]\n        可以看到，word_ids将每一个subtokens位置都对应了一个word的下标。比如第一个位置对应第0个word，第2、3个位置对应第1个word，特殊字符对应none。\n        \n我们通常将特殊字符的label设置为-100，在模型中-100通常会被忽略掉不计算loss。\n两种对齐label的方式:\n- 多个subtokens对齐一个token，对齐一个label。\n- 多个subtokens的第一个subtoken对齐word，对齐一个label，其他subtokens直接赋予-100.\n提供了这两种方式，通过label_all_tokens = True 切换\n'

In [13]:
# 对input text进行分词和labels对齐
def tokenize_and_align_labels(examples):  # examples对应原始tokens indices
    # examples: 输入样本；truncation=True对长度超过max_length的文本进行截断; is_split_into_words表示输入是预先分词的。
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)  # 返回分词后的结果
    print('label before: {tokenized_inputs}')
    labels = []
    # 遍历样本
    for i, label in enumerate(examples[f"{task}_tags"]):  # i -> index; tags表示label
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # word_idx获取分词后每个index对应的原始单词的索引indices
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # 每个单词的index
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:  # index=None，表示特殊标记，[CLS], [SEP], label set -100.
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)  # 两种对齐方式
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    print('label after: {tokenized_inputs}')
    return tokenized_inputs

In [14]:
# label_list: ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']
def compute_metrics(p):  # p是一个包含pred和label的对象，由transformers.Trainer输出
    print(p.__dict__)
    predictions = p.predictions
    labels = p.label_ids
    predictions = np.argmax(predictions, axis=2)  # 选择分类概率最大的下标

    # Remove ignored index (special tokens): -100
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)  # batch, batch的
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [15]:
def compute_eval_metrics(p):
    print(p.__dict__)
    predictions = p.predictions[0]  # 这里predictions是一个元组或list，取第一个元素。
    labels = p.label_ids
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

### teacher model training

In [16]:
'''
- BertForTokenClassfication是transformers中一个用于序列标注任务的模型，它是基于BERT模型，并在其上加了一个linear层，用于每个输入token的分类。
- TrainingArguments是transformers中用于配置training过程的参数类，e.g. lr, batch_size, epochs.
- Trainer是transformers中用于训练和评估模型的高级api，支持分布式训练。它封装了training和eval过程，简化了模型训练流程。
e.g.
# 初始化模型
model = BertForTokenClassification.from_pretrained("bert-base-uncased", num_labels=9)

# 定义训练参数
training_args = TrainingArguments(
    output_dir='./results',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    logging_dir='./logs',
    logging_steps=10,
)

# 定义训练数据集和评估数据集（假设已经定义好了 train_dataset 和 eval_dataset）
train_dataset = ...
eval_dataset = ...

# 初始化 Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,  # 假设已经定义好了 compute_metrics 函数
)

# 开始训练
trainer.train()
'''

from transformers import BertForTokenClassification, TrainingArguments, Trainer

model = BertForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))
model.to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [17]:
from transformers import DataCollatorForTokenClassification  # 创建数据整理器data collator，用于序列标注任务
'''
Data collator数据整理器负责将batch数据整理成适合模型输入的格式。
在序列标注任务中，每个输入样本可能有不同的长度，因此需要进行填充padding以保持batch内所有样本长度一致。
tokenizer是一个分词器实例，用于处理文本数据。它的作用包括：
- 将文本转换为词汇表中的索引index
- 添加特殊标记[CLS], [SEP]
- 填充padding
'''
data_collator = DataCollatorForTokenClassification(tokenizer)

In [18]:
# seqeval是一个专门用于序列标注任务的评估库，主要用于评估NER模型的性能，包含precision, recall, f1-score, accuracy等指标。
metric = load_metric("seqeval")

  metric = load_metric("seqeval")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [19]:
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)

Map:   0%|          | 0/45001 [00:00<?, ? examples/s]

label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_

Map:   0%|          | 0/3443 [00:00<?, ? examples/s]

label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}
label before: {tokenized_inputs}
label after: {tokenized_inputs}


In [20]:
tokenized_datasets.shape

{'train': (45001, 7), 'test': (3443, 7)}

In [21]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 45001
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 3443
    })
})

In [22]:
# 缩减数据规模，降低training时间消耗
tokenized_datasets["train"] = tokenized_datasets["train"].select(range(500))
tokenized_datasets["test"] = tokenized_datasets["test"].select(range(300))

In [23]:
tokenized_datasets.shape

{'train': (500, 7), 'test': (300, 7)}

In [24]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 300
    })
})

In [25]:
args = TrainingArguments(  # TrainingArguments配置训练过程参数
    output_dir = f"test-{task}",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,  # training batch size
    per_device_eval_batch_size=batch_size,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy = "epoch",  # 评估策略为每个epoch训练周期结束后进行评估
#     do_train=True,      # 执行训练过程
#     do_eval=True,       # 执行评估过程
#     no_cuda=True,      # 是否使用cuda，False表示使用gpu
#     load_best_model_at_end=True,  # 表示在训练结束后加载在eval过程中性能最好的model
)



In [26]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [27]:
trainer.train()

Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,No log,0.279432,0.312724,0.435705,0.364111,0.918999
2,No log,0.214606,0.43021,0.561798,0.487277,0.942869


{'predictions': array([[[ 3.1587839e+00, -1.1053336e+00, -7.7257478e-01, ...,
         -1.4152177e-01, -6.3647628e-01, -5.7730252e-01],
        [-2.2807223e-01,  8.4778100e-02,  6.7771420e-02, ...,
          6.3411999e-01,  1.2012570e+00, -3.1364501e-01],
        [-8.3442283e-01, -2.0065895e-01, -2.2755417e-01, ...,
          6.9223404e-01, -8.3874688e-02,  6.9512576e-01],
        ...,
        [ 4.2492695e+00, -6.0536838e-01, -4.0749884e-01, ...,
          3.3651084e-01, -1.6159770e+00, -1.0073408e+00],
        [ 3.9921668e+00, -4.6648353e-01, -3.8017747e-01, ...,
          2.5083235e-01, -1.4356084e+00, -1.0719131e+00],
        [ 3.8555334e+00, -4.5851111e-01, -3.8321680e-01, ...,
          2.5756279e-01, -1.4410714e+00, -1.0734185e+00]],

       [[ 3.3740771e+00, -1.1027575e+00, -3.2235649e-01, ...,
         -2.6564938e-01, -9.1990644e-01, -4.1914588e-01],
        [ 6.0701046e+00, -1.1544365e+00, -9.6927583e-01, ...,
         -4.1781992e-01, -1.4902309e+00, -1.4970781e+00],
        [

TrainOutput(global_step=126, training_loss=0.14215729728577628, metrics={'train_runtime': 485.4054, 'train_samples_per_second': 2.06, 'train_steps_per_second': 0.26, 'total_flos': 39841390815072.0, 'train_loss': 0.14215729728577628, 'epoch': 2.0})

In [28]:
torch.save(model.state_dict(), './outputs/msra_teacher_model.pt')  # save the teacher model weights to distill

### knowledge distillation

In [29]:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
from transformers import BertForTokenClassification, BertConfig,BertTokenizer
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW

In [30]:
from torch.utils.data import DataLoader, RandomSampler
train_dataset=tokenized_datasets["train"].remove_columns(['id','tokens','ner_tags'])  # 移除列
# 创建DataLoader对象，将数据集加载成一个可迭代的对象，以高效批量处理数据。
train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=32,collate_fn=data_collator) #prepare dataloader

In [31]:
'''
initialize the student model by BertConfig and prepare the teacher model
- bert_config_L3.json refers to a 3-layer Bert.
- bert_config.json refers to a standard 12-layer Bert.
'''
# 读取config文件，并配置到student model
bert_config_T3 = BertConfig.from_json_file('./config/bert_config_L3.json')  # BertConfig是transformers中用于配置和初始化BERT模型的类。
# 设置model在ffd时输出hidden states，这在知识蒸馏中很重要，因为我们要比较studnet和teacher模型的隐藏层输出。
bert_config_T3.output_hidden_states = True
bert_config_T3.num_labels = len(label_list)  # 设置标签数量，用于分类任务中

student_model = BertForTokenClassification(bert_config_T3)  # 初始化student model, 一个bert序列标注模型

bert_config = BertConfig.from_json_file('./config/bert_config.json')  # 读取teacher model配置文件
bert_config.output_hidden_states = True
bert_config.num_labels = len(label_list)

teacher_model = BertForTokenClassification(bert_config)   # 初始化teacher model
teacher_model.load_state_dict(torch.load('./outputs/msra_teacher_model.pt'))  # 加载trained model权重

teacher_model.to(device)
student_model.to(device)

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-2): 3 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, 

In [32]:
# 处理batch数据
def proc_fn(batch):
  return {'input_ids':batch['input_ids'],
          'token_type_ids':batch['token_type_ids'],
          'attention_mask':batch['attention_mask'],  # 对应哪些token时padding的，哪些是真实数据。
          'labels':batch['labels']}

#### textbrewer

In [33]:
num_epochs = 20
num_training_steps = len(train_dataloader) * num_epochs

optimizer = AdamW(student_model.parameters(), lr=1e-5)

# 学习率调度器
scheduler_class = get_linear_schedule_with_warmup  # 使用线性学习率调度器
scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}  # 预热步数、训练步数

# adaptor函数，用于从模型中提取logits和hidden states
def simple_adaptor(batch, model_outputs):
  return {"logits":model_outputs.logits, 'hidden': model_outputs.hidden_states}

# 定义蒸馏配置
'''
intermediate_matches的作用是在知识蒸馏中定义teacher model和student model之间的特征匹配策略。通过这些配置，可以指定在哪些层级上进行
    特征匹配，使用什么样的loss来衡量匹配效果，并且可以对不同的匹配进行加权处理，以控制不同层级的重要性。
在文本蒸馏中，通常会选择一些中间层的hidden states作为特征进行匹配，帮助student model更好地学习到teacher model的知识。通过调整
    intermediate_matches的配置，可以优化知识的传递效果，从而提升student model的性能。
- layer_T，是teacher model的层索引index，表示要匹配的是教师模型的哪一层特征。
- layer_S，是student model的层索引index，表示教师模型的哪一层特征传递给学生模型的哪一层进行匹配。
- feature，指定要匹配的特征类型。
'''
distill_config = DistillationConfig(
    temperature = 4.0,
    intermediate_matches=[{"layer_T":0, "layer_S":0, "feature":"hidden", "loss":"hidden_mse", "weight":1},
               {"layer_T":4, "layer_S":1, "feature":"hidden", "loss":"hidden_mse", "weight":1},
               {"layer_T":8, "layer_S":2, "feature":"hidden", "loss":"hidden_mse", "weight":1},
               {"layer_T":12,"layer_S":3, "feature":"hidden", "loss":"hidden_mse", "weight":1}])

# 训练配置
train_config = TrainingConfig(device='cpu')

distiller = GeneralDistiller(  # GeneralDistiller初始化一个知识蒸馏器
    train_config=train_config, 
    distill_config=distill_config,
    model_T=teacher_model, 
    model_S=student_model, 
    adaptor_T=simple_adaptor, 
    adaptor_S=simple_adaptor)

In [34]:
train_config

TrainingConfig:
gradient_accumulation_steps : 1
ckpt_frequency : 1
ckpt_epoch_frequency : 1
ckpt_steps : None
log_dir : None
output_dir : ./saved_models
device : cpu
fp16 : False
fp16_opt_level : O1
data_parallel : False
local_rank : -1

In [35]:
with distiller:
    # distiller.train方法本身不返回任何值，它的作用是执行training过程，并更新学生模型的参数。
    # 保存student model，需要在训练结束后手动保存，torch.save(student_model.state_dict(), "student_model.pth")
    distiller.train(optimizer, train_dataloader, num_epochs, scheduler_class=scheduler_class, 
                    scheduler_args = scheduler_args, callback=None, batch_postprocessor=proc_fn)  
# batch_postprocessor 批处理后处理器，用于在数据从dataloader提取出来后，但在传递给model之前，对数据进行进一步处理。

In [36]:
torch.save(student_model.state_dict(), "./outputs/msra_student_model.pth")

### student model evaluating

In [37]:
bert_config_T3 = BertConfig.from_json_file('./config/bert_config_L3.json')

bert_config_T3.output_hidden_states = True
bert_config_T3.num_labels = len(label_list)
test_model = BertForTokenClassification(bert_config_T3)

In [39]:
test_model.load_state_dict(torch.load('./saved_models/gs320.pkl'))
test_model.to(device)

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-2): 3 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, 

In [40]:
args = TrainingArguments(
    f"distill-test",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    do_train=False,
    do_eval=True,
    no_cuda=False,
    num_train_epochs=2,
    weight_decay=0.01,
)

In [41]:
trainer = Trainer(
    test_model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_eval_metrics
)

In [42]:
'''
trainer.train()和trainer.evaluate()是transformers中进行模型训练和评估的方法
- trainer.train()方法只使用train_datasets对模型进行训练，更新模型参数，调整学习率。
- trainer.evaluate()方法只使用eval_dataset对模型进行评估，不更新参数，不调整学习率。
'''

'\ntrainer.train()和trainer.evaluate()是transformers中进行模型训练和评估的方法\n- trainer.train()方法只使用train_datasets对模型进行训练，更新模型参数，调整学习率。\n- trainer.evaluate()方法只使用eval_dataset对模型进行评估，不更新参数，不调整学习率。\n'

In [43]:
trainer.evaluate()

{'predictions': (array([[[   4.1751356 ,   -1.3892034 ,   -0.62786216, ...,
           -1.1625446 ,   -1.6603606 ,   -1.0444247 ],
        [   4.2052107 ,   -1.0325062 ,   -0.6012637 , ...,
           -1.3700162 ,   -1.4891158 ,   -0.9249886 ],
        [   4.1645384 ,   -1.3668545 ,   -0.6859249 , ...,
           -1.3610991 ,   -1.591393  ,   -1.0895182 ],
        ...,
        [   4.2748175 ,   -1.3126764 ,   -0.77339554, ...,
           -1.236407  ,   -1.6573793 ,   -1.158958  ],
        [   4.0912776 ,   -1.2714908 ,   -0.78826404, ...,
           -1.3551276 ,   -1.523216  ,   -1.1341921 ],
        [   4.135216  ,   -1.3037676 ,   -0.7212115 , ...,
           -1.2934223 ,   -1.6777797 ,   -1.1420298 ]],

       [[   4.180667  ,   -1.389655  ,   -0.62356436, ...,
           -1.173089  ,   -1.6734393 ,   -1.0477221 ],
        [   4.2352104 ,   -1.1125957 ,   -0.5626025 , ...,
           -1.3212243 ,   -1.5171155 ,   -0.88254654],
        [   4.1899295 ,   -1.3223742 ,   -0.6985833 , ..

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.8500558137893677,
 'eval_precision': 0.0,
 'eval_recall': 0.0,
 'eval_f1': 0.0,
 'eval_accuracy': 0.8492268041237113,
 'eval_runtime': 26.3152,
 'eval_samples_per_second': 11.4,
 'eval_steps_per_second': 1.444}