<a href="https://colab.research.google.com/github/tahreemrasul/fine_tune_embedding_model_rag/blob/main/Finetune_Embedding_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Finetune Embedding Model for RAG Applications using 🤗

Embedding models are crucial for successful RAG applications, but they're often trained on general knowledge, which limits their effectiveness for company or domain specific adoption. Customizing embedding for your domain specific data can significantly boost the retrieval performance of your RAG Application. With the new release of Sentence Transformers 3, it's easier than ever to fine-tune embedding models.

In this tutorial, I'll show how to fine-tune an embedding model for RAG applications using a synthetic dataset. In the tutorial, we are going to:

1. Create & Prepare embedding dataset
2. Load pretrained model
3. Define loss function
4. Fine-tune embedding model with SentenceTransformersTrainer
5. Evaluate fine-tuned model


## 1. Install Packages

Before we begin, we need to install the required packages. This includes libraries for dataset handling, model training, and evaluation.

In [None]:
!pip install --upgrade --quiet transformers[torch] sentence-transformers datasets

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->transformers[torch])
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->transformers[torch])
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->transformers[torch])
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->transformers[torch])
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->transformers[torch])
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->transformers[torch])
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.

## 2. Load Dataset

An embedding dataset typically consists of text pairs (question, answer/context) or triplets that represent relationships or similarities between sentences. The dataset format you choose or have available will also impact the loss function you can use. Common formats for embedding datasets:

* **Positive Pair:** Text Pairs of related sentences (query, context | query, answer), suitable for tasks like similarity or semantic search, example datasets: `sentence-transformers/sentence-compression`, `sentence-transformers/natural-questions`.
* **Triplets:** Text triplets consisting of (anchor, positive, negative), example datasets `sentence-transformers/quora-duplicates`, `nirantk/triplets`.
* **Pair with Similarity Score:** Sentence pairs with a similarity score indicating how related they are, example datasets: `sentence-transformers/stsb`, `PhilipMay/stsb_multi_mt`

Learn more at [Dataset Overview](https://sbert.net/docs/sentence_transformer/dataset_overview.html).

We'll use the `datasets` library to load a pre-built dataset. The dataset used here is `sentence-transformers/all-nli`, which provides sentence triplets for training an embedding model. The triplet structure includes:

- **Anchor:** The original sentence or query.
- **Positive:** A correct or relevant response to the anchor.
- **Negative:** An incorrect or irrelevant response to the anchor.

Let's start by loading and exploring the dataset.

In [None]:
from datasets import load_dataset
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"].select(range(5000))
eval_dataset = dataset["dev"]
test_dataset = dataset["test"]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/5.15k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/38.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/782k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/810k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/557850 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/6584 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6609 [00:00<?, ? examples/s]

In [None]:
train_dataset

Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 5000
})

In [None]:
train_dataset.to_pandas()
# Anchor: The original sentence or query.
# Positive answer: A correct or relevant response to the anchor.
# Negative answer: An incorrect or irrelevant response to the anchor

Unnamed: 0,anchor,positive,negative
0,A person on a horse jumps over a broken down a...,"A person is outdoors, on a horse.","A person is at a diner, ordering an omelette."
1,Children smiling and waving at camera,There are children present,The kids are frowning
2,A boy is jumping on skateboard in the middle o...,The boy does a skateboarding trick.,The boy skates down the sidewalk.
3,Two blond women are hugging one another.,There are women showing affection.,The women are sleeping.
4,"A few people in a restaurant setting, one of t...",The diners are at a restaurant.,The people are sitting at desks in school.
...,...,...,...
4995,The people are outside.,People on ATVs and dirt bikes are traveling al...,A woman in a pink shirt is handing a bag to th...
4996,The people are outside.,People on ATVs and dirt bikes are traveling al...,A small group of adult males enjoy a conversat...
4997,The people are outside.,People on ATVs and dirt bikes are traveling al...,Two guys and one girl are sitting at a table i...
4998,The people are outside.,People on ATVs and dirt bikes are traveling al...,People sitting on black chairs on a bus.


## 3. Load Model

After we created our dataset we want to load and fine-tune our embedding model.

For our example, we will use the `BAAI/bge-base-en-v1.5` model as our starting point. `BAAI/bge-base-en-v1.5` is one of the strongest open embedding models for it size, with only 109M parameters and a hidden dimension of 768 it achieves 63.55 on the MTEB Leaderboard.

In [None]:
from sentence_transformers import SentenceTransformer

#popular embedding models:
#https://huggingface.co/nomic-ai/nomic-embed-text-v1
#https://huggingface.co/BAAI/bge-large-en

model = SentenceTransformer("BAAI/bge-base-en")

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/90.1k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/719 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

## 4. Setting up Training Arguments

In [None]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

When fine-tuning embedding models we select a loss function based on our dataset format. For Positive Text pairs we can use the `MultipleNegativesRankingLoss`. The `MultipleNegativesRankingLoss` is a great loss function if you only have positive pairs as it adds in batch negative samples to the loss function to have per sample `n-1` negative samples.

In [None]:
# 3. Define a loss function
loss = MultipleNegativesRankingLoss(model)

We will use a `SentenceTransformerTrainingArguments` class that allows us to specify all the training parameters.

In [None]:
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/bge-base-all-nli-triplet",
    # Optional training parameters:
    num_train_epochs=3,                         # number of epochs
    per_device_train_batch_size=32,             # train batch size
    gradient_accumulation_steps=16,             # for a global batch size of 512
    per_device_eval_batch_size=16,              # evaluation batch size
    warmup_ratio=0.1,                           # warmup ratio
    learning_rate=2e-5,                         # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                 # use constant learning rate scheduler
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    tf32=False,                                 # use tf32 precision if GPU throws no error
    bf16=True,                                  # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="epoch",                      # evaluate after each epoch
    save_strategy="epoch",                      # save after each epoch
    logging_steps=10,                           # log every 10 steps
    save_total_limit=3,                         # save only the last 3 models
    load_best_model_at_end=True,                # load the best model when training ends
    metric_for_best_model="eval_loss",  # Optimizing for the best ndcg@10 score for the 128 dimension
)

## 5. Train

We are now ready to fine-tune our model. We will use the `SentenceTransformersTrainer` a subclass of the `Trainer` from the transformers library, which supports all the same features, including logging, evaluation, and checkpointing.

In [None]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss
)

Start training our model by calling the `train()` method on our `SentenceTransformerTrainer` instance. This will start the training loop and train our model for 4 epochs.

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss
2,0.7304,0.720638


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

TrainOutput(global_step=27, training_loss=1.0126445028516982, metrics={'train_runtime': 120.6986, 'train_samples_per_second': 124.277, 'train_steps_per_second': 0.224, 'total_flos': 0.0, 'train_loss': 1.0126445028516982, 'epoch': 2.751592356687898})

## 6. Test

We evaluated our model during training, but we also want to evaluate it at the end. We use the same `TripletEvaluator` to evaluate the performance of our model on our created test_dataset.

In [None]:
from sentence_transformers.evaluation import TripletEvaluator

test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="all-nli-test",
)

test_evaluator(model)

{'all-nli-test_cosine_accuracy': 0.9378120744439401,
 'all-nli-test_dot_accuracy': 0.06218792555605992,
 'all-nli-test_manhattan_accuracy': 0.9355424421243759,
 'all-nli-test_euclidean_accuracy': 0.9378120744439401,
 'all-nli-test_max_accuracy': 0.9378120744439401}

## 7. Save the Model

Finally, push the model to HF.

In [None]:
from huggingface_hub import login

login(token="your_HF_key_here", add_to_git_credential=True)  # ADD YOUR TOKEN HERE

# 8. Save the trained model
model.save_pretrained("models/bge-base-all-nli-triplet/final")

# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub("bge-base-all-nli-triplet")

Token is valid (permission: fineGrained).
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.cache/huggingface/token
Login successful


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

'https://huggingface.co/trasul/bge-base-all-nli-triplet/commit/29965aa3720937fe532c3c3b27339d5f69c917cd'