In [None]:
# !pip install -U huggingface_hub datasets qdrant-client seaborn

In [None]:
import os
import time
import json

import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm

from datasets import load_dataset
from qdrant_client import QdrantClient, models
from qdrant_client.models import PointStruct

In [None]:
dataset = load_dataset("nirantk/dbpedia-entities-google-palm-gemini-embedding-001-100K", streaming=False, split='train')
dataset

In [None]:
client = QdrantClient(
    url="https://a4197291-1236-40e0-bf18-18e8843a05a2.us-east4-0.gcp.cloud.qdrant.io:6333", 
    api_key=os.getenv("QDRANT_API_KEY"),
    timeout=100,
    prefer_grpc=True
)

# Setting up a Collection with Binary Quantization

In [None]:
collection_name = "gemini-embedding-001"

In [None]:
# client.recreate_collection(
#     collection_name=f"{collection_name}",
#     vectors_config=models.VectorParams(
#         size=768,
#         distance=models.Distance.COSINE,
#         on_disk=True,
#     ),
#     optimizers_config=models.OptimizersConfigDiff(
#         default_segment_number=5,
#         indexing_threshold=0,
#     ),
#     quantization_config=models.BinaryQuantization(
#         binary=models.BinaryQuantizationConfig(always_ram=True),
#     ),
#     shard_number=2,
# )

In [None]:
collection_info = client.get_collection(collection_name=collection_name)

points=[
        {
            "id": i,
            "vector": embedding,
            "payload": {"text": data["text"], "title": data["title"]}
        }
        for i, (embedding, data) in enumerate(zip(dataset["embedding"], dataset))
    ]
points = [PointStruct(**point) for point in points]

if collection_info.vectors_count == 0:
    print("Collection is empty. Begin indexing.")
    bs = 100 # Batch size    
    for i in tqdm(range(0, len(points), bs)):
        slice_points = points[i:i+bs]  # Create a slice of bs points
        client.upsert(
            collection_name=collection_name,
            points=slice_points
        )
    client.update_collection(
        collection_name=f"{collection_name}",
        optimizer_config=models.OptimizersConfigDiff(
            indexing_threshold=20000
        )
    )

In [None]:
collection_info = client.get_collection(collection_name=collection_name)
collection_info.vectors_count

In [None]:
client.search(
    collection_name=f"{collection_name}",
    query_vector=points[32].vector,
    search_params=models.SearchParams(
        quantization=models.QuantizationSearchParams(
            ignore=False,
            rescore=False,
            oversampling=2.0,
        ),
        exact=True,
    ),
    limit=5
)

In [None]:
ds = dataset.train_test_split(test_size=0.1, shuffle=True, seed=37)['test']

In [None]:
oversampling_range = np.arange(1.0, 3.1, 1.0)
rescore_range = [True, False]

def parameterized_search(
        point, 
        oversampling: float, 
        rescore: bool, 
        exact: bool, 
        collection_name: str, 
        ignore: bool = False,
        limit: int = 10
    ):
    if exact:
        return client.search(
            collection_name=collection_name,
            query_vector=point.vector,
            search_params=models.SearchParams(exact=exact),
            limit=limit
        )
    else:
        return client.search(
            collection_name=collection_name,
            query_vector=point.vector,
            search_params=models.SearchParams(
                quantization=models.QuantizationSearchParams(
                    ignore=ignore,
                    rescore=rescore,
                    oversampling=oversampling,
                ),
                exact=exact,     
            ),
            limit=limit
        )

import loguru

logger = loguru.logger
logger.add("logs.log", format="{time} {level} {message}", level="INFO")

results = []
with open("results.json", "w+") as f:
    for point in tqdm(points[10:100]):
        # print(element.payload["text"])
        # print("Oversampling")

        ## Running Grid Search
        for oversampling in oversampling_range:
            for rescore in rescore_range:
                limit_range = [100, 50, 20, 10, 5, 1]
                for limit in limit_range:
                    try:
                        exact = parameterized_search(point=point, oversampling=oversampling, rescore=rescore, exact=True, collection_name=collection_name, limit=limit)
                        hnsw = parameterized_search(point=point, oversampling=oversampling, rescore=rescore, exact=False, collection_name=collection_name, limit=limit)
                    except Exception as e:
                        print(f"Skipping point: {point}\n{e}")
                        continue

                    exact_ids = [item.id for item in exact]
                    hnsw_ids = [item.id for item in hnsw]
                    logger.info(f"Exact: {exact_ids}")
                    logger.info(f"HNSW: {hnsw_ids}")

                    accuracy = len(set(exact_ids) & set(hnsw_ids)) / len(exact_ids)

                    if accuracy is None:
                        continue

                    result = {
                        "query_id": point.id,
                        "oversampling": oversampling,
                        "rescore": rescore,
                        "limit": limit,
                        "accuracy": accuracy,
                    }
                    f.write(json.dumps(result))
                    f.write("\n")
                    logger.info(result)

In [None]:
import pandas as pd
results = pd.read_json("results.json", lines=True)

In [None]:
# results.to_csv("results.csv", index=False)
average_accuracy = results[results['limit'] != 1]
average_accuracy = average_accuracy[average_accuracy['limit'] != 5]
average_accuracy = average_accuracy.groupby(['oversampling', 'rescore', 'limit'])['accuracy'].mean()
average_accuracy = average_accuracy.reset_index()
acc = average_accuracy.pivot(index='limit', columns=['oversampling', 'rescore'], values='accuracy')
acc