In [None]:
import os

PATH = os.getcwd() + "/.cache/huggingface"
os.environ["HF_HOME"] = PATH
os.environ["HF_DATASETS_CACHE"] = PATH
os.environ["TORCH_HOME"] = PATH

import torch
from sentence_transformers import SentenceTransformer
from tqdm.notebook import tqdm
from qdrant_client import QdrantClient
from qdrant_client.models import (
    PointStruct,
    Distance,
    VectorParams,
    SparseVectorParams,
    Modifier,
    Prefetch,
    SparseVector,
    FusionQuery,
    Fusion,
)
import pandas as pd
import math
from tqdm.notebook import tqdm
from BM25 import BM25
from pprint import pprint
import numpy as np

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer("./trained_models/all_mpnet_base_v2", device=DEVICE)

bm25 = BM25(
    stopwords_dir=os.path.abspath("./stopwords"), languages=["english", "bengali"]
)

In [None]:
COLLECTION_NAME = "product_collection_all_mpnet_base_v2_trained"
client = QdrantClient(url="http://localhost:6333", timeout=600)

In [None]:
client.delete_collection(collection_name=COLLECTION_NAME)
client.create_collection(
    collection_name=COLLECTION_NAME,
    vectors_config={"dense_vector": VectorParams(size=768, distance=Distance.COSINE)},
    sparse_vectors_config={"sparse_vector": SparseVectorParams(modifier=Modifier.IDF)},
)

In [None]:
product_info_df = pd.read_csv("./datasets/final_5000_products.csv")
product_info_df = product_info_df.replace(np.nan, None)
product_info_df.head(5)

In [29]:
def format_product_details(name, price, description):
    product_details = ""
    if description is not None:
        product_details = f"Name: {name}\nPrice: {price} taka\n{description}"
    else:
        product_details = f"Name: {name}\nPrice: {price} taka"

    return product_details

In [30]:
total_row = product_info_df.shape[0]
batch_size = 10
total_batch = math.ceil(total_row / batch_size)

In [31]:
documents = []

for idx, row in product_info_df.iterrows():
    title = row["title"]
    description = row["description"]
    price = row["price"]
    formatted_document = format_product_details(title, price, description)
    documents.append(formatted_document)

In [None]:
bm25.calculate_avg_doc_len(documents)
print(bm25.avg_len)

In [None]:
for start in tqdm(range(0, total_row, batch_size)):
    batch = product_info_df.iloc[start : start + batch_size]

    titles = batch["title"].tolist()
    descriptions = batch["description"].tolist()
    prices = batch["price"].tolist()

    texts_for_embedding = [
        format_product_details(title, price, description)
        for title, description, price in zip(titles, descriptions, prices)
    ]
    dense_vectors = model.encode(texts_for_embedding)
    sparse_vectors = bm25.raw_embed(texts_for_embedding)

    points = []
    for idx, (batch_idx, row) in enumerate(batch.iterrows()):
        title = row["title"]
        description = row["description"]
        price = row["price"]

        points.append(
            PointStruct(
                id=batch_idx,
                vector={
                    "dense_vector": dense_vectors[idx],
                    "sparse_vector": sparse_vectors[idx],
                },
                payload={
                    "title": title,
                    "description": description,
                    "price": price,
                },
            )
        )

    operation_info = client.upsert(
        collection_name=COLLECTION_NAME, wait=True, points=points
    )
    print(operation_info, end="\r")

In [None]:
def query(query_text: str):
    dense_vector = model.encode([query_text])[0]
    sparse_vector = bm25.raw_embed([query_text])[0]

    prefetch = [
        Prefetch(query=dense_vector, using="dense_vector", limit=10),
        Prefetch(query=SparseVector(**sparse_vector), using="sparse_vector", limit=10),
    ]

    # results = client.query_points(
    #     collection_name=COLLECTION_NAME,
    #     prefetch=prefetch,
    #     query=FusionQuery(fusion=Fusion.RRF),
    #     with_payload=True,
    #     limit=5,
    # )

    results = client.query_points(
        collection_name=COLLECTION_NAME,
        query=dense_vector,
        using="dense_vector",
        with_payload=True,
        limit=5
    )

    return [{"score": point.score, "payload": point.payload} for point in results.points]

In [None]:
query_result = query("small smartphone")
pprint(query_result)