In [None]:
################ before you do it, make sure you have the latest version of transformer,bitsandbytes,peft,accelerate properly installed 

import torch
import random
from datasets import Dataset
import bitsandbytes as bnb
import pandas as pd
from typing import Optional
import transformers
from transformers import AutoTokenizer, AutoModel,BitsAndBytesConfig,Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig, TaskType,prepare_model_for_int8_training
from peft.tuners.lora import LoraLayer






In [3]:
###### load model with bnb_4bit 
model = AutoModel.from_pretrained(
        "chatglm-6b",
        load_in_4bit=True,
        device_map='auto',
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        ),
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )
tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)

###### froze model weights before adding lora layers
model = prepare_model_for_int8_training(model)


Loading checkpoint shards: 100%|██████████| 8/8 [01:43<00:00, 12.99s/it]


In [4]:

########### identify which layers from the model are good to add lora layers for
def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit 
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


########### create lora layers and add them to relevant layers in the model

config = LoraConfig(
        r=64,
        lora_alpha=16,
        target_modules=find_all_linear_names(model),
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

model = get_peft_model(model, config)



In [11]:
#################### further adjust dtype for some layers in the model 
def update_somelayers_dtype(model):
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer): 
            module = module.to(torch.bfloat16)
        if 'norm' in name:
            module = module.to(torch.bfloat16)  ###### in the qlora library this is set to float32. Setting to bf16 may cause problems
        if 'lm_head' in name or 'embed_tokens' in name:
            if hasattr(module, 'weight'):
                if module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

update_somelayers_dtype(model)

In [None]:
###################### ok load my data. There are two columns: input, output, for example:
#####################  input                    output
#####################  hello                   how are you?
#####################  what's your name ?      Eion Mask

dataset = pd.read_csv('rabbit.csv')[['input','output']]
dataset = dataset.rename({'input':'context','output':'target'},axis=1)
dataset = Dataset.from_pandas(dataset)


###################### tokenize the input and output
def preprocess(example):
    prompt = example["context"]
    target = example["target"]
    prompt_ids = tokenizer.encode(prompt, max_length=1024, truncation=True)
    target_ids = tokenizer.encode(
        target, max_length=512, truncation=True, add_special_tokens=False
    )
    input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]
    example['input_ids'] = input_ids
    example['seq_len'] = len(prompt_ids)
    return example
   

tokenized_dataset = dataset.map(function=preprocess,remove_columns=['context','target'])

In [6]:
###################### this is the messy data collator
##################### to understand how it works, 
####################  you need to check the GLM paper about 
################### how attention masks,  position embeddings,etc., are organized

def get_masks_and_position_ids(
    seq, seq_len, context_length, device, gmask=False, position_encoding_2d=True
):
    mask_position = (
        seq_len - 2
    )  
    attention_mask = torch.ones((1, context_length, context_length), device=device)
    attention_mask.tril_()
    attention_mask[..., : mask_position - 1] = 1
    attention_mask = (attention_mask < 0.5).bool()

    if position_encoding_2d:
        seq_length = seq_len - 1  # is equal to `seq_length = seq.index(150004)`
        position_ids = torch.arange(context_length, dtype=torch.long, device=device)
        if not gmask:
            position_ids[seq_length:] = mask_position
        block_position_ids = torch.cat(
            (
                torch.zeros(seq_length, dtype=torch.long, device=device),
                torch.arange(
                    context_length - seq_length, dtype=torch.long, device=device
                )
                + 1,
            )
        )
        position_ids = torch.stack((position_ids, block_position_ids), dim=0)
    else:
        position_ids = torch.arange(context_length, dtype=torch.long, device=device)
        if not gmask:
            position_ids[context_length - 1 :] = mask_position
    return attention_mask, position_ids

def data_collator(features: list) -> dict:
    # print(features)
    len_ids = [len(feature["input_ids"]) for feature in features]
    longest = max(len_ids) + 1
    input_ids = []
    attention_mask_list = []
    position_ids_list = []
    labels_list = []
    for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
        ids = feature["input_ids"]
        seq_len = feature["seq_len"]
        labels = (
            [-100] * (seq_len - 1)
            + ids[(seq_len - 1) :]
            + [tokenizer.eos_token_id]
            + [-100] * (longest - ids_l - 1)
        )
        ids = ids + [tokenizer.eos_token_id] * (longest - ids_l)
        _ids = torch.LongTensor(ids)
        attention_mask, position_ids = get_masks_and_position_ids(
            ids, seq_len, longest, _ids.device, gmask=False
        )
        labels_list.append(torch.LongTensor(labels))
        input_ids.append(_ids)
        attention_mask_list.append(attention_mask)
        position_ids_list.append(position_ids)
    input_ids = torch.stack(input_ids)
    labels = torch.stack(labels_list)
    attention_mask = torch.stack(attention_mask_list)
    position_ids = torch.stack(position_ids_list)
    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": attention_mask,
        "position_ids": position_ids,
    }

In [9]:
#del dataset

In [14]:
# import gc
# gc.collect()
# torch.cuda.empty_cache()

In [9]:
################# I don't care evalution. So I set it eval_data to none. Pass your actual eval_data here. 
################# Do remember to set remove_unused_columns to false 
################# otherwise it keeps complaining that the seq_len feature is missing! 

outdir ='chatglm-rabbit'
trainer = transformers.Trainer(
    model=model,
    train_dataset=tokenized_dataset,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=32,
        warmup_steps=100,
        num_train_epochs=3,
        learning_rate=3e-4,
        logging_steps=100,
        optim="adamw_torch",
        save_strategy="steps",
        save_steps=250,
        output_dir=outdir,
        save_total_limit=3,
        evaluation_strategy= "no",
        eval_steps= None,
        report_to='none',
        remove_unused_columns=False,
    ),
    data_collator=data_collator
)


In [12]:
model.config.use_cache = False

trainer.train()

Step,Training Loss
100,1.9567
