In [1]:
# Code of your application, which uses environment variables (e.g. from `os.environ` or
# `os.getenv`) as if they came from the actual environment.
import json
import os

import loguru
import numpy as np
from datasets import load_dataset
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from qdrant_client.models import PointStruct
from tqdm import tqdm

load_dotenv()  # take environment variables from .env.
logger = loguru.logger
logger.add("logs.log", format="{time} {level} {message}", level="INFO")

1

In [2]:
DIMENSIONS = 1536
MODEL_NAME = "text-embedding-3-large"
DATASET_NAME = f"Qdrant/dbpedia-entities-openai3-{MODEL_NAME}-{DIMENSIONS}-100K"
# DATASET_NAME = "Qdrant/dbpedia-entities-openai3-small-512-100K"

In [3]:
dataset = load_dataset(
    DATASET_NAME,
    streaming=False,
    split="train",
)
points = [
    {
        "id": i,
        "vector": embedding,
        "payload": {"text": data["text"], "title": data["title"]},
    }
    for i, (embedding, data) in enumerate(zip(dataset["embedding"], dataset))
]
points = [PointStruct(**point) for point in points]

In [4]:
client = QdrantClient(
    url=os.getenv("QDRANT_URL"),
    api_key=os.getenv("QDRANT_API_KEY"),
    timeout=100,
)

In [5]:
collection_name = f"dbpedia-{MODEL_NAME}-{DIMENSIONS}"

In [6]:
client.recreate_collection(
    collection_name=collection_name,
    vectors_config=models.VectorParams(
        size=DIMENSIONS,
        distance=models.Distance.COSINE,
    ),
    optimizers_config=models.OptimizersConfigDiff(
        indexing_threshold=0,
    ),
    quantization_config=models.BinaryQuantization(
        binary=models.BinaryQuantizationConfig(always_ram=True),
    ),
    shard_number=2,
)

True

In [7]:
collection_info = client.get_collection(collection_name=collection_name)

if collection_info.vectors_count == 0:
    logger.info("Collection is empty. Begin upsert.")
    bs = 500  # Batch size
    for i in tqdm(range(0, len(points), bs)):
        slice_points = points[i : i + bs]  # Create a slice of bs points
        client.upsert(collection_name=collection_name, points=slice_points)

[32m2024-01-29 14:16:58.923[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mCollection is empty. Begin upsert.[0m
100%|██████████| 200/200 [06:13<00:00,  1.87s/it]


In [8]:
collection_info = client.get_collection(collection_name=collection_name)
collection_info.vectors_count

100000

## Turn on Indexing

In [9]:
client.update_collection(
    collection_name=f"{collection_name}",
    optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000),
)

True

In [10]:
client.search(
    collection_name=f"{collection_name}",
    query_vector=points[32].vector,
    search_params=models.SearchParams(
        quantization=models.QuantizationSearchParams(
            ignore=False,
            rescore=False,
            oversampling=2.0,
        ),
        exact=True,
    ),
    limit=5,
)

[ScoredPoint(id=32, version=0, score=1.0000001, payload={'text': "Ray Bidwell Collins (December 10, 1889 – July 11, 1965) was an American character actor in stock and Broadway theatre, radio, films and television. With 900 stage roles to his credit, he became one of the most successful actors in the developing field of radio drama. A friend and associate of Orson Welles for many years, Collins went to Hollywood with the Mercury Theatre company and made his feature film debut in Citizen Kane, as Kane's ruthless political rival.", 'title': 'Ray Collins (actor)'}, vector=None, shard_key=None),
 ScoredPoint(id=6108, version=12, score=0.5336138, payload={'text': 'Charles Henry Collingwood (born 30 May 1943), is a British actor.', 'title': 'Charles Collingwood (actor)'}, vector=None, shard_key=None),
 ScoredPoint(id=45230, version=90, score=0.51626015, payload={'text': 'Raymond ‘Ray’ Carney (born February 28, 1947), is an American scholar and critic, primarily known for his work as a film th

## Create a valuation split for the BQ to Compare exact with approximate

In [None]:
ds = dataset.train_test_split(test_size=0.1, shuffle=True, seed=37)["test"]

In [None]:
oversampling_range = np.arange(1.0, 3.1, 1.0)
rescore_range = [True, False]


def parameterized_search(
    point,
    oversampling: float,
    rescore: bool,
    exact: bool,
    collection_name: str,
    ignore: bool = False,
    limit: int = 10,
):
    if exact:
        return client.search(
            collection_name=collection_name,
            query_vector=point.vector,
            search_params=models.SearchParams(exact=exact),
            limit=limit,
        )
    else:
        return client.search(
            collection_name=collection_name,
            query_vector=point.vector,
            search_params=models.SearchParams(
                quantization=models.QuantizationSearchParams(
                    ignore=ignore,
                    rescore=rescore,
                    oversampling=oversampling,
                ),
                exact=exact,
            ),
            limit=limit,
        )


results = []
with open("results.json", "w+") as f:
    for point in tqdm(points[10:100]):
        # print(element.payload["text"])
        # print("Oversampling")

        ## Running Grid Search
        for oversampling in oversampling_range:
            for rescore in rescore_range:
                limit_range = [100, 50, 20, 10, 5, 1]
                for limit in limit_range:
                    try:
                        exact = parameterized_search(
                            point=point,
                            oversampling=oversampling,
                            rescore=rescore,
                            exact=True,
                            collection_name=collection_name,
                            limit=limit,
                        )
                        hnsw = parameterized_search(
                            point=point,
                            oversampling=oversampling,
                            rescore=rescore,
                            exact=False,
                            collection_name=collection_name,
                            limit=limit,
                        )
                    except Exception as e:
                        print(f"Skipping point: {point}\n{e}")
                        continue

                    exact_ids = [item.id for item in exact]
                    hnsw_ids = [item.id for item in hnsw]
                    logger.info(f"Exact: {exact_ids}")
                    logger.info(f"HNSW: {hnsw_ids}")

                    accuracy = len(set(exact_ids) & set(hnsw_ids)) / len(exact_ids)

                    if accuracy is None:
                        continue

                    result = {
                        "query_id": point.id,
                        "oversampling": oversampling,
                        "rescore": rescore,
                        "limit": limit,
                        "accuracy": accuracy,
                    }
                    f.write(json.dumps(result))
                    f.write("\n")
                    logger.info(result)

In [None]:
import pandas as pd

results = pd.read_json("results.json", lines=True)

In [None]:

# results.to_csv("results.csv", index=False)
average_accuracy = results[results["limit"] != 1]
average_accuracy = average_accuracy[average_accuracy["limit"] != 5]
average_accuracy = average_accuracy.groupby(["oversampling", "rescore", "limit"])[
    "accuracy"
].mean()
average_accuracy = average_accuracy.reset_index()
acc = average_accuracy.pivot(
    index="limit", columns=["oversampling", "rescore"], values="accuracy"
)
acc