In [36]:
from qdrant_client import QdrantClient,models
from pymongo import MongoClient
import polars as pl

qclient = QdrantClient(host="localhost",port=6333, prefer_grpc=False)

In [2]:
collection_name = "ondc-query"

In [3]:
offset = 1
all_points = [] 
while offset:
    points,offset = qclient.scroll(
        collection_name=collection_name,
        limit=100,
        offset=offset,
        with_vectors=True
    )
    all_points.extend(points)
    

In [4]:
points = [p.payload for p in all_points]

In [5]:
df = pl.DataFrame(points)

In [6]:
df = df.select(
    pl.col("query").str.to_lowercase().alias("query"),
    pl.col("product_id")
).group_by("query").agg(pl.col("product_id")).with_columns(
    pl.col("product_id").map_elements(lambda s: len(s)).alias("p_count")
)

In [7]:
# df['p_count'].value_counts().sort("count",descending=True)[:6]["count"].plot(kind="bar")
df_2 = df.filter(pl.col("p_count")==2)
df_3 = df.filter(pl.col("p_count")==3)
df_4 = df.filter(pl.col("p_count")==4)

In [9]:
# Queries unique to a product
df_1 = df.filter(pl.col("p_count")==1)
df_1 = df_1.explode("product_id").drop("p_count")


In [10]:
def get_query_id(struct: dict[str,str]):
    product_id = struct['product_id']
    query = struct['query']
    p,o = qclient.scroll(
        collection_name="ondc-query-gen",
        scroll_filter=models.Filter(
            must=[
                models.FieldCondition(
                    key="product_id",
                    match=models.MatchValue(value=product_id)
                ),
            ]
        )
    )
    # We need the query_id 
    for _p in p:
        if _p.payload["query"].lower() == query:
            q_id = _p.id
            break
    return q_id

In [11]:
qid1 = df_1.select([
    pl.struct(["query", "product_id"]).map_elements(get_query_id, strategy="thread_local").alias("query_id")
])

In [31]:
pid = df_1[0]['product_id'][0]
q = df_1[0]['query'][0]
get_query_id({"product_id":pid, "query":q})
p,o = qclient.scroll(
        collection_name="ondc-query-gen",
        scroll_filter=models.Filter(
            must=[
                models.FieldCondition(
                    key="product_id",
                    match=models.MatchValue(value=pid)
                ),
            ]
        )
    )
p[0].id

'1a5af65b-b54f-4db1-975c-2e925b58f367'

In [13]:
df_1 = pl.concat([qid1,df_1],how="horizontal")

In [15]:
df_1.write_parquet("qrel_1.parquet",compression="zstd")

In [34]:
qid = df_1["query_id"][0]
# qclient.search(
#     collection_name="ondc-query-gen",
#     query=pl.col("query").filter(pl.col("query_id")==qid).first(),
#     limit=10
# )

In [None]:
import requests

a = requests.post(
    url="http://localhost:6333/collections/ondc-query-gen/points/recommend",
    json={
        "limit": 10,
        "positive": ["00005bc8-fb19-4806-bcf6-b0ae0fb123a2"],
        "offset": 0,
        # "with_payload": True,
        # "lookup_from":{
        #     "collection": "ondc-index"
        # },
        "with_vector": False,
        # 'with_payload': ["product_name"],
        "using": "sparse",
    },
)
a.json()


In [50]:
import joblib
all_points = joblib.load("all_points.pkl")

In [55]:
all_points = [models.PointStruct(id=idx, payload=p.payload, vector=p.vector) for idx,p in enumerate(all_points)]

In [56]:
qclient.upload_points(
    collection_name="ondc-query",
    points=all_points
)