In [None]:
from datasets import load_dataset

dataset = load_dataset("Qdrant/arxiv-titles-instructorxl-embeddings", split="train", streaming=True)

In [None]:
dataset_iterator = iter(dataset)
train_dataset = [next(dataset_iterator) for _ in range(60000)]
test_dataset = [next(dataset_iterator) for _ in range(1000)]

In [None]:
from qdrant_client import QdrantClient, models

client = QdrantClient("http://localhost:6333")
client.create_collection(
    collection_name="arxiv-titles-instructorxl-embeddings",
    vectors_config=models.VectorParams(
        size=768,  # Size of the embeddings generated by InstructorXL model
        distance=models.Distance.COSINE,
    ),
)

In [None]:
client.upload_points(  # upload_points is available as of qdrant-client v1.7.1
    collection_name="arxiv-titles-instructorxl-embeddings",
    points=[
        models.PointStruct(
            id=item["id"],
            vector=item["vector"],
            payload=item,
        )
        for item in train_dataset
    ],
)

while True:
    collection_info = client.get_collection(collection_name="arxiv-titles-instructorxl-embeddings")
    if collection_info.status == models.CollectionStatus.GREEN:
        # Collection status is green, which means the indexing is finished
        break

In [None]:
def avg_precision_at_k(k: int):
    precisions = []
    for item in test_dataset:
        ann_result = client.search(
            collection_name="arxiv-titles-instructorxl-embeddings",
            query_vector=item["vector"],
            limit=k,
        )

        knn_result = client.search(
            collection_name="arxiv-titles-instructorxl-embeddings",
            query_vector=item["vector"],
            limit=k,
            search_params=models.SearchParams(
                exact=True,  # Turns on the exact search mode
            ),
        )

        # We can calculate the precision@k by comparing the ids of the search results
        ann_ids = set(item.id for item in ann_result)
        knn_ids = set(item.id for item in knn_result)
        precision = len(ann_ids.intersection(knn_ids)) / k
        precisions.append(precision)

    return sum(precisions) / len(precisions)

In [None]:
print(f"avg(precision@5) = {avg_precision_at_k(k=5)}")

In [None]:
client.update_collection(
    collection_name="arxiv-titles-instructorxl-embeddings",
    hnsw_config=models.HnswConfigDiff(
        m=32,  # Increase the number of edges per node from the default 16 to 32
        ef_construct=200,  # Increase the number of neighbours from the default 100 to 200
    )
)

In [None]:
print(f"avg(precision@5) = {avg_precision_at_k(k=5)}")