In [2]:
import pandas as pd
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
import pickle
import os
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
tqdm.pandas()

# Set up paths
ROOT_DIR = Path("..")
DATA_DIR = ROOT_DIR
OUTPUT_DIR = Path(".")
OUTPUT_DIR.mkdir(exist_ok=True)

print("Libraries imported successfully!")

Libraries imported successfully!


In [3]:
# Load the examples dataset (queries and relevance labels)
examples_df = pd.read_parquet(DATA_DIR / "shopping_queries_dataset_examples.parquet")
print(f"Examples dataset shape: {examples_df.shape}")
print(f"\nColumns: {examples_df.columns.tolist()}")
print(f"\nESCI label distribution:")
print(examples_df['esci_label'].value_counts())
print(f"\nSplit distribution:")
print(examples_df['split'].value_counts())
print(f"\nFirst few rows:")
examples_df.head()

Examples dataset shape: (2621288, 9)

Columns: ['example_id', 'query', 'query_id', 'product_id', 'product_locale', 'esci_label', 'small_version', 'large_version', 'split']

ESCI label distribution:
E    1708158
S     574313
I     263165
C      75652
Name: esci_label, dtype: int64

Split distribution:
train    1983272
test      638016
Name: split, dtype: int64

First few rows:


Unnamed: 0,example_id,query,query_id,product_id,product_locale,esci_label,small_version,large_version,split
0,0,revent 80 cfm,0,B000MOO21W,us,I,0,1,train
1,1,revent 80 cfm,0,B07X3Y6B1V,us,E,0,1,train
2,2,revent 80 cfm,0,B07WDM7MQQ,us,E,0,1,train
3,3,revent 80 cfm,0,B07RH6Z8KW,us,E,0,1,train
4,4,revent 80 cfm,0,B07QJ7WYFQ,us,E,0,1,train


In [4]:
# Load the products dataset
products_df = pd.read_parquet(DATA_DIR / "shopping_queries_dataset_products.parquet")
print(f"Products dataset shape: {products_df.shape}")
print(f"\nColumns: {products_df.columns.tolist()}")
print(f"\nMissing values:")
print(products_df.isnull().sum())
print(f"\nProduct locales:")
print(products_df['product_locale'].value_counts())
print(f"\nFirst few rows:")
products_df.head()

Products dataset shape: (1814924, 7)

Columns: ['product_id', 'product_title', 'product_description', 'product_bullet_point', 'product_brand', 'product_color', 'product_locale']

Missing values:
product_id                   0
product_title                0
product_description     877799
product_bullet_point    304116
product_brand           142756
product_color           691001
product_locale               0
dtype: int64

Product locales:
us    1215854
jp     339059
es     260011
Name: product_locale, dtype: int64

First few rows:


Unnamed: 0,product_id,product_title,product_description,product_bullet_point,product_brand,product_color,product_locale
0,B079VKKJN7,"11 Degrees de los Hombres Playera con Logo, Ne...",Esta playera con el logo de la marca Carrier d...,11 Degrees Negro Playera con logo\nA estrenar ...,11 Degrees,Negro,es
1,B079Y9VRKS,Camiseta Eleven Degrees Core TS White (M),,,11 Degrees,Blanco,es
2,B07DP4LM9H,11 Degrees de los Hombres Core Pull Over Hoodi...,La sudadera con capucha Core Pull Over de 11 G...,11 Degrees Azul Core Pull Over Hoodie\nA estre...,11 Degrees,Azul,es
3,B07G37B9HP,11 Degrees Poli Panel Track Pant XL Black,,,11 Degrees,,es
4,B07LCTGDHY,11 Degrees Gorra Trucker Negro OSFA (Talla úni...,,,11 Degrees,Negro (,es


In [5]:
# Data Exploration and Analysis
print("=" * 80)
print("DATA EXPLORATION")
print("=" * 80)

print(f"\n1. Products Dataset:")
print(f"   Total products: {len(products_df):,}")
print(f"   Unique product IDs: {products_df['product_id'].nunique():,}")
print(f"   Products with descriptions: {products_df['product_description'].notna().sum():,} ({products_df['product_description'].notna().sum()/len(products_df)*100:.1f}%)")
print(f"   Products with bullet points: {products_df['product_bullet_point'].notna().sum():,} ({products_df['product_bullet_point'].notna().sum()/len(products_df)*100:.1f}%)")

print(f"\n2. Examples Dataset:")
print(f"   Total examples: {len(examples_df):,}")
print(f"   Unique queries: {examples_df['query'].nunique():,}")
print(f"   Unique products in examples: {examples_df['product_id'].nunique():,}")

# Check overlap
products_in_examples = set(examples_df['product_id'].unique())
products_in_catalog = set(products_df['product_id'].unique())
overlap = products_in_examples.intersection(products_in_catalog)
print(f"   Products in both datasets: {len(overlap):,} ({len(overlap)/len(products_in_examples)*100:.1f}% of examples)")

print(f"\n3. ESCI Label Distribution:")
print(examples_df['esci_label'].value_counts(normalize=True) * 100)

DATA EXPLORATION

1. Products Dataset:
   Total products: 1,814,924
   Unique product IDs: 1,802,772
   Products with descriptions: 937,125 (51.6%)
   Products with bullet points: 1,510,808 (83.2%)

2. Examples Dataset:
   Total examples: 2,621,288
   Unique queries: 130,193
   Unique products in examples: 1,802,772
   Products in both datasets: 1,802,772 (100.0% of examples)

3. ESCI Label Distribution:
E    65.164835
S    21.909573
I    10.039530
C     2.886062
Name: esci_label, dtype: float64


In [6]:
# Data Cleaning and Preparation
print("=" * 80)
print("DATA CLEANING AND PREPARATION")
print("=" * 80)

# Filter products to only those in the examples dataset (for consistency)
print(f"\n1. Filtering products to those in examples dataset...")
products_df_clean = products_df[products_df['product_id'].isin(products_in_examples)].copy()
print(f"   Products after filtering: {len(products_df_clean):,}")

# Handle missing values and create combined text field
print(f"\n2. Creating combined text field for products...")

def create_product_text(row):
    """Combine product fields into a single text for embedding"""
    parts = []
    
    # Title (always present)
    if pd.notna(row['product_title']):
        parts.append(str(row['product_title']).strip())
    
    # Description
    if pd.notna(row['product_description']):
        desc = str(row['product_description']).strip()
        if desc and desc.lower() != 'none':
            parts.append(desc)
    
    # Bullet points
    if pd.notna(row['product_bullet_point']):
        bullets = str(row['product_bullet_point']).strip()
        if bullets and bullets.lower() != 'none':
            parts.append(bullets)
    
    # Brand
    if pd.notna(row['product_brand']):
        brand = str(row['product_brand']).strip()
        if brand and brand.lower() != 'none':
            parts.append(f"Brand: {brand}")
    
    # Color
    if pd.notna(row['product_color']):
        color = str(row['product_color']).strip()
        if color and color.lower() != 'none':
            parts.append(f"Color: {color}")
    
    # Join all parts
    combined = " ".join(parts)
    
    # Clean up multiple spaces
    combined = " ".join(combined.split())
    
    return combined if combined else "No description available"

products_df_clean['product_text'] = products_df_clean.progress_apply(create_product_text, axis=1)

# Check text length statistics
text_lengths = products_df_clean['product_text'].str.len()
print(f"   Text length statistics:")
print(f"     Mean: {text_lengths.mean():.0f} characters")
print(f"     Median: {text_lengths.median():.0f} characters")
print(f"     Min: {text_lengths.min()} characters")
print(f"     Max: {text_lengths.max()} characters")
print(f"     Products with empty text: {(products_df_clean['product_text'].str.len() == 0).sum()}")

# Remove products with empty text
products_df_clean = products_df_clean[products_df_clean['product_text'].str.len() > 0].copy()
print(f"\n   Final product count: {len(products_df_clean):,}")

# Display sample
print(f"\n3. Sample product text:")
print("-" * 80)
sample_idx = 0
sample_product = products_df_clean.iloc[sample_idx]
print(f"Product ID: {sample_product['product_id']}")
print(f"Title: {sample_product['product_title']}")
print(f"\nCombined Text (first 200 chars):")
print(sample_product['product_text'][:200] + "...")

DATA CLEANING AND PREPARATION

1. Filtering products to those in examples dataset...
   Products after filtering: 1,814,924

2. Creating combined text field for products...


100%|██████████████████████████████| 1814924/1814924 [01:45<00:00, 17234.82it/s]


   Text length statistics:
     Mean: 1037 characters
     Median: 710 characters
     Min: 1 characters
     Max: 8803 characters
     Products with empty text: 0

   Final product count: 1,814,924

3. Sample product text:
--------------------------------------------------------------------------------
Product ID: B079VKKJN7
Title: 11 Degrees de los Hombres Playera con Logo, Negro, L

Combined Text (first 200 chars):
11 Degrees de los Hombres Playera con Logo, Negro, L Esta playera con el logo de la marca Carrier de 11 Degrees viene en negro, con el logo de la marca en el pecho y un pequeño texto en la parte poste...


In [7]:
# Prepare ground truth for evaluation
print("=" * 80)
print("PREPARING GROUND TRUTH FOR EVALUATION")
print("=" * 80)

# Filter examples to only include products that exist in our cleaned catalog
examples_df_clean = examples_df[examples_df['product_id'].isin(products_df_clean['product_id'])].copy()
print(f"Examples after filtering: {len(examples_df_clean):,}")

# Create ground truth: for each query, collect relevant product IDs
# We'll consider 'E' (Exact), 'S' (Substitute), and 'C' (Complement) as relevant
# 'I' (Irrelevant) is not relevant
print(f"\nCreating ground truth labels...")
print(f"ESCI label meanings:")
print(f"  E = Exact (relevant)")
print(f"  S = Substitute (relevant)")
print(f"  C = Complement (relevant)")
print(f"  I = Irrelevant (not relevant)")

# Create ground truth dictionary: query -> set of relevant product IDs
ground_truth = {}
for query_id, group in tqdm(examples_df_clean.groupby('query_id'), total=len(examples_df_clean['query_id'].unique())):
    query_text = group['query'].iloc[0]
    # Get relevant products (E, S, C)
    relevant_products = set(
        group[group['esci_label'].isin(['E', 'S', 'C'])]['product_id'].tolist()
    )
    if len(relevant_products) > 0:
        ground_truth[query_text] = relevant_products

print(f"\nGround truth created for {len(ground_truth):,} unique queries")
print(f"Average relevant products per query: {np.mean([len(v) for v in ground_truth.values()]):.1f}")

# Save ground truth for later use
ground_truth_file = OUTPUT_DIR / "ground_truth.pkl"
with open(ground_truth_file, 'wb') as f:
    pickle.dump(ground_truth, f)
print(f"Ground truth saved to {ground_truth_file}")

PREPARING GROUND TRUTH FOR EVALUATION
Examples after filtering: 2,621,288

Creating ground truth labels...
ESCI label meanings:
  E = Exact (relevant)
  S = Substitute (relevant)
  C = Complement (relevant)
  I = Irrelevant (not relevant)


100%|█████████████████████████████████| 130652/130652 [00:49<00:00, 2632.76it/s]



Ground truth created for 130,193 unique queries
Average relevant products per query: 18.0
Ground truth saved to ground_truth.pkl


In [9]:
# Initialize Embedding Model
print("=" * 80)
print("INITIALIZING EMBEDDING MODEL")
print("=" * 80)

# Use a lightweight but effective model
model_name = "all-MiniLM-L6-v2"  # 384 dimensions, fast and efficient
print(f"Loading sentence transformer model: {model_name}")

embedding_model = SentenceTransformer(model_name)
embedding_dim = embedding_model.get_sentence_embedding_dimension()

print(f"Model loaded successfully!")
print(f"Embedding dimension: {embedding_dim}")
print(f"Model max sequence length: {embedding_model.max_seq_length}")

INITIALIZING EMBEDDING MODEL
Loading sentence transformer model: all-MiniLM-L6-v2
Model loaded successfully!
Embedding dimension: 384
Model max sequence length: 256


In [None]:
# Generate Embeddings and Build FAISS Indices Incrementally
# This approach processes directly from dataframe to avoid memory issues
print("=" * 80)
print("GENERATING EMBEDDINGS AND BUILDING FAISS INDICES")
print("=" * 80)

import gc  # For memory cleanup

# Process directly from dataframe - no large lists in memory
n_products = len(products_df_clean)
chunk_size = 2000  # Smaller chunks for better memory management
batch_size = 32  # Smaller batch size for embedding model

print(f"Total products: {n_products:,}")
print(f"Processing in chunks of {chunk_size:,} products (directly from dataframe)")
print(f"Embedding batch size: {batch_size}")

# Initialize FAISS indices
print("\nInitializing FAISS indices...")
flat_index = faiss.IndexFlatL2(embedding_dim)

# For IVF, we need to train first
nlist = min(100, n_products // 10)
quantizer = faiss.IndexFlatL2(embedding_dim)
ivf_index = faiss.IndexIVFFlat(quantizer, embedding_dim, nlist)
ivf_trained = False

# HNSW can be built incrementally
m = 32
ef_construction = 200
ef_search = 50
hnsw_index = faiss.IndexHNSWFlat(embedding_dim, m)
hnsw_index.hnsw.efConstruction = ef_construction
hnsw_index.hnsw.efSearch = ef_search

# Process dataframe in chunks directly
print(f"\nProcessing embeddings and building indices...")
n_chunks = (n_products + chunk_size - 1) // chunk_size

# Store product IDs as we process (needed for later)
product_ids = []

for chunk_idx in tqdm(range(n_chunks), desc="Processing chunks"):
    start_idx = chunk_idx * chunk_size
    end_idx = min((chunk_idx + 1) * chunk_size, n_products)
    
    # Get chunk directly from dataframe (only loads this chunk into memory)
    chunk_df = products_df_clean.iloc[start_idx:end_idx].copy()
    chunk_texts = chunk_df['product_text'].tolist()  # Small list, only for this chunk
    chunk_ids = chunk_df['product_id'].tolist()
    product_ids.extend(chunk_ids)  # Store IDs for later use
    
    # Generate embeddings for this chunk
    chunk_embeddings = embedding_model.encode(
        chunk_texts,
        batch_size=batch_size,
        show_progress_bar=False,
        convert_to_numpy=True,
        normalize_embeddings=True,
        device='cpu'
    ).astype('float32')
    
    # Add to Flat index
    flat_index.add(chunk_embeddings)
    
    # For IVF: train on first chunk, then add incrementally
    if not ivf_trained:
        print(f"   Training IVF index on first chunk...")
        ivf_index.train(chunk_embeddings)
        ivf_trained = True
    ivf_index.add(chunk_embeddings)
    
    # Add to HNSW index
    hnsw_index.add(chunk_embeddings)
    
    # Clean up memory
    del chunk_df, chunk_texts, chunk_embeddings
    gc.collect()
    
    # Print progress every 20 chunks
    if (chunk_idx + 1) % 20 == 0:
        print(f"   Processed {end_idx:,} / {n_products:,} products ({end_idx/n_products*100:.1f}%)")

# Set IVF nprobe
ivf_index.nprobe = 10

print(f"\nAll embeddings generated and indices built!")
print(f"  Flat index: {flat_index.ntotal:,} vectors")
print(f"  IVF index: {ivf_index.ntotal:,} vectors (nlist={nlist}, nprobe=10)")
print(f"  HNSW index: {hnsw_index.ntotal:,} vectors (m={m})")
print(f"  Product IDs stored: {len(product_ids):,}")

GENERATING EMBEDDINGS AND BUILDING FAISS INDICES
Total products: 1,814,924
Processing in chunks of 2,000 products (directly from dataframe)
Embedding batch size: 32

Initializing FAISS indices...

Processing embeddings and building indices...



Processing chunks:   0%|                                | 0/908 [00:00<?, ?it/s]

In [None]:
# Index Statistics
print("=" * 80)
print("INDEX STATISTICS")
print("=" * 80)

print(f"\n1. FlatL2 Index (Exact Search):")
print(f"   Vectors: {flat_index.ntotal:,}")
print(f"   Index size: {flat_index.ntotal * embedding_dim * 4 / 1024 / 1024:.2f} MB")

print(f"\n2. IVF Index (Approximate Search):")
print(f"   Vectors: {ivf_index.ntotal:,}")
print(f"   Parameters: nlist={nlist}, nprobe=10")

print(f"\n3. HNSW Index (Approximate Search):")
print(f"   Vectors: {hnsw_index.ntotal:,}")
print(f"   Parameters: m={m}, ef_construction={ef_construction}, ef_search={ef_search}")

print("\n✓ All indices ready for use!")

In [None]:
# Test the Indices
print("=" * 80)
print("TESTING THE INDICES")
print("=" * 80)

# Get a sample query from the examples
test_query = list(ground_truth.keys())[0]
print(f"\nTest query: '{test_query}'")

# Generate query embedding
query_embedding = embedding_model.encode(
    [test_query],
    convert_to_numpy=True,
    normalize_embeddings=True
)[0].astype('float32')

print(f"Query embedding shape: {query_embedding.shape}")

# Search in each index
k = 10
print(f"\nSearching for top {k} results in each index...")

# Flat index
import time
start = time.time()
flat_distances, flat_indices = flat_index.search(query_embedding.reshape(1, -1), k)
flat_time = time.time() - start
print(f"\n1. Flat Index:")
print(f"   Search time: {flat_time*1000:.2f}ms")
print(f"   Top {k} results:")
for i, (dist, idx) in enumerate(zip(flat_distances[0], flat_indices[0]), 1):
    product_id = product_ids[idx]
    product_title = products_df_clean[products_df_clean['product_id'] == product_id]['product_title'].iloc[0]
    print(f"     {i}. [{product_id}] {product_title[:60]}... (distance: {dist:.4f})")

# IVF index
start = time.time()
ivf_distances, ivf_indices = ivf_index.search(query_embedding.reshape(1, -1), k)
ivf_time = time.time() - start
print(f"\n2. IVF Index:")
print(f"   Search time: {ivf_time*1000:.2f}ms")
print(f"   Top {k} results:")
for i, (dist, idx) in enumerate(zip(ivf_distances[0], ivf_indices[0]), 1):
    product_id = product_ids[idx]
    product_title = products_df_clean[products_df_clean['product_id'] == product_id]['product_title'].iloc[0]
    print(f"     {i}. [{product_id}] {product_title[:60]}... (distance: {dist:.4f})")

# HNSW index
start = time.time()
hnsw_distances, hnsw_indices = hnsw_index.search(query_embedding.reshape(1, -1), k)
hnsw_time = time.time() - start
print(f"\n3. HNSW Index:")
print(f"   Search time: {hnsw_time*1000:.2f}ms")
print(f"   Top {k} results:")
for i, (dist, idx) in enumerate(zip(hnsw_distances[0], hnsw_indices[0]), 1):
    product_id = product_ids[idx]
    product_title = products_df_clean[products_df_clean['product_id'] == product_id]['product_title'].iloc[0]
    print(f"     {i}. [{product_id}] {product_title[:60]}... (distance: {dist:.4f})")

print(f"\nSpeed comparison:")
print(f"  Flat:  {flat_time*1000:.2f}ms")
print(f"  IVF:   {ivf_time*1000:.2f}ms ({flat_time/ivf_time:.2f}x faster)")
print(f"  HNSW:  {hnsw_time*1000:.2f}ms ({flat_time/hnsw_time:.2f}x faster)")

In [None]:
# Save Everything to Disk
print("=" * 80)
print("SAVING TO DISK")
print("=" * 80)

# Create output directory
output_dir = OUTPUT_DIR / "data"
output_dir.mkdir(exist_ok=True)

# 1. Save FAISS indices
print("\n1. Saving FAISS indices...")

# Flat index
flat_index_file = output_dir / "faiss_index_flat.bin"
faiss.write_index(flat_index, str(flat_index_file))
print(f"   Flat index saved to: {flat_index_file}")

# IVF index
ivf_index_file = output_dir / "faiss_index_ivf.bin"
faiss.write_index(ivf_index, str(ivf_index_file))
print(f"   IVF index saved to: {ivf_index_file}")

# HNSW index
hnsw_index_file = output_dir / "faiss_index_hnsw.bin"
faiss.write_index(hnsw_index, str(hnsw_index_file))
print(f"   HNSW index saved to: {hnsw_index_file}")

# 2. Save product IDs mapping
print("\n2. Saving product IDs mapping...")
product_ids_file = output_dir / "product_ids.pkl"
with open(product_ids_file, 'wb') as f:
    pickle.dump(product_ids, f)
print(f"   Product IDs saved to: {product_ids_file}")

# 3. Save cleaned products dataframe
print("\n3. Saving cleaned products dataframe...")
products_file = output_dir / "products_clean.parquet"
products_df_clean.to_parquet(products_file, index=False)
print(f"   Products dataframe saved to: {products_file}")

# 4. Save metadata
print("\n4. Saving metadata...")
metadata = {
    'embedding_dim': embedding_dim,
    'model_name': model_name,
    'n_products': len(product_ids),
    'index_types': ['flat', 'ivf', 'hnsw'],
    'ivf_params': {'nlist': nlist, 'nprobe': nprobe},
    'hnsw_params': {'m': m, 'ef_construction': ef_construction, 'ef_search': ef_search}
}

metadata_file = output_dir / "metadata.pkl"
with open(metadata_file, 'wb') as f:
    pickle.dump(metadata, f)
print(f"   Metadata saved to: {metadata_file}")

# 5. Save ground truth (already saved earlier, but confirm)
print("\n5. Ground truth already saved to: ground_truth.pkl")

print("\n" + "=" * 80)
print("ALL DATA SAVED SUCCESSFULLY!")
print("=" * 80)
print(f"\nOutput directory: {output_dir}")
print(f"\nFiles created:")
print(f"  - faiss_index_flat.bin")
print(f"  - faiss_index_ivf.bin")
print(f"  - faiss_index_hnsw.bin")
print(f"  - product_ids.pkl")
print(f"  - products_clean.parquet")
print(f"  - metadata.pkl")
print(f"  - ../ground_truth.pkl")

In [None]:
# Summary Statistics
print("=" * 80)
print("SUMMARY")
print("=" * 80)

print(f"\nDataset Statistics:")
print(f"  Original products: {len(products_df):,}")
print(f"  Cleaned products: {len(products_df_clean):,}")
print(f"  Products with embeddings: {flat_index.ntotal:,}")
print(f"  Embedding dimension: {embedding_dim}")
print(f"  Estimated embedding size: {flat_index.ntotal * embedding_dim * 4 / 1024 / 1024:.2f} MB")

print(f"\nIndex Statistics:")
print(f"  Flat index vectors: {flat_index.ntotal:,}")
print(f"  IVF index vectors: {ivf_index.ntotal:,} (nlist={nlist}, nprobe=10)")
print(f"  HNSW index vectors: {hnsw_index.ntotal:,} (m={m})")

print(f"\nGround Truth:")
print(f"  Unique queries: {len(ground_truth):,}")
print(f"  Average relevant products per query: {np.mean([len(v) for v in ground_truth.values()]):.1f}")

print(f"\nOutput Files:")
print(f"  All files saved to: {output_dir}")
print(f"  Ready for retrieval and evaluation!")

print("\n" + "=" * 80)
print("DATA PREPARATION COMPLETE!")
print("=" * 80)