# 蒸馏学生模型

In [1]:
from transformers import (
    GPT2TokenizerFast,
    LlamaForCausalLM,
    LlamaConfig,
    GPT2LMHeadModel,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import  Subset
from random import sample
from pathlib import Path
from babylm_dataset import BabylmDataset

# 定义超参数
#############
LR = 2.5e-4
BATCH_SIZE = 32
SEQ_LENGTH = 128

TEMPERATURE = 2.0
ALPHA = 0.5
#############

teacher_dir1 = './models/llama-teacher'
teacher_dir2 = './models/gpt2-teacher'


MODEL_NAME = f'Baby-Llama-58M'
MODEL_OUTPUT = Path('./models') /  MODEL_NAME
EVAL_SAMPLES = 500

tokenizer_path = "./models/gpt-clean-16000.json"
tokenizer = GPT2TokenizerFast(tokenizer_file= str(tokenizer_path))
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.pad_token = "<pad>"

# in the original code I had random_chunk = False
# random_chunk=True is expected to improve the model performance a bit
data_train_path = "./data/train_10M_clean"
data_eval_path = "./data/dev_clean"
train_dataset = BabylmDataset(data_train_path, SEQ_LENGTH, tokenizer=tokenizer, random_chunk=True)
full_eval_dataset = BabylmDataset(data_eval_path, SEQ_LENGTH, tokenizer=tokenizer, offset=0)

eval_indices = sample(range(len(full_eval_dataset)), EVAL_SAMPLES)
eval_dataset = Subset(full_eval_dataset, eval_indices)



Loading data from data/train_10M_clean/tokenized_GPT2TokenizerFast_16000.pt
🔥 数据集总大小: 16912909
🔥 为了缩短训练时间，这里缩减为: 375842
Loading data from data/dev_clean/tokenized_GPT2TokenizerFast_16000.pt
🔥 数据集总大小: 17428872
🔥 为了缩短训练时间，这里缩减为: 87144


  self.data = torch.load(tokenized_file)


In [2]:
tokenizer.model_max_length = SEQ_LENGTH

config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,
    num_hidden_layers=16,
    intermediate_size=1024,
    num_attention_heads=8,
    bos_token_id=tokenizer.convert_tokens_to_ids("<s>"),
    eos_token_id=tokenizer.convert_tokens_to_ids("</s>"),
    pad_token_id=tokenizer.convert_tokens_to_ids("<pad>"),
    max_position_embeddings=2*SEQ_LENGTH,
)

student = LlamaForCausalLM(config)
# student = LlamaForCausalLM.from_pretrained(student_dir)


teacher1 = LlamaForCausalLM.from_pretrained(teacher_dir1)
teacher2 = GPT2LMHeadModel.from_pretrained(teacher_dir2)
teachers = [teacher1, teacher2]


data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False,
)


print(f'model num parameters: student = {student.num_parameters()}')
print(f'model num parameters: teacher1 = {teacher1.num_parameters()}')
print(f'model num parameters: teacher2 = {teacher2.num_parameters()}')



#  Distillation Trainer
#  We modified the Trainer from this repo https://github.com/philschmid/knowledge-distillation-transformers-pytorch-sagemaker
# to work with an ensemble of teachers


class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature


class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_models=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teachers = teacher_models
        for teacher in self.teachers:
            # place each teacher on same device as student
            self._move_model_to_device(teacher, self.model.device)
            teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False):
        # compute student output
        outputs_student = model(**inputs)
        student_loss = outputs_student.loss

        # compute teacher output
        with torch.no_grad():
            all_teacher_logits = []
            for teacher in self.teachers:
                outputs_teacher = teacher(**inputs)
                all_teacher_logits.append(outputs_teacher.logits)
            avg_teacher_logits = torch.stack(all_teacher_logits).mean(dim=0)

        # assert size
        assert outputs_student.logits.size() == avg_teacher_logits.size()

        # Soften probabilities and compute distillation loss
        loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = (
            loss_function(
                F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
                F.softmax(avg_teacher_logits / self.args.temperature, dim=-1),
            )
            * (self.args.temperature ** 2)
        )
        # Return weighted student loss
        loss = self.args.alpha * student_loss + (1.0 - self.args.alpha) * loss_logits
        return (loss, outputs_student) if return_outputs else loss




model num parameters: student = 58343936
model num parameters: teacher1 = 359973888
model num parameters: teacher2 = 704928768


In [4]:
training_args = DistillationTrainingArguments(
    output_dir=MODEL_OUTPUT,
    overwrite_output_dir=True,
    save_strategy = "epoch",
    evaluation_strategy = "epoch",
    num_train_epochs=6,
    gradient_accumulation_steps=1,
    per_device_train_batch_size=BATCH_SIZE,
    save_total_limit=1,  # Set to zero to avoid saving
    warmup_steps=200, 
    lr_scheduler_type="cosine",
    learning_rate=LR,
    logging_steps=20,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    weight_decay=0.1,
    alpha=ALPHA,
    temperature=TEMPERATURE,
    no_cuda=True,
)


trainer = DistillationTrainer(
        student,
        training_args,
        teacher_models=teachers,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,

    )


trainer.train()


trainer.save_model(MODEL_OUTPUT)
tokenizer.save_pretrained(MODEL_OUTPUT)



  0%|          | 0/552 [00:00<?, ?it/s]

{'loss': 394.2176, 'grad_norm': 197.1733856201172, 'learning_rate': 2.5e-05, 'epoch': 0.22}
{'loss': 320.1827, 'grad_norm': 207.2023468017578, 'learning_rate': 5e-05, 'epoch': 0.43}
{'loss': 261.6258, 'grad_norm': 215.23489379882812, 'learning_rate': 7.5e-05, 'epoch': 0.65}
{'loss': 162.9615, 'grad_norm': 175.06642150878906, 'learning_rate': 0.0001, 'epoch': 0.87}


  0%|          | 0/63 [00:00<?, ?it/s]

{'eval_loss': 58.3864860534668, 'eval_runtime': 352.6741, 'eval_samples_per_second': 1.418, 'eval_steps_per_second': 0.179, 'epoch': 1.0}
{'loss': 109.23, 'grad_norm': 167.6467742919922, 'learning_rate': 0.000125, 'epoch': 1.09}
{'loss': 72.0003, 'grad_norm': 139.8084259033203, 'learning_rate': 0.00015, 'epoch': 1.3}
{'loss': 40.9822, 'grad_norm': 101.28172302246094, 'learning_rate': 0.000175, 'epoch': 1.52}
{'loss': 19.4506, 'grad_norm': 24.261682510375977, 'learning_rate': 0.0002, 'epoch': 1.74}
{'loss': 11.7896, 'grad_norm': 33.13188171386719, 'learning_rate': 0.00022500000000000002, 'epoch': 1.96}


  0%|          | 0/63 [00:00<?, ?it/s]

{'eval_loss': 16.758838653564453, 'eval_runtime': 352.6818, 'eval_samples_per_second': 1.418, 'eval_steps_per_second': 0.179, 'epoch': 2.0}
{'loss': 8.6971, 'grad_norm': 17.671709060668945, 'learning_rate': 0.00025, 'epoch': 2.17}
{'loss': 6.6607, 'grad_norm': 11.907523155212402, 'learning_rate': 0.0002480139005420145, 'epoch': 2.39}
{'loss': 5.2117, 'grad_norm': 5.838933944702148, 'learning_rate': 0.00024211871562497024, 'epoch': 2.61}
{'loss': 4.5462, 'grad_norm': 6.678253650665283, 'learning_rate': 0.00023250178002596255, 'epoch': 2.83}


  0%|          | 0/63 [00:00<?, ?it/s]

{'eval_loss': 10.838906288146973, 'eval_runtime': 352.0537, 'eval_samples_per_second': 1.42, 'eval_steps_per_second': 0.179, 'epoch': 3.0}
{'loss': 4.1238, 'grad_norm': 9.188096046447754, 'learning_rate': 0.0002194686967942823, 'epoch': 3.04}
{'loss': 3.6455, 'grad_norm': 4.341468334197998, 'learning_rate': 0.0002034336259226065, 'epoch': 3.26}
{'loss': 3.399, 'grad_norm': 5.3074846267700195, 'learning_rate': 0.0001849061233400071, 'epoch': 3.48}
{'loss': 3.2133, 'grad_norm': 4.197958469390869, 'learning_rate': 0.00016447494845187814, 'epoch': 3.7}
{'loss': 3.0286, 'grad_norm': 4.915552616119385, 'learning_rate': 0.00014278935478416067, 'epoch': 3.91}


  0%|          | 0/63 [00:00<?, ?it/s]

{'eval_loss': 10.316187858581543, 'eval_runtime': 352.3373, 'eval_samples_per_second': 1.419, 'eval_steps_per_second': 0.179, 'epoch': 4.0}
{'loss': 2.8323, 'grad_norm': 4.201409816741943, 'learning_rate': 0.00012053845827012746, 'epoch': 4.13}
{'loss': 2.7643, 'grad_norm': 4.0402913093566895, 'learning_rate': 9.842933880587791e-05, 'epoch': 4.35}
{'loss': 2.5842, 'grad_norm': 3.586437702178955, 'learning_rate': 7.716457095436378e-05, 'epoch': 4.57}
{'loss': 2.5513, 'grad_norm': 3.4587912559509277, 'learning_rate': 5.741989781805035e-05, 'epoch': 4.78}
{'loss': 2.5102, 'grad_norm': 3.301051378250122, 'learning_rate': 3.98227575507636e-05, 'epoch': 5.0}


  0%|          | 0/63 [00:00<?, ?it/s]

{'eval_loss': 9.898730278015137, 'eval_runtime': 352.8048, 'eval_samples_per_second': 1.417, 'eval_steps_per_second': 0.179, 'epoch': 5.0}
{'loss': 2.4339, 'grad_norm': 2.2749125957489014, 'learning_rate': 2.4932344884454963e-05, 'epoch': 5.22}
{'loss': 2.3826, 'grad_norm': 2.400008201599121, 'learning_rate': 1.3221841267536088e-05, 'epoch': 5.43}
{'loss': 2.3744, 'grad_norm': 2.3529045581817627, 'learning_rate': 5.063378298187843e-06, 'epoch': 5.65}
{'loss': 2.3903, 'grad_norm': 1.7968189716339111, 'learning_rate': 7.162122785128316e-07, 'epoch': 5.87}


  0%|          | 0/63 [00:00<?, ?it/s]

{'eval_loss': 9.67468547821045, 'eval_runtime': 352.9349, 'eval_samples_per_second': 1.417, 'eval_steps_per_second': 0.179, 'epoch': 6.0}
{'train_runtime': 16646.7096, 'train_samples_per_second': 1.058, 'train_steps_per_second': 0.033, 'train_loss': 52.86958230751148, 'epoch': 6.0}


('models/Baby-Llama-58M/tokenizer_config.json',
 'models/Baby-Llama-58M/special_tokens_map.json',
 'models/Baby-Llama-58M/vocab.json',
 'models/Baby-Llama-58M/merges.txt',
 'models/Baby-Llama-58M/added_tokens.json',
 'models/Baby-Llama-58M/tokenizer.json')