In [1]:
import os

PATH = os.getcwd() + "/.cache/huggingface"
os.environ["HF_HOME"] = PATH
os.environ["HF_DATASETS_CACHE"] = PATH
os.environ["TORCH_HOME"] = PATH
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0"
os.environ["HIP_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from tqdm.notebook import tqdm
from qdrant_client import QdrantClient
from qdrant_client.models import (
    PointStruct,
    Distance,
    VectorParams,
    SparseVectorParams,
    Modifier,
)
import pandas as pd
import math
from tqdm.notebook import tqdm
from pprint import pprint
from fastembed.sparse.bm25 import Bm25

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("../models/all-MiniLM-L6-v2-tokenizer")
model = AutoModel.from_pretrained("../models/all-MiniLM-L6-v2-model")
model = model.to(DEVICE)

bm25_embed_model = Bm25("Qdrant/bm25", cache_dir=".cache/")

In [3]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )


def dense_embedding(texts: list[str]):
    encoded_queries = tokenizer(
        texts, padding=True, truncation=True, return_tensors="pt"
    )

    with torch.no_grad():
        queries_outputs = model(**encoded_queries)

    embeddings = mean_pooling(queries_outputs, encoded_queries["attention_mask"])
    embeddings = F.normalize(embeddings, p=2, dim=1)
    return embeddings.cpu().numpy()


def sparse_embedding(texts: list[str]):
    embedding = list(bm25_embed_model.passage_embed(texts))
    return embedding

In [4]:
COLLECTION_NAME = "product_collection_all-MiniLM-L6-v2"

In [9]:
client = QdrantClient(url="http://localhost:6333", timeout=600)
client.create_collection(
    collection_name=COLLECTION_NAME,
    vectors_config={
        "modern_bert_base": VectorParams(size=384, distance=Distance.COSINE)
    },
    sparse_vectors_config={"bm25": SparseVectorParams(modifier=Modifier.IDF)},
)

True

In [10]:
product_info_df = pd.read_csv("./final_5000_products.csv")
product_info_df.head(5)

Unnamed: 0,title,price,description
0,ZKTeco uFace302 Multi-Biometric T&A Access Con...,28500,ZKTeco uFace302 Multi-Biometric T&A Access Con...
1,HP 682 Black Original Ink Advantage Cartridge,1300,"HP 682 Black Original Ink Advantage Cartridge,..."
2,HP 508A Yellow Original LaserJet Toner (Bundle...,24500,What is the price of HP 508A Yellow Toner in B...
3,Transcend ESD300P 1TB Type-C Portable SSD,10500,"Transcend ESD300P 1TB Type-C Portable SSD,The ..."
4,HP 307A Yellow LaserJet Toner Cartridge,30500,"HP 307A Yellow LaserJet Toner Cartridge,HP 307..."


In [15]:
total_row = product_info_df.shape[0]
batch_size = 20
total_batch = math.ceil(total_row / batch_size)

In [16]:
start_idx = 0
progress_bar = tqdm(total=total_batch)
cnt = 0

while True:
    end_idx = min(total_row, start_idx + batch_size)
    # print(f"{start_idx=} {end_idx=}")

    batch = product_info_df.iloc[start_idx:end_idx]
    
    titles = batch["title"].tolist()
    descriptions = batch["description"].tolist()
    prices = batch["price"].tolist()

    texts_for_embedding = [f"Title:{title} | Description: {description} | Price: {price}" for title, description, price in zip(titles, descriptions, prices)]
    dense_vectors = dense_embedding(texts_for_embedding)
    sparse_vectors = sparse_embedding(texts_for_embedding)

    
    points = []
    for idx, (batch_idx, row) in enumerate(batch.iterrows()):
        title = row["title"]
        description = row["description"]
        price = row["price"]

        points.append(
            PointStruct(
                id=batch_idx,
                vector={
                    "modern_bert_base": dense_vectors[idx],
                    "bm25": sparse_vectors[idx].as_object()
                },
                payload={
                    "title": title,
                    "description": description,
                    "price": price,
                },
            )
        )

    operation_info = client.upsert(
        collection_name=COLLECTION_NAME, wait=True, points=points
    )
    print(operation_info, end="\r")
    progress_bar.update(1)

    if end_idx >= total_row:
        break

    start_idx = end_idx

  0%|          | 0/250 [00:00<?, ?it/s]

operation_id=249 status=<UpdateStatus.COMPLETED: 'completed'>