# 因果语言模型训练实例

## Step1 导入相关包

In [1]:
from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
    BloomForCausalLM,
)

  from .autonotebook import tqdm as notebook_tqdm


## Step2 加载数据集

In [2]:
ds = Dataset.load_from_disk("./wiki_cn_filtered/")

In [3]:
ds

Dataset({
    features: ['source', 'completion'],
    num_rows: 10000
})

In [4]:
ds[0]

{'source': 'wikipedia.zh2307',
 'completion': "西安交通大学博物馆（Xi'an Jiaotong University Museum）是一座位于西安交通大学的博物馆，馆长是锺明善。\n历史\n2004年9月20日开始筹建，2013年4月8日正式建成开馆，位于西安交通大学兴庆校区陕西省西安市咸宁西路28号。建筑面积6,800平米，展厅面积4,500平米，馆藏文物4,900余件。包括历代艺术文物馆、碑石书法馆、西部农民画馆、邢良坤陶瓷艺术馆、陕西秦腔博物馆和书画展厅共五馆一厅。\n营业时间\n* 周一至周六：上午九点至十二点，下午一点至五点\n* 周日闭馆"}

## Step3 数据集处理

In [5]:
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-389m-zh")


def process_func(examples):
    contents = [e + tokenizer.eos_token for e in examples["completion"]]
    return tokenizer(
        contents, max_length=384, truncation=True
    )  # labels在DataCollatorForLanguageModeling

In [6]:
tokenized_ds = ds.map(process_func, batched=True, remove_columns=ds.column_names)
tokenized_ds

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 10000
})

In [7]:
from torch.utils.data import DataLoader

dl = DataLoader(
    tokenized_ds,
    batch_size=2,
    collate_fn=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

In [8]:
next(enumerate(dl))
# 3是pad
# 2是eos

(0,
 {'input_ids': tensor([[    3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3

In [9]:
tokenizer.pad_token, tokenizer.pad_token_id

('<pad>', 3)

In [10]:
tokenizer.eos_token, tokenizer.eos_token_id

('</s>', 2)

## Step4 创建模型

In [11]:
model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-389m-zh")

In [12]:
model

BloomForCausalLM(
  (transformer): BloomModel(
    (word_embeddings): Embedding(42437, 1024)
    (word_embeddings_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (h): ModuleList(
      (0-23): 24 x BloomBlock(
        (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (self_attention): BloomAttention(
          (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
          (dense): Linear(in_features=1024, out_features=1024, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): BloomMLP(
          (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
          (gelu_impl): BloomGelu()
          (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (l

In [13]:
from torch.nn import ModuleList

model.transformer.h = ModuleList([model.transformer.h[i] for i in range(0, 24, 6)])
model

BloomForCausalLM(
  (transformer): BloomModel(
    (word_embeddings): Embedding(42437, 1024)
    (word_embeddings_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (h): ModuleList(
      (0-3): 4 x BloomBlock(
        (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (self_attention): BloomAttention(
          (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
          (dense): Linear(in_features=1024, out_features=1024, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): BloomMLP(
          (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
          (gelu_impl): BloomGelu()
          (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_

In [14]:
for param in model.transformer.word_embeddings.parameters():
    param.requires_grad = False

## Step5 配置训练参数

In [15]:
args = TrainingArguments(
    output_dir="./causal_lm",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    logging_steps=10,
    num_train_epochs=1,
    fp16=True,
)

## Step6 创建训练器

In [16]:
trainer = Trainer(
    args=args,
    model=model,
    tokenizer=tokenizer,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

## Step7 模型训练

In [17]:
trainer.train()

  0%|          | 0/1250 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
  1%|          | 11/1250 [00:02<04:11,  4.93it/s]

{'loss': 25.0814, 'grad_norm': 27.7113037109375, 'learning_rate': 4.972e-05, 'epoch': 0.01}


  2%|▏         | 21/1250 [00:04<04:01,  5.09it/s]

{'loss': 9.8204, 'grad_norm': 14.930438995361328, 'learning_rate': 4.932e-05, 'epoch': 0.02}


  2%|▏         | 31/1250 [00:06<04:05,  4.96it/s]

{'loss': 9.0141, 'grad_norm': 8.550276756286621, 'learning_rate': 4.8920000000000006e-05, 'epoch': 0.02}


  3%|▎         | 41/1250 [00:08<03:57,  5.10it/s]

{'loss': 8.6502, 'grad_norm': 6.589648723602295, 'learning_rate': 4.852e-05, 'epoch': 0.03}


  4%|▍         | 50/1250 [00:10<03:48,  5.25it/s]

{'loss': 8.4842, 'grad_norm': 9.226041793823242, 'learning_rate': 4.812000000000001e-05, 'epoch': 0.04}


  5%|▍         | 60/1250 [00:12<04:06,  4.83it/s]

{'loss': 8.2818, 'grad_norm': 10.640965461730957, 'learning_rate': 4.7720000000000004e-05, 'epoch': 0.05}


  6%|▌         | 71/1250 [00:14<03:57,  4.97it/s]

{'loss': 8.1291, 'grad_norm': 7.599116802215576, 'learning_rate': 4.732e-05, 'epoch': 0.06}


  6%|▋         | 81/1250 [00:16<03:58,  4.90it/s]

{'loss': 8.0644, 'grad_norm': 5.996996879577637, 'learning_rate': 4.6920000000000005e-05, 'epoch': 0.06}


  7%|▋         | 91/1250 [00:18<03:45,  5.14it/s]

{'loss': 7.9662, 'grad_norm': 6.339902877807617, 'learning_rate': 4.652e-05, 'epoch': 0.07}


  8%|▊         | 101/1250 [00:20<03:44,  5.12it/s]

{'loss': 7.7175, 'grad_norm': 10.138394355773926, 'learning_rate': 4.612e-05, 'epoch': 0.08}


  9%|▉         | 111/1250 [00:22<03:47,  5.02it/s]

{'loss': 7.7277, 'grad_norm': 7.408307075500488, 'learning_rate': 4.572e-05, 'epoch': 0.09}


 10%|▉         | 121/1250 [00:24<03:41,  5.10it/s]

{'loss': 7.5669, 'grad_norm': 6.882763385772705, 'learning_rate': 4.532e-05, 'epoch': 0.1}


 10%|█         | 131/1250 [00:26<03:39,  5.10it/s]

{'loss': 7.352, 'grad_norm': 6.7376275062561035, 'learning_rate': 4.4920000000000004e-05, 'epoch': 0.1}


 11%|█▏        | 141/1250 [00:28<03:32,  5.22it/s]

{'loss': 7.3839, 'grad_norm': 8.50362491607666, 'learning_rate': 4.452e-05, 'epoch': 0.11}


 12%|█▏        | 151/1250 [00:30<03:34,  5.12it/s]

{'loss': 7.3005, 'grad_norm': 8.517974853515625, 'learning_rate': 4.412e-05, 'epoch': 0.12}


 13%|█▎        | 161/1250 [00:32<03:33,  5.09it/s]

{'loss': 7.2619, 'grad_norm': 8.635285377502441, 'learning_rate': 4.372e-05, 'epoch': 0.13}


 14%|█▎        | 171/1250 [00:34<03:35,  5.01it/s]

{'loss': 7.2138, 'grad_norm': 7.927448272705078, 'learning_rate': 4.332e-05, 'epoch': 0.14}


 14%|█▍        | 181/1250 [00:36<03:27,  5.14it/s]

{'loss': 7.2152, 'grad_norm': 6.33795690536499, 'learning_rate': 4.292e-05, 'epoch': 0.14}


 15%|█▌        | 191/1250 [00:38<03:33,  4.97it/s]

{'loss': 7.1099, 'grad_norm': 6.475816249847412, 'learning_rate': 4.2520000000000006e-05, 'epoch': 0.15}


 16%|█▌        | 201/1250 [00:39<03:25,  5.11it/s]

{'loss': 6.9577, 'grad_norm': 6.652378559112549, 'learning_rate': 4.212e-05, 'epoch': 0.16}


 17%|█▋        | 211/1250 [00:41<03:29,  4.95it/s]

{'loss': 7.0232, 'grad_norm': 6.48954963684082, 'learning_rate': 4.172e-05, 'epoch': 0.17}


 18%|█▊        | 221/1250 [00:43<03:15,  5.26it/s]

{'loss': 6.9848, 'grad_norm': 8.78028678894043, 'learning_rate': 4.1320000000000004e-05, 'epoch': 0.18}


 18%|█▊        | 231/1250 [00:45<03:17,  5.17it/s]

{'loss': 6.8951, 'grad_norm': 8.50175952911377, 'learning_rate': 4.092e-05, 'epoch': 0.18}


 19%|█▉        | 241/1250 [00:47<03:17,  5.11it/s]

{'loss': 6.9336, 'grad_norm': 6.770652770996094, 'learning_rate': 4.0520000000000005e-05, 'epoch': 0.19}


 20%|██        | 251/1250 [00:49<03:16,  5.08it/s]

{'loss': 6.7765, 'grad_norm': 7.195130348205566, 'learning_rate': 4.012e-05, 'epoch': 0.2}


 21%|██        | 261/1250 [00:51<03:16,  5.02it/s]

{'loss': 6.7157, 'grad_norm': 8.293961524963379, 'learning_rate': 3.972e-05, 'epoch': 0.21}


 22%|██▏       | 271/1250 [00:53<03:12,  5.08it/s]

{'loss': 6.8086, 'grad_norm': 9.06152057647705, 'learning_rate': 3.932e-05, 'epoch': 0.22}


 22%|██▏       | 281/1250 [00:55<03:02,  5.32it/s]

{'loss': 6.6481, 'grad_norm': 12.948522567749023, 'learning_rate': 3.892e-05, 'epoch': 0.22}


 23%|██▎       | 291/1250 [00:57<03:03,  5.22it/s]

{'loss': 6.7655, 'grad_norm': 12.9730806350708, 'learning_rate': 3.8520000000000004e-05, 'epoch': 0.23}


 24%|██▍       | 301/1250 [00:59<03:10,  4.98it/s]

{'loss': 6.6854, 'grad_norm': 5.789427757263184, 'learning_rate': 3.812e-05, 'epoch': 0.24}


 25%|██▍       | 311/1250 [01:01<03:08,  4.99it/s]

{'loss': 6.6405, 'grad_norm': 6.855485916137695, 'learning_rate': 3.772e-05, 'epoch': 0.25}


 26%|██▌       | 321/1250 [01:03<03:02,  5.10it/s]

{'loss': 6.8144, 'grad_norm': 10.671525955200195, 'learning_rate': 3.732e-05, 'epoch': 0.26}


 26%|██▋       | 331/1250 [01:05<02:57,  5.18it/s]

{'loss': 6.4784, 'grad_norm': 8.773934364318848, 'learning_rate': 3.692e-05, 'epoch': 0.26}


 27%|██▋       | 341/1250 [01:07<03:00,  5.03it/s]

{'loss': 6.5977, 'grad_norm': 8.347003936767578, 'learning_rate': 3.652e-05, 'epoch': 0.27}


 28%|██▊       | 351/1250 [01:09<02:59,  5.00it/s]

{'loss': 6.459, 'grad_norm': 6.387214660644531, 'learning_rate': 3.6120000000000007e-05, 'epoch': 0.28}


 29%|██▉       | 361/1250 [01:11<03:02,  4.88it/s]

{'loss': 6.6761, 'grad_norm': 7.835637092590332, 'learning_rate': 3.5720000000000004e-05, 'epoch': 0.29}


 30%|██▉       | 371/1250 [01:13<02:53,  5.07it/s]

{'loss': 6.6177, 'grad_norm': 12.695526123046875, 'learning_rate': 3.532e-05, 'epoch': 0.3}


 30%|███       | 381/1250 [01:15<02:50,  5.11it/s]

{'loss': 6.4098, 'grad_norm': 7.163578510284424, 'learning_rate': 3.4920000000000004e-05, 'epoch': 0.3}


 31%|███▏      | 391/1250 [01:17<02:50,  5.04it/s]

{'loss': 6.4449, 'grad_norm': 6.923464298248291, 'learning_rate': 3.452e-05, 'epoch': 0.31}


 32%|███▏      | 401/1250 [01:19<02:50,  4.98it/s]

{'loss': 6.6301, 'grad_norm': 6.532619953155518, 'learning_rate': 3.412e-05, 'epoch': 0.32}


 33%|███▎      | 411/1250 [01:21<02:42,  5.15it/s]

{'loss': 6.4502, 'grad_norm': 7.821795463562012, 'learning_rate': 3.372e-05, 'epoch': 0.33}


 34%|███▎      | 421/1250 [01:23<02:45,  5.02it/s]

{'loss': 6.5521, 'grad_norm': 7.884487628936768, 'learning_rate': 3.332e-05, 'epoch': 0.34}


 34%|███▍      | 431/1250 [01:25<02:40,  5.10it/s]

{'loss': 6.2737, 'grad_norm': 12.052517890930176, 'learning_rate': 3.292e-05, 'epoch': 0.34}


 35%|███▌      | 441/1250 [01:27<02:39,  5.06it/s]

{'loss': 6.2823, 'grad_norm': 6.344372749328613, 'learning_rate': 3.252e-05, 'epoch': 0.35}


 36%|███▌      | 451/1250 [01:29<02:35,  5.15it/s]

{'loss': 6.4038, 'grad_norm': 7.512506008148193, 'learning_rate': 3.212e-05, 'epoch': 0.36}


 37%|███▋      | 461/1250 [01:31<02:34,  5.10it/s]

{'loss': 6.474, 'grad_norm': 7.014685153961182, 'learning_rate': 3.172e-05, 'epoch': 0.37}


 38%|███▊      | 471/1250 [01:33<02:36,  4.99it/s]

{'loss': 6.3745, 'grad_norm': 7.1853156089782715, 'learning_rate': 3.132e-05, 'epoch': 0.38}


 38%|███▊      | 481/1250 [01:35<02:33,  5.01it/s]

{'loss': 6.2991, 'grad_norm': 6.805530548095703, 'learning_rate': 3.092e-05, 'epoch': 0.38}


 39%|███▉      | 490/1250 [01:37<02:29,  5.08it/s]

{'loss': 6.136, 'grad_norm': 7.782148838043213, 'learning_rate': 3.0520000000000006e-05, 'epoch': 0.39}


 40%|████      | 500/1250 [01:39<02:28,  5.04it/s]

{'loss': 6.2536, 'grad_norm': 9.031482696533203, 'learning_rate': 3.0120000000000003e-05, 'epoch': 0.4}


 41%|████      | 511/1250 [01:41<02:33,  4.80it/s]

{'loss': 6.2781, 'grad_norm': 6.593645095825195, 'learning_rate': 2.9720000000000003e-05, 'epoch': 0.41}


 42%|████▏     | 521/1250 [01:43<02:22,  5.12it/s]

{'loss': 6.2619, 'grad_norm': 7.887132167816162, 'learning_rate': 2.9320000000000004e-05, 'epoch': 0.42}


 42%|████▏     | 531/1250 [01:45<02:20,  5.11it/s]

{'loss': 6.1763, 'grad_norm': 9.341087341308594, 'learning_rate': 2.8920000000000004e-05, 'epoch': 0.42}


 43%|████▎     | 541/1250 [01:47<02:26,  4.84it/s]

{'loss': 6.1902, 'grad_norm': 6.864844799041748, 'learning_rate': 2.852e-05, 'epoch': 0.43}


 44%|████▍     | 551/1250 [01:49<02:24,  4.83it/s]

{'loss': 6.1737, 'grad_norm': 8.743148803710938, 'learning_rate': 2.8120000000000002e-05, 'epoch': 0.44}


 45%|████▍     | 561/1250 [01:51<02:11,  5.24it/s]

{'loss': 6.2513, 'grad_norm': 9.898697853088379, 'learning_rate': 2.7720000000000002e-05, 'epoch': 0.45}


 46%|████▌     | 571/1250 [01:53<02:19,  4.88it/s]

{'loss': 6.0621, 'grad_norm': 11.046476364135742, 'learning_rate': 2.7320000000000003e-05, 'epoch': 0.46}


 46%|████▋     | 581/1250 [01:55<02:10,  5.15it/s]

{'loss': 6.1444, 'grad_norm': 7.026841640472412, 'learning_rate': 2.692e-05, 'epoch': 0.46}


 47%|████▋     | 591/1250 [01:57<02:12,  4.96it/s]

{'loss': 6.2561, 'grad_norm': 7.96036434173584, 'learning_rate': 2.652e-05, 'epoch': 0.47}


 48%|████▊     | 601/1250 [01:59<02:08,  5.05it/s]

{'loss': 6.1677, 'grad_norm': 7.887172698974609, 'learning_rate': 2.612e-05, 'epoch': 0.48}


 49%|████▉     | 611/1250 [02:01<02:03,  5.18it/s]

{'loss': 6.1498, 'grad_norm': 7.269429683685303, 'learning_rate': 2.572e-05, 'epoch': 0.49}


 50%|████▉     | 621/1250 [02:03<02:06,  4.97it/s]

{'loss': 6.154, 'grad_norm': 8.294212341308594, 'learning_rate': 2.5319999999999998e-05, 'epoch': 0.5}


 50%|█████     | 631/1250 [02:05<02:03,  5.00it/s]

{'loss': 6.1495, 'grad_norm': 7.3320698738098145, 'learning_rate': 2.4920000000000002e-05, 'epoch': 0.5}


 51%|█████▏    | 641/1250 [02:07<02:02,  4.96it/s]

{'loss': 6.1452, 'grad_norm': 8.789143562316895, 'learning_rate': 2.4520000000000002e-05, 'epoch': 0.51}


 52%|█████▏    | 651/1250 [02:09<01:57,  5.09it/s]

{'loss': 6.0305, 'grad_norm': 7.396072864532471, 'learning_rate': 2.412e-05, 'epoch': 0.52}


 53%|█████▎    | 661/1250 [02:11<01:52,  5.22it/s]

{'loss': 5.8859, 'grad_norm': 6.592757225036621, 'learning_rate': 2.372e-05, 'epoch': 0.53}


 54%|█████▎    | 671/1250 [02:13<01:58,  4.90it/s]

{'loss': 6.0789, 'grad_norm': 7.619958877563477, 'learning_rate': 2.332e-05, 'epoch': 0.54}


 54%|█████▍    | 681/1250 [02:15<01:52,  5.08it/s]

{'loss': 6.0635, 'grad_norm': 7.9850568771362305, 'learning_rate': 2.292e-05, 'epoch': 0.54}


 55%|█████▌    | 691/1250 [02:17<01:46,  5.26it/s]

{'loss': 5.9889, 'grad_norm': 8.171202659606934, 'learning_rate': 2.252e-05, 'epoch': 0.55}


 56%|█████▌    | 701/1250 [02:19<01:49,  4.99it/s]

{'loss': 5.9422, 'grad_norm': 6.664157390594482, 'learning_rate': 2.212e-05, 'epoch': 0.56}


 57%|█████▋    | 711/1250 [02:21<01:40,  5.34it/s]

{'loss': 5.8901, 'grad_norm': 10.180691719055176, 'learning_rate': 2.1720000000000002e-05, 'epoch': 0.57}


 58%|█████▊    | 721/1250 [02:23<01:46,  4.95it/s]

{'loss': 6.065, 'grad_norm': 6.782871723175049, 'learning_rate': 2.1320000000000003e-05, 'epoch': 0.58}


 58%|█████▊    | 731/1250 [02:25<01:41,  5.13it/s]

{'loss': 6.0656, 'grad_norm': 7.444766998291016, 'learning_rate': 2.092e-05, 'epoch': 0.58}


 59%|█████▉    | 741/1250 [02:27<01:38,  5.19it/s]

{'loss': 6.0267, 'grad_norm': 8.486977577209473, 'learning_rate': 2.052e-05, 'epoch': 0.59}


 60%|██████    | 751/1250 [02:29<01:38,  5.05it/s]

{'loss': 5.9551, 'grad_norm': 6.476780414581299, 'learning_rate': 2.012e-05, 'epoch': 0.6}


 61%|██████    | 761/1250 [02:31<01:37,  5.02it/s]

{'loss': 5.8448, 'grad_norm': 9.83322525024414, 'learning_rate': 1.972e-05, 'epoch': 0.61}


 62%|██████▏   | 771/1250 [02:33<01:30,  5.29it/s]

{'loss': 5.9843, 'grad_norm': 8.426373481750488, 'learning_rate': 1.932e-05, 'epoch': 0.62}


 62%|██████▏   | 781/1250 [02:34<01:31,  5.13it/s]

{'loss': 5.8684, 'grad_norm': 9.699344635009766, 'learning_rate': 1.8920000000000002e-05, 'epoch': 0.62}


 63%|██████▎   | 791/1250 [02:36<01:30,  5.06it/s]

{'loss': 5.9971, 'grad_norm': 6.695237159729004, 'learning_rate': 1.8520000000000002e-05, 'epoch': 0.63}


 64%|██████▍   | 801/1250 [02:38<01:29,  5.04it/s]

{'loss': 5.9993, 'grad_norm': 7.506155490875244, 'learning_rate': 1.812e-05, 'epoch': 0.64}


 65%|██████▍   | 811/1250 [02:40<01:28,  4.98it/s]

{'loss': 5.9589, 'grad_norm': 8.042360305786133, 'learning_rate': 1.772e-05, 'epoch': 0.65}


 66%|██████▌   | 821/1250 [02:42<01:22,  5.18it/s]

{'loss': 5.9443, 'grad_norm': 6.797581195831299, 'learning_rate': 1.732e-05, 'epoch': 0.66}


 66%|██████▋   | 831/1250 [02:44<01:18,  5.36it/s]

{'loss': 5.8965, 'grad_norm': 7.741754531860352, 'learning_rate': 1.692e-05, 'epoch': 0.66}


 67%|██████▋   | 841/1250 [02:46<01:24,  4.83it/s]

{'loss': 5.9826, 'grad_norm': 6.744418144226074, 'learning_rate': 1.652e-05, 'epoch': 0.67}


 68%|██████▊   | 851/1250 [02:48<01:15,  5.27it/s]

{'loss': 5.8525, 'grad_norm': 7.191407680511475, 'learning_rate': 1.612e-05, 'epoch': 0.68}


 69%|██████▉   | 861/1250 [02:50<01:18,  4.95it/s]

{'loss': 5.9853, 'grad_norm': 8.631134033203125, 'learning_rate': 1.5720000000000002e-05, 'epoch': 0.69}


 70%|██████▉   | 871/1250 [02:52<01:09,  5.45it/s]

{'loss': 5.7081, 'grad_norm': 9.768930435180664, 'learning_rate': 1.5320000000000002e-05, 'epoch': 0.7}


 70%|███████   | 881/1250 [02:54<01:12,  5.08it/s]

{'loss': 5.9669, 'grad_norm': 6.751153945922852, 'learning_rate': 1.4920000000000001e-05, 'epoch': 0.7}


 71%|███████▏  | 891/1250 [02:56<01:11,  5.02it/s]

{'loss': 5.9145, 'grad_norm': 7.695727348327637, 'learning_rate': 1.452e-05, 'epoch': 0.71}


 72%|███████▏  | 901/1250 [02:58<01:05,  5.36it/s]

{'loss': 5.7608, 'grad_norm': 8.050070762634277, 'learning_rate': 1.412e-05, 'epoch': 0.72}


 73%|███████▎  | 911/1250 [03:00<01:06,  5.13it/s]

{'loss': 5.8249, 'grad_norm': 8.1297025680542, 'learning_rate': 1.3719999999999999e-05, 'epoch': 0.73}


 74%|███████▎  | 921/1250 [03:02<01:06,  4.96it/s]

{'loss': 5.7837, 'grad_norm': 6.902261257171631, 'learning_rate': 1.3320000000000001e-05, 'epoch': 0.74}


 74%|███████▍  | 931/1250 [03:04<01:03,  5.01it/s]

{'loss': 6.0135, 'grad_norm': 6.36602258682251, 'learning_rate': 1.2920000000000002e-05, 'epoch': 0.74}


 75%|███████▌  | 941/1250 [03:06<00:59,  5.22it/s]

{'loss': 5.926, 'grad_norm': 7.123579502105713, 'learning_rate': 1.252e-05, 'epoch': 0.75}


 76%|███████▌  | 951/1250 [03:08<00:57,  5.18it/s]

{'loss': 5.8279, 'grad_norm': 6.975743770599365, 'learning_rate': 1.2120000000000001e-05, 'epoch': 0.76}


 77%|███████▋  | 961/1250 [03:10<00:58,  4.93it/s]

{'loss': 5.9738, 'grad_norm': 7.319375991821289, 'learning_rate': 1.172e-05, 'epoch': 0.77}


 78%|███████▊  | 971/1250 [03:12<00:55,  5.04it/s]

{'loss': 5.765, 'grad_norm': 7.393123626708984, 'learning_rate': 1.132e-05, 'epoch': 0.78}


 78%|███████▊  | 981/1250 [03:14<00:54,  4.95it/s]

{'loss': 5.7586, 'grad_norm': 6.508246898651123, 'learning_rate': 1.092e-05, 'epoch': 0.78}


 79%|███████▉  | 991/1250 [03:16<00:51,  5.05it/s]

{'loss': 5.7729, 'grad_norm': 6.114186763763428, 'learning_rate': 1.0520000000000001e-05, 'epoch': 0.79}


 80%|████████  | 1000/1250 [03:17<00:48,  5.17it/s]

{'loss': 5.6357, 'grad_norm': 6.639893054962158, 'learning_rate': 1.012e-05, 'epoch': 0.8}


 81%|████████  | 1011/1250 [03:20<00:47,  5.07it/s]

{'loss': 5.81, 'grad_norm': 8.2906494140625, 'learning_rate': 9.72e-06, 'epoch': 0.81}


 82%|████████▏ | 1021/1250 [03:22<00:46,  4.90it/s]

{'loss': 5.8975, 'grad_norm': 6.351805210113525, 'learning_rate': 9.32e-06, 'epoch': 0.82}


 82%|████████▏ | 1031/1250 [03:24<00:41,  5.28it/s]

{'loss': 5.7507, 'grad_norm': 7.1663007736206055, 'learning_rate': 8.920000000000001e-06, 'epoch': 0.82}


 83%|████████▎ | 1041/1250 [03:26<00:40,  5.12it/s]

{'loss': 5.7756, 'grad_norm': 7.548929214477539, 'learning_rate': 8.52e-06, 'epoch': 0.83}


 84%|████████▍ | 1051/1250 [03:28<00:39,  5.00it/s]

{'loss': 5.7896, 'grad_norm': 6.4999260902404785, 'learning_rate': 8.12e-06, 'epoch': 0.84}


 85%|████████▍ | 1061/1250 [03:30<00:36,  5.12it/s]

{'loss': 5.7791, 'grad_norm': 7.038336753845215, 'learning_rate': 7.72e-06, 'epoch': 0.85}


 86%|████████▌ | 1071/1250 [03:32<00:34,  5.13it/s]

{'loss': 5.5303, 'grad_norm': 7.09450626373291, 'learning_rate': 7.32e-06, 'epoch': 0.86}


 86%|████████▋ | 1081/1250 [03:34<00:32,  5.23it/s]

{'loss': 5.6981, 'grad_norm': 7.594775199890137, 'learning_rate': 6.92e-06, 'epoch': 0.86}


 87%|████████▋ | 1091/1250 [03:36<00:30,  5.13it/s]

{'loss': 5.7311, 'grad_norm': 7.273234844207764, 'learning_rate': 6.519999999999999e-06, 'epoch': 0.87}


 88%|████████▊ | 1101/1250 [03:38<00:28,  5.25it/s]

{'loss': 5.6545, 'grad_norm': 6.342525482177734, 'learning_rate': 6.12e-06, 'epoch': 0.88}


 89%|████████▉ | 1111/1250 [03:40<00:27,  5.13it/s]

{'loss': 5.5172, 'grad_norm': 7.224043846130371, 'learning_rate': 5.72e-06, 'epoch': 0.89}


 90%|████████▉ | 1121/1250 [03:42<00:24,  5.17it/s]

{'loss': 5.6402, 'grad_norm': 7.006211757659912, 'learning_rate': 5.32e-06, 'epoch': 0.9}


 90%|█████████ | 1131/1250 [03:43<00:23,  5.17it/s]

{'loss': 5.5356, 'grad_norm': 6.270766258239746, 'learning_rate': 4.92e-06, 'epoch': 0.9}


 91%|█████████▏| 1141/1250 [03:45<00:21,  5.15it/s]

{'loss': 5.6393, 'grad_norm': 6.178441524505615, 'learning_rate': 4.52e-06, 'epoch': 0.91}


 92%|█████████▏| 1151/1250 [03:47<00:17,  5.51it/s]

{'loss': 5.6057, 'grad_norm': 7.177580833435059, 'learning_rate': 4.12e-06, 'epoch': 0.92}


 93%|█████████▎| 1161/1250 [03:49<00:17,  5.10it/s]

{'loss': 5.703, 'grad_norm': 6.6502203941345215, 'learning_rate': 3.72e-06, 'epoch': 0.93}


 94%|█████████▎| 1171/1250 [03:51<00:14,  5.30it/s]

{'loss': 5.6061, 'grad_norm': 5.831533908843994, 'learning_rate': 3.3200000000000004e-06, 'epoch': 0.94}


 94%|█████████▍| 1181/1250 [03:53<00:12,  5.48it/s]

{'loss': 5.6175, 'grad_norm': 6.3972320556640625, 'learning_rate': 2.92e-06, 'epoch': 0.94}


 95%|█████████▌| 1191/1250 [03:55<00:11,  5.13it/s]

{'loss': 5.6815, 'grad_norm': 5.972646236419678, 'learning_rate': 2.52e-06, 'epoch': 0.95}


 96%|█████████▌| 1201/1250 [03:57<00:09,  4.97it/s]

{'loss': 5.7556, 'grad_norm': 5.876369476318359, 'learning_rate': 2.12e-06, 'epoch': 0.96}


 97%|█████████▋| 1211/1250 [03:59<00:07,  5.03it/s]

{'loss': 5.8282, 'grad_norm': 6.661118507385254, 'learning_rate': 1.72e-06, 'epoch': 0.97}


 98%|█████████▊| 1221/1250 [04:01<00:05,  5.17it/s]

{'loss': 5.9197, 'grad_norm': 7.870673656463623, 'learning_rate': 1.32e-06, 'epoch': 0.98}


 98%|█████████▊| 1231/1250 [04:03<00:03,  5.14it/s]

{'loss': 5.4518, 'grad_norm': 6.345099925994873, 'learning_rate': 9.2e-07, 'epoch': 0.98}


 99%|█████████▉| 1241/1250 [04:05<00:01,  5.26it/s]

{'loss': 5.6396, 'grad_norm': 6.199156761169434, 'learning_rate': 5.2e-07, 'epoch': 0.99}


100%|██████████| 1250/1250 [04:06<00:00,  5.26it/s]

{'loss': 5.679, 'grad_norm': 6.600950717926025, 'learning_rate': 1.2e-07, 'epoch': 1.0}


100%|██████████| 1250/1250 [04:07<00:00,  5.05it/s]

{'train_runtime': 247.5437, 'train_samples_per_second': 40.397, 'train_steps_per_second': 5.05, 'train_loss': 6.5083125213623045, 'epoch': 1.0}





TrainOutput(global_step=1250, training_loss=6.5083125213623045, metrics={'train_runtime': 247.5437, 'train_samples_per_second': 40.397, 'train_steps_per_second': 5.05, 'total_flos': 853652006830080.0, 'train_loss': 6.5083125213623045, 'epoch': 1.0})

## Step8 模型推理

In [18]:
from transformers import pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)

In [28]:
model.config.n_layer = len(range(0, 24, 6))

In [31]:
pipe(
    "西安交通大学博物馆（Xi'an Jiaotong University Museum）是一座位于西安",
    max_length=128,
    do_sample=True,
    use_cache=False,  # 没有这个会报错
)



[{'generated_text': "西安交通大学博物馆（Xi'an Jiaotong University Museum）是一座位于西安市西安市的一座地下文物馆，建有长安图书馆馆馆，位于西安大图书馆。\n历史\n西安东体育场\n西安地铁与西安地铁1号为西安地铁车站，位于西安地铁2号线3号线上。车站南侧，东地铁9号线2号线站台与月台，东侧，地下两层地面为西安西北路地下；地下2层2层9层，车站有100层2层9层层，高9层的12层，2层4层楼1层。\n历史\n1座车站3层7层"}]

In [33]:
pipe(
    "下面是一则游戏新闻。小编报道，近日，游戏产业发展的非常",
    max_length=128,
    do_sample=True,
    use_cache=False,  # 没有这个会报错
    top_k=5,
)



[{'generated_text': '下面是一则游戏新闻。小编报道，近日，游戏产业发展的非常丰富，在游戏中使用，游戏中的玩家，玩家将玩家玩家游戏，玩家将游戏游戏游戏，游戏玩家游戏，游戏玩家可以设定设定设定。游戏玩家可以游戏玩法玩法。\n玩家将玩家游戏游戏游戏，游戏玩法为游戏，游戏的游戏游戏游戏，游戏是游戏游戏，游戏游戏玩法。\n游戏游戏玩家玩家玩家游戏\n游戏玩法\n游戏\n游戏游戏是玩家游戏玩家游戏游戏，游戏玩家可以设定为玩家游戏游戏，游戏玩家玩家玩家玩家在游戏玩法玩法玩法，玩家玩家玩家玩家可以玩家'}]