In [188]:
import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM
import numpy as np
import torch
torch.set_printoptions(threshold=10000)

In [189]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf",  use_fast=True)
tokenizer.add_special_tokens({'pad_token':'[PAD]'})

1

In [190]:
mode = "train"    #  mode:  train or test， 也是本地训练集目录或者测试集目录的名称
data = Dataset.load_from_disk(mode)

数据还是比较多，因此选了一个子集训练（测试）。训练的话，取10W条。 因为斯坦福的指令微调才52k数据，就能取得不错的效果。因此这里
先用一部分数据训练看看情况。测试集5W条也有点多，用1000条测试就行。（这里虽然名字是test测试集，但实际上我是用在训练的时候做为了eval数据)

In [191]:
if mode == "train":
    size = 100000
else:
    size = 1000
indices = np.random.choice(data.num_rows, size=size, replace=False)
small_data = data.select(indices)

In [192]:
small_data

Dataset({
    features: ['convs'],
    num_rows: 100000
})

In [193]:
df = pd.DataFrame(small_data['convs'], columns=["convs"])

In [194]:
response_template = '[/INST]'
instruction_template = '[INST]'
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template,
                                           response_template=response_template, tokenizer=tokenizer)

下面直接将文本转换为token，以免在运行的时候进行转换。max_length要和训练的参数保持一致。

In [195]:
dataset = Dataset.from_list(df['convs'].apply(lambda x: tokenizer(x, return_length=True,  add_special_tokens=False, padding='max_length', max_length=1024, truncation=True)).to_list())

In [196]:
dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'length'],
    num_rows: 100000
})

In [197]:
save_name = f"small_{mode}_tokens"
dataset.save_to_disk("small_test_tokens")

Saving the dataset (0/2 shards):   0%|          | 0/100000 [00:00<?, ? examples/s]

可以使用下面的代码查看数据是否是正常的

In [198]:
dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                         collate_fn=collator,
                                         batch_size=1)

In [200]:
for batch in dataloader:
    print(batch)
    break

{'input_ids': tensor([[    1,   518, 25580, 29962,  3532, 14816, 29903,  6778,    13,    13,
          3492,   526,   263,  8444, 29892,  3390,  1319,   322, 15993, 20255,
         29889, 29849,  1234,   408,  1371,  3730,   408,  1950, 29892,  1550,
          1641,  9109, 29889,    13, 10858,  6089,   881,   451,  3160,   738,
         10311,  1319, 29892,   443,   621,   936, 29892, 11021,   391, 29892,
          7916,   391, 29892,   304, 27375, 29892, 18215, 29892,   470, 27302,
          2793, 29889,  3529,  9801,   393,   596, 20890,   526,  5374,   635,
           443,  5365,  1463,   322,  6374,   297,  5469, 29889,    13,    13,
          3644,   263,  1139,   947,   451,  1207,   738,  4060, 29892,   470,
           338,   451,  2114,  1474, 16165,   261,   296, 29892,  5649,  2020,
          2012,   310, 22862,  1554,   451,  1959, 29889,   960,   366,  1016,
         29915, 29873,  1073,   278,  1234,   304,   263,  1139, 29892,  3113,
          1016, 29915, 29873,  6232,  