# Finetune an Embeddings Model using Databricks GPU Serverless

Databricks now supports serverless compute accelerated with graphics processing units (GPUs), which can be used to train or fine-tune custom models using the framework of your choice while achieving state-of-the-art efficiency, performance, and quality. This notebook describes how you can use GPU-accelerated Serverless to finetune an embeddings model on your own data using the Sentence Transformer package.

We were inspired by Phil Schmid's (Google DeepMind) article [Fine-tune Embedding models for Retrieval Augmented Generation](https://www.philschmid.de/fine-tune-embedding-model-for-rag) and have adapted the code to run on Databricks, utilizing the [MLflow Sentence Transformer Flavor](https://mlflow.org/docs/latest/ml/deep-learning/sentence-transformers/guide).

## Our Approach

Using the sentence-transformers package, we will finetune the `modernbert-embed-base` model on our own data. Finetuning an embeddings model can yield significant improvements in the retrieval step of a RAG application. Our knowledge base consists of roughly 40 scientific papers covering various GenAI topics.

We will do the following:

1. Load the PDFs (which are stored in an external volume) into a Delta table using Auto loader with the Binary file type.

2. Convert the binary file content into markdown text using `pymupdf4llm`.

3. Chunk the markdown text into smaller pieces (standard practice for RAG applications).

4. Generate a synthetic question for each chunk.

   For our loss function we will use the `MultipleNegativesRankingLoss`, which expects a dataset structured as the example below. However, we do not have an 'anchor' or question yet.

   ```json
   [
     {
       "anchor": "what is rag?",
       "positive": "rag stands for retrieval augmented generation and is a method to ...",
       "global_chunk_id": "f47ac10b-58cc-4372-a567-0e02b2c3d479"
     }
   ]
   ```

5. Evaluate the retrieval performance on the `modernbert-embed-base`
6. Finetune the model using our chunks and synthetic questions
7. Evaluate the retrieval performance on our new model


## Getting Started

Before getting started, you need to connect to an GPU accelerated Serverless instance.

1. Navigate to the Environment side panel on the rightmost side of the notebook.
2. Set Accelerator to A10G for this demo.
3. You do not need to install any dependencies in the environment panel.
4. Select 3 as your Environment version
5. Select Apply and then Confirm you want to apply this environment to your notebook.

Now we can start!


In [0]:
# Install the required packges
%pip install pymupdf>=1.26.0 pymupdf4llm>=0.0.24 langchain-text-splitters sentence-transformers==4.1.0 transformers[torch] datasets markdownify>=1.1.0 plotly

dbutils.library.restartPython()

In [0]:
# Define the variables used across the notebook (change these to your own values)

CATALOG = "justin_zweep_gen_ai_demos"
SCHEMA = "embeddings"

volume_name = "papers"
pdf_volume = "/Volumes/justin_zweep_gen_ai_demos/embeddings/papers"

bronze_table_name = "papers_bronze"
silver_table_name = "papers_silver"

chunks_table_name = "paper_chunks"
chunks_with_questions_table_name = "paper_chunks_w_questions"

# https://huggingface.co/nomic-ai/modernbert-embed-base
model_id = "nomic-ai/modernbert-embed-base"
finetune_model_id = "modernbert-embed-base-finetuned"


In [0]:
# Create objects in unity catalog
spark.sql(f"CREATE CATALOG IF NOT EXISTS {CATALOG}")
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG}.{SCHEMA}")

# After creating the folume you need to upload your PDF's here.
spark.sql(f"CREATE VOLUME IF NOT EXISTS {CATALOG}.{SCHEMA}.{volume_name}")

## 1. Load the PDFs into a Delta table


In [0]:
from pyspark.sql.functions import col, current_timestamp, current_user

(
    spark.readStream.format("cloudFiles")
    .option("cloudFiles.format", "binaryFile")  # we read the PDF file content as binary
    .option("pathGlobFilter", "*.pdf")  # ensure we only load PDF files
    .load(pdf_volume)
    .select(
        col("content").alias("file_content"),
        col("path").alias("file_path"),
        col("modificationTime").alias("modification_time"),
        col("length").alias("file_size_bytes"),
        current_timestamp().alias("_load_timestamp"),
        current_user().alias("_load_user"),
    )
    .writeStream.format("delta")
    .outputMode("append")
    .option(
        "checkpointLocation", f"{pdf_volume}/_checkpoints_2"
    )  # set a checkpoint to ingeset files only once
    .trigger(availableNow=True)
    .toTable(f"{CATALOG}.{SCHEMA}.{bronze_table_name}")
    .awaitTermination()
)

## 2. Convert the binary file content into markdown text

First we define a UDF that converts the binary data to markdown, then we apply this to the bronze table and store it in a silver table.


In [0]:
import traceback
import warnings

import fitz
import pymupdf4llm
import pyspark.sql.functions as func
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType, StructField, StructType


@udf(
    returnType=StructType(
        [
            StructField("content", StringType(), True),
            StructField("parser_status", StringType(), True),
        ]
    )
)
def parse_file(content_bytes):
    try:
        pdf_doc = fitz.Document(stream=content_bytes, filetype="pdf")
        md_text = pymupdf4llm.to_markdown(pdf_doc)

        parsed_document = {
            "content": md_text.strip(),
            "parser_status": "SUCCESS",
        }
    except Exception as e:
        status = f"An error occurred: {e}\n{traceback.format_exc()}"
        warnings.warn(status)
        parsed_document = {
            "content": "",
            "parser_status": f"ERROR: {status}",
        }
    return parsed_document

In [0]:
from pyspark.sql.functions import reverse, split

(
    spark.read.table(f"{CATALOG}.{SCHEMA}.{bronze_table_name}")
    .withColumn("markdown_content", parse_file(col("file_content")))
    .withColumn("file_name", reverse(split("file_path", "/"))[0])
    .select("file_path", "markdown_content.*", "file_name")
    .write.mode("overwrite")
    .option("mergeSchema", "true")
    .saveAsTable(f"{CATALOG}.{SCHEMA}.{silver_table_name}")
)

In [0]:
spark.read.table(f"{CATALOG}.{SCHEMA}.{silver_table_name}").display()

## 3. Chunk the markdown text

First we define a UDF that chunks the markdown text, next we apply it.


In [0]:
from pyspark.sql.functions import col, udf
from pyspark.sql.types import ArrayType, StringType


def create_markdown_chunker(max_sequence_length: int):
    chunk_schema = ArrayType(StringType())

    @udf(returnType=chunk_schema)
    def _chunk_markdown(markdown_text):
        if not markdown_text or markdown_text.strip() == "":
            return []
        try:
            from langchain_text_splitters import MarkdownHeaderTextSplitter
            from langchain_text_splitters import RecursiveCharacterTextSplitter

            headers_to_split_on = [
                ("#", "Header 1"),
                ("##", "Header 2"),
                ("###", "Header 3"),
                ("####", "Header 4"),
            ]
            markdown_splitter = MarkdownHeaderTextSplitter(
                headers_to_split_on, strip_headers=False
            )

            # While a token is not equal to a character, we apply as if it is.
            # The token splitter does not work as serverlesss is read only file system.
            # And it needs to downlad the tokenizer in order to work.
            md_header_splits = markdown_splitter.split_text(markdown_text)
            chunk_size = max_sequence_length * 0.9
            chunk_overlap = max_sequence_length * 0.1
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=chunk_size, chunk_overlap=chunk_overlap
            )

            txt_splits = text_splitter.split_documents(md_header_splits)
            splits = [split.page_content for split in txt_splits]

            return splits

        except Exception as e:
            return []

    return _chunk_markdown


# The modern bart model can handle a sequence length of 8912, in this case we ensure chunks are no longer then roughly a quarter of the sequence length. We do this so we have a more (smaller) chunks to finetune on.
markdown_chunker = create_markdown_chunker(max_sequence_length=(8912 * 0.25))

In [0]:
chunked_df = (
    spark.read.table(f"{CATALOG}.{SCHEMA}.{silver_table_name}")
    .withColumn("chunks", markdown_chunker(col("content")))
    .select(
        "file_name",
        func.explode("chunks").alias("chunk_text"),
        func.md5(col("chunk_text")).alias("chunk_id"),
    )
    .write.mode("overwrite")
    .option("mergeSchema", "true")
    .saveAsTable(f"{CATALOG}.{SCHEMA}.{chunks_table_name}")
)


In [0]:
# Count the number of chunks
spark.read.table(f"{CATALOG}.{SCHEMA}.{chunks_table_name}").count()

In [0]:
# Ensure we have all PDF's
spark.read.table(f"{CATALOG}.{SCHEMA}.{chunks_table_name}").select(
    "file_name"
).distinct().count()

In [0]:
spark.read.table(f"{CATALOG}.{SCHEMA}.{chunks_table_name}").limit(10).display()

## 4. Generate a synthetic question


In [0]:
from pyspark.sql.functions import expr

# Load the chunks
chunks = spark.read.table(f"{CATALOG}.{SCHEMA}.{chunks_table_name}")

# Use the ai_query SQL function to generate a question for each chunk in the table.
chunks_with_question = chunks.withColumn(
    "ancher",
    expr(
        """ai_query(
            'databricks-meta-llama-3-3-70b-instruct',
            CONCAT(
                'You are a question-answer pair generator. Based on the provided chunk context, create a specific, detailed question that can be fully answered using the information in the chunks. **Requirements:** 1. Carefully read the provided chunks 2. Generate a SPECIFIC question that targets key information from one or more chunks - avoid generic questions like What is the main topic? 3. Try to create a question that requires synthesizing information from multiple chunks, but only if the chunks are related 4. Only generate a single question, no and does it .. 5. Answer the question comprehensively using ONLY the information available in the chunks, but do NOT mention the chunks 6. If the chunks don''t contain sufficient information to answer, assign a score of 0 7. Provide detailed, informative answers of at least 3 sentences 8. Focus on factual content, processes, methods, or specific findings mentioned in the text A key part is that the questions should mimic a real question that a user would ask about the topics in the chunks, while the user was not aware of the chunk''s existence **Example:** Context: The study evaluated three machine learning models for sentiment analysis. Model A achieved 85 percent accuracy, Model B reached 92 percent accuracy, and Model C obtained 78 percent accuracy on the test dataset. Good Question: Which machine learning model performed best in the sentiment analysis evaluation and what were the accuracy scores of all three models? Good Answer: Model B performed best in the sentiment analysis evaluation, achieving the highest accuracy score of 92 percent on the test dataset. Model A achieved 85 percent accuracy, while Model C obtained the lowest score of 78 percent accuracy. The study compared these three models to determine their relative performance on sentiment analysis tasks. ONLY RETURN THE QUESTION',
                chunk_text
            ),
            named_struct('max_tokens', 1200, 'temperature', 0.5)
        )"""
    ),
)

# Overwrite the chunks table
chunks_with_question.write.mode("overwrite").saveAsTable(
    f"{CATALOG}.{SCHEMA}.{chunks_with_questions_table_name}"
)

In [0]:
chunks_with_question.limit(10).display()

### 4.1 Prepare the data for the finetuning

We need a specific data structure for the sentence transformer package, this section convert the spark DF into the required format.


In [0]:
from datasets import Dataset, concatenate_datasets

chunks_with_question = spark.read.table(
    f"{CATALOG}.{SCHEMA}.{chunks_with_questions_table_name}"
)

# Convert SparkDF to PandasDF; make initial conversion step with `from_dict` possible.
df = chunks_with_question.toPandas()
df = df[["ancher", "chunk_text", "chunk_id"]]
df.columns = ["anchor", "positive", "global_chunk_id"]

# Convert from PandasDF to Dataset
dataset = Dataset.from_dict(df)
dataset = dataset.add_column("id", range(len(dataset)))

# split dataset into a 10% test set
dataset = dataset.train_test_split(test_size=0.1)

# Use in-memory HuggingFace Datasets objects directly, avoid saving/loading from disk
train_dataset = dataset["train"]
test_dataset = dataset["test"]

In [0]:
# Might be able to simplify this section:
# https://sbert.net/docs/package_reference/sentence_transformer/losses.html#multiplenegativesrankingloss

# Combine train and test datasets into a single corpus
# This ensures we have all possible text chunks available for retrieval evaluation
corpus_dataset = concatenate_datasets([train_dataset, test_dataset])

# Convert datasets into dictionary format required by the InformationRetrievalEvaluator
# corpus: maps corpus IDs to their text chunks (documents)
# Format: {corpus_id: text_chunk}
corpus = dict(zip(corpus_dataset["id"], corpus_dataset["positive"]))

# queries: maps query IDs to their questions
# Format: {query_id: question_text}
queries = dict(zip(test_dataset["id"], test_dataset["anchor"]))

# Create a mapping between queries and their relevant documents
# This tells the evaluator which documents are correct matches for each query
relevant_docs = {}
for q_id, global_chunk_id in zip(test_dataset["id"], test_dataset["global_chunk_id"]):
    # Initialize empty list for each query if not already present
    if q_id not in relevant_docs:
        relevant_docs[q_id] = []

    # Find all corpus entries that share the same global_chunk_id
    # This handles cases where multiple questions can refer to the same text chunk
    matching_corpus_ids = [
        cid
        for cid, chunk in zip(corpus_dataset["id"], corpus_dataset["global_chunk_id"])
        if chunk == global_chunk_id
    ]
    # Add the matching corpus IDs to the relevant documents for this query
    relevant_docs[q_id].extend(matching_corpus_ids)

## 5. Evaluate the retrieval performance with `modernbert-embed-base`


In [0]:
import torch
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainingArguments,
    SentenceTransformerTrainer,
)
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers import SentenceTransformer

import mlflow

In [0]:
# Dimensions of interest
matryoshka_dimensions = [768, 512, 256, 128, 64]  # Important: large to small

# Create empty list to hold evaluators
matryoshka_evaluators = []

# Create an evaluator for each above dimension
for dim in matryoshka_dimensions:
    # Define the evaluator
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to the respective dimension
        score_functions={"cosine": cos_sim},
        show_progress_bar=True,
    )
    # Add to list
    matryoshka_evaluators.append(ir_evaluator)

# Create a sequential evaluator
# Able to run all our dimension specific InformationRetrievalEvaluators sequentially.
evaluator = SequentialEvaluator(matryoshka_evaluators)

In [0]:
# Loading via SentenceTransformer
model = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

In [0]:
# Evaluate the model
with mlflow.start_run(run_name="base-model"):
    input_example = ["Sample domain-specific text"]
    output_example = model.encode(input_example)
    signature = mlflow.models.infer_signature(
        model_input=input_example,
        model_output=output_example,
    )

    model_info = mlflow.sentence_transformers.log_model(
        model=model,
        artifact_path="model",
        # input_example=input_example,
        # output_example=output_example,
        signature=signature,
        task="llm/v1/embeddings",
    )
    base_results = evaluator(model)

    for key, value in base_results.items():
        mlflow.log_metric(key, value)

## 6. Finetune the model using our chunks and synthetic questions


In [0]:
# load model with SDPA for using Flash Attention 2
model = SentenceTransformer(
    model_id,
    model_kwargs={"attn_implementation": "sdpa"},
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="ModernBERT Embed base test with Matryoshka",
    ),
)

In [0]:
# Initial Loss
base_loss = MultipleNegativesRankingLoss(model)

# Matryoshka Loss Wrapper
train_loss = MatryoshkaLoss(model, base_loss, matryoshka_dims=matryoshka_dimensions)

In [0]:
import os

with mlflow.start_run(run_name="fine_tuning_experiment"):
    run_id = mlflow.active_run().info.run_id
    checkpoint_location = f"{pdf_volume}/{run_id}"
    os.makedirs(checkpoint_location, exist_ok=True)

    args = SentenceTransformerTrainingArguments(
        output_dir=checkpoint_location,
        num_train_epochs=1,  # Reduced from 4 - sufficient for test
        per_device_train_batch_size=16,  # Reduced from 32
        gradient_accumulation_steps=2,  # Reduced from 16 - effective batch size of 32
        per_device_eval_batch_size=32,  # Increased for faster eval
        warmup_ratio=0.05,  # Reduced warmup for shorter training
        learning_rate=2e-5,  # Keep same - good starting point
        lr_scheduler_type="linear",  # Linear decay simpler for short training
        optim="adamw_torch_fused",  # Keep fused optimizer
        tf32=True,  # Keep for speed
        bf16=True,  # Keep for memory efficiency
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # Keep for ranking loss
        eval_strategy="epoch",  # Keep - only 1 eval now
        save_strategy="epoch",  # Keep
        logging_steps=5,  # Reduced from 10 - see progress faster
        save_total_limit=1,  # Reduced from 3 - save space
        load_best_model_at_end=True,  # Disable for test run
        metric_for_best_model="eval_dim_768_cosine_ndcg@10",  # Keep same metric
        report_to="none",  # Keep disabled
    )

    # lot the SentenceTransformerTrainingArguments to mlflow
    mlflow.log_params(args.to_dict())

    # Initialise the trainer
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset.select_columns(["anchor", "positive"]),
        loss=train_loss,
        evaluator=evaluator,
    )

    # Run the training
    trainer.train()

    # Log metrics:
    for entry in trainer.state.log_history:
        step = entry.get("step", 0)
        for key, value in entry.items():
            if isinstance(value, (int, float)) and key != "step":
                mlflow.log_metric(key.replace("eval_", ""), value, step=step)

    # Log the fine-tuned model

    best_model = trainer.model
    input_example = ["Sample domain-specific text"]
    output_example = best_model.encode(input_example)
    signature = mlflow.models.infer_signature(
        model_input=input_example,
        model_output=output_example,
    )

    model_info = mlflow.sentence_transformers.log_model(
        model=best_model,
        artifact_path="model",
        # input_example=input_example,
        # output_example=output_example,
        signature=signature,
        task="llm/v1/embeddings",
        registered_model_name=f"{CATALOG}.{SCHEMA}.{finetune_model_id}",
    )

## 7. Evaluate the retrieval performance on our new model

In order to evaluate our finetuned embeddings model, we can look at various metrics. Each of these addresses different aspects of the retrievel, and depending on your use case you need to select most representitive metric to optimize.

**Result Set Composition**

- Accuracy@k: Did we find at least one relevant document in top-k results?
  - `(queries with ≥1 relevant doc in top k) / (total queries)`
- Precision@k: What fraction of retrieved documents are relevant?
  - `(relevant docs in top k) / k`
- Recall@k: What fraction of all relevant documents did we find?
  - `(relevant docs in top k) / (total relevant docs)`
- F1@k: Balanced measure of precision and recall
  - `2 * (Precision@k * Recall@k) / (Precision@k + Recall@k)`

**Ranking Quality & Position**

- NDCG@k: How well are relevant documents ranked? (higher positions = better)
  - `DCG@k / IDCG@k` where `DCG@k = Σ(i=1 to k) rel_i / log2(i + 2)`
- MRR@k: How quickly do we find the first relevant result?
  - `(1/|Q|) Σ(i=1 to |Q|) 1/rank_i`
- MAP@k: Comprehensive ranking assessment across all relevant documents
  - `(1/|Q|) Σ(q=1 to |Q|) AP@k(q)` where `AP@k = (1/min(k, R)) Σ(r=1 to k) (P@r * rel(r))`

The definitions and explanations of the metrics are provided sourced from this [notebook](https://github.com/ALucek/ft-modernbert-domain/blob/main/FT_Embedding_Models_on_Domain_Specific_Data.ipynb).


### 7.1 Extract metrics from mlflow


In [None]:
import pandas as pd

In [0]:
base_model_rn = mlflow.get_run("insert your run id")
base_model_results = pd.DataFrame([base_model_rn.data.metrics])

custom_model_run = mlflow.get_run("insert your run id")
custom_model_results = pd.DataFrame([custom_model_run.data.metrics])

In [0]:
base_model_results_stacked = base_model_results.stack().reset_index()[["level_1", 0]]
base_model_results_stacked.columns = ["metric", "score"]
base_model_results_stacked_clean = base_model_results_stacked["metric"].str.split(
    "_|@", expand=True
)[[1, 3, 4]]
base_model_results_stacked_clean.columns = ["dimension", "metric", "at_k"]
base_model_results_stacked_clean["score"] = base_model_results_stacked["score"]
base_model_results_stacked_clean = base_model_results_stacked_clean[
    base_model_results_stacked_clean.metric.notnull()
]
base_model_results_stacked_clean["model"] = model_id

In [0]:
custom_model_results_stacked = custom_model_results.stack().reset_index()[
    ["level_1", 0]
]
custom_model_results_stacked.columns = ["metric", "score"]
custom_model_results_stacked_clean = custom_model_results_stacked["metric"].str.split(
    "_|@", expand=True
)[[1, 3, 4]]
custom_model_results_stacked_clean.columns = ["dimension", "metric", "at_k"]
custom_model_results_stacked_clean["score"] = custom_model_results_stacked["score"]
custom_model_results_stacked_clean = custom_model_results_stacked_clean[
    custom_model_results_stacked_clean.metric.notnull()
]
custom_model_results_stacked_clean = custom_model_results_stacked_clean[
    custom_model_results_stacked_clean.metric != "cosine"
]
custom_model_results_stacked_clean["model"] = finetune_model_id

In [0]:
results = pd.concat(
    [base_model_results_stacked_clean, custom_model_results_stacked_clean]
)

### 7.2 Plot Results


In [0]:
import plotly.express as px

In [0]:
# f1 does not come out of the box, so let's calculate it.
f1_set = results[results.metric.isin(["precision", "recall"])]
f1_set = f1_set.pivot_table(
    index=["model", "dimension", "at_k"], columns="metric", values="score"
)
f1_set["f1"] = f1_set.apply(
    lambda x: 2 * (x.precision * x.recall) / (x.precision + x.recall), axis=1
)
f1_set = (
    f1_set.reset_index()
    .drop(columns=["precision", "recall"])
    .set_index(["dimension", "model", "at_k"])
    .stack()
    .rename("score")
    .reset_index()
)
f1_set["metric"] = "f1"

# add back to results df
results = pd.concat([results, f1_set])
results = results[results.dimension != "cosine"]

results["at_k"] = results["at_k"].astype(str)
results["dimension"] = results["dimension"].astype(str)

In [0]:
fig = px.bar(
    results[results.metric.isin(["accuracy", "precision", "recall", "f1"])],
    title="Result Set Composition Metrics",
    x="at_k",
    y="score",
    facet_col="dimension",
    facet_row="metric",
    color="model",
    barmode="group",
    height=1000,
    category_orders={
        "dimension": ["768", "512", "256", "128", "64"],
        "at_k": ["1", "3", "5", "10"],
    },
)

fig.update_yaxes(tickformat=".0%", dtick=0.2, range=[0, 1])

# As x is interval, we need to change the type to category to avoid empty bars.
fig.update_xaxes(type="category")

fig.show()

In [0]:
fig = px.bar(
    results[results.metric.isin(["mrr", "ndcg"])],
    title="Ranking Quality Metrics",
    x="at_k",
    y="score",
    facet_col="dimension",
    facet_row="metric",
    color="model",
    barmode="group",
    height=1000,
    category_orders={
        "dimension": ["768", "512", "256", "128", "64"],
    },
    text="score",
)

fig.update_yaxes(
    tickformat=".0%", dtick=0.1, range=[0, 1], showticklabels=False, title=None
)

fig.update_traces(texttemplate="%{text:.0%}", textposition="outside")

fig.show()

## Optional: Serve Model through a Serving Endpoint

Because the model is registered in Unity Catalog, with the task `llm/v1/embeddings` it can be directly deployed to a serving endpoint.

# ![Model Finetune](model-finetune.png)

# ![Model Serving](serving-finetune.png)
