In [1]:
# 这份代码修改自仓库： https://github.com/timinar/BabyLlama

# 训练教师模型GPT-2 和 Llama

> 论文中写到：
>
> "The GPT-2 model has 24 layers, 16 attention heads, an embedding dimension of 1536, intermediate size of 6144, and maximum sequence length of 128, resulting in 705M parameters. It was trained for 6 epochs with a batch size of 256 and maximum learning rate3 of 2.5 · 10−4. The LLaMA model has 24 layers, 8 attention heads, a hidden size of 1024, intermediate size of 3072, and maximum sequence length of 256, resulting in 360M parameters. It was trained for 4 epochs with a batch size of 128 and maximum learning rate of 3 · 10−4."

In [2]:
# 准备数据
from transformers import DataCollatorForLanguageModeling
from transformers import GPT2TokenizerFast
from babylm_dataset import BabylmDataset
from random import sample, seed
from torch.utils.data import Subset

data_train_path = "./data/train_10M_clean"
data_eval_path = "./data/dev_clean"
tokenizer_path = "./models/gpt-clean-16000.json"

SEQ_LENGTH = 128
tokenizer = GPT2TokenizerFast(tokenizer_file= str(tokenizer_path))
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.pad_token = "<pad>"
tokenizer.model_max_length = SEQ_LENGTH

# 进入BsbylmDataset类，可以在初始化函数中修改数据集大小
train_dataset = BabylmDataset(data_train_path, SEQ_LENGTH, tokenizer=tokenizer, random_chunk=True)
full_eval_dataset = BabylmDataset(data_eval_path, SEQ_LENGTH, tokenizer=tokenizer, offset=0)

seed(2024) # we fix the same subset for all models
eval_indices = sample(range(len(full_eval_dataset)), 200)
eval_dataset = Subset(full_eval_dataset, eval_indices)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False,
)

Loading data from data/train_10M_clean/tokenized_GPT2TokenizerFast_16000.pt
🔥 数据集总大小: 16912909
🔥 为了缩短训练时间，这里缩减为: 375842
Loading data from data/dev_clean/tokenized_GPT2TokenizerFast_16000.pt
🔥 数据集总大小: 17428872
🔥 为了缩短训练时间，这里缩减为: 87144


  self.data = torch.load(tokenized_file)


In [3]:
# 训练GPT2模型
from transformers import (
    GPT2Config, GPT2LMHeadModel, 
)
from transformers import Trainer, TrainingArguments
model_config = GPT2Config(
        vocab_size=tokenizer.vocab_size,
        n_positions=2*tokenizer.model_max_length,
        n_embd=1536,
        n_layer=24,
        n_head=16,
        pad_token_id=tokenizer.convert_tokens_to_ids("<pad>"),
    )
model = GPT2LMHeadModel(model_config)

output_dir = "./models/gpt2-teacher"

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    save_strategy = "epoch",
    evaluation_strategy = "epoch",
    num_train_epochs=6,
    gradient_accumulation_steps=2,
    per_device_train_batch_size=12,
    save_total_limit=1,  # Set to zero to avoid saving
    warmup_steps=300, 
    lr_scheduler_type="cosine",
    learning_rate=float(2.5e-4),
    logging_steps=20,
    fp16=False,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    torch_compile = False,
    no_cuda = True,   # we use CPU，显卡足够大的话可以改为False
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)



  0%|          | 0/732 [00:00<?, ?it/s]

{'loss': 6.2285, 'grad_norm': 19.691848754882812, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.16}
{'loss': 4.1185, 'grad_norm': 5.937752723693848, 'learning_rate': 3.3333333333333335e-05, 'epoch': 0.33}
{'loss': 3.535, 'grad_norm': 6.225916862487793, 'learning_rate': 5e-05, 'epoch': 0.49}
{'loss': 3.3047, 'grad_norm': 5.6779255867004395, 'learning_rate': 6.666666666666667e-05, 'epoch': 0.65}
{'loss': 3.1285, 'grad_norm': 4.134920597076416, 'learning_rate': 8.333333333333333e-05, 'epoch': 0.82}
{'loss': 2.9916, 'grad_norm': 4.114354133605957, 'learning_rate': 0.0001, 'epoch': 0.98}


  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 6.713097095489502, 'eval_runtime': 88.6904, 'eval_samples_per_second': 2.255, 'eval_steps_per_second': 0.282, 'epoch': 1.0}
{'loss': 2.8335, 'grad_norm': 3.389439582824707, 'learning_rate': 0.00011666666666666667, 'epoch': 1.14}
{'loss': 2.8309, 'grad_norm': 2.93381929397583, 'learning_rate': 0.00013333333333333334, 'epoch': 1.3}
{'loss': 2.7779, 'grad_norm': 3.458967685699463, 'learning_rate': 0.00015, 'epoch': 1.47}
{'loss': 2.7734, 'grad_norm': 2.757002830505371, 'learning_rate': 0.00016666666666666666, 'epoch': 1.63}
{'loss': 2.7185, 'grad_norm': 2.951125144958496, 'learning_rate': 0.00018333333333333334, 'epoch': 1.79}
{'loss': 2.6367, 'grad_norm': 2.8377842903137207, 'learning_rate': 0.0002, 'epoch': 1.96}


  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 6.771847724914551, 'eval_runtime': 87.665, 'eval_samples_per_second': 2.281, 'eval_steps_per_second': 0.285, 'epoch': 2.0}
{'loss': 2.5443, 'grad_norm': 2.7709591388702393, 'learning_rate': 0.00021666666666666668, 'epoch': 2.11}
{'loss': 2.6008, 'grad_norm': 2.589423894882202, 'learning_rate': 0.00023333333333333333, 'epoch': 2.28}
{'loss': 2.5744, 'grad_norm': 2.468208074569702, 'learning_rate': 0.00025, 'epoch': 2.44}
{'loss': 2.477, 'grad_norm': 3.0269947052001953, 'learning_rate': 0.00024868020482261805, 'epoch': 2.6}
{'loss': 2.4907, 'grad_norm': 2.3736724853515625, 'learning_rate': 0.0002447486890394361, 'epoch': 2.77}
{'loss': 2.4772, 'grad_norm': 2.330203056335449, 'learning_rate': 0.00023828847337958127, 'epoch': 2.93}


  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 6.848089694976807, 'eval_runtime': 87.1521, 'eval_samples_per_second': 2.295, 'eval_steps_per_second': 0.287, 'epoch': 3.0}
{'loss': 2.3498, 'grad_norm': 2.5652854442596436, 'learning_rate': 0.00022943597642661705, 'epoch': 3.09}
{'loss': 2.3245, 'grad_norm': 2.2860045433044434, 'learning_rate': 0.0002183781339051245, 'epoch': 3.25}
{'loss': 2.339, 'grad_norm': 2.232682704925537, 'learning_rate': 0.0002053484512108174, 'epoch': 3.42}
{'loss': 2.2986, 'grad_norm': 2.2428107261657715, 'learning_rate': 0.00019062207254182, 'epoch': 3.58}
{'loss': 2.2996, 'grad_norm': 2.464080572128296, 'learning_rate': 0.00017450997075489462, 'epoch': 3.74}
{'loss': 2.241, 'grad_norm': 2.0995633602142334, 'learning_rate': 0.00015735238063781508, 'epoch': 3.91}


  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 6.915830135345459, 'eval_runtime': 87.684, 'eval_samples_per_second': 2.281, 'eval_steps_per_second': 0.285, 'epoch': 4.0}
{'loss': 2.1507, 'grad_norm': 2.2546820640563965, 'learning_rate': 0.0001395116142656538, 'epoch': 4.07}
{'loss': 2.1271, 'grad_norm': 2.361213207244873, 'learning_rate': 0.00012136441015711107, 'epoch': 4.23}
{'loss': 2.1107, 'grad_norm': 2.0600945949554443, 'learning_rate': 0.00010329397779163371, 'epoch': 4.39}
{'loss': 2.0559, 'grad_norm': 2.0332272052764893, 'learning_rate': 8.568190548104832e-05, 'epoch': 4.56}
{'loss': 2.0615, 'grad_norm': 1.9290919303894043, 'learning_rate': 6.890010247494224e-05, 'epoch': 4.72}
{'loss': 2.0547, 'grad_norm': 1.9184967279434204, 'learning_rate': 5.330294545611927e-05, 'epoch': 4.88}


  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 6.9010701179504395, 'eval_runtime': 87.82, 'eval_samples_per_second': 2.277, 'eval_steps_per_second': 0.285, 'epoch': 5.0}
{'loss': 1.9642, 'grad_norm': 1.976872444152832, 'learning_rate': 3.9219795266408314e-05, 'epoch': 5.04}
{'loss': 1.9338, 'grad_norm': 1.9681077003479004, 'learning_rate': 2.6948041885053036e-05, 'epoch': 5.2}
{'loss': 1.9191, 'grad_norm': 1.9897916316986084, 'learning_rate': 1.6746824526945162e-05, 'epoch': 5.37}
{'loss': 1.901, 'grad_norm': 1.9622206687927246, 'learning_rate': 8.831559471647183e-06, 'epoch': 5.53}
{'loss': 1.9304, 'grad_norm': 1.8895570039749146, 'learning_rate': 3.3693911775220242e-06, 'epoch': 5.69}
{'loss': 1.9119, 'grad_norm': 1.924885630607605, 'learning_rate': 4.756627385318069e-07, 'epoch': 5.86}


  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 6.873990058898926, 'eval_runtime': 87.3604, 'eval_samples_per_second': 2.289, 'eval_steps_per_second': 0.286, 'epoch': 5.96}


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


{'train_runtime': 27515.7373, 'train_samples_per_second': 0.64, 'train_steps_per_second': 0.027, 'train_loss': 2.572689468091954, 'epoch': 5.96}


('./models/gpt2-teacher/tokenizer_config.json',
 './models/gpt2-teacher/special_tokens_map.json',
 './models/gpt2-teacher/vocab.json',
 './models/gpt2-teacher/merges.txt',
 './models/gpt2-teacher/added_tokens.json',
 './models/gpt2-teacher/tokenizer.json')

In [4]:
# 训练Llama模型
from transformers import (
    LlamaConfig, LlamaForCausalLM,  
)
from transformers import Trainer, TrainingArguments
model_config = LlamaConfig(
        vocab_size=tokenizer.vocab_size,
        max_position_embeddings=2*tokenizer.model_max_length,
        hidden_size=1024,
        intermediate_size=3072,
        num_hidden_layers=24,
        num_attention_heads=8,
        tie_word_embeddings=False,
        pad_token_id=tokenizer.convert_tokens_to_ids("<pad>"),
    )
model = LlamaForCausalLM(model_config)

output_dir = "./models/llama-teacher"

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    save_strategy = "epoch",
    evaluation_strategy = "epoch",
    num_train_epochs=4,
    gradient_accumulation_steps=2,
    per_device_train_batch_size=16,
    save_total_limit=1,  # Set to zero to avoid saving
    warmup_steps=300, 
    lr_scheduler_type="cosine",
    learning_rate=float(3e-4),
    logging_steps=20,
    fp16=False,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    torch_compile = False,
    no_cuda=True   # we use CPU，显卡足够大的话可以改为False
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)



  0%|          | 0/368 [00:00<?, ?it/s]

{'loss': 16.6259, 'grad_norm': 6.6172637939453125, 'learning_rate': 1.9999999999999998e-05, 'epoch': 0.22}
{'loss': 10.6818, 'grad_norm': 5.324873924255371, 'learning_rate': 3.9999999999999996e-05, 'epoch': 0.43}
{'loss': 8.1373, 'grad_norm': 2.213642120361328, 'learning_rate': 5.9999999999999995e-05, 'epoch': 0.65}
{'loss': 7.0195, 'grad_norm': 2.61836576461792, 'learning_rate': 7.999999999999999e-05, 'epoch': 0.87}


  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 6.424047946929932, 'eval_runtime': 44.2429, 'eval_samples_per_second': 4.52, 'eval_steps_per_second': 0.565, 'epoch': 1.0}
{'loss': 6.3157, 'grad_norm': 2.8587779998779297, 'learning_rate': 9.999999999999999e-05, 'epoch': 1.09}
{'loss': 5.9112, 'grad_norm': 2.363039970397949, 'learning_rate': 0.00011999999999999999, 'epoch': 1.3}
{'loss': 5.6126, 'grad_norm': 2.3234667778015137, 'learning_rate': 0.00014, 'epoch': 1.52}
{'loss': 5.4582, 'grad_norm': 1.9862345457077026, 'learning_rate': 0.00015999999999999999, 'epoch': 1.74}
{'loss': 5.2141, 'grad_norm': 2.734612464904785, 'learning_rate': 0.00017999999999999998, 'epoch': 1.96}


  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 6.415191650390625, 'eval_runtime': 44.059, 'eval_samples_per_second': 4.539, 'eval_steps_per_second': 0.567, 'epoch': 2.0}
{'loss': 5.0382, 'grad_norm': 1.8031628131866455, 'learning_rate': 0.00019999999999999998, 'epoch': 2.17}
{'loss': 4.913, 'grad_norm': 1.8266865015029907, 'learning_rate': 0.00021999999999999995, 'epoch': 2.39}
{'loss': 4.6786, 'grad_norm': 1.9183133840560913, 'learning_rate': 0.00023999999999999998, 'epoch': 2.61}
{'loss': 4.6846, 'grad_norm': 1.8740675449371338, 'learning_rate': 0.00026, 'epoch': 2.83}


  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 6.405597686767578, 'eval_runtime': 44.1624, 'eval_samples_per_second': 4.529, 'eval_steps_per_second': 0.566, 'epoch': 3.0}
{'loss': 4.5987, 'grad_norm': 1.6913737058639526, 'learning_rate': 0.00028, 'epoch': 3.04}
{'loss': 4.3079, 'grad_norm': 1.8193702697753906, 'learning_rate': 0.0003, 'epoch': 3.26}
{'loss': 4.3523, 'grad_norm': 1.8672795295715332, 'learning_rate': 0.00024039519545688846, 'epoch': 3.48}
{'loss': 4.2241, 'grad_norm': 1.3985413312911987, 'learning_rate': 0.00010895055148918756, 'epoch': 3.7}
{'loss': 4.0808, 'grad_norm': 1.3998284339904785, 'learning_rate': 1.0129165589346643e-05, 'epoch': 3.91}


  0%|          | 0/25 [00:00<?, ?it/s]

{'eval_loss': 6.422872543334961, 'eval_runtime': 43.8794, 'eval_samples_per_second': 4.558, 'eval_steps_per_second': 0.57, 'epoch': 4.0}
{'train_runtime': 8266.189, 'train_samples_per_second': 1.42, 'train_steps_per_second': 0.045, 'train_loss': 6.166872097098309, 'epoch': 4.0}


('./models/llama-teacher/tokenizer_config.json',
 './models/llama-teacher/special_tokens_map.json',
 './models/llama-teacher/vocab.json',
 './models/llama-teacher/merges.txt',
 './models/llama-teacher/added_tokens.json',
 './models/llama-teacher/tokenizer.json')