# How to Fine-Tune LLMs in 2024 with TRL

<a target="_blank" href="https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/beef48ac9a12718d62fa6180eaa46ebac5d825f1/training/fine-tune-gemma-with-trl-qlora.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

https://openincolab.com/

https://github.com/huggingface/agents-course/blob/main/notebooks/bonus-unit1/gemma-SFT-thinking-function_call.ipynb

Large Language Models or LLMs have seen a lot of progress in the last year. We went from now ChatGPT competitor to a whole zoo of LLMs, including Meta AI's [Llama 3](https://huggingface.co/blog/llama31), Mistrals [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) & [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) models, TII [Falcon](https://huggingface.co/tiiuae/falcon-40b), and many more. 
Those LLMs can be used for a variety of tasks, including chatbots, question answering, summarization without any additional training. However, if you want to customize a model for your application. You may need to fine-tune the model on your data to achieve higher quality results than prompting or saving cost by training smaller models more efficient model.

 
This blog post walks you thorugh how to fine-tune open LLMs using Hugging Face [TRL](https://huggingface.co/docs/trl/index) and [Transformers](https://huggingface.co/docs/transformers/index). You will learn how to:

1. What is Q-LoRA and system requirements to fine-tune Model
2. Setup the development environment
3. Create and prepare the dataset
4. Fine-tune LLM using `trl` and the `SFTTrainer` 
5. Test and evaluate the LLM

_Note: This blog was created to run on free Google colaboratory account using a NVIDIA T4 GPU with 16GB, but can be easily adapted to run on bigger GPUs and bigger models, see Memory requirements below._


## 1. What is Q-LoRA and system requirements to fine-tune Model

[Quantized Low-Rank Adaptation (QLoRA)](https://arxiv.org/abs/2305.14314) has emerged as popular method to efficiently fine-tuning LLMs as it dramatically reduces computational resource requirements while maintaining high performance. in Q-LoRA the pretrained model is quantized to 4-bit and the weights are frozen. Then trainable adapter layers (LoRA) are attached and only the adapter layers are trained. Afterwards the adapter weights can be merged with the base model or kept as a separate adapter. 

The memory efficiency of QLoRA makes fine-tuning accessible across various hardware configurations. 

| Model Size | Minimum VRAM Requirement |
| :-- | :-- |
| 1B | 5 GB |
| 4B | 12 GB |
| 12B | 24 GB |
| 27B | 40 GB |



## 2. Setup development environment

The first step is to install Hugging Face Libraries, including trl, transformers. If you haven't heard of trl yet, don't worry. It is a library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs. 


In [None]:
# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard

# Install Hugging Face libraries
%pip install  --upgrade \
  "transformers==4.49.0" \
  "datasets==3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.15.2" \
  "peft==0.14.0" 
  
# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, e.g. L4
#% pip install flash-attn

_Note: If you are using a GPU with Ampere architecture (e.g. NVIDIA L4) or newer you can use Flash attention. Flash Attention is a an method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. The TL;DR; accelerates training up to 3x. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main)._

Before we can start training we have to make sure that we accepted the termos of Model to be able to use it. You can accept the license on Hugging Face by clicking on the Agree and access repository button on the model page at:

After you have accepted the license you need a valid Hugging Face Token to access the model. If you are running inside a Google Colab we can securely use our Hugging Face Token using the Colab secrets otherwise you can set the token as environment variable by running `os.environ["HF_TOKEN"] = "YOUR_TOKEN"`

Make sure your token has write access too, as you will push your model to the hub during training.

In [None]:
import os 
from google.colab import userdata

# Set the token as environment variable
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN',"OR_ADD_YOUR_TOKEN_HERE")

## 3. Create and prepare the dataset

When fine-tuning LLMs, it is important you know your use case and the task you want to solve. This will help you create a dataset to fine-tune your model. If you haven't defined your use case yet. You might want to go back to the drawing board. 

As an example, this blog focus on the following use case:

> Fine-tune a model, which can generate SQL queries based on a natural language instruction, which can then be integrated into Data Analysis tool. The goal is to reduce the time it takes to create a SQL query and make it easier for non-technical users to create SQL queries. 

Text-to-SQL can be a good use case for fine-tuning LLMs, as it is a complex task that requires a lot of (internal) knowledge about the data and the SQL language. 

Once you have determined that fine-tuning is the right solution we need to create a dataset to fine-tune our model. The dataset should be a diverse set of demonstrations of the task(s) you want to solve. There are several ways to create such a dataset, including:
* Using existing open-source datasets, e.g., [Spider](https://huggingface.co/datasets/spider)
* Using LLMs to create synthetically datasets, e.g., [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)
* Using Humans to create datasets, e.g., [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k).
* Using a combination of the above methods, e.g., [Orca](https://huggingface.co/datasets/Open-Orca/OpenOrca)

Each of the methods has its own advantages and disadvantages and depends on the budget, time, and quality requirements. For example, using an existing dataset is the easiest but might not be tailored to your specific use case, while using humans might be the most accurate but can be time-consuming and expensive. It is also possible to combine several methods to create an instruction dataset, as shown in [Orca: Progressive Learning from Complex Explanation Traces of GPT-4.](https://arxiv.org/abs/2306.02707)


This example uses an already existing dataset ([gretelai/synthetic_text_to_sql](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql)), a high quality synthetic Text-to-SQL dataset including natural language instructions, schema definitions, reasoning and the corresponding SQL query.

[Hugging Face TRL](https://huggingface.co/docs/trl/en/index) supports automatic templating of  conversation dataset formats. This means we only need to convert our dataset into the right json objects and `trl` will take care of templating and putting it into the right format.

```json
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
```

The [gretelai/synthetic_text_to_sql](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql) contains over 100k samples to keep the example simple it downsamples the dataset to only 10,000 samples. 

You can now use the 🤗 Datasets library to load the dataset and create a prompt template to combine the natural language instruction, schema definition and add a system message for our assistant.

_Note: This step can be different for your use case. For example, if you have already a dataset from, e.g. working with OpenAI, you can skip this step and go directly to the fine-tuning step._

In [None]:
from datasets import load_dataset

# System message for the assistant 
system_message = """You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""

# User prompt that combines the user query and the schema
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>
"""
def create_conversation(sample):
  return {
    "messages": [
      {"role": "system", "content": system_message},
      {"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
      {"role": "assistant", "content": sample["sql"]}
    ]
  }  

# Load dataset from the hub
dataset = load_dataset("gretelai/synthetic_text_to_sql", split="train")
dataset = dataset.shuffle().select(range(12500))

# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)

print(dataset["train"][345]["messages"])

# save datasets to disk 
dataset["train"].to_json("train_dataset.json", orient="records")
dataset["test"].to_json("test_dataset.json", orient="records")

## 4. Fine-tune LLM using `trl` and the `SFTTrainer` 

You are now ready to fine-tune your model. Hugging Face TRL [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) makes it straightfoward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features, including:
* Dataset formatting, including conversational and instruction format
* Training on completions only, ignoring prompts
* Packing datasets for more efficient training
* PEFT (parameter-efficient fine-tuning) support including Q-LoRA
* Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)

First, load the prepared dataset from disk. 

In [1]:
from datasets import load_dataset

# Load jsonl data from disk
dataset = load_dataset("json", data_files="train_dataset.json", split="train")

This example uses Llama 3.1 8B, you can swap it out for another version by changing the `model_id` variable. We will use bitsandbytes to quantize our model to 4-bit.

_Note: Be aware the bigger the model the more memory it will require. In our example we will use the 8B version, which can be tuned on 24GB GPUs. If you have a smaller GPU._

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import setup_chat_format

# Hugging Face model id
model_id = "meta-llama/Meta-Llama-3.1-8B" # or `mistralai/Mistral-7B-v0.1`


# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# define model kwargs

model_kwargs = dict(
    attn_implementation="sdpa", # What attention implementation to use, alternative to flash_attention_2
    torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
    use_cache=False, # As we use gradient checkpointing
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None: 
    tokenizer.pad_token = tokenizer.eos_token

The `SFTTrainer` supports a native integration with `peft`, which makes it easy to efficiently tune LLMs using QLoRA. You only need to create a `LoraConfig` and provide it to the trainer.

In [3]:
from peft import LoraConfig

peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.05,
        r=16,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM", 
        lora_modules_to_save=["lm_head", "embed_tokens"] # you might need to change this for qwen or other models

)

Before you can start your training you need to define the hyperparameters (`SFTConfig`) you want to use.

In [4]:
from transformers import SFTConfig

args = SFTConfig(
    output_dir="code-llama-3-1-8b-text-to-sql", # directory to save and repository id
    max_seq_length=1024,  # max sequence length for model and packing of the dataset
    packing=True, # Groups multiple samples in the dataset into a single sequence 
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=4,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    fp16=True if torch_dtype == torch.float16 else False,   # use bfloat16 precision
    bf16=False if torch_dtype == torch.float16 else True,   # use tf32 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
)

We now have every building block we need to create our `SFTTrainer` to start then training our model.

In [None]:
from trl import SFTTrainer


trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    tokenizer=tokenizer,
    dataset_kwargs={
        "add_special_tokens": False,  # We template with special tokens
        "append_concat_token": False, # No need to add additional separator token
    }
)

Start training by calling the `train()` method.

In [None]:
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()

# save model the final model again to the hub
trainer.save_model()

### Merge LoRA adapter in to the original model

When using QLoRA, you only train adapters and not the full model. This means when saving the model during training you only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with serving Stacks like vLLM or TGI you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method. This will save a default model, which can be used for inference.

_Note: This requires > 30GB CPU Memory._

In [None]:
#### COMMENT IN TO MERGE PEFT AND BASE MODEL ####
# from peft import AutoPeftModelForCausalLM

# # Load PEFT model on CPU
# model = AutoPeftModelForCausalLM.from_pretrained(
#     args.output_dir,
#     torch_dtype=torch.float16,
#     low_cpu_mem_usage=True,
# )  
# # Merge LoRA and base model and save
# merged_model = model.merge_and_unload()
# merged_model.save_pretrained(args.output_dir,safe_serialization=True, max_shard_size="2GB")

## 4. Test Model and run Inference

After the training is done you'll want to evaluate and test your model. You can load different samples from the test dataset and evaluate the model on those samples.

_Note: Evaluating Generative AI models is not a trivial task since 1 input can have multiple correct outputs. This example only focuses on manual evaluation and vibe checks._

In [None]:
import torch
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM

model_id = "./code-llama-3-1-8b-text-to-sql"

# Load Model with PEFT adapter
model = AutoModelForCausalLM.from_pretrained(
  model_id,
  device_map="auto",
  torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

Let's load the test dataset and try to generate an instruction.

In [None]:
from datasets import load_dataset 
from random import randint
import re

# Load our test dataset
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
rand_idx = randint(0, len(eval_dataset))

# Test on sample 
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)

# Extract the user query and original answer
print(f"Query:\n{re.search(r'<USER_QUERY>(.*?)</USER_QUERY>', eval_dataset[rand_idx]['messages'][1]['content']).group(1)}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

Nice! Your model was able to generate a SQL query based on the natural language instruction. 


_Note: As mentioned above, evaluating generative models is not a trivial task. In this example, you can use the accuracy of the generated SQL based on the ground truth SQL query as your metric. An alternative way could be to automatically execute the generated SQL query and compare the results with the ground truth. This would be a more accurate metric but requires more work to setup._