In [36]:
#!/usr/bin/env python3
"""
Jupyter Notebook version of full_pipeline_lsh.py
"""

import os
import math
import random
import sys
from multiprocessing import Pool, cpu_count
from datasketch import MinHash, MinHashLSH
from pymongo import MongoClient, UpdateOne
from tqdm import tqdm
import numpy as np

# ---------------- CONFIG ----------------
MONGO_URI = os.getenv("MONGO_URI", "mongodb://localhost:27017/cs5600")
PRODUCTS_COLL = os.getenv("PRODUCTS_COLL", "products")
SIG_COLL = os.getenv("SIG_COLL", "productsignatures")

K_VALUES = [2, 3, 5, 7, 10]
HASH_VALUES = [10, 20, 50, 100, 150]
NUM_BANDS_VALUES = [4, 5, 10, 15, 20, 25]
MAX_HASHES = max(HASH_VALUES)
NUM_WORKERS = max(1, min(cpu_count() - 1, 6))  # leave one core free, cap at 12
TUNING_SAMPLE = 1000
TOP_K = 10
BULK_BATCH = 500
# ----------------------------------------

# ---------------- DB ----------------
client = MongoClient(MONGO_URI)
db = client.get_default_database()
products_col = db[PRODUCTS_COLL]
signatures_col = db[SIG_COLL]
# ------------------------------------


In [19]:
def get_char_shingles(text, k):
    if not text:
        return []
    clean = " ".join(str(text).lower().split())
    if len(clean) < k:
        return []
    return [clean[i:i+k] for i in range(len(clean) - k + 1)]


def compute_minhash(text, k, num_perm=MAX_HASHES):
    m = MinHash(num_perm=num_perm)
    shingles = get_char_shingles(text, k)
    for s in shingles:
        m.update(s.encode("utf8"))
    return m


In [None]:
def _worker_precompute(product):
    asin = product.get("asin")
    title = product.get("title") or ""
    desc_field = product.get("description") or ""
    desc = " ".join(desc_field) if isinstance(desc_field, list) else desc_field
    hybrid = (title + " " + desc).strip()

    out = {"asin": asin}
    for k in K_VALUES:
        out[f"pst_k{k}"] = compute_minhash(title, k)
        out[f"psd_k{k}"] = compute_minhash(desc, k)
        out[f"pstd_k{k}"] = compute_minhash(hybrid, k)
    return out


def precompute_all(products):
    print(f"⚡ Precomputing MinHash objects ({NUM_WORKERS} workers)...")
    results = []
    with Pool(processes=NUM_WORKERS) as pool:
        for r in tqdm(pool.imap_unordered(_worker_precompute, products),
                      total=len(products), file=sys.stdout):
            results.append(r)

    all_sigs = {}
    for r in results:
        asin = r["asin"]
        all_sigs[asin] = {k: {"pst": r[f"pst_k{k}"], 
                              "psd": r[f"psd_k{k}"], 
                              "pstd": r[f"pstd_k{k}"]} for k in K_VALUES}
    print("✅ Precompute finished.")
    return all_sigs


In [21]:
def compute_bands_from_signature(sig, num_hashes, num_bands):
    hv = sig.hashvalues[:num_hashes]
    rows_per_band = max(1, num_hashes // num_bands)
    bands = []
    for b in range(num_bands):
        start = b * rows_per_band
        band = tuple(hv[start:start+rows_per_band])
        if band:
            bands.append(band)
    return bands


def candidate_probability(s, r, b):
    return 1.0 - (1.0 - (s ** r)) ** b


In [22]:
def tune_parameters(products, all_sigs, sample_size=TUNING_SAMPLE):
    print("📊 Tuning parameters (k, num_hashes, num_bands)...")
    sampled = random.sample(products, sample_size) if sample_size and sample_size < len(products) else products
    best = None

    for k in K_VALUES:
        for num_hashes in HASH_VALUES:
            for num_bands in NUM_BANDS_VALUES:
                r = max(1, num_hashes // num_bands)
                seen = set()
                overlaps = 0

                for p in sampled:
                    asin = p["asin"]
                    sigs = all_sigs.get(asin)
                    if not sigs:
                        continue
                    sig = sigs[k]["pstd"]
                    bands = compute_bands_from_signature(sig, num_hashes, num_bands)
                    if any(b in seen for b in bands):
                        overlaps += 1
                    for b in bands:
                        seen.add(b)

                recall = overlaps / max(1, len(sampled))
                pLow = candidate_probability(0.4, r, num_bands)
                pHigh = candidate_probability(0.8, r, num_bands)
                theoryScore = pHigh - pLow
                penalty = 0.5 if (recall < 0.3 or recall > 0.9) else 0.0
                score = theoryScore - penalty

                if not best or score > best["score"]:
                    best = {"k": k, "num_hashes": num_hashes, "num_bands": num_bands, "r": r,
                            "overlaps": overlaps, "recall": recall, "pLow": pLow, "pHigh": pHigh, "score": score}
    print("✅ Best configuration:", best)
    return best


In [45]:
# def build_lsh(all_sigs, k, num_hashes, num_bands):
#     r = max(1, num_hashes // num_bands)
#     lsh = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
#     asin_to_minhash = {}

#     for asin, sigs in all_sigs.items():
#         m = MinHash(num_perm=num_hashes)
#         m.hashvalues = sigs[k]["pstd"].hashvalues[:num_hashes]
#         lsh.insert(asin, m)
#         asin_to_minhash[asin] = m
#     return lsh, asin_to_minhash



from datasketch import MinHash, MinHashLSH
from pymongo import UpdateOne
from tqdm import tqdm
import sys
import numpy as np

def build_lsh_all(all_sigs, k, num_hashes, num_bands):
    """
    Build three separate LSH indexes (pst, psd, pstd) and return mapping
    asin_to_minhash for each type.
    """
    r = max(1, num_hashes // num_bands)
    lsh_pst = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
    lsh_psd = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))
    lsh_pstd = MinHashLSH(num_perm=num_hashes, params=(num_bands, r))

    asin_to_minhash = {"pst": {}, "psd": {}, "pstd": {}}

    for asin, sigs in all_sigs.items():
        # pst
        m_pst = MinHash(num_perm=num_hashes)
        m_pst.hashvalues = sigs[k]["pst"].hashvalues[:num_hashes]
        lsh_pst.insert(asin, m_pst)
        asin_to_minhash["pst"][asin] = m_pst

        # psd
        m_psd = MinHash(num_perm=num_hashes)
        m_psd.hashvalues = sigs[k]["psd"].hashvalues[:num_hashes]
        lsh_psd.insert(asin, m_psd)
        asin_to_minhash["psd"][asin] = m_psd

        # pstd
        m_pstd = MinHash(num_perm=num_hashes)
        m_pstd.hashvalues = sigs[k]["pstd"].hashvalues[:num_hashes]
        lsh_pstd.insert(asin, m_pstd)
        asin_to_minhash["pstd"][asin] = m_pstd

    return (lsh_pst, lsh_psd, lsh_pstd), asin_to_minhash


def get_topk_similars(asin, m, lsh, asin_to_minhash, top_k=10):
    """
    Query LSH for candidates, compute Jaccard similarity, return top-K.
    """
    candidates = lsh.query(m)
    sims = []
    for cand in candidates:
        if cand == asin:
            continue
        score = m.jaccard(asin_to_minhash[cand])
        sims.append((score, cand))
    sims.sort(reverse=True, key=lambda x: x[0])
    return [{"asin": c, "score": float(s)} for s, c in sims[:top_k]]


def final_compute_and_write(products, all_sigs, best, signatures_col, top_k=10, bulk_batch=500):
    print("\n🚀 Computing top-10 similars with LSH...")
    k = best["k"]
    num_hashes = best["num_hashes"]
    num_bands = best["num_bands"]

    (lsh_pst, lsh_psd, lsh_pstd), asin_to_minhash = build_lsh_all(all_sigs, k, num_hashes, num_bands)
    updates = []

    for p in tqdm(products, total=len(products), file=sys.stdout):
        asin = p["asin"]
        sigs = all_sigs[asin][k]

        similars = {
            "pst": get_topk_similars(asin, sigs["pst"], lsh_pst, asin_to_minhash["pst"], top_k),
            "psd": get_topk_similars(asin, sigs["psd"], lsh_psd, asin_to_minhash["psd"], top_k),
            "pstd": get_topk_similars(asin, sigs["pstd"], lsh_pstd, asin_to_minhash["pstd"], top_k),
        }

        updates.append(UpdateOne(
            {"asin": asin},
            {"$set": {
                "asin": asin,
                "pst_sig": sigs["pst"].hashvalues[:num_hashes].tolist(),
                "psd_sig": sigs["psd"].hashvalues[:num_hashes].tolist(),
                "pstd_sig": sigs["pstd"].hashvalues[:num_hashes].tolist(),
                "similar": similars
            }},
            upsert=True
        ))

        if len(updates) >= bulk_batch:
            signatures_col.bulk_write(updates)
            updates = []

    if updates:
        signatures_col.bulk_write(updates)
    print("✅ All results saved to DB.")



In [None]:
# def final_compute_and_write(products, all_sigs, best):
#     print("\n🚀 Computing top-10 similars with LSH...")
#     k = best["k"]
#     num_hashes = best["num_hashes"]
#     num_bands = best["num_bands"]

#     (lsh_pst, lsh_psd, lsh_pstd), asin_to_minhash = build_lsh_all(all_sigs, k, num_hashes, num_bands)
#     updates = []

#     for p in tqdm(products, total=len(products), file=sys.stdout):
#         asin = p["asin"]
#         sigs = all_sigs[asin][k]

#         # Wrap hash lists into MinHash objects for querying
#         m_pst = MinHash(num_perm=num_hashes)
#         m_pst.hashvalues = np.array(sigs["pst"][:num_hashes], dtype='uint64')

#         m_psd = MinHash(num_perm=num_hashes)
#         m_psd.hashvalues = np.array(sigs["psd"][:num_hashes], dtype='uint64')

#         m_pstd = MinHash(num_perm=num_hashes)
#         m_pstd.hashvalues = np.array(sigs["pstd"][:num_hashes], dtype='uint64')

#         similars = {
#             "pst": get_topk_similars(asin, m_pst, lsh_pst, asin_to_minhash["pst"]),
#             "psd": get_topk_similars(asin, m_psd, lsh_psd, asin_to_minhash["psd"]),
#             "pstd": get_topk_similars(asin, m_pstd, lsh_pstd, asin_to_minhash["pstd"]),
#         }

#         updates.append(UpdateOne(
#             {"asin": asin},
#             {"$set": {
#                 "asin": asin,
#                 "pst_sig": sigs["pst"][:num_hashes],
#                 "psd_sig": sigs["psd"][:num_hashes],
#                 "pstd_sig": sigs["pstd"][:num_hashes],
#                 "similar": similars
#             }},
#             upsert=True
#         ))

#         if len(updates) >= BULK_BATCH:
#             signatures_col.bulk_write(updates)
#             updates = []

#     if updates:
#         signatures_col.bulk_write(updates)

#     print("✅ All results saved to DB.")



In [25]:
# products = list(products_col.find({}, {"asin":1, "title":1, "description":1}))
# print(f"📦 Found {len(products)} products.")
# if not products:
#     print("No products in DB — exiting.")
# else:
#     all_sigs = precompute_all(products)
#     best = tune_parameters(products, all_sigs, sample_size=TUNING_SAMPLE)
#     final_compute_and_write(products, all_sigs, best)


In [26]:
products = list(products_col.find({}, {"asin":1, "title":1, "description":1}))
print(f"📦 Found {len(products)} products.")


📦 Found 30239 products.


In [27]:
if products:
    all_sigs = precompute_all(products)


⚡ Precomputing MinHash objects (6 workers)...
100%|██████████| 30239/30239 [04:11<00:00, 120.14it/s]
✅ Precompute finished.


In [28]:
if products and all_sigs:
    best = tune_parameters(products, all_sigs, sample_size=TUNING_SAMPLE)


📊 Tuning parameters (k, num_hashes, num_bands)...
✅ Best configuration: {'k': 2, 'num_hashes': 150, 'num_bands': 20, 'r': 7, 'overlaps': 595, 'recall': 0.595, 'pLow': 0.032262951677010765, 'pHigh': 0.9909703154270351, 'score': 0.9587073637500243}


In [47]:
if products and all_sigs and best:
    final_compute_and_write(products, all_sigs, best, signatures_col)




🚀 Computing top-10 similars with LSH...
100%|██████████| 30239/30239 [04:30<00:00, 111.96it/s]
✅ All results saved to DB.
