# How to Fine-Tune LLMs in 2024 with TRL

Large Language Models or LLMs have seen a lot of progress in the last year. We went from now ChatGPT competitor to a whole zoo of LLMs, including Google's [Gemma](https://huggingface.co/blog/gemma), Meta AI's [Llama 2](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf), Mistrals [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) & [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) models, TII [Falcon](https://huggingface.co/tiiuae/falcon-40b), and many more.
Those LLMs can be used for a variety of tasks, including chatbots, question answering, summarization without any additional training. However, if you want to customize a model for your application. You may need to fine-tune the model on your data to achieve higher quality results than prompting or saving cost by training smaller models more efficient model.


This notebook walks you thorugh how to fine-tune open LLMs using Hugging Face [TRL](https://huggingface.co/docs/trl/index), [Transformers](https://huggingface.co/docs/transformers/index) & [datasets](https://huggingface.co/docs/datasets/index) in 2024. In the notebook, we are going to:

1. Define our use case
2. Setup development environment
3. Create and prepare the dataset
4. Fine-tune LLM using `trl` and the `SFTTrainer`
5. Test and evaluate the LLM
6. Deploy the LLM for Production

_Note: This blog was created to run on consumer size GPUs (16GB), e.g. NVIDIA T4, free from Google Colab, but can be easily adapted to run on bigger GPUs._


## 1. Define our use case

When fine-tuning LLMs, it is important you know your use case and the task you want to solve. This will help you to choose the right model or help you to create a dataset to fine-tune your model. If you haven't defined your use case yet. You might want to go back to the drawing board.
I want to mention that not all use cases require fine-tuning and it is always recommended to evaluate and try out already fine-tuned models or API-based models before fine-tuning your own model.

As an example, we are going to use the following use case:

> We want to fine-tune the [Google Gemma 2 2B](https://huggingface.co/blog/gemma-july-update) model, which was primarily trained with English corpus, so that it can follow Japanese instructions better.


## 2. Setup development environment

Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs.


In [None]:
# Install Pytorch & other libraries
!pip install --upgrade torchvision torch tensorboard

# Install Hugging Face libraries
!pip install --upgrade \
  transformers \
  datasets \
  accelerate \
  evaluate \
  bitsandbytes

# Install trl
!pip install --upgrade peft trl

If you are using a GPU with Ampere architecture (e.g. NVIDIA A10G or RTX 4090/3090) or newer you can use Flash attention. Flash Attention is a method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. The accelerates training up to 3x. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main).

_Note: If your machine has less than 96GB of RAM and lots of CPU cores, reduce the number of `MAX_JOBS`. On the `g5.2xlarge` we used `4`._

In [None]:
# import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'
# # install flash-attn
# !pip install ninja packaging
# !MAX_JOBS=4 pip install flash-attn --no-build-isolation

_Installing flash attention can take quite a bit of time (10-45 minutes)._

Accessing the Google Gemma model on Hugging Face requires login and consent, so make sure to put in the right token.

In [None]:
from huggingface_hub import login
from google.colab import userdata
import os

HUGGINGFACE_TOKEN = userdata.get('HUGGINGFACE_TOKEN')

login(
  token=HUGGINGFACE_TOKEN, # ADD YOUR TOKEN HERE
  add_to_git_credential=True
)

## 3. Create and prepare the dataset

Once you have determined that fine-tuning is the right solution we need to create a dataset to fine-tune our model. The dataset should be a diverse set of demonstrations of the task you want to solve. There are several ways to create such a dataset, including:
* Using existing open-source datasets, e.g., [Spider](https://huggingface.co/datasets/spider)
* Using LLMs to create synthetically datasets, e.g., [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)
* Using Humans to create datasets, e.g., [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k).
* Using a combination of the above methods, e.g., [Orca](https://huggingface.co/datasets/Open-Orca/OpenOrca)

Each of the methods has its own advantages and disadvantages and depends on the budget, time, and quality requirements. For example, using an existing dataset is the easiest but might not be tailored to your specific use case, while using humans might be the most accurate but can be time-consuming and expensive. It is also possible to combine several methods to create an instruction dataset, as shown in [Orca: Progressive Learning from Complex Explanation Traces of GPT-4.](https://arxiv.org/abs/2306.02707)

In our case the Google Gemma was pre-trained with multilingual corpus but during the instruction tuning phase it was all done with English. The pretrained model does preserve some capabilities to understand and output Japanese, but it's clearly not good enough. So the idea is to use more Japanese instruction/answer pairs to finetune the pretrained model so that the model can better follows Japanese instructions. Note that the [instruction tuned version of Google Gemma 2](https://huggingface.co/google/gemma-2-2b-it) also has basic capability to follow instructions in Japanese, but we are not going to use the instruction tuned version.

We will use a Japanese dataset called [Mustain/JapaneseQADataset](https://huggingface.co/datasets/Mustain/JapaneseQADataset), which contains a bunch of question and human answers over several subjects.

With the latest release of `trl` it now supports popular instruction and conversation dataset formats. This means we only need to convert our dataset to one of the supported formats and `trl` will take care of the rest. Those formats include:
* conversational format
```json
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
```
* instruction format

```json
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
```

In our example we are going to load our open-source dataset using the Hugging Datasets library and then convert it into the the conversational format, where we include the schema definition in the system message for our assistant. We'll then save the dataset as jsonl file, which we can then use to fine-tune our model. We are randomly downsampling the dataset to only 10,000 samples.

_Note: This step can be different for your use case. For example, if you have already a dataset from, e.g. working with OpenAI, you can skip this step and go directly to the fine-tuning step._

In [None]:
from datasets import load_dataset

# Convert dataset to OAI messages
system_message = """あなたは知識豊富な人工知能アシスタントです。ユーザーは日本語であなたに質問をし、あなたは自身の知識に基づいて日本語で誠実に回答します。"""

def create_conversation(sample):
    # Split the text into Human and assistant parts
    parts = sample['text'].split('### ')

    # Extract the user question (Human part)
    user_content = parts[1].split(': ', 1)[1].strip()

    # Extract the assistant answer
    assistant_content = parts[2].split(': ', 1)[1].strip()

    return {
        "messages": [
            {"role": "system", "content": f'{system_message}'},
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": assistant_content}
        ]
    }

# Load dataset from the hub
dataset_dict = load_dataset("Mustain/JapaneseQADataset")
dataset = dataset_dict['train']
dataset = dataset.shuffle(seed=42)
dataset = dataset.select(range(10000))

print(create_conversation(dataset[0]))

# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, batched=False)

# Split dataset into train and test
dataset = dataset.train_test_split(test_size=0.1)

# save datasets to disk
dataset["train"].to_json("train_dataset.json", orient="records")
dataset["test"].to_json("test_dataset.json", orient="records")

### Test the non-instruction tuned model

Let's do a test run on the non-tuned model and see what we get:

In [None]:
from datasets import load_dataset
from random import randint
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

model_id = "google/gemma-2-2b" # or bigger base models

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load our test dataset
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")

# Test on sample
rand_idx=12 # Store index to compare output with fine-tuned model later
prompt = eval_dataset[rand_idx]["messages"][:2][0]['content'] + eval_dataset[rand_idx]["messages"][:2][1]['content']
print(prompt)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)

print(f"Query:\n{prompt}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

Obviously this answer is not what we are looking for.

In [None]:
del model, tokenizer, pipe # free memory

## 4. Fine-tune LLM using `trl` and the `SFTTrainer`

We are now ready to fine-tune our model. We will use the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) from `trl` to fine-tune our model. The `SFTTrainer` makes it straightfoward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features, including:
* Dataset formatting, including conversational and instruction format
* Training on completions only, ignoring prompts
* Packing datasets for more efficient training
* PEFT (parameter-efficient fine-tuning) support including Q-LoRA
* Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)

We will use the dataset formatting, packing and PEFT features in our example. As peft method we will use [QLoRA](https://arxiv.org/abs/2305.14314) a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization. If you want to learn more about QLoRA and how it works, check out [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post.

Now, lets get started! 🚀

First, we need to load our dataset from disk.

In [None]:
from datasets import load_dataset

# Load jsonl data from disk
dataset = load_dataset("json", data_files="train_dataset.json", split="train")

Next, we will load our LLM. For our use case we are going to use Gemma 2 2B. Gemma is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models.
But we can easily swap out the model for another model, e.g. [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) or [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) models, TII [Falcon](https://huggingface.co/tiiuae/falcon-40b), or any other LLMs by changing our `model_id` variable. We will use bitsandbytes to quantize our model to 4-bit.

_Note: Note: Be aware that the larger the model, the more memory it will require. In our example, we will use the 2B version, which can be fine-tuned on GPUs with 24GB of VRAM. If you have a GPU with less memory, you may need to use alternative techniques or a smaller model._


Correctly, preparing the LLM and Tokenizer for training chat/conversational models is crucial. We need to add new special tokens to the tokenizer and model and teach to understand the different roles in a conversation. In `trl` we have a convinient method called [setup_chat_format](https://huggingface.co/docs/trl/main/en/sft_trainer#add-special-tokens-for-chat-format), which:
* Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
* Resizes the model’s embedding layer to accommodate the new tokens.
* Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from trl import setup_chat_format

# Hugging Face model id
model_id = "google/gemma-2-2b" # or bigger base models

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    # attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = 'right' # to prevent warnings

# set chat template to OAI chatML, remove if you start from a fine-tuned model
model, tokenizer = setup_chat_format(model, tokenizer)



The `SFTTrainer`  supports a native integration with `peft`, which makes it super easy to efficiently tune LLMs using, e.g. QLoRA. We only need to create our `LoraConfig` and provide it to the trainer. Our `LoraConfig` parameters are defined based on the [qlora paper](https://arxiv.org/pdf/2305.14314.pdf) and sebastian's [blog post](https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms).

In [None]:
from peft import LoraConfig

# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
        lora_alpha=128,
        lora_dropout=0.05,
        r=256,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)

Before we can start our training we need to define the hyperparameters (`TrainingArguments`) we want to use.

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="gemma-japanese",            # directory to save and repository id
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    bf16=True,                              # use bfloat16 precision if you have supported GPU
    tf32=True,                              # use tf32 precision if you have supported GPU
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
)

We now have every building block we need to create our `SFTTrainer` to start then training our model.

In [None]:
from trl import SFTTrainer

max_seq_length = 1024 # max sequence length for model and packing of the dataset

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    dataset_kwargs={
        "add_special_tokens": False,  # We template with special tokens
        "append_concat_token": False, # No need to add additional separator token
    }
)

Start training our model by calling the `train()` method on our `Trainer` instance. This will start the training loop and train our model for 3 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model.

In [None]:
from datetime import datetime

# Get and print the current time before training
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"Training started at: {current_time}")

# start training, the model will be automatically saved to the hub and the output directory
trainer.train()

# save model
trainer.save_model()

It took ~3.5 hrs on the free T4 GPU in Google Colab to finish training.

In [None]:
# free the memory again
del model
del trainer
torch.cuda.empty_cache()

### Merge LoRA adapter in to the original model_

When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method. This will save a default model, which can be used for inference.

_Note: You might require > 30GB CPU Memory._

In [None]:
### COMMENT IN TO MERGE PEFT AND BASE MODEL ####
from peft import AutoPeftModelForCausalLM

# Load PEFT model on CPU
model = AutoPeftModelForCausalLM.from_pretrained(
    args.output_dir,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)
# Merge LoRA and base model and save
merged_model = model.merge_and_unload()
merged_model.save_pretrained(args.output_dir,safe_serialization=True, max_shard_size="2GB")

## 4. Test Model and run Inference

After the training is done we want to evaluate and test our model. We will load different samples from the original dataset and evaluate the model on those samples, using a simple loop and accuracy as our metric.

_Note: Evaluating Generative AI models is not a trivial task since 1 input can have multiple correct outputs.


In [None]:
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline

peft_model_id = args.output_dir

# Load Model with PEFT adapter
model = AutoPeftModelForCausalLM.from_pretrained(
  peft_model_id,
  device_map="auto",
  torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

Let’s load our test dataset try to generate an instruction.

In [None]:
from datasets import load_dataset
from random import randint

# Test on sample
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)

print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

Now our model is able to follw instructions in Japanese much better.

## 5. Next step

In this tutorial we have only used a small dataset to finetune the Gemma 2 2B model. It is by no means enough to make it perfect. Also the 2B model is small in size so don't expect it to do all kind of magic for you. But the point is that you understand the process and can:

1. Scale it up (more data, more compute, bigger model Gemma 2 9B to build a more powerful model *or*
2. Add domain data so that the model can do well in your specific area of interest