# Fine-tuning custom embedding models on synthetic data

Bootstrapping and maintaining production-ready RAG pipelines, requires optimizaiton various components like the LLM, vector database, embeddings and rerankers. Within this notbeook we will showcase how you can optimize and maintain your embedding models through synthetic data and human feedback. Besides ZenML, we will do this by using two open source libraries: `argilla` and `distilabel`. Both of these libraries focus optimizing model outputs through improving data quality, however, each one of them take a diferent approach to tackle the same problem. `distilabel` provides a scalable and reliable approach to distilling knowledge from LLMs by generating synthetic data or providing AI feedback with LLMs as judges. `argilla` enables AI engineers and domain experts to collaborate on data projects by allowing them to organize and explore data through within an interactive and engagig UI. Both libraries can be used individually but they work better together.

- ⚗️ distilabel is a framework for synthetic data and AI feedback - [docs](distilabel.argilla.io)
- Argilla is a collaboration tool for AI engineers and domain experts to build high-quality datasets - [docs](docs.argilla.io)

## The dataset - vibe check

Before starting any project, it is always important to look at your data. Our data is publicly [available on the Hugging Face Hub](https://huggingface.co/datasets/zenml/rag_qa_embedding_questions_0_60_0) so we can have a quick look through their dataset viewer within an embedded iFrame. 

As we can see, our dataset contains a column called `page_content`, which was obtained from the ZenML docs.

TODO: add context about `page_content`, chunking etc. Could we look into semantic chunking?

<iframe src="https://huggingface.co/datasets/zenml/rag_qa_embedding_questions_0_60_0/embed/viewer" frameborder="0" width="100%" height="560px"></iframe>

Alternatively, we can load the entire dataset to disk with `datasets.load_dataset`. There is only a single split (`train`), which we will provide as argument to the function.

In [None]:
from datasets import load_dataset

repo_name = "zenml/rag_qa_embedding_questions_0_60_0"

dataset = load_dataset(repo_name, split="train")

dataset

## Generate synthetic query generation with `distilabel`

The [`GenerateSentencePair`](https://distilabel.argilla.io/latest/components-gallery/tasks/generatesentencepair/) component from `distilabel` that can be used to generate training datasets for embeddings models. It is a pre-defined `Task` that given an `anchor` sentence generates a `positive` sentence related to the anchor. We will also generate unrelated `negative` sentences by passing `triplet=True` and we will also provide a `context` to guide the LLM towards more specific behavior. 

Additionally, we will use the [`OpenAILLM`](https://distilabel.argilla.io/latest/components-gallery/llms/openaillm/) with `gpt4o` and [`LoadDataFromHub`](https://distilabel.argilla.io/latest/components-gallery/steps/loaddatafromhub/) to load [our dataset](https://huggingface.co/datasets/zenml/rag_qa_embedding_questions_0_60_0).

In our case, we will use the `page_content` column from our dataset as `anchor` to generate `positive` and `negatives` sentences that function as training data for the embedding model.

Now, let's capture this logic in a `distilabel` `Pipeline`!

In [None]:
import os

from distilabel.steps.tasks import GenerateSentencePair
from distilabel.llms import OpenAILLM
from distilabel.steps import LoadDataFromHub
from distilabel.pipeline import Pipeline

# TODO: I think we might optimize this a bit more.

context = (
"""
The text is a chunk from technical documentation of ZenML.
ZenML is an MLOps + LLMOps framework that makes your infrastructure and workflow metadata accessible to data science teams.
Along with prose explanations, the text chunk may include code snippets and logs but these are identifiable from the surrounding backticks.
"""
)

llm = OpenAILLM(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY"))

with Pipeline(name="generate_embedding_queries") as pipeline:
    load_dataset = LoadDataFromHub(
        num_examples=10,  # uncomment this for demo purposes
        output_mappings={"page_content": "anchor"},
    )
    generate_sentence_pair = GenerateSentencePair(
        triplet=True,  # `False` to generate only positive
        action="query",
        llm=llm,
        input_batch_size=10,
        context=context,
    )

    load_dataset >> generate_sentence_pair

Next, we can execute this using `pipeline.run`. We will provide some `parameters` to specific components within our pipeline.

In [None]:
distiset = pipeline.run(  #
    parameters={
        load_dataset.name: {
            "repo_id": "zenml/rag_qa_embedding_questions_0_60_0",
            "split": "train",
        },
        generate_sentence_pair.name: {
            "llm": {
                "generation_kwargs": {
                    "temperature": 0.7,
                    "max_new_tokens": 512,
                }
            }
        },
    },
)

Let's vibe check our data again. If you are not happy with the results you can either tweak our `parameters` or optimize the `context` prompt which is passed to the LLM.

In [None]:
from rich import print

example = distiset["default"]["train"][9]
del example["embedding"]
print(example)

### (Optional) Push the distiset to the Hugging Face Hub

Synthetic data generation can be expensive becuae of the reliance on LLMs, so first store our data on the Hub.

In [None]:
distiset.push_to_hub(
    repo_id="zenml/rag_qa_embedding_questions_0_60_0",
    token=os.getenv("HUGGINGFACE_API_KEY"),
    create_pr=True, # TODO: why?
)

### (Optional) Review synthetic query generation with `argilla` 

Data is never as clean as it can be and this also holds for synthetically generated data, therefore, it is always good to spent some time and look at your data. We will used Argilla to do this. If you are not familiar with Argilla, we recommend taking a look at the [Argilla quickstar docs](https://docs.argilla.io/latest/getting_started/quickstart/). Alternatively, you can use your Hugging Face account to login to the [Argilla demo Space](https://argilla-argilla-template-space.hf.space).

To start exploring data, we first need to define an `argilla.Dataset`. We will create a basic datset with some input `TextFields` for the `anchor` and output `TextQuestions` for the `positive` and `negative` pairs. Additionally, we will use the `parent_section` and `token_count` as `MetaDataProperty` and we will be adding some vectors to allow for semantic search. Finally, we will also compute current similarities for embeddings of `anchor-positive`, `positive-negative` and `anchor-negative` pairs.

In [None]:
dataset = load_dataset("zenml/rag_qa_embedding_questions_0_60_0", split="train")

In [None]:
from sentence_transformers import SentenceTransformer
import torch

model_id = "sentence-transformers/all-MiniLM-L6-v2"  # Hugging Face model ID

model = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

In [None]:
import argilla as rg

settings = rg.Setting(
    fields=[
        rg.TextField("anchor")
    ],
    questions=[
        rg.TextQuestion("positive"),
        rg.TextQuestion("negative")
    ],
    metadata=[
        rg.TermsMetadataProperty("parent_section"),
        rg.IntegerMetadataProperty("token_count"),
        rg.FloatMetadataProperty("similarity-positive-negative"),
        rg.FloatMetadataProperty("similarity-anchor-positive"),
        rg.FloatMetadataProperty("similarity-anchor-negative"),
    ],
    vectors=[
        rg.VectorField("anchor-vector", dimensions=model.get_sentence_embedding_dimension())
    ]
)
ds = rg.Dataset(
    name="rag_qa_embedding_questions_0_60_0",
    settings=settings
)
ds.create()

Next, we will process the original Hugging Face dataset.

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

def format_data(batch):
    def get_embeddings(batch_column):
        vectors = model.encode(batch_column)
        return [vector.tolist() for vector in vectors]
    batch["anchor-vector"] = get_embeddings(batch["anchor"])
    batch["positive-vector"] = get_embeddings(batch["positive"])
    batch["negative-vector"] = get_embeddings(batch["negative"])

    def get_similarities(a, b):
        similarities = []
        
        for pos_vec, neg_vec in zip(a, b):
            similarity = cosine_similarity([pos_vec], [neg_vec])[0][0]
            similarities.append(similarity)
    
    batch["similarity-positive-negative"] = get_similarities(batch["positive-vector"], batch["negative-vector"])
    batch["similarity-anchor-positive"] = get_similarities(batch["anchor-vector"], batch["positive-vector"])
    batch["similarity-anchor-negative"] = get_similarities(batch["anchor-vector"], batch["negative-vector"])
    return batch

dataset = dataset.map(format_data, batched=True, batch_size=1000)

Lastly, we will log the records to Argilla.

In [None]:
records = []
for idx, entry in enumerate(dataset):
    records.append(
        rg.Record(
            id=idx,
            fields={"achor": entry["anchor"]},
            suggestions=[
                rg.Suggestion("positive", value=entry["positive"]),
                rg.Suggestion("negative", value=entry["negative"]),
            ],
            metadata={
                "parent_section": entry["parent_section"],
                "token_count": entry["token_count"],
                "similarity-positive-negative": entry["similarity-positive-negative"],
                "similarity-anchor-positive": entry["similarity-anchor-positive"],
                "similarity-anchor-negative": entry["similarity-anchor-negative"]
            },
            vectors={"question-vector": entry["question-vector"]}
        )
    )
ds.records.log(records)

Now we can explore the UI and filter out the bad apples. Tip, start filtering on high similarity of 'similarity-anchor-negative' or 'similarity-positive-negative' and low similarity of 'similarity-anchor-positive'.

# Prepare the embedding dataset

Follows [Phil Schmid's tutorial](https://www.philschmid.de/fine-tune-embedding-model-for-rag#5-evaluate-fine-tuned-model-against-baseline) fairly heavily.

In [2]:
from datasets import load_dataset

# Load dataset from the hub
dataset = load_dataset("zenml/rag_qa_embedding_questions_0_60_0", split="train")

# Add an id column to the dataset
dataset = dataset.add_column("id", range(len(dataset)))

# split dataset into a 10% test set
dataset = dataset.train_test_split(test_size=0.1)

train_dataset_path = "../data/train_dataset.json"
test_dataset_path = "../data/test_dataset.json"
# save datasets to disk
dataset["train"].to_json(train_dataset_path, orient="records")
dataset["test"].to_json(test_dataset_path, orient="records")

Creating json from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

1395773

# Create baseline + evaluate pretrained model

In [3]:
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets

model_id = "sentence-transformers/all-MiniLM-L6-v2"  # Hugging Face model ID
matryoshka_dimensions = [384, 256, 128, 64]  # Important: large to small

# Load a model
model = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

# load test dataset
test_dataset = load_dataset("json", data_files=test_dataset_path, split="train")
train_dataset = load_dataset("json", data_files=train_dataset_path, split="train")
corpus_dataset = concatenate_datasets([train_dataset, test_dataset])

# Convert the datasets to dictionaries
corpus = dict(
    zip(corpus_dataset["id"], corpus_dataset["positive"])
)  # Our corpus (cid => document)
queries = dict(
    zip(test_dataset["id"], test_dataset["anchor"])
)  # Our queries (qid => question)

# Create a mapping of relevant document (1 in our case) for each query
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for q_id in queries:
    relevant_docs[q_id] = [q_id]


matryoshka_evaluators = []
# Iterate over the different dimensions
for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to a certain dimension
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)

# Create a sequential evaluator
evaluator = SequentialEvaluator(matryoshka_evaluators)

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [4]:
# Evaluate the model
results = evaluator(model)

# # COMMENT IN for full results
# print(results)

# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print
    print(f"{key}: {results[key]}")

dim_384_cosine_ndcg@10: 0.5163534966981647
dim_256_cosine_ndcg@10: 0.5007072121406431
dim_128_cosine_ndcg@10: 0.47107077962798377
dim_64_cosine_ndcg@10: 0.40703812333002265


# Define loss function

In [5]:
from sentence_transformers import SentenceTransformerModelCardData, SentenceTransformer

model_id = "sentence-transformers/all-MiniLM-L6-v2"

# load model with SDPA for using Flash Attention 2
model = SentenceTransformer(
    model_id,
    model_kwargs={"attn_implementation": "sdpa"},
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="BGE base Financial Matryoshka",
    ),
)

In [6]:
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss

matryoshka_dimensions = [384, 256, 128, 64]  # Important: large to small
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

# Finetune a model

In [7]:
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers

# load train dataset again
train_dataset = load_dataset("json", data_files=train_dataset_path, split="train")

# define training arguments
args = SentenceTransformerTrainingArguments(
    output_dir="bge-base-financial-matryoshka",  # output directory and hugging face model ID
    num_train_epochs=4,  # number of epochs
    per_device_train_batch_size=32,  # train batch size
    gradient_accumulation_steps=16,  # for a global batch size of 512
    per_device_eval_batch_size=16,  # evaluation batch size
    warmup_ratio=0.1,  # warmup ratio
    learning_rate=2e-5,  # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",  # use constant learning rate scheduler
    optim="adamw_torch_fused",  # use fused adamw optimizer
    tf32=True,  # use tf32 precision
    bf16=True,  # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="epoch",  # evaluate after each epoch
    save_strategy="epoch",  # save after each epoch
    logging_steps=10,  # log every 10 steps
    save_total_limit=3,  # save only the last 3 models
    load_best_model_at_end=True,  # load the best model when training ends
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score for the 128 dimension
)

Failure while loading azureml_run_type_providers. Failed to load entrypoint azureml.scriptrun = azureml.core.script_run:ScriptRun._from_run_dto with exception (pyOpenSSL 24.1.0 (/home/strickvl/.pyenv/versions/3.10.12/envs/new-rag/lib/python3.10/site-packages), Requirement.parse('pyopenssl<24.0.0')).
Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.


In [8]:
from sentence_transformers import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,  # training arguments
    train_dataset=train_dataset.select_columns(
        ["positive", "anchor"]
    ),  # training dataset
    loss=train_loss,
    evaluator=evaluator,
)

In [9]:
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()

# save the best model
trainer.save_model()

# push model to hub
# trainer.model.push_to_hub("bge-base-financial-matryoshka")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mstrickvl[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Dim 384 Cosine Accuracy@1,Dim 384 Cosine Accuracy@3,Dim 384 Cosine Accuracy@5,Dim 384 Cosine Accuracy@10,Dim 384 Cosine Precision@1,Dim 384 Cosine Precision@3,Dim 384 Cosine Precision@5,Dim 384 Cosine Precision@10,Dim 384 Cosine Recall@1,Dim 384 Cosine Recall@3,Dim 384 Cosine Recall@5,Dim 384 Cosine Recall@10,Dim 384 Cosine Ndcg@10,Dim 384 Cosine Mrr@10,Dim 384 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,No log,No log,0.325301,0.542169,0.60241,0.728916,0.325301,0.180723,0.120482,0.072892,0.325301,0.542169,0.60241,0.728916,0.515565,0.448738,0.456236,0.301205,0.53012,0.60241,0.698795,0.301205,0.176707,0.120482,0.06988,0.301205,0.53012,0.60241,0.698795,0.500707,0.437545,0.447064,0.307229,0.475904,0.548193,0.662651,0.307229,0.158635,0.109639,0.066265,0.307229,0.475904,0.548193,0.662651,0.472812,0.413843,0.422149,0.240964,0.421687,0.475904,0.590361,0.240964,0.140562,0.095181,0.059036,0.240964,0.421687,0.475904,0.590361,0.407038,0.349493,0.358603,0.358603
2,No log,No log,0.379518,0.596386,0.686747,0.76506,0.379518,0.198795,0.137349,0.076506,0.379518,0.596386,0.686747,0.76506,0.568495,0.505474,0.512137,0.343373,0.590361,0.686747,0.76506,0.343373,0.196787,0.137349,0.076506,0.343373,0.590361,0.686747,0.76506,0.556675,0.489604,0.496523,0.325301,0.536145,0.614458,0.73494,0.325301,0.178715,0.122892,0.073494,0.325301,0.536145,0.614458,0.73494,0.518131,0.449969,0.457559,0.307229,0.481928,0.572289,0.662651,0.307229,0.160643,0.114458,0.066265,0.307229,0.481928,0.572289,0.662651,0.475449,0.416375,0.425072,0.425072


Attempted to log scalar metric eval_dim_384_cosine_accuracy@1:
0.3253012048192771
Attempted to log scalar metric eval_dim_384_cosine_accuracy@3:
0.5421686746987951
Attempted to log scalar metric eval_dim_384_cosine_accuracy@5:
0.6024096385542169
Attempted to log scalar metric eval_dim_384_cosine_accuracy@10:
0.7289156626506024
Attempted to log scalar metric eval_dim_384_cosine_precision@1:
0.3253012048192771
Attempted to log scalar metric eval_dim_384_cosine_precision@3:
0.18072289156626506
Attempted to log scalar metric eval_dim_384_cosine_precision@5:
0.12048192771084336
Attempted to log scalar metric eval_dim_384_cosine_precision@10:
0.07289156626506023
Attempted to log scalar metric eval_dim_384_cosine_recall@1:
0.3253012048192771
Attempted to log scalar metric eval_dim_384_cosine_recall@3:
0.5421686746987951
Attempted to log scalar metric eval_dim_384_cosine_recall@5:
0.6024096385542169
Attempted to log scalar metric eval_dim_384_cosine_recall@10:
0.7289156626506024
Attempted to l

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Attempted to log scalar metric eval_dim_384_cosine_accuracy@1:
0.3795180722891566
Attempted to log scalar metric eval_dim_384_cosine_accuracy@3:
0.5963855421686747
Attempted to log scalar metric eval_dim_384_cosine_accuracy@5:
0.6807228915662651
Attempted to log scalar metric eval_dim_384_cosine_accuracy@10:
0.7650602409638554
Attempted to log scalar metric eval_dim_384_cosine_precision@1:
0.3795180722891566
Attempted to log scalar metric eval_dim_384_cosine_precision@3:
0.19879518072289157
Attempted to log scalar metric eval_dim_384_cosine_precision@5:
0.13614457831325297
Attempted to log scalar metric eval_dim_384_cosine_precision@10:
0.07650602409638552
Attempted to log scalar metric eval_dim_384_cosine_recall@1:
0.3795180722891566
Attempted to log scalar metric eval_dim_384_cosine_recall@3:
0.5963855421686747
Attempted to log scalar metric eval_dim_384_cosine_recall@5:
0.6807228915662651
Attempted to log scalar metric eval_dim_384_cosine_recall@10:
0.7650602409638554
Attempted to l

# Evaluate fine-tuned model against baseline

In [10]:
from sentence_transformers import SentenceTransformer

fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)
# Evaluate the model
results = evaluator(fine_tuned_model)

# # COMMENT IN for full results
# print(results)

# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print(f"{key}: {results[key]}")

dim_384_cosine_ndcg@10: 0.5684952818876474
dim_256_cosine_ndcg@10: 0.5566750064887807
dim_128_cosine_ndcg@10: 0.5181307756477362
dim_64_cosine_ndcg@10: 0.47544898005380076
