# How to Fine-Tune Multimodal Models or VLMs with Hugging Face TRL

Multimodal LLMs are making tremendous progress recently. We now have a diverse ecosystem of powerful open Multimodal models, mostly Vision-Language Models (VLM), including Meta AI's [Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct), Mistral AI's [Pixtral-12B](https://huggingface.co/mistralai/Pixtral-12B-2409), Qwen's [Qwen2-VL-7B](https://huggingface.co/Qwen/Qwen2-VL-7B), and Allen AI's [Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924).

These VLMs can handle a variety of multimodal tasks, including image captioning, visual question answering, and image-text matching without additional training. However, to customize a model for your specific application, you may need to fine-tune it on your data to achieve higher quality results or to create a more efficient model for your use case.

This blog post walks you through how to fine-tune open VLMs using Hugging Face [TRL](https://huggingface.co/docs/trl/index), [Transformers](https://huggingface.co/docs/transformers/index) & [datasets](https://huggingface.co/docs/datasets/index) in 2024. We'll cover:

1. Define our multimodal use case 
1. Setup development environment
2. Create and prepare the multimodal dataset
3. Fine-tune VLM using `trl` and the `SFTTrainer` 
4. Test and evaluate the VLM

_Note: This blog was created to run on consumer size GPUs (24GB), e.g. NVIDIA A10G or RTX 4090/3090, but can be easily adapted to run on bigger GPUs._

## 1. Define our multimodal use case 

When fine-tuning VLMs, it's crucial to clearly define your use case and the multimodal task you want to solve. This will guide your choice of base model and help you create an appropriate dataset for fine-tuning. If you haven't defined your use case yet, you might want to revisit your requirements.

It's worth noting that for most use cases fine-tuning might not be the first option. We recommend evaluating pre-trained models or API-based solutions before committing to fine-tuning your own model.

As an example, we'll use the following multimodal use case:

> We want to fine-tune a model that can generate detailed product descriptions based on product images and basic metadata. This model will be integrated into our e-commerce platform to help sellers create more compelling listings. The goal is to reduce the time it takes to create product descriptions and improve their quality and consistency.

Existing models might already be very good for this use case, but you might want to tweak/tune it to your specific needs. This image-to-text generation task is well-suited for fine-tuning VLMs, as it requires understanding visual features and combining them with textual information to produce coherent and relevant descriptions. I created a test dataset for this use case using Gemini 1.5 [philschmid/amazon-product-descriptions-vlm](https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm).

## 2. Setup development environment

Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs. 


In [None]:
# Install Pytorch & other libraries
%pip install "torch==2.4.0" tensorboard pillow

# Install Hugging Face libraries
%pip install  --upgrade \
  "transformers==4.44.2" \
  "datasets==2.21.0" \
  "accelerate==0.33.0" \
  "evaluate==0.4.2" \
  "bitsandbytes==0.43.3" \
  "trl==0.9.6" \
  "peft==0.12.0" 

If you are using a GPU with Ampere architecture (e.g. NVIDIA A10G or RTX 4090/3090) or newer you can use Flash attention. Flash Attention is a an method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. The TL;DR; accelerates training up to 3x. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main).

_Note: If your machine has less than 96GB of RAM and lots of CPU cores, reduce the number of `MAX_JOBS`. On the `g6.2xlarge` we used `4`._

In [None]:
import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'
# install flash-attn
!pip install ninja packaging
!pip install flash-attn

_Installing flash attention can take quite a bit of time (10-45 minutes)._

We will use the [Hugging Face Hub](https://huggingface.co/models) as a remote model versioning service. This means we will automatically push our model, logs and information to the Hub during training. You must register on the [Hugging Face](https://huggingface.co/join) for this. After you have an account, we will use the `login` util from the `huggingface_hub` package to log into our account and store our token (access key) on the disk.



In [None]:
from huggingface_hub import login

login(
  token="", # ADD YOUR TOKEN HERE
  add_to_git_credential=True
)


## 3. Create and prepare the dataset

Once you have determined that fine-tuning is the right solution we need to create a dataset to fine-tune our model. We have to prepare the dataset in a format that the model can understand. 

In our example we will use [philschmid/amazon-product-descriptions-vlm](https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm), which contains 1,350 amazon products with title, images and descriptions and metadata. We want to fine-tune our model to generate product descriptions based on the images, title and metadata. Therefore we need to create a prompt including the title, metadata and image and the completion is the description. 

TRL supports popular instruction and conversation dataset formats. This means we only need to convert our dataset to one of the supported formats and `trl` will take care of the rest. 
```json
"messages": [
  {"role": "system", "content": [{"type":"text", "text": "You are a helpful...."}]},
  {"role": "user", "content": [{
    "type": "text", "text":  "How many dogs are in the image?", 
    "type": "image", "text": None 
    }]},
  {"role": "assistant", "content": [{"type":"text", "text": "There are 3 dogs in the image."}]}
],
```

In our example we are going to load our dataset using the Datasets library and apply our frompt and convert it into the the conversational format. 

Lets start with defining our instruction prompt. 

In [None]:
# note the image is not provided in the prompt its included as part of the "processor"
prompt= """Create a Short Product description based on the provided ##PRODUCT NAME## and ##CATEGORY## and image. 
Only return description. The description should be SEO optimized and for a better mobile search experience.

##PRODUCT NAME##: {product_name}
##CATEGORY##: {category}"""


Now, we can format our dataset.

In [None]:
from datasets import load_dataset

# Convert dataset to OAI messages
system_message = "You are an expert product description writer for Amazon."

def format_data(sample):
  return {
    "messages": [
      {"role": "system", "content": [{"type":"text", "text": system_message}]},
      {"role": "user", "content": [
        {"type": "text", "text":  prompt.format(product_name=sample["Product Name"], category=sample["Category"]),
         "type": "image", "text": None }
      ]},
      {"role": "assistant", "content": [{"type":"text", "text": sample["answer"]}]}
    ],
    "image": sample["image"],
  }  

# Load dataset from the hub
dataset_id = "philschmid/amazon-product-descriptions-vlm"
dataset = load_dataset("philschmid/amazon-product-descriptions-vlm", split="train")

# Convert dataset to OAI messages
remove_columns = dataset.features.keys() - {"image"} # keep only image column
dataset = dataset.map(format_data, remove_columns=remove_columns,batched=False)

print(dataset["train"][345]["messages"])

# save datasets to disk 
dataset["train"].to_json("train_dataset.json", orient="records")

## 4. Fine-tune VLM using `trl` and the `SFTTrainer` 

We are now ready to fine-tune our model. We will use the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) from `trl` to fine-tune our model. The `SFTTrainer` makes it straightfoward to supervise fine-tune open LLMs and VLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features, including:
* Dataset formatting, including conversational and instruction format
* Training on completions only, ignoring prompts
* Packing datasets for more efficient training
* PEFT (parameter-efficient fine-tuning) support including Q-LoRA
* Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)
 
We will use the PEFT features in our example. As peft method we will use [QLoRA](https://arxiv.org/abs/2305.14314) a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization. If you want to learn more about QLoRA and how it works, check out [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post.

Now, lets get started! 🚀 

First, we need to load our dataset from disk. 

In [None]:
from datasets import load_dataset

# Load jsonl data from disk
dataset = load_dataset("json", data_files="train_dataset.json", split="train")

Next, we will load our multimodal LLM. For our use case we are going to use Qwen 2 VL 7B model.
But we can easily swap out the model for another model, including Meta AI's [Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct), Mistral AI's [Pixtral-12B](https://huggingface.co/mistralai/Pixtral-12B-2409) or any other LLMs by changing our `model_id` variable. We will use bitsandbytes to quantize our model to 4-bit.

_Note: Be aware the bigger the model the more memory it will require. In our example we will use the 7B version, which can be tuned on 24GB GPUs._

Correctly, preparing the LLM, Tokenizer and Processor for training VLMs is crucial. The Processor is responsible for including the special tokens and image features in the input. 

In [None]:
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
from trl import setup_chat_format

# Hugging Face model id
model_id = "Qwen/Qwen2-VL-2B-Instruct" 

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
processor = AutoProcessor.from_pretrained(model_id)

The `SFTTrainer`  supports a native integration with `peft`, which makes it super easy to efficiently tune LLMs using, e.g. QLoRA. We only need to create our `LoraConfig` and provide it to the trainer. Our `LoraConfig` parameters are defined based on the [qlora paper](https://arxiv.org/pdf/2305.14314.pdf) and sebastian's [blog post](https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms).

In [None]:
from peft import LoraConfig

# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.05,
        r=8,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM", 
)

Before we can start our training we need to define the hyperparameters (`TrainingArguments`) we want to use and make sure our inputs are correcty provided to the model. Different to text-only supervised fine-tuning we need to provide the image to the model as well. Our template only signals the model that an image is present, but we need to provide the content as well. Therefore we create a custom `DataCollator` which formates the inputs correctly and include the image features. 

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="code-llama-3-1-8b-text-to-sql", # directory to save and repository id
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=8,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
)

# Create a data collator to encode text and image pairs
def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    images = [example["images"] for example in examples]

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  #
    # Ignore the image token index in the loss computation (model specific)
    image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
    labels[labels == image_token_id] = -100
    batch["labels"] = labels

    return batch

We now have every building block we need to create our `SFTTrainer` to start then training our model.

In [None]:
from trl import SFTTrainer

max_seq_length = 1024 # max sequence length for model and packing of the dataset, we keep it short as the description are short

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=processor.tokenizer,
    packing=True,
    dataset_kwargs={
        "add_special_tokens": False,  # We template with special tokens
        "append_concat_token": False, # No need to add additional separator token
    }
)

Start training our model by calling the `train()` method on our `Trainer` instance. This will start the training loop and train our model for 3 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model.

In [None]:
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()

# save model 
trainer.save_model(args.output_dir)

# push to hub
trainer.push_to_hub(dataset_name=dataset_id)
if trainer.accelerator.is_main_process:
    processor.push_to_hub(args.hub_model_id)

The training with Flash Attention for 3 epochs with a dataset of 10k samples took 02:05:58 on a `g6.2xlarge`. The instance costs `1,212$/h` which brings us to a total cost of only `1.8$`.

In [None]:
# free the memory again
del model
del trainer
torch.cuda.empty_cache()

## 4. Test Model and run Inference

After the training is done we want to evaluate and test our model. We will load different samples from the original dataset and evaluate the model on those samples, using a simple loop and accuracy as our metric. 

_Note: Evaluating Generative AI models is not a trivial task since 1 input can have multiple correct outputs. If you want to learn more about evaluating generative models, check out [Evaluate LLMs and RAG a practical example using Langchain and Hugging Face](https://www.philschmid.de/evaluate-llm) blog post._



In [None]:
import torch
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM

model_id = "./code-llama-3-1-8b-text-to-sql"

# Load Model with PEFT adapter
model = AutoModelForCausalLM.from_pretrained(
  model_id,
  device_map="auto",
  torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

Let’s load our test dataset try to generate an instruction.

In [None]:
from datasets import load_dataset 
from random import randint


# Load our test dataset
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
rand_idx = randint(0, len(eval_dataset))

# Test on sample 
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)

print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

Nice! Our model was able to generate a SQL query based on the natural language instruction. Lets evaluate our model on the full 2,500 samples of our test dataset. 
_Note: As mentioned above, evaluating generative models is not a trivial task. In our example we used the accuracy of the generated SQL based on the ground truth SQL query as our metric. An alternative way could be to automatically execute the generated SQL query and compare the results with the ground truth. This would be a more accurate metric but requires more work to setup._

## Bonus: Use TRL example script

TRL provides a simple example script to fine-tune multimodal models. You can find the script [here](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py). The script can be directly run from the command line and supports all the features of the `SFTTrainer`.
```bash
# Tested on 8x H100 GPUs
accelerate launch
    --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
    examples/scripts/sft_vlm.py \
    --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
    --model_name_or_path llava-hf/llava-1.5-7b-hf \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 8 \
    --output_dir sft-llava-1.5-7b-hf \
    --bf16 \
    --torch_dtype bfloat16 \
    --gradient_checkpointing
```