# Understanding the `pylate` API: A Minimal Example

This notebook provides a minimal, step-by-step guide to using the `pylate` library. We will cover the essential components of the API, including:

1.  **Setup and Imports**: Getting your environment ready.
2.  **Loading a Model**: How to load a pre-trained ColBERT-style model.
3.  **Inference (Encoding & Retrieval)**: How to encode documents and queries, build an index, and retrieve relevant documents.
4.  **Fine-Tuning**: How to fine-tune a base model on a custom dataset using the `SentenceTransformerTrainer`.
5.  **Saving and Loading**: How to save your fine-tuned model and load it back for later use.

By the end of this notebook, you will have a clear understanding of the core workflow for both using and training models with `pylate`.

### 1. Setup and Imports

First, let's import the necessary libraries. We need `torch` for tensor operations, `datasets` to handle our data, and various modules from `pylate` and `sentence_transformers` for modeling, training, and evaluation.

⚠️ **Python 3.12 Compatibility Note**

`torch.compile` (Torch Dynamo) is **not supported on Python 3.12+**.  
ModernBERT (used inside the ColBERT family) decorates some internal
functions with `@torch.compile`, which raises a `RuntimeError` under
Python 3.12.  

For production or training workloads we **strongly recommend** creating a
virtual environment with Python 3.10 (or 3.9 / 3.11):
```bash
conda create -n pylate-310 python=3.10 pytorch torchvision -c pytorch
conda activate pylate-310
pip install pylate sentence-transformers datasets plaid-index
```
The next code-cell patches `torch.compile` so this notebook can still run on
Python 3.12, **but JIT acceleration will be disabled**.

In [1]:
import sys
import torch

# -------------------------------------------------------------
# Disable torch.compile on Python 3.12+ to avoid Dynamo errors
# -------------------------------------------------------------
if sys.version_info >= (3, 12):
    if hasattr(torch, "compile"):
        def _compile_noop(model=None, *args, **kwargs):
            """Fallback replacement that returns the original model/function."""
            return model if model is not None else (lambda x: x)
        torch.compile = _compile_noop
        print("[INFO] torch.compile disabled (Python 3.12+ detected). "
              "Consider using Python 3.10 for full Torch Dynamo support.")
else:
    print(f"[INFO] Running on Python {sys.version.split()[0]}. torch.compile intact.")

[INFO] Running on Python 3.10.18. torch.compile intact.


In [2]:
import torch
from datasets import Dataset
from sentence_transformers import (
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)

from pylate import evaluation, losses, models, utils, retrieve, indexes
import os

# Create an output directory for our fine-tuned model
os.makedirs("output/pylate-minimal-example", exist_ok=True)

  from .autonotebook import tqdm as notebook_tqdm


### 2. Loading a Pre-trained Model

`pylate` makes it easy to load any ColBERT-style model from the Hugging Face Hub. We'll use `lightonai/GTE-ModernColBERT-v1`, which is the base model for the `Reason-ModernColBERT` we aim to recreate.

The `pylate.models.ColBERT` class handles the model architecture. It wraps a standard transformer model and adds the necessary layers for late-interaction retrieval.

In [4]:
print("Loading pre-trained model...")
model_id = "lightonai/GTE-ModernColBERT-v1"
model = models.ColBERT(model_name_or_path=model_id, device="mps")
print("Model loaded successfully!")

Loading pre-trained model...
Model loaded successfully!


### 3. Inference: Encoding, Indexing, and Retrieval

Now that we have a model, let's use it for its primary purpose: retrieval. This is a three-step process:
1.  **Encode Documents**: Convert your document collection into vector representations.
2.  **Index Documents**: Store these vectors in an efficient search index.
3.  **Encode Query & Retrieve**: Convert a search query into a vector and use it to find the most similar documents in the index.

#### 3.1. Document Preparation & Encoding

ColBERT is a "late-interaction" model, which means it represents documents as a set of vectors (one for each token) rather than a single vector. This preserves more granular information.

When encoding, we must tell the model whether we are encoding a **query** or a **document**. This is done with the `is_query` flag.

**`is_query=False`**: Use this for documents. The model will process the text and may apply document-specific padding or truncation.
**`is_query=True`**: Use this for queries. The model will process the text and may apply query-specific tokens or processing.

In [5]:
documents = [
    "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France.",
    "Photosynthesis is a process used by plants, algae, and certain bacteria to convert light energy into chemical energy.",
    "A CPU, or Central Processing Unit, is the primary component of a computer that executes instructions.",
    "Paris is the capital and most populous city of France."
]
document_ids = ["doc1", "doc2", "doc3", "doc4"]

print("Encoding documents...")
document_embeddings = model.encode(
    documents,
    is_query=False,  # Critical for encoding documents
    show_progress_bar=True
)

print(f"Encoded {len(document_embeddings)} documents.")
print("Shape of the first document's embedding:", document_embeddings[0].shape)

Encoding documents...


Encoding documents (bs=32): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.42it/s]

Encoded 4 documents.
Shape of the first document's embedding: (22, 128)





#### 3.2. Indexing Documents

To perform fast retrieval over thousands or millions of documents, we need to store their embeddings in a specialized index. `pylate` integrates with efficient index libraries. Here, we use `PLAID`, which is optimized for ColBERT's multi-vector representations.

We'll create an in-memory index for this example. For larger collections, you can specify a folder to persist the index to disk.

In [9]:
print("Creating an in-memory PLAID index...")
index = indexes.Voyager(
    index_folder="output/pylate-minimal-index", # Directory to store index files
    index_name="minimal_example_index",
    override=True,  # Overwrite if it already exists
    # num_partitions=8,  # Use only 8 centroids – suitable for our tiny toy set
)

print("Adding documents to the index...")
index.add_documents(
    documents_ids=document_ids,
    documents_embeddings=document_embeddings,
)
print("Documents added successfully.")

Creating an in-memory PLAID index...
Adding documents to the index...


Adding documents to the index (bs=2000): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 94.57it/s]

Documents added successfully.





#### 3.3. Query Encoding & Retrieval

Now we encode our search query using `is_query=True` and use the `ColBERT` retriever to search the index.

In [10]:
query = "What is the capital of France?"

print("Encoding query...")
query_embedding = model.encode(
    [query], # Note: encode expects a list of strings
    is_query=True # Critical for encoding queries
)

# Initialize the retriever with our index
retriever = retrieve.ColBERT(index=index)

print("Performing retrieval...")
search_results = retriever.retrieve(
    queries_embeddings=query_embedding,
    k=2 # Retrieve the top 2 most relevant documents
)

print("\nSearch Results:")
for hit in search_results[0]: # Results for the first (and only) query
    print(f"  Document ID: {hit['id']}, Score: {hit['score']:.4f}")

Encoding query...
Performing retrieval...


Retrieving documents (bs=50): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.70it/s]


Search Results:
  Document ID: doc4, Score: 29.9538
  Document ID: doc1, Score: 29.5036





### 4. Fine-Tuning a Model

While pre-trained models are powerful, fine-tuning them on a domain-specific dataset can significantly boost performance. The training process in `pylate` is built on the `sentence-transformers` `Trainer` API, making it familiar and robust.

The key components are:
- **Dataset**: A collection of training examples, typically triplets of (query, positive_document, negative_document).
- **Model**: The base model to be fine-tuned.
- **Loss Function**: A function that calculates how "wrong" the model's predictions are, guiding it to improve. For retrieval, `Contrastive` loss is common.
- **Trainer**: An object that orchestrates the entire training loop.

#### 4.1. Preparing a Training Dataset

We'll create a tiny dataset of triplets. Each triplet consists of:
- `query`: The search query.
- `positive`: A document that is relevant to the query.
- `negative`: A document that is *not* relevant to the query.

The model learns to score the `(query, positive)` pair higher than the `(query, negative)` pair.

In [11]:
train_samples = {
    "query": [
        "What is the capital of France?", 
        "What does a CPU do?"
    ],
    "positive": [
        "Paris is the capital and most populous city of France.",
        "A CPU, or Central Processing Unit, is the primary component of a computer that executes instructions."
    ],
    "negative": [
        "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France.", # Related but not the direct answer
        "Photosynthesis is a process used by plants to convert light energy into chemical energy." # Unrelated
    ]
}

train_dataset = Dataset.from_dict(train_samples)
eval_dataset = Dataset.from_dict(train_samples) # Using the same for simplicity

print("Sample training data:")
print(train_dataset[0])

Sample training data:
{'query': 'What is the capital of France?', 'positive': 'Paris is the capital and most populous city of France.', 'negative': 'The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France.'}


#### 4.2. Setting up Training Components

Now we define the model, loss, evaluator, and training arguments.

In [12]:
# 1. Model: We'll fine-tune the same model we loaded earlier.
# For a real scenario, you might start from a more general base like 'bert-base-uncased'.
training_model = models.ColBERT(model_name_or_path=model_id)

# 2. Loss Function: Contrastive loss pushes positive pairs closer and negative pairs further apart.
# A temperature around 0.02 is often a good starting point.
train_loss = losses.Contrastive(model=training_model, temperature=0.02)

# 3. Evaluator: This will compute metrics on the evaluation set during training.
dev_evaluator = evaluation.ColBERTTripletEvaluator(
    anchors=eval_dataset["query"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
)

# 4. Data Collator: This prepares batches of data for the ColBERT model.
data_collator = utils.ColBERTCollator(training_model.tokenize)

# 5. Training Arguments: Configure the training process.
output_dir = "output/pylate-minimal-example"
args = SentenceTransformerTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    fp16=torch.cuda.is_available(),  # Use mixed precision if a GPU is available
    save_strategy="epoch",
    evaluation_strategy="epoch",
    logging_steps=1,
)

The 'main_distance_function' parameter is deprecated. Please use 'main_similarity_function' instead. 'main_distance_function' will be removed in a future release.


### 5. The Training Loop

With all the components ready, we can initialize the `SentenceTransformerTrainer` and start the fine-tuning process.

In [13]:
trainer = SentenceTransformerTrainer(
    model=training_model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
    evaluator=dev_evaluator,
    data_collator=data_collator,
)

print("Starting fine-tuning...")
# This will run the training loop. It will take a few moments even on this tiny dataset.
trainer.train()
print("Fine-tuning complete!")

Starting fine-tuning...




Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,0.0,1.0




Fine-tuning complete!


### 6. Saving and Loading the Fine-Tuned Model

The `Trainer` automatically saves the final model checkpoints in the specified `output_dir`. You can easily load this model for inference, just like you loaded the pre-trained model from the Hub.

In [16]:
print(f"Model saved in: {output_dir}")

# Load the fine-tuned model from the output directory
print("\nLoading fine-tuned model...")
# SentenceTransformerTrainer saves each epoch to a numbered checkpoint
# (e.g. output_dir/checkpoint-1).  We load from that sub-directory.
checkpoint_dir = f"{output_dir}/checkpoint-1"
fine_tuned_model = models.ColBERT(model_name_or_path=checkpoint_dir)
print("Fine-tuned model loaded successfully.")

# You can now use this model for inference just like before
fine_tuned_query_embedding = fine_tuned_model.encode(
    [query],
    is_query=True
)
print("\nSuccessfully encoded a query with the fine-tuned model.")

Model saved in: output/pylate-minimal-example

Loading fine-tuned model...
Fine-tuned model loaded successfully.

Successfully encoded a query with the fine-tuned model.


### Conclusion

This notebook has walked you through the fundamental API of the `pylate` library. You've learned how to:
- Load a pre-trained ColBERT model.
- Encode documents and queries correctly.
- Build a retrieval index and search it.
- Set up a complete fine-tuning pipeline.
- Train a model and save the result.

With this foundation, you are now ready to apply `pylate` to larger, real-world datasets and build powerful, reasoning-intensive retrieval systems.