In [None]:
# https://github.com/lightonai/pylate
from pylate import indexes, models, retrieve

model = models.ColBERT(
    model_name_or_path="colbert-ir/colbertv2.0",
)

index = indexes.PLAID(
    index_folder="pylate-index",
    index_name="index",
    override=True,
)

retriever = retrieve.ColBERT(index=index)

  from .autonotebook import tqdm as notebook_tqdm
No sentence-transformers model found with name colbert-ir/colbertv2.0.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [2]:
documents_ids = ["1", "2", "3"]

documents = [
    "ColBERT’s late-interaction keeps token-level embeddings to deliver cross-encoder-quality ranking at near-bi-encoder speed, enabling fine-grained relevance, robustness across domains, and hardware-friendly scalable search.",

    "PLAID compresses ColBERT token vectors via product quantization to shrink storage by 10×, uses two-stage centroid scoring for sub-200 ms latency, and plugs directly into existing ColBERT pipelines.",

    "PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize fine-tuning, inference, and retrieval with state-of-the-art ColBERT models. It enables easy fine-tuning on both single and multiple GPUs, providing flexibility for various hardware setups. PyLate also streamlines document retrieval and allows you to load a wide range of models, enabling you to construct ColBERT models from most pre-trained language models.",
]

# Encode the documents
documents_embeddings = model.encode(
    documents,
    batch_size=32,
    is_query=False, # Encoding documents
    show_progress_bar=True,
)

# Add the documents ids and embeddings to the PLAID index
index.add_documents(
    documents_ids=documents_ids,
    documents_embeddings=documents_embeddings,
)

Encoding documents (bs=32): 100%|██████████| 1/1 [00:02<00:00,  2.69s/it]


<pylate.indexes.fast_plaid.FastPlaid at 0x3231f09d0>

In [4]:
queries_embeddings = model.encode(
    ["What is Pylate library?", "Explain me what is colbert."],
    batch_size=32,
    is_query=True, # Encoding queries
    show_progress_bar=True,
)

scores = retriever.retrieve(
    queries_embeddings=queries_embeddings,
    k=10,
)

print(scores)

Encoding queries (bs=32): 100%|██████████| 1/1 [00:00<00:00, 14.06it/s]


[[{'id': '3', 'score': 27.17529296875}, {'id': '2', 'score': 5.95947265625}, {'id': '1', 'score': 5.05743408203125}], [{'id': '1', 'score': 17.4180908203125}, {'id': '3', 'score': 16.8282470703125}, {'id': '2', 'score': 15.77056884765625}]]
