In [11]:
# !pip install -U huggingface_hub datasets qdrant-client seaborn

In [16]:
import os
import time
import json

import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm

from datasets import load_dataset
from qdrant_client import QdrantClient, models
from qdrant_client.models import PointStruct

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/nirantk/.cache/huggingface/token
Login successful


In [18]:
dataset = load_dataset("nirantk/dbpedia-entities-google-palm-gemini-embedding-001-100K", streaming=False, split='train')
dataset

Dataset({
    features: ['_id', 'title', 'text', 'embedding'],
    num_rows: 100000
})

In [4]:
client = QdrantClient(
    url="https://a4197291-1236-40e0-bf18-18e8843a05a2.us-east4-0.gcp.cloud.qdrant.io:6333", 
    api_key=os.getenv("QDRANT_API_KEY"),
    timeout=100,
    prefer_grpc=True
)

# Setting up a Collection with Binary Quantization

In [5]:
collection_name = "gemini-embedding-001"

In [6]:
# client.recreate_collection(
#     collection_name=f"{collection_name}",
#     vectors_config=models.VectorParams(
#         size=768,
#         distance=models.Distance.COSINE,
#         on_disk=True,
#     ),
#     optimizers_config=models.OptimizersConfigDiff(
#         default_segment_number=5,
#         indexing_threshold=0,
#     ),
#     quantization_config=models.BinaryQuantization(
#         binary=models.BinaryQuantizationConfig(always_ram=True),
#     ),
#     shard_number=2,
# )

In [29]:
collection_info = client.get_collection(collection_name=collection_name)

points=[
        {
            "id": i,
            "vector": embedding,
            "payload": {"text": data["text"], "title": data["title"]}
        }
        for i, (embedding, data) in enumerate(zip(dataset["embedding"], dataset))
    ]
points = [PointStruct(**point) for point in points]

if collection_info.vectors_count == 0:
    print("Collection is empty. Begin indexing.")
    bs = 100 # Batch size    
    for i in tqdm(range(0, len(points), bs)):
        slice_points = points[i:i+bs]  # Create a slice of bs points
        client.upsert(
            collection_name=collection_name,
            points=slice_points
        )
    client.update_collection(
        collection_name=f"{collection_name}",
        optimizer_config=models.OptimizersConfigDiff(
            indexing_threshold=20000
        )
    )

In [20]:
collection_info = client.get_collection(collection_name=collection_name)
collection_info.vectors_count

100000

In [49]:
client.search(
    collection_name=f"{collection_name}",
    query_vector=points[32].vector,
    search_params=models.SearchParams(
        quantization=models.QuantizationSearchParams(
            ignore=False,
            rescore=False,
            oversampling=2.0,
        ),
        exact=True,
    ),
    limit=5
)

[ScoredPoint(id=32, version=0, score=1.0, payload={'text': 'Sobrassada (Catalan pronunciation: [soβɾəˈsaðə]; Spanish: sobrasada) is a raw, cured sausage from the Balearic Islands made with ground pork, paprika and salt and other spices. Sobrassada, along with botifarró are traditional Balearic sausage meat products prepared in the laborious but festive rites that still mark the autumn and winter pig slaughter known as a matança (in Spanish, matanza) in Majorca and Eivissa.', 'title': 'Sobrassada'}, vector=None, shard_key=None),
 ScoredPoint(id=78000, version=780, score=0.7025156617164612, payload={'text': 'Tamborrada of Donostia (in Basque Donostiako Danborrada) is a celebratory drum festival held every year on January 20 in the city of San Sebastián, Spain. At midnight, in the Konstituzio Plaza in the "Alde Zaharra/Parte Vieja" (Old Town), the mayor raises the flag of San Sebastián. The festival lasts for 24 hours. Participants, dressed as cooks and soldiers, march in companies across

In [52]:
ds = dataset.train_test_split(test_size=0.1, shuffle=True, seed=37)['test']

In [41]:
oversampling_range = np.arange(1.0, 3.1, 1.0)
rescore_range = [True, False]

def parameterized_search(
        point, 
        oversampling: float, 
        rescore: bool, 
        exact: bool, 
        collection_name: str, 
        ignore: bool = False,
        limit: int = 10
    ):
    if exact:
        return client.search(
            collection_name=collection_name,
            query_vector=point.vector,
            search_params=models.SearchParams(exact=exact),
            limit=limit
        )
    else:
        return client.search(
            collection_name=collection_name,
            query_vector=point.vector,
            search_params=models.SearchParams(
                quantization=models.QuantizationSearchParams(
                    ignore=ignore,
                    rescore=rescore,
                    oversampling=oversampling,
                ),
                exact=exact,     
            ),
            limit=limit
        )

import loguru

logger = loguru.logger
logger.add("logs.log", format="{time} {level} {message}", level="INFO")

results = []
with open("results.json", "w+") as f:
    for point in tqdm(points[10:100]):
        # print(element.payload["text"])
        # print("Oversampling")

        ## Running Grid Search
        for oversampling in oversampling_range:
            for rescore in rescore_range:
                limit_range = [100, 50, 20, 10, 5, 1]
                for limit in limit_range:
                    try:
                        exact = parameterized_search(point=point, oversampling=oversampling, rescore=rescore, exact=True, collection_name=collection_name, limit=limit)
                        hnsw = parameterized_search(point=point, oversampling=oversampling, rescore=rescore, exact=False, collection_name=collection_name, limit=limit)
                    except Exception as e:
                        print(f"Skipping point: {point}\n{e}")
                        continue

                    exact_ids = [item.id for item in exact]
                    hnsw_ids = [item.id for item in hnsw]
                    logger.info(f"Exact: {exact_ids}")
                    logger.info(f"HNSW: {hnsw_ids}")

                    accuracy = len(set(exact_ids) & set(hnsw_ids)) / len(exact_ids)

                    if accuracy is None:
                        continue

                    result = {
                        "query_id": point.id,
                        "oversampling": oversampling,
                        "rescore": rescore,
                        "limit": limit,
                        "accuracy": accuracy,
                    }
                    f.write(json.dumps(result))
                    f.write("\n")
                    logger.info(result)

  0%|          | 0/90 [00:00<?, ?it/s][32m2023-12-12 17:41:16.110[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m60[0m - [1mExact: [10, 66255, 24225, 16929, 1920, 12330, 85959, 92205, 14612, 81919, 77440, 67979, 9427, 11005, 45368, 23711, 63166, 92916, 77967, 74118, 50111, 40430, 11354, 92981, 69729, 34201, 93541, 73537, 98492, 29672, 77704, 10561, 77488, 25158, 45210, 292, 24395, 75552, 75706, 17899, 16461, 37401, 76977, 18557, 27409, 79674, 86758, 16139, 85973, 54626, 82512, 21251, 61804, 93325, 53426, 11178, 59425, 71625, 63427, 51834, 70509, 49228, 44169, 75198, 37485, 69982, 67391, 16707, 96165, 52468, 98838, 54938, 28464, 8018, 31446, 96817, 86572, 25616, 34994, 78614, 40926, 71403, 48364, 22707, 70459, 91513, 98079, 20433, 46955, 59068, 42365, 5643, 12983, 82536, 39641, 95444, 1659, 9655, 37439, 38499][0m
[32m2023-12-12 17:41:16.112[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m61[0m - [1mHNSW: [10, 24225, 16929, 1920, 12330, 92205, 

In [39]:
df = pd.DataFrame(results)
df[df['score'] > 0.9].groupby(['oversampling', 'rescore', 'k']).mean()['present'].unstack().plot(kind='bar', figsize=(10, 5))
# sum(df.present) / len(df)

KeyError: 'score'