# Fine-tune ColPali 🛠️

[![Colab](https://img.shields.io/badge/Open_in_Colab-F9AB00?logo=googlecolab&logoColor=fff&style=for-the-badge)](https://colab.research.google.com/github/tonywu71/colpali-cookbooks/blob/main/examples/finetune_colpali.ipynb)
[![GitHub](https://img.shields.io/badge/ColPali_Cookbooks-100000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/tonywu71/colpali-cookbooks)
[![arXiv](https://img.shields.io/badge/arXiv-2407.01449-b31b1b.svg?style=for-the-badge)](https://arxiv.org/abs/2407.01449)
[![Hugging Face](https://img.shields.io/badge/Vidore-FFD21E?style=for-the-badge&logo=huggingface&logoColor=000)](https://huggingface.co/vidore)
[![X](https://img.shields.io/badge/Thread-%23000000?style=for-the-badge&logo=X&logoColor=white)](https://x.com/tonywu_71/status/1839281156811874515)

## Introduction

With our new model *ColPali*, we propose to leverage VLMs to construct efficient multi-vector embeddings in the visual space for document retrieval. By feeding the ViT output patches from PaliGemma-3B to a linear projection, we create a multi-vector representation of documents. We train the model to maximize the similarity between these document embeddings and the query embeddings, following the ColBERT method.

Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.

![ColPali Architecture](https://github.com/tonywu71/colpali-cookbooks/blob/main/assets/architecture/colpali_architecture.jpeg?raw=true)

The following notebook guide you through how to fine-tune ColPali to improve its retrieval performance for the data distribution of your use case. In particular, we will fine-tune ColPali on [VDSID-French](https://huggingface.co/datasets/vidore/vdsid_french), a French-language document retrieval dataset.

## What if I want to use my own documents to fine-tune ColPali?

If you are a company, you probably want to fine-tune ColPali on your own documents. But they probably lack the queries that are necessary to train a vision retrieval model... But fear not as Daniel van Strien has published an awesome [🤗 blog post](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html) on how to use VLMs to generate quality queries for your PDFs and to create a dataset that you can use for fine-tuning.

## Hardware Requirements

This notebook was tested on GCP VM with an A100-40GB GPU. I recommend seting this VM up with [SkyPilot](https://github.com/skypilot-org/skypilot) using this [config](https://github.com/tonywu71/colpali-cookbooks/blob/main/skypilot/a100/config.yaml). You should also be able to run it on a smaller GPU but you'll need a stronger quantization strategy and a smaller batch size.

In [1]:
# ==========================     USER INPUTS     ==========================

# Define the name used for the model you will push to the HuggingFace Hub.
# Leave it empty to disable pushing the model.
hf_pushed_model_name = "tonywu71/finetune_colpali_v1_2-vdsid_french-4bit"

# Define the name used for the WandB experiment. Leave it empty to disable WandB logging.
# In particular, leave it empty if you don't have a WandB account.
wandb_experiment_name = "finetune_colpali_v1_2-vdsid_french-4bit"

# =========================================================================

if not wandb_experiment_name:
    print("WandB logging is disabled.")

## Installation

This notebook leverages [`colpali-engine`](https://github.com/illuin-tech/colpali), the official implementation of ColPali. This package also contains the training code (processor, collator, trainer...) for fine-tuning ColPali on your own dataset.

In [2]:
!pip install -q -U "colpali-engine[train]>=0.3.0,<0.4.0"

## Login to a HuggingFace account

Because ColPali uses the [PaliGemma3B](https://huggingface.co/google/paligemma-3b-mix-448) checkpoints, you need to accept its terms and conditions before using it. Once accepted, use the following cell to login to your HuggingFace account.

In [3]:
!pip install -q -U huggingface_hub
from huggingface_hub import login

login()

## Login to Weight&Biases (optional)

You can use Weights&Biases to log the training process. This step is optional.

In [4]:
if wandb_experiment_name:
    !pip install -q -U wandb
    import wandb

    wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtonywu_71[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Imports

In [5]:
from pathlib import Path
from typing import cast

import torch
from colpali_engine.collators.visual_retriever_collator import VisualRetrieverCollator
from colpali_engine.loss import ColbertPairwiseCELoss
from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer
from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch
from datasets import DatasetDict, load_dataset
from peft import LoraConfig
from torch import nn
from transformers import BitsAndBytesConfig, TrainerCallback, TrainingArguments


def print_trainable_parameters(model: nn.Module) -> None:
    """
    Print the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params:,} || all params: {all_param:,} || trainable%: {100 * trainable_params / all_param}"
    )

## Choose a quantization strategy

ColPali is quite a large model with 3B parameters. While you can load the model and run inference on a L4 GPU or a M1+ Mac (in BF16, you'll need ≈6GB of VRAM), you will need much more VRAM to train it when taking into account the gradients and the AdamW optimizer states (≈48GB).

Therefore, we will use LoRA to limit the number of trainable parameters (like how the original ColPali was trained).

Even with LoRA, you might struggle to train ColPali on consumer GPUs. This is because of the contrastive loss used in ColPali: the larger the batch size, the more VRAM is used. Thus, we need to quantize the model to further reduce its memory footprint. Hence, we recommend using the 4-bit quantization with LoRA, i.e. QLoRA.

See this [🤗 blog post](https://huggingface.co/docs/transformers/main/en/quantization/overview) for more information on quantization and this [🤗 blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes) for more details on QLoRA.

In [6]:
# ==========================     USER INPUT     ==========================

QUANTIZATION_STRATEGY = "4bit"

# ========================================================================

# Automatically set the device
device = get_torch_device("auto")

if QUANTIZATION_STRATEGY and device != "cuda:0":
    raise ValueError("This notebook requires a CUDA GPU to use quantization.")

# Prepare quantization config
if QUANTIZATION_STRATEGY is None:
    bnb_config = None
elif QUANTIZATION_STRATEGY == "8bit":
    bnb_config = BitsAndBytesConfig(
        load_in_8bit=True,
    )
elif QUANTIZATION_STRATEGY == "4bit":
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
else:
    raise ValueError(f"Invalid quantization strategy: {QUANTIZATION_STRATEGY}")

## Load the pre-trained model

For simplicity, we will continue training the LoRA adapter from the original ColPali model.

In [7]:
# Pre-trained model name (with LoRA adapter)
model_name = "vidore/colpali-v1.2"

# Get the LoRA config from the pretrained model
lora_config = LoraConfig.from_pretrained(model_name)

# Load the model with the loaded pre-trained adapter
model = cast(
    ColPali,
    ColPali.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map=device,
    ),
)

if not model.active_adapters():
    raise ValueError("No adapter found in the model.")

# The LoRA weights are frozen by default. We need to unfreeze them to fine-tune the model.
for name, param in model.named_parameters():
    if "lora" in name:
        param.requires_grad = True

print_trainable_parameters(model)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 39,292,928 || all params: 1,766,287,216 || trainable%: 2.224605808390791


**Additional note:** It is possible to train a new LoRA adapter on top of the existing one. This allows to have extra flexibility in the LoRA config and to faciltate adapter hot-swapping.

To do this:

1. Load the version of ColPali with the adapter already merged with the base model weights: `model = ColPali.from_pretrained("vidore/colpali-v1.2-merged")`.
2. Define your LoRA adapter using `lora_config = LoraConfig(...)`.
3. Use `model = get_peft_model(model, lora_config)` to get the model with the new adapter.
4. Run fine-tuning as usual.

## Load the processor and the collator

In [8]:
if lora_config.base_model_name_or_path is None:
    raise ValueError("Base model name or path is required in the LoRA config.")

processor = cast(
    ColPaliProcessor,
    ColPaliProcessor.from_pretrained(model_name),
)
collator = VisualRetrieverCollator(processor=processor)

## Load the dataset

VDSID-French is a subset of the [`vidore/vdsid`](https://huggingface.co/datasets/vidore/vdsid). It contains 5000 document-question-answer triplet of French documets, split into a train set of 4700 examples and a test set of 300 examples.

This dataset was created and chosen for this fine-tuning because ColPali was mainly trained on English documents. Thus fine-tuning on French documents can help to improve the multilingual capabilities of the model.

In [9]:
# Load the dataset
dataset_name = "vidore/vdsid_french"
ds = cast(DatasetDict, load_dataset(dataset_name))

# Rename the columns to match the trainer's requirements
ds = ds.rename_column("page_image", "image")
ds["train"] = ds["train"].shuffle(seed=42)

ds

DatasetDict({
    train: Dataset({
        features: ['document_filename', 'document_url', 'search_query', 'search_topic', 'search_subtopic', 'search_language', 'search_filetype', 'page_number', 'page_description', 'page_language', 'page_contains_table', 'page_contains_figure', 'page_contains_paragraph', 'image', 'query_type', 'query_answerability', 'query_modality', 'query_language', 'query_reasoning', 'query', 'query_is_self_contained', 'query_is_self_contained_reasoning', 'answer'],
        num_rows: 4700
    })
    test: Dataset({
        features: ['document_filename', 'document_url', 'search_query', 'search_topic', 'search_subtopic', 'search_language', 'search_filetype', 'page_number', 'page_description', 'page_language', 'page_contains_table', 'page_contains_figure', 'page_contains_paragraph', 'image', 'query_type', 'query_answerability', 'query_modality', 'query_language', 'query_reasoning', 'query', 'query_is_self_contained', 'query_is_self_contained_reasoning', 'answer'],
   

## Define training args

Depending on your hardware, you might need to adjust the train batch size. This parameter is crucial for the training of ColPali as it is trained with a contrastive loss: the larger the batch size, the more representative the negative samples are, and the better the model will perform.

In [10]:
checkpoints_dir = Path("checkpoints")
checkpoints_dir.mkdir(exist_ok=True, parents=True)

training_args = TrainingArguments(
    output_dir=str(checkpoints_dir),
    hub_model_id=hf_pushed_model_name if hf_pushed_model_name else None,
    overwrite_output_dir=True,
    num_train_epochs=1.5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    gradient_checkpointing=False,
    eval_strategy="steps",
    save_steps=200,
    logging_steps=20,
    eval_steps=100,
    warmup_steps=100,
    learning_rate=5e-5,
    save_total_limit=1,
    report_to=["wandb"] if wandb_experiment_name else [],
)

## Create the trainer

The trainer uses a ColBERT contrastive hard-margin loss. Read the [ColPali paper](https://doi.org/10.48550/arXiv.2407.01449) for more details on this loss function.

In [11]:
class EvaluateFirstStepCallback(TrainerCallback):
    """
    Run eval after the first training step.
    Used to have a more precise evaluation learning curve.
    """

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step == 1:
            control.should_evaluate = True


trainer = ContrastiveTrainer(
    model=model,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    args=training_args,
    data_collator=collator,
    loss_func=ColbertPairwiseCELoss(),
    is_vision_model=True,
)

trainer.args.remove_unused_columns = False
trainer.add_callback(EvaluateFirstStepCallback())

## Evaluate the model before training

Let's see how ColPali performs on the test set prior to fine-tuning.

In [12]:
eval_results = trainer.evaluate()
eval_results



{'eval_loss': 0.04893038421869278,
 'eval_model_preparation_time': 0.0121,
 'eval_runtime': 41.4353,
 'eval_samples_per_second': 7.24,
 'eval_steps_per_second': 1.81}

## Fine-tune the model

Finally, time for fine-tuning! Run the following cell then go make yourself a cup of tea or take a walk while the model trains. 🚀 

On a A100-40GB GPU and with the default training parameters, fine-tuning should take around 30 minutes. ⏱️

In [13]:
# Prepare WandB logging
if wandb_experiment_name:
    wandb_tags = ["finetuning", "colpali"]

    if bnb_config:
        wandb_tags.append("quantization")

    run = wandb.init(
        project="colpali",
        name=wandb_experiment_name,
        job_type="finetuning",
        tags=wandb_tags,
        config={
            "model_name": model_name,
            "bitsandbytes_config": bnb_config.to_dict() if bnb_config else None,
            "dataset_name": dataset_name,
        },
    )

# Train the model
train_results = trainer.train()

train_results

VBox(children=(Label(value='0.032 MB of 0.032 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/loss,▁
eval/model_preparation_time,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/global_step,▁

0,1
eval/loss,0.04893
eval/model_preparation_time,0.0121
eval/runtime,41.4353
eval/samples_per_second,7.24
eval/steps_per_second,1.81
train/global_step,0.0


Step,Training Loss,Validation Loss,Model Preparation Time
1,No log,0.04893,0.0121
100,0.017100,0.021548,0.0121
200,0.010000,0.03012,0.0121
300,0.009000,0.031809,0.0121
400,0.012000,0.031864,0.0121




TrainOutput(global_step=440, training_loss=0.022981878132982688, metrics={'train_runtime': 1922.3275, 'train_samples_per_second': 3.667, 'train_steps_per_second': 0.229, 'total_flos': 1.0598129422848e+17, 'train_loss': 0.022981878132982688, 'epoch': 1.4978723404255319})

## Evaluate the model after training

Now, let's see how the fine-tuned model performs on the test set. You should observe a drop in the evaluation loss. 

To further evaluate the model, you can use the [`vidore-benchmark`](https://github.com/illuin-tech/vidore-benchmark) to measure the retrieval performance (e.g. NDCG@5).

In [14]:
eval_results = trainer.evaluate()
eval_results

{'eval_loss': 0.03164391964673996,
 'eval_model_preparation_time': 0.0121,
 'eval_runtime': 40.3053,
 'eval_samples_per_second': 7.443,
 'eval_steps_per_second': 1.861,
 'epoch': 1.4978723404255319}

During my own experiments, I got the following learning curves:

<p align="center"><img width=800 src="https://github.com/tonywu71/colpali-cookbooks/blob/main/assets/finetuning/learning_curves.jpeg?raw=true"/></p>

A few observations:

- The training loss is globally decreasing, which proves that the model is learning.
- The validation loss rapidly decreases for the first 100 steps, goes back up and starts plateauing after 200 steps. This is a sign of overfitting, so we probably should have stopped the training a bit earlier.

## Conclude the WandB run (optional)

In [15]:
if wandb_experiment_name:
    run.finish()

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/loss,█▁▃▄▄▄
eval/model_preparation_time,▁▁▁▁▁▁
eval/runtime,▅▂▁▃▁█
eval/samples_per_second,▄▇█▅█▁
eval/steps_per_second,▄▇█▅█▁
train/epoch,▁▁▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇████
train/global_step,▁▁▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇████
train/grad_norm,▂▁█▂▁▁▂▆▁▁▁▁▁▁▂▁▁▁▁▁▁▁
train/learning_rate,▂▄▅▇██▇▇▆▆▆▅▅▄▄▃▃▃▂▂▁▁
train/loss,▇█▆▃▃▂▄▇▅▂▃█▂▄▂▂▂▂▁▂▂▁

0,1
eval/loss,0.03164
eval/model_preparation_time,0.0121
eval/runtime,40.3053
eval/samples_per_second,7.443
eval/steps_per_second,1.861
total_flos,1.0598129422848e+17
train/epoch,1.49787
train/global_step,440.0
train/grad_norm,0.02536
train/learning_rate,0.0


## Push the model to the Hub (optional)

If satisfied with the fine-tuned model, you can push it to the Hub to share it with the community! 😍

You can find my fine-tuned ColPali model for reference at [`tonywu71/finetune_colpali_v1_2-vdsid_french-4bit`](https://huggingface.co/tonywu71/finetune_colpali_v1_2-vdsid_french-4bit). For more inspiration, check out the [🤗 Hf Hub](https://huggingface.co/models?other=base_model:finetune:vidore/colpaligemma-3b-pt-448-base) to see the ColPali models that the community has already fine-tuned.

In [16]:
if hf_pushed_model_name:
    trainer.push_to_hub(tags=["colpali"], dataset=dataset_name)



Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/157M [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.24k [00:00<?, ?B/s]

Congrats, you have successfully fine-tuned ColPali! 🎉

## Load your fine-tuned model (optional)

To use your freshly fine-tuned model, use the code from the following cell to load your own ColPali in your own project. 🫶🏼

In [17]:
# Unload the previous model and clean the GPU cache
del model
tear_down_torch()

# Load your fine-tuned ColPali
model = ColPali.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]