# Fine-tuning text generation model

In [3]:
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline, 
    logging, 
    TextStreamer
)
from peft import (
    LoraConfig, 
    PeftModel, 
    prepare_model_for_kbit_training, 
    get_peft_model
)
import os
import torch
import wandb 
import platform
import warnings
from datasets import load_dataset
from trl import SFTTrainer


In [4]:
base_model = "mistralai/Mistral-7B-v0.1"
dataset_name, new_model = "gathnex/Gath_baize", "LEO_mistral_7b"

In [5]:
# Loading a Gath_baize dataset
dataset = load_dataset(dataset_name, split="train")
dataset["chat_sample"][0]

'The conversation between Human and AI assisatance named Gathnex [INST] Generate a headline given a content block.\nThe Sony Playstation 5 is the latest version of the console. It has improved graphics and faster processing power.\n[/INST] Experience Amazing Graphics and Speed with the New Sony Playstation 5'

In [6]:
# Configuration for quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

In [None]:
# Load base model(Mistral 7B)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map={"": 0}
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()

In [7]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.padding_side = 'right'
tokenizer.add_bos_token, tokenizer.add_eos_token

(True, True)

In [None]:
wandb.login(key = "your-key")
run = wandb.init(
    # set the wandb project where this run will be logged
    project="Local Mistral7B finetuning",
    job_type="training", 
    anonymous="allow",    
    # track hyperparameters and run metadata
    config={
    "learning_rate": 2e-4,
    "architecture": "LLM",
    "dataset": "gathnex/Gath_baize",
    "epochs": 1,
    }
)

In [9]:
model = prepare_model_for_kbit_training(model)

peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
    )

model = get_peft_model(model, peft_config)

In [None]:
# Training Arguments
# Hyperparameters should be adjusted based on the hardware you using
training_arguments = TrainingArguments(
    output_dir= "./results",
    num_train_epochs= 1,
    per_device_train_batch_size= 4,
    auto_find_batch_size =True,
    gradient_accumulation_steps= 2,
    optim = "paged_adamw_8bit",
    save_steps= 5000,
    logging_steps= 30,
    learning_rate= 5e-5,
    weight_decay= 0.001,
    fp16= False,
    bf16= False,
    max_grad_norm= 0.3,
    max_steps= -1,
    warmup_ratio= 0.3,
    group_by_length= True,
    lr_scheduler_type= "constant",
    report_to="wandb",
)
# Setting sft parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length= None,
    dataset_text_field="chat_sample",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)

In [None]:
# Run training
trainer.train()

# Save the fine-tuned model
trainer.model.save_pretrained(new_model)
model.config.use_cache = True
model.eval()

# Stop wandb
wandb.finish()