#### review
![avatar](./imgs/img_1.png)

#### trainer arguments & trainer

In [2]:
from transformers import Trainer, TrainingArguments
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import wandb

In [3]:
class DistillTrainingArguments(TrainingArguments):
    # TrainingArguments: @dataclass
    # 增加两个 KD 所需的参数参数
    def __init__(self, *args, alpha=0.5, temperature=2., **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature

In [5]:
class DistillTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model

    def compute_loss(self, model, inputs, return_outputs=False):
        s_output = model(**inputs)
        s_ce = s_output.loss
        s_logits = s_output.logits

        with torch.no_grad():
            t_output = self.teacher_model(**inputs)
            t_logits = t_output.logits

        loss_kl_fct = nn.KLDivLoss(reduction="batchmean")
        loss_kd = self.args.temperature ** 2 * loss_kl_fct(
            F.log_softmax(s_logits / self.args.temperature, dim=-1),
            F.softmax(t_logits / self.args.temperature, dim=-1),
        )

        loss = self.args.alpha * s_ce + (1 - self.args.alpha) * loss_kd
        return (loss, s_output) if return_outputs else loss

#### pipeline
##### datasets

In [9]:
import os
os.environ['HTTP_PROXY'] = 'http://192.168.0.103:2022'
os.environ['HTTPS_PROXY'] = 'http://192.168.0.103:2022'

In [10]:
from datasets import load_dataset
clinc = load_dataset("clinc_oos", "plus")

Using the latest cached version of the module from /home/zhuyuedlut/.cache/huggingface/modules/datasets_modules/datasets/clinc_oos/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1 (last modified on Wed Jul 12 21:58:36 2023) since it couldn't be found locally at clinc_oos., or remotely on the Hugging Face Hub.


Downloading and preparing dataset clinc_oos/plus to /home/zhuyuedlut/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1...


HF google storage unreachable. Downloading and preparing it from source


In [None]:
clinc

In [None]:
clinc['train'][:10]

In [None]:
intents = clinc['train'].features['intent']
num_labels = intents.num_classes
num_labels

#### Student model初始化