In [9]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModel,BitsAndBytesConfig
from transformers import Trainer,TrainingArguments
from peft import get_peft_model, prepare_model_for_kbit_training, TaskType, LoraConfig
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
#配置一下全局超参数
base_model_path="/mnt/data/chatglm3-6b-model"
train_data_path="./static/datasets.csv"
seed=42
max_inputs=512
max_outputs=1536
lora_rank=16
lora_dropout=0.05
lora_alpha=32


# 数据是gpt生成的，虚拟女友

In [10]:
#加载数据

dataset=load_dataset("csv",data_files=train_data_path)

#加载分词器


tokenizer=AutoTokenizer.from_pretrained(base_model_path,trust_remote_code=True)


In [11]:
dataset

DatasetDict({
    train: Dataset({
        features: ['man', 'wemen'],
        num_rows: 169
    })
})

In [12]:
#使用分词器对数据处理，进行分词，且加上特殊符号
def tokenizer_function(example, tokenizer,ignore_lable_id=-100):
    question=example["man"]
    answer=example["wemen"]
    q_ids=tokenizer.encode(question,add_special_tokens=False)
    a_ids=tokenizer.encode(answer,add_special_tokens=False)
    if len(q_ids)>max_inputs-2:
        q_ids=q_ids[:max_inputs-2]
    if len(a_ids)>max_outputs-1:
        a_ids=a_ids[:max_outputs-1]
    inputs_ids=tokenizer.build_inputs_with_special_tokens(q_ids,a_ids)
    question_length=len(q_ids)+2
    inputs_labels=[ignore_lable_id]*question_length+inputs_ids[question_length:]
    return {"input_ids":inputs_ids,"labels":inputs_labels}


tokenized_dataset=dataset["train"].map(lambda example:tokenizer_function(example,tokenizer),batched=False,remove_columns=["wemen","man"])

tokenized_dataset = tokenized_dataset.shuffle(seed=seed)
tokenized_dataset = tokenized_dataset.flatten_indices()

In [13]:
tokenized_dataset

Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 169
})

In [14]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [15]:
show_random_elements(tokenized_dataset)

Unnamed: 0,input_ids,labels
0,"[64790, 64792, 30910, 37040, 31123, 54546, 55411, 55058, 31708, 55465, 44248, 31123, 56558, 54607, 54546, 33338, 41071, 54547, 55296, 55674, 31404, 36718, 54547, 55296, 31514, 54728, 54929, 55268, 33876, 55227, 36229, 44248, 31123, 54546, 33021, 54701, 42354, 55296, 31926, 31404, 2]","[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 36718, 54547, 55296, 31514, 54728, 54929, 55268, 33876, 55227, 36229, 44248, 31123, 54546, 33021, 54701, 42354, 55296, 31926, 31404, 2]"
1,"[64790, 64792, 30910, 39661, 31123, 31869, 39229, 50444, 33748, 36354, 31123, 32193, 31897, 31740, 31669, 54537, 31155, 30910, 37040, 31123, 41608, 34281, 33458, 31740, 55282, 31155, 38307, 31983, 34319, 54591, 31123, 39807, 39508, 31674, 31902, 54664, 31740, 31123, 34110, 32190, 33588, 31155, 31925, 54622, 32814, 50444, 54537, 31123, 34318, 31937, 34897, 31903, 39396, 55282, 31514, 32469, 34329, 31844, 35122, 56645, 31155, 54725, 41487, 31903, 41230, 32316, 37316, 31123, 31803, 38149, 31123, 32817, 33514, 35263, 35341, 31802, 51965, 32402, 55282, 31155, 32194, 53125, 31123, 40207, 54701, 50444, 33748, 55370, 31123, 39661, 31155, 2]","[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 30910, 37040, 31123, 41608, 34281, 33458, 31740, 55282, 31155, 38307, 31983, 34319, 54591, 31123, 39807, 39508, 31674, 31902, 54664, 31740, 31123, 34110, 32190, 33588, 31155, 31925, 54622, 32814, 50444, 54537, 31123, 34318, 31937, 34897, 31903, 39396, 55282, 31514, 32469, 34329, 31844, 35122, 56645, 31155, 54725, 41487, 31903, 41230, 32316, 37316, 31123, 31803, 38149, 31123, 32817, 33514, 35263, 35341, 31802, 51965, 32402, 55282, 31155, 32194, 53125, 31123, 40207, 54701, 50444, 33748, 55370, 31123, 39661, 31155, 2]"
2,"[64790, 64792, 36474, 54591, 31123, 31869, 33071, 35367, 31514, 30910, 39661, 31123, 35398, 33071, 54657, 32884, 55282, 31155, 33057, 50165, 31123, 53128, 55771, 55771, 31123, 32805, 49495, 32729, 36804, 31155, 45360, 43324, 38493, 32693, 40657, 55282, 31514, 2]","[-100, -100, -100, -100, -100, -100, -100, -100, -100, 30910, 39661, 31123, 35398, 33071, 54657, 32884, 55282, 31155, 33057, 50165, 31123, 53128, 55771, 55771, 31123, 32805, 49495, 32729, 36804, 31155, 45360, 43324, 38493, 32693, 40657, 55282, 31514, 2]"
3,"[64790, 64792, 30910, 39661, 31123, 54546, 31869, 32056, 42917, 31123, 33149, 57350, 54738, 31123, 54868, 33149, 55450, 31155, 30910, 54835, 33519, 31123, 36731, 54622, 32483, 31624, 54657, 34697, 31123, 39229, 33485, 56389, 54668, 56123, 55491, 31123, 40322, 37972, 32024, 54868, 31123, 32043, 33168, 57149, 54578, 33503, 38425, 31123, 34933, 31897, 54591, 54727, 54530, 31155, 2]","[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 30910, 54835, 33519, 31123, 36731, 54622, 32483, 31624, 54657, 34697, 31123, 39229, 33485, 56389, 54668, 56123, 55491, 31123, 40322, 37972, 32024, 54868, 31123, 32043, 33168, 57149, 54578, 33503, 38425, 31123, 34933, 31897, 54591, 54727, 54530, 31155, 2]"
4,"[64790, 64792, 36474, 31717, 55398, 31123, 41608, 31897, 54622, 32103, 33115, 32566, 54631, 31155, 30910, 58070, 31123, 38505, 36778, 31876, 54536, 32566, 31627, 33115, 40895, 31123, 31828, 35094, 31820, 31818, 44393, 43963, 31822, 33115, 31155, 2]","[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 30910, 58070, 31123, 38505, 36778, 31876, 54536, 32566, 31627, 33115, 40895, 31123, 31828, 35094, 31820, 31818, 44393, 43963, 31822, 33115, 31155, 2]"
5,"[64790, 64792, 53456, 35416, 31749, 54652, 37909, 33893, 54948, 34317, 54537, 31123, 32044, 31643, 34628, 33764, 54537, 31123, 31894, 51688, 42001, 42425, 48726, 48046, 31155, 30910, 58070, 31123, 52029, 32131, 31897, 31123, 35416, 32536, 45520, 32436, 32696, 34317, 54542, 35550, 31123, 31772, 34992, 31820, 31676, 32088, 33075, 31155, 2]","[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 30910, 58070, 31123, 52029, 32131, 31897, 31123, 35416, 32536, 45520, 32436, 32696, 34317, 54542, 35550, 31123, 31772, 34992, 31820, 31676, 32088, 33075, 31155, 2]"
6,"[64790, 64792, 34211, 32483, 54701, 40120, 35574, 33450, 31123, 35323, 54948, 54661, 54537, 31404, 30910, 58147, 31404, 35574, 32967, 33114, 33085, 53757, 31810, 55282, 31123, 41236, 54622, 32300, 31688, 31642, 31850, 33634, 31514, 2]","[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 30910, 58147, 31404, 35574, 32967, 33114, 33085, 53757, 31810, 55282, 31123, 41236, 54622, 32300, 31688, 31642, 31850, 33634, 31514, 2]"
7,"[64790, 64792, 53456, 48525, 39534, 31809, 32697, 54530, 55551, 55323, 31123, 54591, 34311, 56645, 31404, 30910, 58070, 31123, 32664, 55551, 55323, 55282, 31514, 54929, 38953, 55370, 31404, 2]","[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 30910, 58070, 31123, 32664, 55551, 55323, 55282, 31514, 54929, 38953, 55370, 31404, 2]"
8,"[64790, 64792, 53779, 31624, 32483, 54657, 36858, 31123, 32507, 40296, 54539, 31155, 34211, 32185, 31822, 34697, 31123, 31624, 32967, 34778, 32774, 31669, 35933, 31155, 34607, 33085, 40167, 32187, 31796, 31638, 31862, 32721, 31635, 31123, 36545, 45342, 54622, 31155, 31925, 31123, 33021, 31937, 31674, 31902, 56645, 31123, 31844, 38921, 55433, 42917, 31155, 2]","[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 34211, 32185, 31822, 34697, 31123, 31624, 32967, 34778, 32774, 31669, 35933, 31155, 34607, 33085, 40167, 32187, 31796, 31638, 31862, 32721, 31635, 31123, 36545, 45342, 54622, 31155, 31925, 31123, 33021, 31937, 31674, 31902, 56645, 31123, 31844, 38921, 55433, 42917, 31155, 2]"
9,"[64790, 64792, 53456, 51879, 33149, 56750, 57763, 31123, 34110, 54574, 32056, 55079, 55014, 31155, 30910, 39661, 31123, 31844, 40914, 31404, 44104, 33485, 34649, 54819, 54802, 31404, 54622, 33671, 49682, 31123, 32192, 31627, 44762, 36310, 32321, 31155, 31844, 34855, 31123, 32248, 32553, 43146, 54747, 31123, 31855, 34022, 49665, 55187, 57204, 37915, 55555, 31123, 54688, 41406, 35087, 54563, 32122, 31123, 35805, 32523, 49060, 37900, 31155, 35094, 31820, 56024, 41107, 31123, 36280, 55379, 55176, 32176, 40102, 31123, 33485, 31996, 31155, 32192, 31627, 31123, 32192, 32010, 31920, 31123, 31749, 55283, 54718, 33533, 32088, 31404, 2]","[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 30910, 39661, 31123, 31844, 40914, 31404, 44104, 33485, 34649, 54819, 54802, 31404, 54622, 33671, 49682, 31123, 32192, 31627, 44762, 36310, 32321, 31155, 31844, 34855, 31123, 32248, 32553, 43146, 54747, 31123, 31855, 34022, 49665, 55187, 57204, 37915, 55555, 31123, 54688, 41406, 35087, 54563, 32122, 31123, 35805, 32523, 49060, 37900, 31155, 35094, 31820, 56024, 41107, 31123, 36280, 55379, 55176, 32176, 40102, 31123, 33485, 31996, 31155, 32192, 31627, 31123, 32192, 32010, 31920, 31123, 31749, 55283, 54718, 33533, 32088, 31404, 2]"


In [16]:
class Data_Collector:
    def __init__(self,pad_token_id:int,max_length:int=2048,ignore_lable_id:int=-100):
        self.pad_token_id=pad_token_id
        self.max_length=max_length
        self.ignore_lable_id=ignore_lable_id
    def __call__(self,batch_data):
        len_list=[len(i["input_ids"]) for i in batch_data]
        batch_max_len=max(len_list)
        input_ids,labels=[],[]
        for len_of_d,d in sorted(zip(len_list,batch_data),key=lambda x:-x[0]):
            pad_len=batch_max_len-len_of_d
            input_id=d["input_ids"]+[self.pad_token_id]*pad_len
            lable=d["labels"]+[self.ignore_lable_id]*pad_len
            if batch_max_len>self.max_length:
                input_id=input_ids[:self.max_length]
                label=lable[:self.max_length]
            input_ids.append(torch.LongTensor(input_id))
            labels.append(torch.LongTensor(lable))
        input_ids=torch.stack(input_ids)
        labels=torch.stack(labels)
        return {"input_ids":input_ids,"labels":labels}
data_collector=Data_Collector(pad_token_id=tokenizer.pad_token_id)


In [17]:
#加载模型
q_config=BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)
base_model=AutoModel.from_pretrained(base_model_path,quantization_config=q_config,device_map="auto",trust_remote_code=True)
base_model.supports_gradient_checkpointing = True
base_model.config.use_cache = False

kbit_model=prepare_model_for_kbit_training(base_model)
target_model=TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING['chatglm']

lora_config=LoraConfig(
    target_modules=target_model,
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias='none',
    inference_mode=False,
    task_type=TaskType.CAUSAL_LM
)
qlora_model=get_peft_model(kbit_model,lora_config)

Loading checkpoint shards: 100%|██████████| 7/7 [02:14<00:00, 19.17s/it]
You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.


In [18]:

#构建训练器
output_dir="phb/chatglm3-ft"
training_args=TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    learning_rate=1e-3,
    num_train_epochs=3,
    lr_scheduler_type="linear",
    warmup_ratio=0.1,
    logging_steps=1,
    save_strategy="steps",
    save_steps=10,
    optim="adamw_torch",
    fp16=True
)
trainer=Trainer(
    model=qlora_model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collector
)


Detected kernel version 4.19.24, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [19]:
trainer.train()


Step,Training Loss
1,3.4555
2,4.0157
3,3.4879
4,3.3771
5,3.0112
6,3.3898
7,3.0331
8,2.5396
9,2.846
10,2.7452


Checkpoint destination directory phb/chatglm3-ft/checkpoint-10 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory phb/chatglm3-ft/checkpoint-20 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory phb/chatglm3-ft/checkpoint-30 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory phb/chatglm3-ft/checkpoint-40 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory phb/chatglm3-ft/checkpoint-50 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory phb/chatglm3-ft/checkpoint-60 already exists and is non-empty.Saving will proceed but saved results may be invalid.


TrainOutput(global_step=66, training_loss=2.0720901615691907, metrics={'train_runtime': 70.2552, 'train_samples_per_second': 7.217, 'train_steps_per_second': 0.939, 'total_flos': 2193800282247168.0, 'train_loss': 2.0720901615691907, 'epoch': 3.0})

In [20]:
trainer.model.save_pretrained(output_dir)