In [None]:
# Cell 1
import os
import math
import random
import sys
from multiprocessing import Pool, cpu_count
from datasketch import MinHash, MinHashLSH
from pymongo import MongoClient, UpdateOne
from tqdm import tqdm
import numpy as np

# ---------------- CONFIG ----------------
MONGO_URI = os.getenv("MONGO_URI", "mongodb://localhost:27017/cs5600")
PRODUCTS_COLL = os.getenv("PRODUCTS_COLL", "products")
SIG_COLL = os.getenv("SIG_COLL", "productsignatures")

K_VALUES = [2, 3, 5, 7, 10]
HASH_VALUES = [10, 20, 50, 100, 150]
NUM_BANDS_VALUES = [4, 5, 10, 15, 20, 25]
MAX_HASHES = max(HASH_VALUES)
NUM_WORKERS = max(1, min(cpu_count() - 1, 12))
TUNING_SAMPLE = None
TOP_K = 10
BULK_BATCH = 500
# ----------------------------------------


In [None]:
# Cell 2
client = MongoClient(MONGO_URI)
db = client.get_default_database()
products_col = db[PRODUCTS_COLL]
signatures_col = db[SIG_COLL]


In [None]:
# Cell 3
def get_char_shingles(text, k):
    if not text:
        return []
    clean = " ".join(str(text).lower().split())
    if len(clean) < k:
        return []
    return [clean[i:i+k] for i in range(len(clean) - k + 1)]

def compute_minhash(text, k, num_perm=MAX_HASHES):
    m = MinHash(num_perm=num_perm)
    shingles = get_char_shingles(text, k)
    for s in shingles:
        m.update(s.encode("utf8"))
    return m


In [None]:
# Cell 4
def _worker_precompute(product):
    asin = product.get("asin")
    title = product.get("title") or ""
    desc_field = product.get("description") or ""
    desc = " ".join(desc_field) if isinstance(desc_field, list) else desc_field
    hybrid = (title + " " + desc).strip()

    out = {"asin": asin}
    for k in K_VALUES:
        out[f"pst_k{k}"] = compute_minhash(title, k)
        out[f"psd_k{k}"] = compute_minhash(desc, k)
        out[f"pstd_k{k}"] = compute_minhash(hybrid, k)
    return out

def precompute_all(products):
    print(f"⚡ Precomputing MinHash objects ({NUM_WORKERS} workers)...")
    results = []
    with Pool(processes=NUM_WORKERS) as pool:
        for r in tqdm(pool.imap_unordered(_worker_precompute, products),
                      total=len(products), file=sys.stdout):
            results.append(r)

    all_sigs = {}
    for r in results:
        asin = r["asin"]
        all_sigs[asin] = {k: {"pst": r[f"pst_k{k}"], 
                              "psd": r[f"psd_k{k}"], 
                              "pstd": r[f"pstd_k{k}"]} for k in K_VALUES}
    print("✅ Precompute finished.")
    return all_sigs


In [None]:
# # Cell 5
# def build_lsh_cached(all_sigs, k, num_hashes, num_bands):
#     """
#     Build separate LSHs for pst, psd, pstd using precomputed MinHash objects.
#     Returns (lsh_pst, lsh_psd, lsh_pstd), asin_to_minhash dict
#     """
#     r = max(1, num_hashes // num_bands)
#     lsh_pst = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
#     lsh_psd = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
#     lsh_pstd = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
#     asin_to_minhash = {"pst": {}, "psd": {}, "pstd": {}}

#     for asin, sigs in all_sigs.items():
#         # pst
#         m_pst = MinHash(num_perm=num_hashes)
#         m_pst.hashvalues = sigs[k]["pst"].hashvalues[:num_hashes]
#         lsh_pst.insert(asin, m_pst)
#         asin_to_minhash["pst"][asin] = m_pst
#         # psd
#         m_psd = MinHash(num_perm=num_hashes)
#         m_psd.hashvalues = sigs[k]["psd"].hashvalues[:num_hashes]
#         lsh_psd.insert(asin, m_psd)
#         asin_to_minhash["psd"][asin] = m_psd
#         # pstd
#         m_pstd = MinHash(num_perm=num_hashes)
#         m_pstd.hashvalues = sigs[k]["pstd"].hashvalues[:num_hashes]
#         lsh_pstd.insert(asin, m_pstd)
#         asin_to_minhash["pstd"][asin] = m_pstd

#     return (lsh_pst, lsh_psd, lsh_pstd), asin_to_minhash

def build_lsh_all(all_sigs, k, num_hashes, num_bands):
    r = max(1, num_hashes // num_bands)
    
    lsh_pst = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
    lsh_psd = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
    lsh_pstd = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
    
    asin_to_minhash = {"pst": {}, "psd": {}, "pstd": {}}
    
    for asin, sigs in all_sigs.items():
        # Slice MinHash to match num_hashes
        m_pst = MinHash(num_perm=num_hashes)
        m_pst.hashvalues = sigs[k]["pst"].hashvalues[:num_hashes]
        lsh_pst.insert(asin, m_pst)
        asin_to_minhash["pst"][asin] = m_pst
        
        m_psd = MinHash(num_perm=num_hashes)
        m_psd.hashvalues = sigs[k]["psd"].hashvalues[:num_hashes]
        lsh_psd.insert(asin, m_psd)
        asin_to_minhash["psd"][asin] = m_psd
        
        m_pstd = MinHash(num_perm=num_hashes)
        m_pstd.hashvalues = sigs[k]["pstd"].hashvalues[:num_hashes]
        lsh_pstd.insert(asin, m_pstd)
        asin_to_minhash["pstd"][asin] = m_pstd
        
    return (lsh_pst, lsh_psd, lsh_pstd), asin_to_minhash



In [None]:
# Cell 6
def get_topk_similars(asin, m, lsh, asin_to_minhash, top_k=TOP_K):
    candidates = lsh.query(m)
    sims = []
    for cand in candidates:
        if cand == asin:
            continue
        score = m.jaccard(asin_to_minhash[cand])
        sims.append((score, cand))
    sims.sort(reverse=True, key=lambda x: x[0])
    return [{"asin": c, "score": float(s)} for s, c in sims[:top_k]]


In [None]:
# # Cell 7
# def final_compute_and_write(products, all_sigs, best, signatures_col, top_k=TOP_K, bulk_batch=BULK_BATCH):
#     print("\n🚀 Computing top-10 similars with LSH...")
#     k = best["k"]
#     num_hashes = best["num_hashes"]
#     num_bands = best["num_bands"]

#     (lsh_pst, lsh_psd, lsh_pstd), asin_to_minhash = build_lsh_cached(all_sigs, k, num_hashes, num_bands)
#     updates = []

#     for p in tqdm(products, total=len(products), file=sys.stdout):
#         asin = p["asin"]
#         sigs = all_sigs[asin][k]
#         similars = {
#             "pst": get_topk_similars(asin, sigs["pst"], lsh_pst, asin_to_minhash["pst"], top_k),
#             "psd": get_topk_similars(asin, sigs["psd"], lsh_psd, asin_to_minhash["psd"], top_k),
#             "pstd": get_topk_similars(asin, sigs["pstd"], lsh_pstd, asin_to_minhash["pstd"], top_k),
#         }
#         updates.append(UpdateOne(
#             {"asin": asin},
#             {"$set": {
#                 "asin": asin,
#                 "pst_sig": sigs["pst"].hashvalues[:num_hashes].tolist(),
#                 "psd_sig": sigs["psd"].hashvalues[:num_hashes].tolist(),
#                 "pstd_sig": sigs["pstd"].hashvalues[:num_hashes].tolist(),
#                 "similar": similars
#             }},
#             upsert=True
#         ))
#         if len(updates) >= bulk_batch:
#             signatures_col.bulk_write(updates)
#             updates = []

#     if updates:
#         signatures_col.bulk_write(updates)
#     print("✅ All results saved to DB.")

def final_compute_and_write(products, all_sigs, best, signatures_col, top_k=TOP_K, bulk_batch=BULK_BATCH):
    print("\n🚀 Computing top-k similars with LSH...")
    
    k = best["k"]
    num_hashes = best["num_hashes"]
    num_bands = best["num_bands"]
    
    (lsh_pst, lsh_psd, lsh_pstd), asin_to_minhash = build_lsh_all(all_sigs, k, num_hashes, num_bands)
    
    updates = []
    for p in tqdm(products, total=len(products), file=sys.stdout):
        asin = p["asin"]
        sigs = all_sigs[asin][k]
        
        # Slice MinHash objects before querying
        m_pst = MinHash(num_perm=num_hashes)
        m_pst.hashvalues = sigs["pst"].hashvalues[:num_hashes]
        
        m_psd = MinHash(num_perm=num_hashes)
        m_psd.hashvalues = sigs["psd"].hashvalues[:num_hashes]
        
        m_pstd = MinHash(num_perm=num_hashes)
        m_pstd.hashvalues = sigs["pstd"].hashvalues[:num_hashes]
        
        similars = {
            "pst": get_topk_similars(asin, m_pst, lsh_pst, asin_to_minhash["pst"], top_k),
            "psd": get_topk_similars(asin, m_psd, lsh_psd, asin_to_minhash["psd"], top_k),
            "pstd": get_topk_similars(asin, m_pstd, lsh_pstd, asin_to_minhash["pstd"], top_k),
        }
        
        updates.append(UpdateOne(
            {"asin": asin},
            {"$set": {
                "asin": asin,
                "pst_sig": sigs["pst"].hashvalues[:num_hashes].tolist(),
                "psd_sig": sigs["psd"].hashvalues[:num_hashes].tolist(),
                "pstd_sig": sigs["pstd"].hashvalues[:num_hashes].tolist(),
                "similar": similars
            }},
            upsert=True
        ))
        
        if len(updates) >= bulk_batch:
            signatures_col.bulk_write(updates)
            updates = []
    
    if updates:
        signatures_col.bulk_write(updates)
    
    print("✅ All results saved to DB.")





In [None]:
# # Cell 8
# def tune_best_model(products, all_sigs, top_k=TOP_K, sample_size=TUNING_SAMPLE):
#     print("📊 Tuning best (k, num_hashes, num_bands) based on Precision@10...")
#     sampled = random.sample(products, sample_size) if sample_size < len(products) else products
#     best_models = {"pst": None, "psd": None, "pstd": None}

#     for key in ["pst", "psd", "pstd"]:
#         best_score = -1
#         for k in K_VALUES:
#             for num_hashes in HASH_VALUES:
#                 for num_bands in NUM_BANDS_VALUES:
#                     r = max(1, num_hashes // num_bands)
#                     # Build LSH only once per combination
#                     lsh = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
#                     asin_to_mh = {}
#                     for p in sampled:
#                         asin = p["asin"]
#                         m = all_sigs[asin][k][key]
#                         m_sliced = MinHash(num_perm=num_hashes)
#                         m_sliced.hashvalues = m.hashvalues[:num_hashes]
#                         lsh.insert(asin, m_sliced)
#                         asin_to_mh[asin] = m_sliced
#                     # Evaluate Precision@10
#                     total_prec = 0
#                     for p in sampled:
#                         asin = p["asin"]
#                         topk = get_topk_similars(asin, asin_to_mh[asin], lsh, asin_to_mh, top_k)
#                         # Check overlap with similar_asins from DB
#                         actual = set(p.get("similar_asins", []))
#                         pred = set([x["asin"] for x in topk])
#                         # if actual:
#                         #     total_prec += len(pred & actual) / min(len(pred), len(actual))
#                         if actual and pred:
#                             total_prec += len(pred & actual) / min(len(pred), len(actual))
#                     avg_prec = total_prec / len([p for p in sampled if p.get("similar_asins")])
#                     if avg_prec > best_score:
#                         best_score = avg_prec
#                         best_models[key] = {"k": k, "num_hashes": num_hashes, "num_bands": num_bands, "r": r, "precision": avg_prec}
#         print(f"✅ Best model for {key}: {best_models[key]}")
#     return best_models

def tune_best_model(products, all_sigs, top_k=TOP_K, sample_size=TUNING_SAMPLE):
    print("📊 Tuning best (k, num_hashes, num_bands) based on Precision@10...")
    sampled = products
    best_models = {"pst": None, "psd": None, "pstd": None}

    for key in ["pst", "psd", "pstd"]:
        best_score = -1
        for k in K_VALUES:
            for num_hashes in HASH_VALUES:
                for num_bands in NUM_BANDS_VALUES:
                    r = max(1, num_hashes // num_bands)

                    # Skip invalid LSH params
                    if num_bands * r > num_hashes:
                        continue

                    # Build LSH
                    lsh = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
                    asin_to_mh = {}
                    for p in sampled:
                        asin = p["asin"]
                        m = all_sigs[asin][k][key]
                        m_sliced = MinHash(num_perm=num_hashes)
                        m_sliced.hashvalues = m.hashvalues[:num_hashes]
                        lsh.insert(asin, m_sliced)
                        asin_to_mh[asin] = m_sliced

                    # Evaluate Precision@10
                    total_prec = 0
                    valid_count = 0
                    for p in sampled:
                        asin = p["asin"]
                        topk = get_topk_similars(asin, asin_to_mh[asin], lsh, asin_to_mh, top_k)
                        actual = set(p.get("similar_asins", []))
                        pred = set([x["asin"] for x in topk])
                        if actual and pred:
                            total_prec += len(pred & actual) / min(len(pred), len(actual))
                            valid_count += 1

                    if valid_count == 0:
                        continue

                    avg_prec = total_prec / valid_count
                    if avg_prec > best_score:
                        best_score = avg_prec
                        best_models[key] = {"k": k, "num_hashes": num_hashes, "num_bands": num_bands, "r": r, "precision": avg_prec}

        print(f"✅ Best model for {key}: {best_models[key]}")
    return best_models



In [None]:
# Cell 9
products = list(products_col.find({}, {"asin":1, "title":1, "description":1, "similar_asins":1}))
print(f"📦 Found {len(products)} products.")

if products:
    all_sigs = precompute_all(products)


In [None]:
if products:
    best_models = tune_best_model(products, all_sigs)

In [None]:
if products:
    for key in ["pst", "psd", "pstd"]:
        print(f"\nProcessing top-{TOP_K} for {key} using best model...")
        final_compute_and_write(products, all_sigs, best_models[key], signatures_col, top_k=TOP_K)
