In [19]:
!apt-get install git-lfs  # or: !git lfs install
!git lfs install
!pip install faiss-cpu faiss-gpu-cu12 transformers sentence-transformers

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.3).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.
Git LFS initialized.


In [20]:
import os
import warnings
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import faiss
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder
warnings.filterwarnings('ignore')

In [21]:
if not os.path.exists("esci-data"):
    !git clone https://github.com/amazon-science/esci-data.git

In [22]:
MODELS = {
    "bge-small-en": "BAAI/bge-small-en",
    "mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2"
}
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
USE_RERANKER = True
PRODUCT_LOCALE = "us"
ESCI_LABEL = "E"
N_QUERIES = 50
N_ROWS = 500
TOP_K = 10
RANDOM_STATE = 42

In [23]:
def read_input_files(base_path="esci-data/shopping_queries_dataset"):
    print("\nStep 1 Reading input files")
    files = {
        "examples": os.path.join(base_path, "shopping_queries_dataset_examples.parquet"),
        "products": os.path.join(base_path, "shopping_queries_dataset_products.parquet"),
    }
    data = {}
    for name, path in files.items():
        print(f"Loading {name} from {path}")
        df = pd.read_parquet(path) if path.endswith(".parquet") else pd.read_csv(path)
        print(f"Shape: {df.shape}")
        data[name] = df
    return data

In [24]:
def filter_examples(df):
    print("\nStep 2 Filtering examples")
    filtered = df[(df['product_locale'] == PRODUCT_LOCALE) & (df['esci_label'] == ESCI_LABEL)].copy()
    print(f"Filtered rows: {filtered.shape[0]}")
    return filtered

In [25]:
def sample_queries_and_rows(df):
    print("\nStep 3 Sampling queries and rows")
    np.random.seed(RANDOM_STATE)
    unique_queries = df['query'].unique()
    print(f"Total unique queries: {len(unique_queries)}")
    selected_queries = np.random.choice(unique_queries, N_QUERIES, replace=False)
    print(f"Selected queries: {len(selected_queries)}")
    df_filtered = df[df['query'].isin(selected_queries)]
    df_sampled = df_filtered.sample(min(N_ROWS, len(df_filtered)), random_state=RANDOM_STATE).reset_index(drop=True)
    print(f"Sampled rows: {df_sampled.shape[0]}")
    return df_sampled, selected_queries

In [26]:
def create_product_text(df_products):
    print("\nStep 4 Creating combined product text")
    text_cols = [c for c in df_products.columns if c.startswith("product_")]
    print(f"Text columns: {text_cols}")
    df = df_products.copy()
    df['combined_text'] = df[text_cols].astype(str).apply(
        lambda row: ' | '.join([v for v in row if v and v != 'nan']), axis=1
    )
    print(f"Created combined_text for {len(df)} products")
    return df

In [27]:
def generate_embeddings(df_products, model):
    print(f"Generating embeddings for {len(df_products)} products")
    embeddings = model.encode(df_products['combined_text'].tolist(),
                               batch_size=64, show_progress_bar=True, convert_to_numpy=True)
    print(f"Embeddings shape: {embeddings.shape}")
    return embeddings.astype(np.float32)

In [28]:
def create_ground_truth(df_sampled, df_products):
    print("\nStep 5 Creating ground truth mapping")
    pid_to_idx = {pid: i for i, pid in enumerate(df_products['product_id'])}
    ground_truth = {}
    for query in df_sampled['query'].unique():
        indices = [pid_to_idx[pid] for pid in df_sampled[df_sampled['query'] == query]['product_id'] if pid in pid_to_idx]
        if indices:
            ground_truth[query] = set(indices)
    print(f"  Ground truth created for {len(ground_truth)} queries")
    return ground_truth

In [29]:
def build_search_index(embeddings):
    print("\nStep 6 Building FAISS index")
    faiss.normalize_L2(embeddings)
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings)
    print(f"Index built with {embeddings.shape[0]} vectors of dim {dim}")
    return index

In [30]:
def perform_search(index, query_embeddings, top_k=TOP_K):
    print(f"Performing FAISS search for {len(query_embeddings)} queries...")
    faiss.normalize_L2(query_embeddings)
    scores, indices = index.search(query_embeddings, top_k)
    print(f"Search complete. Returned shape: {indices.shape}")
    return scores, indices

In [31]:
def rerank_results(queries, search_indices, df_products_text):
    print("\n Step 7 Applying cross-encoder re-ranking...")
    reranker = CrossEncoder(RERANKER_MODEL)
    reranked_indices = []
    for i, (q, retrieved_idxs) in enumerate(zip(queries, search_indices), 1):
        candidates = [(q, df_products_text.iloc[idx]['combined_text']) for idx in retrieved_idxs]
        scores = reranker.predict(candidates)
        sorted_pairs = sorted(zip(retrieved_idxs, scores), key=lambda x: x[1], reverse=True)
        reranked_indices.append([idx for idx, _ in sorted_pairs])
        if i % 10 == 0 or i == len(queries):
            print(f"    Reranked {i}/{len(queries)} queries")
    print("Re-ranking complete.")
    return np.array(reranked_indices)

In [32]:
def evaluate(search_indices, queries, ground_truth, k_values=[1, 5, 10]):
    print("\nStep 8 Evaluating results...")
    results = {f"hits@{k}": [] for k in k_values}
    mrr_scores = []
    for i, q in enumerate(queries):
        if q not in ground_truth:
            continue
        relevant = ground_truth[q]
        retrieved = search_indices[i]
        for k in k_values:
            results[f"hits@{k}"].append(1 if relevant & set(retrieved[:k]) else 0)
        for rank, idx in enumerate(retrieved, 1):
            if idx in relevant:
                mrr_scores.append(1.0 / rank)
                break
        else:
            mrr_scores.append(0.0)
    print(f"Evaluation complete. MRR: {np.mean(mrr_scores):.4f}")
    return {k: np.mean(v) for k, v in results.items()} | {"mrr": np.mean(mrr_scores)}

In [33]:
def run_pipeline_for_model(model_name, model_path, df_sampled, selected_queries, df_products_text, ground_truth):
    print(f"\n Running pipeline for model: {model_name} ")
    model = SentenceTransformer(model_path)

    # Product embeddings
    product_embeddings = generate_embeddings(df_products_text, model)

    # Query embeddings
    print(f"Generating query embeddings for {len(selected_queries)} queries...")
    query_embeddings = model.encode(selected_queries, batch_size=32, convert_to_numpy=True).astype(np.float32)

    # Search
    index = build_search_index(product_embeddings)
    _, indices = perform_search(index, query_embeddings)

    # Optional re-ranking
    if USE_RERANKER:
        indices = rerank_results(selected_queries, indices, df_products_text)

    # Evaluate
    results = evaluate(indices, selected_queries, ground_truth)
    print(f"Completed model: {model_name} \n")
    return results

In [34]:
def run_comparison():
    datasets = read_input_files()
    df_examples = datasets["examples"]
    df_products = datasets["products"]

    df_filtered = filter_examples(df_examples)
    df_sampled, selected_queries = sample_queries_and_rows(df_filtered)

    unique_product_ids = df_sampled['product_id'].unique()
    df_products_filtered = df_products[df_products['product_id'].isin(unique_product_ids)]
    df_products_text = create_product_text(df_products_filtered)

    ground_truth = create_ground_truth(df_sampled, df_products_text)

    all_results = {}
    for model_name, model_path in MODELS.items():
        res = run_pipeline_for_model(model_name, model_path, df_sampled, selected_queries, df_products_text, ground_truth)
        all_results[model_name] = res

    df_results = pd.DataFrame(all_results).T
    print("\nFinal Model Comparison ")
    print(df_results.to_string(float_format="%.4f"))
    return df_results

In [35]:
comparison_results = run_comparison()


Step 1 Reading input files
Loading examples from esci-data/shopping_queries_dataset/shopping_queries_dataset_examples.parquet
Shape: (2621288, 9)
Loading products from esci-data/shopping_queries_dataset/shopping_queries_dataset_products.parquet
Shape: (1814924, 7)

Step 2 Filtering examples
Filtered rows: 1247558

Step 3 Sampling queries and rows
  Total unique queries: 97344
  Selected queries: 50
  Sampled rows: 500

Step 4 Creating combined product text
Text columns: ['product_id', 'product_title', 'product_description', 'product_bullet_point', 'product_brand', 'product_color', 'product_locale']
Created combined_text for 503 products

Step 5 Creating ground truth mapping
  Ground truth created for 49 queries

 Running pipeline for model: bge-small-en 
Generating embeddings for 503 products


Batches:   0%|          | 0/8 [00:00<?, ?it/s]

Embeddings shape: (503, 384)
   Generating query embeddings for 50 queries...

Step 6 Building FAISS index
Index built with 503 vectors of dim 384
Performing FAISS search for 50 queries...
Search complete. Returned shape: (50, 10)

 Step 7 Applying cross-encoder re-ranking...


config.json:   0%|          | 0.00/791 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/133M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

    Reranked 10/50 queries
    Reranked 20/50 queries
    Reranked 30/50 queries
    Reranked 40/50 queries
    Reranked 50/50 queries
  Re-ranking complete.

Step 8 Evaluating results...
  Evaluation complete. MRR: 1.0000
 Completed model: bge-small-en 


 Running pipeline for model: mpnet-base-v2 
Generating embeddings for 503 products


Batches:   0%|          | 0/8 [00:00<?, ?it/s]

Embeddings shape: (503, 768)
   Generating query embeddings for 50 queries...

Step 6 Building FAISS index
Index built with 503 vectors of dim 768
Performing FAISS search for 50 queries...
Search complete. Returned shape: (50, 10)

 Step 7 Applying cross-encoder re-ranking...
    Reranked 10/50 queries
    Reranked 20/50 queries
    Reranked 30/50 queries
    Reranked 40/50 queries
    Reranked 50/50 queries
  Re-ranking complete.

Step 8 Evaluating results...
  Evaluation complete. MRR: 1.0000
 Completed model: mpnet-base-v2 


 Final Model Comparison 
               hits@1  hits@5  hits@10    mrr
bge-small-en   1.0000  1.0000   1.0000 1.0000
mpnet-base-v2  1.0000  1.0000   1.0000 1.0000
