In [14]:
# ================= Cell 1: Imports & Config =================
import os
import sys
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from datasketch import MinHash, MinHashLSH
from pymongo import MongoClient, UpdateOne

# ---------------- CONFIG ----------------
MONGO_URI = os.getenv("MONGO_URI", "mongodb://localhost:27017/cs5600")
PRODUCTS_COLL = os.getenv("PRODUCTS_COLL", "products")
SIG_COLL = os.getenv("SIG_COLL", "productsignatures")

K_VALUES = [2, 3, 5, 7, 10]
# HASH_VALUES = [10, 20, 30, 50, 100, 150]
HASH_VALUES = [150]
NUM_BANDS_VALUES = [4, 5, 10, 15, 20]
NUM_WORKERS = max(1, min(cpu_count() - 1, 12))
TOP_K = 10
BULK_BATCH = 500
TUNING_SAMPLE = 100  # Top-100 products with most similar_asins
# ------------------------------------------------------------

# ================= Cell 2: MongoDB =================
client = MongoClient(MONGO_URI)
db = client.get_default_database()
products_col = db[PRODUCTS_COLL]
signatures_col = db[SIG_COLL]

# ================= Cell 3: Word Shingles & MinHash =================
def get_word_shingles(text, k):
    if not text:
        return []
    words = str(text).lower().split()
    if len(words) < k:
        return [" ".join(words)] if words else ["dummy"]
    return [" ".join(words[i:i+k]) for i in range(len(words) - k + 1)]

def compute_minhash_words(text, k, num_perm=150):
    m = MinHash(num_perm=num_perm)
    shingles = get_word_shingles(text, k)
    for s in shingles:
        m.update(s.encode("utf8"))
    return m

# ================= Cell 4: Precompute MinHash =================
def _worker_precompute_words(product):
    asin = product.get("asin")
    title = product.get("title") or ""
    desc_field = product.get("description") or ""
    desc = " ".join(desc_field) if isinstance(desc_field, list) else desc_field
    hybrid = (title + " " + desc).strip()

    out = {"asin": asin}
    for k in K_VALUES:
        out[f"pst_k{k}"] = compute_minhash_words(title, k)
        out[f"psd_k{k}"] = compute_minhash_words(desc, k)
        out[f"pstd_k{k}"] = compute_minhash_words(hybrid, k)
    return out

def precompute_all(products):
    print(f"⚡ Precomputing MinHash objects ({NUM_WORKERS} workers)...")
    results = []
    with Pool(processes=NUM_WORKERS) as pool:
        for r in tqdm(pool.imap_unordered(_worker_precompute_words, products),
                      total=len(products), file=sys.stdout):
            results.append(r)

    all_sigs = {}
    for r in results:
        asin = r["asin"]
        all_sigs[asin] = {k: {"pst": r[f"pst_k{k}"], 
                              "psd": r[f"psd_k{k}"], 
                              "pstd": r[f"pstd_k{k}"]} for k in K_VALUES}
    print("✅ Precompute finished.")
    return all_sigs

# ================= Cell 5: Build LSH =================
def build_lsh_all(all_sigs, k, num_hashes, num_bands):
    r = max(1, num_hashes // num_bands)
    lsh_pst = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
    lsh_psd = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
    lsh_pstd = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
    asin_to_minhash = {"pst": {}, "psd": {}, "pstd": {}}

    for asin, sigs in all_sigs.items():
        m_pst = MinHash(num_perm=num_hashes)
        m_pst.hashvalues = sigs[k]["pst"].hashvalues[:num_hashes]
        lsh_pst.insert(asin, m_pst)
        asin_to_minhash["pst"][asin] = m_pst

        m_psd = MinHash(num_perm=num_hashes)
        m_psd.hashvalues = sigs[k]["psd"].hashvalues[:num_hashes]
        lsh_psd.insert(asin, m_psd)
        asin_to_minhash["psd"][asin] = m_psd

        m_pstd = MinHash(num_perm=num_hashes)
        m_pstd.hashvalues = sigs[k]["pstd"].hashvalues[:num_hashes]
        lsh_pstd.insert(asin, m_pstd)
        asin_to_minhash["pstd"][asin] = m_pstd

    return (lsh_pst, lsh_psd, lsh_pstd), asin_to_minhash

# ================= Cell 6: Top-K Similar =================
def get_topk_similars(asin, m, lsh, asin_to_minhash, top_k=TOP_K):
    candidates = lsh.query(m)
    sims = []
    for cand in candidates:
        if cand == asin:
            continue
        score = m.jaccard(asin_to_minhash[cand])
        sims.append((score, cand))
    sims.sort(reverse=True, key=lambda x: x[0])
    return [{"asin": c, "score": float(s)} for s, c in sims[:top_k]]

# ================= Cell 7: Grid Search for Best Models =================
def tune_best_model(products, all_sigs, top_k=TOP_K, sample_size=TUNING_SAMPLE):
    print("📊 Tuning best (k, num_hashes, num_bands) based on Precision@10...")
    # Pick top-100 products with most similar_asins
    products = sorted([p for p in products if p.get("similar_asins")], 
                      key=lambda x: len(x.get("similar_asins", [])), reverse=True)[:sample_size]
    best_models = {"pst": None, "psd": None, "pstd": None}

    for key in ["pst", "psd", "pstd"]:
        best_score = -1
        for k in K_VALUES:
            for num_hashes in HASH_VALUES:
                for num_bands in NUM_BANDS_VALUES:
                    r = max(1, num_hashes // num_bands)
                    if num_bands * r > num_hashes:
                        continue

                    lsh = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
                    asin_to_mh = {}
                    for p in products:
                        asin = p["asin"]
                        m = all_sigs[asin][k][key]
                        m_sliced = MinHash(num_perm=num_hashes)
                        m_sliced.hashvalues = m.hashvalues[:num_hashes].copy()

                        lsh.insert(asin, m_sliced)
                        asin_to_mh[asin] = m_sliced

                    total_prec = 0
                    valid_count = 0
                    for p in products:
                        asin = p["asin"]
                        topk = get_topk_similars(asin, asin_to_mh[asin], lsh, asin_to_mh, top_k)
                        actual = set(p.get("similar_asins", []))
                        pred = set([x["asin"] for x in topk])
                        if actual and pred:
                            total_prec += len(pred & actual) / min(len(pred), len(actual))
                            valid_count += 1

                    if valid_count == 0:
                        continue

                    avg_prec = total_prec / valid_count
                    if avg_prec > best_score:
                        best_score = avg_prec
                        best_models[key] = {"k": k, "num_hashes": num_hashes, "num_bands": num_bands, "r": r, "precision": avg_prec}

        print(f"✅ Best model for {key}: {best_models[key]}")
    return best_models

# ================= Cell 8: Final Compute & Save =================
def final_compute_and_write(products, all_sigs, best_models, signatures_col, top_k=TOP_K, bulk_batch=BULK_BATCH):
    print("\n🚀 Computing top-k similars with LSH using best models...")

    key_to_idx = {"pst": 0, "psd": 1, "pstd": 2}
    updates = []

    for key in ["pst", "psd", "pstd"]:
        best = best_models[key]
        k_best = best["k"]
        num_hashes = best["num_hashes"]
        num_bands = best["num_bands"]

        # Build LSH and get MinHash dict
        lsh, asin_to_mh = build_lsh_all(all_sigs, k_best, num_hashes, num_bands)

        for p in tqdm(products, total=len(products), file=sys.stdout):
            asin = p["asin"]
            sigs = all_sigs[asin][k_best]

            # Slice MinHash object
            m = MinHash(num_perm=num_hashes)
            m.hashvalues = sigs[key].hashvalues[:num_hashes]

            # Query LSH
            lsh_obj = lsh[key_to_idx[key]]
            topk = get_topk_similars(asin, m, lsh_obj, asin_to_mh[key], top_k)

            # Save only that subfield in MongoDB
            updates.append(UpdateOne(
                {"asin": asin},
                {"$set": {
                    "asin": asin,
                    f"{key}_sig": sigs[key].hashvalues[:num_hashes].tolist(),
                    f"similar.{key}": topk
                }},
                upsert=True
            ))

            if len(updates) >= bulk_batch:
                signatures_col.bulk_write(updates)
                updates = []

    if updates:
        signatures_col.bulk_write(updates)

    print("✅ All results saved to DB.")

# ================= Cell 9: Run Pipeline =================
products = list(products_col.find({}, {"asin":1, "title":1, "description":1, "similar_asins":1}))
print(f"📦 Found {len(products)} products.")

📦 Found 30239 products.


In [11]:
if products:
    all_sigs = precompute_all(products)


⚡ Precomputing MinHash objects (12 workers)...
100%|██████████| 30239/30239 [01:36<00:00, 311.85it/s]
✅ Precompute finished.


In [16]:
if products:
    best_models = tune_best_model(products, all_sigs)

📊 Tuning best (k, num_hashes, num_bands) based on Precision@10...
✅ Best model for pst: {'k': 2, 'num_hashes': 150, 'num_bands': 20, 'r': 7, 'precision': 0.16666666666666666}
✅ Best model for psd: {'k': 2, 'num_hashes': 150, 'num_bands': 15, 'r': 10, 'precision': 0.25}
✅ Best model for pstd: {'k': 2, 'num_hashes': 150, 'num_bands': 15, 'r': 10, 'precision': 0.375}


In [18]:
if products:
    final_compute_and_write(products, all_sigs, best_models, signatures_col, top_k=TOP_K)


🚀 Computing top-k similars with LSH using best models...
100%|██████████| 30239/30239 [00:35<00:00, 862.24it/s] 
100%|██████████| 30239/30239 [00:56<00:00, 530.89it/s]
100%|██████████| 30239/30239 [00:39<00:00, 759.08it/s] 
✅ All results saved to DB.
