In [None]:
# Cell 1 — Install deps (Colab)
!pip -q install open_clip_torch faiss-cpu pillow requests pandas numpy networkx tqdm torch sentence-transformers --no-cache-dir


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m264.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m263.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Cell 2 — Imports, paths, helpers, vocab, seed
from __future__ import annotations

import os, io, re, json, math, hashlib, random
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
from types import SimpleNamespace

import numpy as np
import pandas as pd
import requests
from PIL import Image
from tqdm.auto import tqdm

import faiss
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
import open_clip
from sentence_transformers import SentenceTransformer

# ------------------- Paths -------------------
QUERIES_JSONL  = "/content/not_discontinued_no_brand_small_data_with_constraints.jsonl"   # has intents + constraints
CATALOG_JSONL  = "/content/meta_All_Beauty_not_discontinued.jsonl"                        # catalog to build KG and search
# CATALOG_JSONL  = "/content/testing_meta.jsonl"
OUT_DIR        = "/content/mm_index"
IMG_CACHE      = os.path.join(OUT_DIR, "images")
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(IMG_CACHE, exist_ok=True)

# ------------------- Encoders ----------------
CLIP_MODEL_NAME = "ViT-L-14"
CLIP_PRETRAINED = "openai"
E5_MODEL_NAME   = "intfloat/e5-base-v2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ------------------- Reproducibility ----------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if DEVICE == "cuda": torch.cuda.manual_seed_all(SEED)

# ------------------- Helpers ------------------
def md5(s: str) -> str:
    return hashlib.md5(s.encode("utf-8")).hexdigest()

def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)

def read_jsonl(path: str):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                yield json.loads(line)

def normalize_images_field(images_field):
    """Normalize images into [{'variant': str, 'url': str}, ...]."""
    out = []
    if images_field is None:
        return out
    if isinstance(images_field, str):
        out.append({"variant": "MAIN", "url": images_field})
        return out
    if isinstance(images_field, list):
        for it in images_field:
            if isinstance(it, str):
                out.append({"variant": "", "url": it})
            elif isinstance(it, dict):
                for k in ("hi_res", "large", "thumb", "url"):
                    url = it.get(k)
                    if url:
                        out.append({"variant": str(it.get("variant", "")), "url": url})
                        break
    return out

def load_image(url: str, timeout: float = 10.0) -> Optional[Image.Image]:
    try:
        cache_path = os.path.join(IMG_CACHE, md5(url) + ".jpg")
        if not os.path.exists(cache_path):
            r = requests.get(url, timeout=timeout)
            r.raise_for_status()
            Image.open(io.BytesIO(r.content)).convert("RGB").save(cache_path)
        return Image.open(cache_path).convert("RGB")
    except Exception:
        return None

def extract_summary_field(s) -> str:
    if s is None:
        return ""
    if isinstance(s, str):
        return s
    if isinstance(s, dict):
        for k in ("short", "summary", "text"):
            if s.get(k):
                return str(s[k])
        return " ".join(str(v) for v in s.values())
    if isinstance(s, list):
        return " ".join(map(str, s))
    return str(s)

def build_text_blob(rec: dict) -> str:
    title = str(rec.get("title") or "")
    summ  = extract_summary_field(rec.get("summary"))
    feats = rec.get("features") or []
    feats = " ".join(map(str, feats[:12])) if isinstance(feats, list) else str(feats)
    desc  = rec.get("description") or []
    desc  = " ".join(map(str, desc[:12])) if isinstance(desc, list) else str(desc)
    details = rec.get("details") or {}
    if isinstance(details, dict):
        det = " ".join([f"{k}: {v}" for k, v in details.items() if v is not None])
    else:
        det = str(details)
    txt = " \n ".join([title, summ, feats, desc, det])
    return re.sub(r"\s+", " ", txt).strip()

def _sanitize_value(v: str) -> str:
    s = str(v or "").strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def _split_multi(v: str) -> List[str]:
    s = _sanitize_value(v)
    parts = re.split(r"[;,/|]| and ", s)
    out = [p.strip() for p in parts if p and p.strip()]
    return out or ([s] if s else [])

# ------------------- Attribute vocab & weights -------------------
item_forms = {
    'cream','liquid','gel','pair','powder','spray','oil','bar','lotion',
    'pencil','stick','wand','balm','wrap','scrunchie','individual','elastic',
    'sheet','clay','serum','foam','wax','butter','clip','wipes','spiral',
    'ribbon','mask','aerosol','strip'
}
materials = {
    'human hair','synthetic','plastic','acrylic','human','metal','cotton',
    'rubber','faux mink','silicone','silk','ceramic','nylon','polyester',
    'mink fur','stainless steel','wood','acrylonitrile butadiene styrene (abs)'
}
hair_types = {'straight','wavy','curly','kinky','coily','all','dry','thick','fine','normal','frizzy','color','damaged'}
age_ranges = {'adult','kid','child','baby','all ages'}
material_features = {'natural','cruelty free','organic','latex free','non-toxic','reusable','vegan','disposable','biodegradable warning','gluten free','certified organic'}
colors = {'black','pink','white','blue','brown','red','natural','clear','gold','silver','green','purple','multicolor','beige'}
skin_types = {'all','sensitive','dry','acne prone','oily','normal','combination'}
styles = {'modern','french','straight','compact','curly','african','classic','art deco','wavy','earloop'}

ATTR_KEYS: Dict[str, Dict[str, Any]] = {
    "item_form":        {"vocab": item_forms,         "beta": 1},
    "material":         {"vocab": materials,          "beta": 1},
    "hair_type":        {"vocab": hair_types,         "beta": 1},
    "age_range":        {"vocab": age_ranges,         "beta": 1},
    "material_feature": {"vocab": material_features,  "beta": 1},
    "color":            {"vocab": colors,             "beta": 1},
    "skin_type":        {"vocab": skin_types,         "beta": 1},
    "style":            {"vocab": styles,             "beta": 1},
}
DETAIL_KEYS = list(ATTR_KEYS.keys())


In [None]:
# Cell 3 — KG (brand + selected attributes ONLY), item→item (e5), corpus builder

# ---------------- Edge budgets ----------------
LAMBDA_ATTR = 0.6   # total budget for product↔(brand/attr) edges
LAMBDA_SIM  = 0.4   # total budget for product↔product similarity edges

def _norm_brand(s: Optional[str]) -> Optional[str]:
    if s is None: return None
    s = s.strip()
    s = re.sub(r"\s+", " ", s)
    return s if s else None

class OptimizedKnowledgeGraph:
    """
    Product graph with:
      - Nodes: product, brand, attr(key,value)
      - Edges: product↔brand, product↔attr, product↔product (similar_to)
    NOTE: No category nodes are built.
    Price/rating remain PRODUCT PROPERTIES (no numeric nodes).
    """
    def __init__(self):
        self.graph = nx.MultiDiGraph()

    # ---------- Node builders ----------
    def _add_product_node(self, product: Dict):
        pid = product.get('parent_asin') or product.get('asin')
        if not pid: return None
        node_id = f"product_{pid}"
        if not self.graph.has_node(node_id):
            self.graph.add_node(
                node_id,
                type='product',
                parent_asin=pid,
                title=str(product.get('title') or ""),
                main_category=str(product.get('main_category') or ""),  # stored as prop only
                price=float(product['price']) if (product.get('price') not in (None, "", "None")) else None,
                average_rating=float(product['average_rating']) if (product.get('average_rating') not in (None, "", "None")) else (
                    float(product['avg_rating']) if (product.get('avg_rating') not in (None, "", "None")) else None
                ),
                rating_number=product.get('rating_number')
            )
        return node_id

    def _brand_node_id(self, brand_str: Optional[str]) -> Optional[str]:
        if not brand_str: return None
        return f"brand_{brand_str}"  # keep brand as-is to match constraints exactly

    def _attr_node_id(self, key: str, value: str) -> str:
        return f"attr|{key}|{value}"

    # ---------- Graph construction ----------
    def build_knowledge_graph(self, meta_data: List[Dict]) -> nx.MultiDiGraph:
        G = self.graph
        # pass 1: product nodes
        for product in meta_data:
            self._add_product_node(product)

        # pass 2: brand + selected-attr edges (NO category nodes)
        for product in meta_data:
            pid = product.get('parent_asin') or product.get('asin')
            if not pid: continue
            pnode = f"product_{pid}"

            # brand
            brand_raw = (product.get('details') or {}).get('Brand') or product.get('store')
            brand = _norm_brand(brand_raw)
            bnode = self._brand_node_id(brand)
            if bnode:
                if not G.has_node(bnode):
                    G.add_node(bnode, type='brand', brand=brand)
                G.add_edge(pnode, bnode, relation='belongs_to_brand', etype='attr')
                G.add_edge(bnode, pnode, relation='contains_product', etype='attr')

            # selected details → attr nodes
            details = product.get('details') or {}
            raw_map = {
                "Item Form": "item_form",
                "Material": "material",
                "Hair Type": "hair_type",
                "Age Range (Description)": "age_range",
                "Material Feature": "material_feature",
                "Color": "color",
                "Skin Type": "skin_type",
                "Style": "style",
            }
            for raw_key, canon_key in raw_map.items():
                if canon_key not in ATTR_KEYS:
                    continue
                val = details.get(raw_key)
                if not val:
                    continue
                vocab = ATTR_KEYS[canon_key]["vocab"]
                values = _split_multi(val)
                for v in values:
                    if v in vocab:
                        anode = self._attr_node_id(canon_key, v)
                        if not G.has_node(anode):
                            G.add_node(anode, type='attr', key=canon_key, value=v)
                        G.add_edge(pnode, anode, relation='has_attr', key_name=canon_key, etype='attr')
                        G.add_edge(anode, pnode, relation='attr_of', key_name=canon_key, etype='attr')

        return G

    # ---------- Item→Item similarity edges (text-only, e5) ----------
    def add_item_item_edges_from_text(
        self,
        items_df: pd.DataFrame,
        e5_emb: np.ndarray,
        k: int = 10,                 # reduced to Top-10
        sim_threshold: float = 0.5,   # increased threshold
        same_category_only: bool = True
    ):
        """
        Add undirected (two directed) similarity edges between products based on e5 cosine.
        Assumes e5_emb are L2-normalized.
        Optionally restrict to same category using items_df['category'] strings (no category nodes are built).
        """
        G = self.graph
        ids = items_df["id"].astype(str).tolist()
        cat = items_df["category"].astype(str).fillna("").tolist()

        index = faiss.IndexFlatIP(e5_emb.shape[1])
        index.add(e5_emb.astype(np.float32))

        D, I = index.search(e5_emb.astype(np.float32), min(k+1, len(ids)))
        for i, (scores, nbrs) in enumerate(zip(D, I)):
            pid_i = ids[i]
            cat_i = cat[i]
            for s, j in zip(scores, nbrs):
                if j == i:
                    continue
                if s < sim_threshold:
                    continue
                if same_category_only and cat[j] != cat_i:
                    continue
                pid_j = ids[int(j)]
                u = f"product_{pid_i}"
                v = f"product_{pid_j}"
                self.graph.add_edge(u, v, relation='similar_to', etype='sim', weight_raw=float(s))
                self.graph.add_edge(v, u, relation='similar_to', etype='sim', weight_raw=float(s))

    # ---------- Edge weight finalization ----------
    def finalize_edge_weights(self):
        """
        Assign a single 'weight' to every edge, using:
          - attr edges: LAMBDA_ATTR * beta_key / log(1 + deg(attr_node))
          - brand treated as attr with beta = 1.0
          - sim edges: LAMBDA_SIM * cosine
        """
        G = self.graph
        deg = dict(G.degree())

        for u, v, k, data in G.edges(keys=True, data=True):
            if data.get("etype") == "sim":
                s = float(data.get("weight_raw", 1.0))
                data["weight"] = float(LAMBDA_SIM * s)
            else:
                beta = 1.0
                if data.get("relation") in ("has_attr", "attr_of"):
                    key_name = data.get("key_name")
                    if key_name and key_name in ATTR_KEYS:
                        beta = float(ATTR_KEYS[key_name]["beta"])
                # choose the hub node (attr/brand)
                hub = v if G.nodes[v].get("type") in ("attr","brand") else u
                hdeg = max(1, deg.get(hub, 1))
                data["weight"] = float(LAMBDA_ATTR * beta / math.log(1.0 + hdeg))



In [None]:
class CLIPEncoder:
    def __init__(self, model_name: str = CLIP_MODEL_NAME, pretrained: str = CLIP_PRETRAINED, device: str = DEVICE):
        self.device = device
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
        self.model = self.model.to(device)
        self.tokenizer = open_clip.get_tokenizer(model_name)
        self.model.eval()

    @torch.no_grad()
    def encode_text(self, texts: List[str], batch_size: int = 64) -> np.ndarray:
        vecs: List[np.ndarray] = []
        texts = list(texts)
        for i in range(0, len(texts), batch_size):
            toks = self.tokenizer(texts[i:i+batch_size]).to(self.device)
            feats = self.model.encode_text(toks)
            feats = feats / feats.norm(dim=-1, keepdim=True)
            vecs.append(feats.float().cpu().numpy())
        if vecs:
            return np.vstack(vecs)
        dummy = self.model.encode_text(self.tokenizer(["."]).to(self.device))
        return np.zeros((0, dummy.shape[-1]), dtype=np.float32)

    @torch.no_grad()
    def encode_images(self, images: List[Image.Image], batch_size: int = 32) -> np.ndarray:
        vecs: List[np.ndarray] = []
        images = list(images)
        for i in range(0, len(images), batch_size):
            ims = [self.preprocess(im).to(self.device) for im in images[i:i+batch_size]]
            if not ims: continue
            batch = torch.stack(ims)
            feats = self.model.encode_image(batch)
            feats = feats / feats.norm(dim=-1, keepdim=True)
            vecs.append(feats.float().cpu().numpy())
        if vecs:
            return np.vstack(vecs)
        dummy = self.model.encode_text(self.tokenizer(["."]).to(self.device))
        return np.zeros((0, dummy.shape[-1]), dtype=np.float32)

class FaissIndex:
    def __init__(self, dim: int):
        self.index = faiss.IndexFlatIP(dim)
        self.ids: List[str] = []

    def add(self, embeddings: np.ndarray, ids: List[str]):
        assert embeddings.shape[0] == len(ids)
        self.index.add(embeddings.astype(np.float32))
        self.ids.extend(ids)

    def search(self, queries: np.ndarray, topk: int) -> Tuple[np.ndarray, List[List[str]]]:
        D, I = self.index.search(queries.astype(np.float32), topk)
        id_lists = [[self.ids[int(j)] for j in row] for row in I]
        return D, id_lists

@dataclass
class Corpus:
    df: pd.DataFrame
    text_index: FaissIndex
    image_index: FaissIndex
    text_emb: np.ndarray
    img_emb: np.ndarray
    e5_emb: np.ndarray
    kg: OptimizedKnowledgeGraph
    enc: CLIPEncoder
    e5: SentenceTransformer

def build_catalog_corpus(jsonl_path: str, out_dir: str,
                         sim_k: int = 10, sim_threshold: float = 0.5) -> Corpus:
    ensure_dir(out_dir)
    raw_records = [rec for rec in read_jsonl(jsonl_path)]

    rows = []
    for i, rec in enumerate(raw_records):
        pid = rec.get("parent_asin") or rec.get("asin") or md5(rec.get("title", f"row{i}"))
        title = str(rec.get("title") or "")
        brand = (rec.get("details") or {}).get("Brand") or rec.get("store")
        category = rec.get("main_category") if isinstance(rec.get("main_category"), str) else None
        price = rec.get("price")
        average_rating = rec.get("average_rating");  average_rating = rec.get("avg_rating") if average_rating is None else average_rating
        text_blob = build_text_blob(rec)

        rows.append({
            "id": pid,
            "title": title,
            "brand": (brand if brand is None else _sanitize_value(brand)),
            "category": (None if category is None else _sanitize_value(category)),  # kept in DF for analytics/same_category_only
            "price": float(price) if isinstance(price, (int,float,str)) and str(price) not in ("", "None") else None,
            "average_rating": float(average_rating) if isinstance(average_rating, (int,float,str)) and str(average_rating) not in ("", "None") else None,
            "text": text_blob,
            "raw": rec,
        })

    df = pd.DataFrame(rows)
    print(f"Loaded catalog: {len(df)} products")

    # ---- Build KG with selected attributes (no numeric & no category nodes) ----
    kg = OptimizedKnowledgeGraph()
    kg.build_knowledge_graph(raw_records)

    # ---- Encoders & vectors ----
    enc = CLIPEncoder(CLIP_MODEL_NAME, CLIP_PRETRAINED, DEVICE)
    text_emb = enc.encode_text(df["text"].tolist())
    dim = text_emb.shape[1]

    # images (catalog only) — fast path
    img_emb = fast_encode_main_images(enc, df, batch_size=FAST_BATCH, num_workers=NUM_WORKERS, fp16=USE_FP16)

    # E5 (text-only) for item→item edges
    e5 = SentenceTransformer(E5_MODEL_NAME, device=DEVICE)
    e5_inputs = (df["title"].astype(str) + " \n " + df["text"].astype(str)).tolist()
    e5_emb = e5.encode(e5_inputs, batch_size=128, convert_to_numpy=True, normalize_embeddings=True)
    assert e5_emb.shape[0] == len(df)

    # ---- Build item→item edges, finalize weights ----
    kg.add_item_item_edges_from_text(
        items_df=df, e5_emb=e5_emb, k=sim_k, sim_threshold=sim_threshold, same_category_only=True
    )
    kg.finalize_edge_weights()

    # ---- Build FAISS indices (optional later) ----
    text_index = FaissIndex(text_emb.shape[1]); text_index.add(text_emb, df["id"].tolist())
    mask = (np.linalg.norm(img_emb, axis=1) > 0)
    image_index = FaissIndex(img_emb.shape[1]); image_index.add(img_emb[mask], df.loc[mask, "id"].tolist())

    return Corpus(df=df, text_index=text_index, image_index=image_index,
                  text_emb=text_emb, img_emb=img_emb, e5_emb=e5_emb,
                  kg=kg, enc=enc, e5=e5)

In [None]:
# Cell 5.fast — Fast MAIN image encoding (parallel download + DataLoader + AMP)

from concurrent.futures import ThreadPoolExecutor, as_completed
import requests, io, os
from torch.utils.data import Dataset, DataLoader

# ---------- Speed knobs ----------
FAST_MAX_WORKERS = min(32, os.cpu_count() or 16)  # parallel HTTP + CPU decode
IMG_TIMEOUT = 6.0
RETRIES = 2
FAST_BATCH = 128          # try 64–128 for ViT-L/14; 128–256 for ViT-B/32
NUM_WORKERS = max(2, (os.cpu_count() or 4)//2)  # DataLoader workers
USE_FP16 = True           # AMP on GPU
ONLY_MAIN = True          # keep True; set False to fallback to first available

# Shared HTTP session with pooling/retries
_session = requests.Session()
_adapter = requests.adapters.HTTPAdapter(
    pool_connections=FAST_MAX_WORKERS, pool_maxsize=FAST_MAX_WORKERS, max_retries=RETRIES
)
_session.mount("http://", _adapter); _session.mount("https://", _adapter)

def _pick_main_url(rec):
    imgs = normalize_images_field(rec.get("images"))
    if ONLY_MAIN:
        imgs = [im for im in imgs if str(im.get("variant","")).upper() == "MAIN"] or imgs
    for it in imgs:
        u = it.get("url")
        if u: return u
    return None

def predownload_main_images(df: pd.DataFrame) -> list:
    """Return index_to_path list aligned to df; parallel download missing to IMG_CACHE."""
    index_to_path = [None]*len(df)
    to_fetch = []
    for idx, rec in enumerate(df["raw"].tolist()):
        url = _pick_main_url(rec)
        if not url:
            continue
        cache_path = os.path.join(IMG_CACHE, md5(url) + ".jpg")
        index_to_path[idx] = cache_path
        if not os.path.exists(cache_path):
            to_fetch.append((url, cache_path))

    def _fetch(url, path):
        try:
            r = _session.get(url, timeout=IMG_TIMEOUT)
            r.raise_for_status()
            Image.open(io.BytesIO(r.content)).convert("RGB").save(path)
            return True
        except Exception:
            return False

    if to_fetch:
        with ThreadPoolExecutor(max_workers=FAST_MAX_WORKERS) as ex:
            futs = [ex.submit(_fetch, u, p) for (u, p) in to_fetch]
            for _ in tqdm(as_completed(futs), total=len(futs), desc="Downloading MAIN images (parallel)"):
                pass

    return index_to_path

class DiskImageDataset(Dataset):
    def __init__(self, index_to_path, preprocess):
        self.samples = [(i, p) for i, p in enumerate(index_to_path) if p and os.path.exists(p)]
        self.preprocess = preprocess

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        i, path = self.samples[idx]
        try:
            im = Image.open(path).convert("RGB")
            return self.preprocess(im), i
        except Exception:
            return None  # collate will drop

def _collate(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        return None, None
    imgs, idxs = zip(*batch)
    return torch.stack(imgs, 0), torch.tensor(idxs, dtype=torch.long)

@torch.no_grad()
def fast_encode_main_images(enc: CLIPEncoder, df: pd.DataFrame,
                            batch_size=FAST_BATCH, num_workers=NUM_WORKERS, fp16=USE_FP16) -> np.ndarray:
    # dim via text projection (fast) or tiny encode fallback
    dim = enc.model.text_projection.shape[1] if hasattr(enc.model, "text_projection") else enc.encode_text(["."], 1).shape[1]
    img_vecs = np.zeros((len(df), dim), dtype=np.float32)

    index_to_path = predownload_main_images(df)
    ds = DiskImageDataset(index_to_path, enc.preprocess)
    if len(ds) == 0:
        return img_vecs

    loader = DataLoader(
        ds, batch_size=batch_size, shuffle=False, num_workers=num_workers,
        pin_memory=True, persistent_workers=(num_workers > 0),
        prefetch_factor=(2 if num_workers > 0 else None), collate_fn=_collate
    )

    if DEVICE == "cuda":
        torch.backends.cudnn.benchmark = True

    for batch in tqdm(loader, total=len(loader), desc="Encoding MAIN images (fast)"):
        if batch[0] is None:
            continue
        imgs, idxs = batch
        imgs = imgs.to(enc.device, non_blocking=True)
        if fp16 and enc.device.startswith("cuda"):
            with torch.cuda.amp.autocast(dtype=torch.float16):
                feats = enc.model.encode_image(imgs)
        else:
            feats = enc.model.encode_image(imgs)
        feats = F.normalize(feats, p=2, dim=-1).float().cpu().numpy()
        img_vecs[idxs.numpy()] = feats

    return img_vecs


In [None]:
# Cell 5 — Build corpus (from catalog), expand queries (from queries file), split

corpus = build_catalog_corpus(CATALOG_JSONL, OUT_DIR, sim_k=5, sim_threshold=0.5)
items_df = corpus.df.copy()

def expand_queries_from_records(jsonl_path: str, items_df: pd.DataFrame) -> pd.DataFrame:
    """
    Reads queries file (intents + constraints) and builds q-rows for products
    that exist in the catalog (items_df). Evaluates retrieval over the full catalog.
    """
    pid_set = set(items_df["id"].tolist())
    rows = []
    for rec in read_jsonl(jsonl_path):
        pid = rec.get("parent_asin") or rec.get("asin")
        if not pid or pid not in pid_set:
            continue
        gen = rec.get("generated") or {}
        intents = gen.get("intents") or []
        constraints_list = gen.get("constraints") or []
        C0 = constraints_list[0] if (isinstance(constraints_list, list) and len(constraints_list)>0) else {}
        for j, intent in enumerate(intents):
            rows.append({"qid": f"{pid}::q{j}", "pid": pid, "intent_text": str(intent), "constraints": C0})
    return pd.DataFrame(rows)

queries_df = expand_queries_from_records(QUERIES_JSONL, items_df)
print("Total queries (intersecting catalog):", len(queries_df))




Loaded catalog: 39038 products


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


open_clip_model.safetensors:   0%|          | 0.00/1.71G [00:00<?, ?B/s]



Downloading MAIN images (parallel):   0%|          | 0/39038 [00:00<?, ?it/s]

Encoding MAIN images (fast):   0%|          | 0/305 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(dtype=torch.float16):


modules.json:   0%|          | 0.00/387 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/57.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/650 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

Total queries (intersecting catalog): 3655


In [None]:
# Cell 6 — KG stats: counts by node type and edge relation (plus triplets and degree stats)

from collections import Counter, defaultdict
import numpy as np

assert 'corpus' in globals() and hasattr(corpus, 'kg'), "Run Cell 5 to build `corpus` first."
G = corpus.kg.graph

print(f"Graph: {type(G).__name__} | directed={G.is_directed()} | multigraph={G.is_multigraph()}")
print(f"Total nodes: {G.number_of_nodes():,} | Total edges: {G.number_of_edges():,}\n")

# --- Node types ---
node_type_counter = Counter(data.get('type', '(missing)') for _, data in G.nodes(data=True))
print("Node types:")
for t, c in sorted(node_type_counter.items(), key=lambda x: (-x[1], x[0])):
    print(f"  - {t}: {c:,}")
print()

# --- Edge relations ---
edge_rel_counter = Counter()
edge_rel_triplet_counter = Counter()  # (src_type, relation, dst_type)

for u, v, k, data in G.edges(keys=True, data=True):
    rel = data.get('relation', '(missing)')
    edge_rel_counter[rel] += 1
    su = G.nodes[u].get('type', '(missing)')
    sv = G.nodes[v].get('type', '(missing)')
    edge_rel_triplet_counter[(su, rel, sv)] += 1

print("Edge relations:")
for rel, c in sorted(edge_rel_counter.items(), key=lambda x: (-x[1], x[0])):
    print(f"  - {rel}: {c:,}")
print()

print("Edge type triplets (src_type --relation--> dst_type):")
for (su, rel, sv), c in sorted(edge_rel_triplet_counter.items(), key=lambda x: (-x[1], x[0])):
    print(f"  - {su} --{rel}--> {sv}: {c:,}")

# --- Degree stats per node type (quick sanity) ---
def _degree_stats_by_type():
    per_type_deg = defaultdict(list)
    for n, data in G.nodes(data=True):
        t = data.get('type', '(missing)')
        per_type_deg[t].append(G.degree(n))
    print("\nDegree stats by node type (min / median / mean / max):")
    for t, degs in per_type_deg.items():
        arr = np.asarray(degs, dtype=np.int64)
        print(f"  - {t:>18}: {arr.min():,} / {int(np.median(arr)):,} / {arr.mean():.2f} / {arr.max():,}")

_degree_stats_by_type()


Graph: MultiDiGraph | directed=True | multigraph=True
Total nodes: 52,431 | Total edges: 560,102

Node types:
  - product: 39,038
  - brand: 13,285
  - attr: 108

Edge relations:
  - similar_to: 388,458
  - attr_of: 51,400
  - has_attr: 51,400
  - belongs_to_brand: 34,422
  - contains_product: 34,422

Edge type triplets (src_type --relation--> dst_type):
  - product --similar_to--> product: 388,458
  - attr --attr_of--> product: 51,400
  - product --has_attr--> attr: 51,400
  - brand --contains_product--> product: 34,422
  - product --belongs_to_brand--> brand: 34,422

Degree stats by node type (min / median / mean / max):
  -            product: 0 / 22 / 24.30 / 206
  -              brand: 2 / 2 / 5.18 / 422
  -               attr: 16 / 308 / 951.85 / 14,420


In [None]:
# Cell 6 — Metrics (evaluated against the FULL catalog candidate set)

def accuracy_at_k(ranks: List[int], K: int) -> float:
    n = max(1, len(ranks))
    return sum(1 for r in ranks if r <= K) / n

def ndcg_at_k_single(ranks: List[int], K: int) -> float:
    val = 0.0
    for r in ranks:
        if r <= K:
            val += 1.0 / math.log2(r + 1.0)
    n = max(1, len(ranks))
    return val / n

def ranks_from_scores(pid_lists: List[List[str]], offsets: List[int], flat_scores: np.ndarray, true_pids: List[str]) -> List[int]:
    ranks = []
    cur = 0
    for qi, pids in enumerate(pid_lists):
        K = len(pids)
        if K == 0:
            ranks.append(10**9); continue
        sl = flat_scores[cur:cur+K]
        order = np.argsort(-sl)
        ranked = [pids[i] for i in order]
        tpid = true_pids[qi]
        r = next((i+1 for i,p in enumerate(ranked) if p==tpid), len(ranked)+1)
        ranks.append(r)
        cur += K
    return ranks

def recall_at_k_presence(pid_lists, truths, K: int):
    """
    Presence-based Recall@K on candidate lists (over the full catalog),
    plus basic coverage and rank stats for hits.
    """
    hits = 0
    ranks = []
    non_empty = 0
    for pids, tpid in zip(pid_lists, truths):
        if pids: non_empty += 1
        try:
            r = pids.index(tpid) + 1
            if r <= K:
                hits += 1
                ranks.append(r)
        except ValueError:
            pass
    n = len(truths)
    rec = hits / max(1, n)
    med_rank = float(np.median(ranks)) if ranks else float('inf')
    mean_rank = float(np.mean(ranks)) if ranks else float('inf')
    coverage = non_empty / max(1, n)
    return rec, hits, med_rank, mean_rank, coverage


In [None]:
# Cell 4 — Unified SINGLE-PPR (brand treated as normal attr) + progress bar + VAL report

from types import SimpleNamespace
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import networkx as nx
import math
import re

# ---------- Tunables ----------
PPR_ALPHA = 0.80
PPR_HOPS  = 2        # enable product→{brand|attr|similar}→product

# Keep DETAIL_KEYS consistent with earlier cells
DETAIL_KEYS = ["item_form","material","hair_type","age_range","material_feature","color","skin_type","style"]

# ---------- Parsers & normalizers ----------
def _parse_float_safe(x):
    if x is None: return None
    if isinstance(x, (int, float, np.integer, np.floating)):
        try:
            xf = float(x);  return None if np.isnan(xf) else xf
        except Exception: return None
    try:
        s = str(x).strip()
    except Exception:
        return None
    if s == "" or s.lower() in {"none","null","nan","n/a"}: return None
    try:
        return float(s)
    except Exception:
        return None

def _norm_str(v):
    if v is None: return None
    s = str(v).strip()
    return s if s else None

def _as_list(x):
    if x is None:
        return []
    if isinstance(x, (list, tuple, set)):
        return [str(v).strip().lower() for v in x if str(v).strip()]
    s = str(x).strip()
    if not s:
        return []
    parts = re.split(r"[;,/|]", s)
    out = [p.strip().lower() for p in parts if p.strip()]
    return out if out else [s.lower()]

def make_constraints(C_raw, items_df=None) -> SimpleNamespace:
    """
    Normalize raw constraints into:
      brand, price_min, price_max, rating_min (parsed),
      details: { key -> [values] } (vocab-gated).
    """
    if isinstance(C_raw, SimpleNamespace):
        d = dict(C_raw.__dict__)
    elif isinstance(C_raw, dict):
        d = C_raw
    else:
        d = {}

    brand      = _norm_str(d.get("brand"))
    category   = _norm_str(d.get("category"))
    price_min  = _parse_float_safe(d.get("price_min"))
    price_max  = _parse_float_safe(d.get("price_max"))
    rating_min = _parse_float_safe(d.get("rating_min"))
    if price_min is not None and price_max is not None and price_min > price_max:
        price_min, price_max = price_max, price_min

    details = {}
    for k in DETAIL_KEYS:
        vals = _as_list(d.get(k))
        if not vals:
            continue
        if k in ATTR_KEYS:
            vocab = ATTR_KEYS[k]["vocab"]
            vals = [vv for vv in vals if vv in vocab]
        if vals:
            details[k] = vals

    return SimpleNamespace(
        brand=brand, category=category,
        price_min=price_min, price_max=price_max, rating_min=rating_min,
        details=details
    )

# ---------- Graph utils ----------
def _expand_k_hop(G: nx.MultiDiGraph, starts: list[str], hops: int) -> set:
    """Collect nodes within k hops (directed: predecessors + successors)."""
    frontier = set(starts)
    curr = set(starts)
    for _ in range(max(0, int(hops))):
        nxt = set()
        for n in curr:
            if n not in G:
                continue
            if G.is_directed():
                nxt.update(G.predecessors(n)); nxt.update(G.successors(n))
            else:
                nxt.update(G.neighbors(n))
        curr = nxt - frontier
        frontier |= nxt
    return frontier

def _catalog_sig_from_ids(ids: list[str]) -> str:
    return md5("\x1e".join(ids))

# ---------- HARD price/rating filters BEFORE PPR ----------
def _allowed_product_nodes(corpus, C: SimpleNamespace) -> set:
    """
    Returns a set of product node ids (e.g., 'product_B000...') that satisfy
    price_min/price_max/rating_min. Missing price/rating FAILS the constraint
    if that constraint is provided.
    """
    df = corpus.df
    ids = df["id"].astype(str).tolist()
    prices  = pd.to_numeric(df["price"], errors="coerce").to_numpy()           # NaN for missing
    ratings = pd.to_numeric(df["average_rating"], errors="coerce").to_numpy()  # NaN for missing

    pmn  = getattr(C, "price_min", None)
    pmx  = getattr(C, "price_max", None)
    rmin = getattr(C, "rating_min", None)

    mask = np.ones(len(ids), dtype=bool)
    if pmn is not None:
        mask &= (~np.isnan(prices)) & (prices >= float(pmn))
    if pmx is not None:
        mask &= (~np.isnan(prices)) & (prices <= float(pmx))
    if rmin is not None:
        mask &= (~np.isnan(ratings)) & (ratings >= float(rmin))

    allowed = {f"product_{pid}" for pid, keep in zip(ids, mask) if keep}
    return allowed

# ---------- Treat BRAND like any other attribute ----------
ATTR_KEYS.setdefault("brand", {"vocab": None, "beta": 1.0})

def _finalize_edge_weights_brand_equal(self):
    """
    Assign 'weight' to every edge:
      - sim edges:  LAMBDA_SIM * cosine
      - attr edges (incl BRAND): LAMBDA_ATTR * beta_key / log(1 + deg(hub))
    BRAND edges use key_name='brand' so they follow the same rule as any other attribute.
    """
    G = self.graph
    deg = dict(G.degree())
    for u, v, k, data in G.edges(keys=True, data=True):
        if data.get("etype") == "sim":
            s = float(data.get("weight_raw", 1.0))
            data["weight"] = float(LAMBDA_SIM * s)
            continue

        relation = data.get("relation")
        key_name = data.get("key_name")
        if relation in ("belongs_to_brand", "contains_product"):
            key_name = "brand"

        beta = 1.0
        if key_name in ATTR_KEYS:
            beta = float(ATTR_KEYS[key_name].get("beta", 1.0))

        hub = v if G.nodes[v].get("type") in ("attr", "brand") else u
        hdeg = max(1, deg.get(hub, 1))
        data["weight"] = float(LAMBDA_ATTR * beta / math.log(1.0 + hdeg))

# Monkey-patch onto the existing class
OptimizedKnowledgeGraph.finalize_edge_weights = _finalize_edge_weights_brand_equal

# If a corpus already exists, re-finalize weights immediately
if 'corpus' in globals():
    corpus.kg.finalize_edge_weights()

# ---------- Unified personalization (brand included like any other key) ----------
def _build_unified_personalization(G: nx.MultiDiGraph, C: SimpleNamespace) -> dict[str, float]:
    pers: dict[str, float] = {}

    # brand seed (weighted by ATTR_KEYS['brand']['beta'])
    if getattr(C, "brand", None):
        bnode = f"brand_{C.brand}"
        if bnode in G:
            w = float(ATTR_KEYS["brand"].get("beta", 1.0))
            pers[bnode] = pers.get(bnode, 0.0) + w

    # detail seeds
    if getattr(C, "details", None):
        for key, values in (C.details or {}).items():
            if not values:
                continue
            beta = float(ATTR_KEYS.get(key, {}).get("beta", 1.0))
            for v in values:
                anode = f"attr|{key}|{v}"
                if anode in G:
                    pers[anode] = pers.get(anode, 0.0) + beta
    return pers

# ---------- SINGLE PPR (unified) ----------
_PPR_UNIFIED_CACHE: dict = {}

def compute_ppr_prior_vector_unified(
    corpus,
    C_raw,
    alpha: float = PPR_ALPHA,
    hops: int = PPR_HOPS,
    normalize: bool = True,
) -> np.ndarray:
    """
    ONE PageRank with personalization on brand + attributes together.
    BRAND is treated exactly like any other attribute (beta via ATTR_KEYS['brand']).
    Hard price/rating filters are applied by pruning products from the subgraph before PR.
    """
    C = C_raw if isinstance(C_raw, SimpleNamespace) else make_constraints(C_raw)
    G = corpus.kg.graph
    ids_all = corpus.df["id"].astype(str).tolist()

    # HARD filters
    allowed_products = _allowed_product_nodes(corpus, C)
    if (getattr(C, "price_min", None) is not None or
        getattr(C, "price_max", None) is not None or
        getattr(C, "rating_min", None) is not None):
        if not allowed_products:
            return np.zeros((len(ids_all),), dtype=np.float32)

    # Unified seeds
    pers = _build_unified_personalization(G, C)
    if not pers:
        return np.zeros((len(ids_all),), dtype=np.float32)

    # Cache key
    detail_key = tuple((k, tuple(sorted(vs))) for k, vs in sorted((getattr(C, "details", {}) or {}).items()))
    key = (
        _catalog_sig_from_ids(ids_all),
        C.brand, C.category,
        _parse_float_safe(C.price_min), _parse_float_safe(C.price_max), _parse_float_safe(C.rating_min),
        detail_key,
        float(alpha), int(hops), float(LAMBDA_ATTR), float(LAMBDA_SIM),
        float(ATTR_KEYS["brand"].get("beta", 1.0))
    )
    if key in _PPR_UNIFIED_CACHE:
        return _PPR_UNIFIED_CACHE[key]

    # Build subgraph H from k-hop union (+ one-hop padding)
    seed_nodes = list(pers.keys())
    keep = _expand_k_hop(G, seed_nodes, hops=hops)
    extra = set()
    for n in list(keep):
        if n not in G:
            continue
        if G.is_directed():
            nbrs = set(G.predecessors(n)) | set(G.successors(n))
        else:
            nbrs = set(G.neighbors(n))
        extra |= nbrs
    keep |= extra
    H = G.subgraph(keep).copy()

    # Prune disallowed products
    if allowed_products:
        drop = [n for n in list(H.nodes()) if (G.nodes[n].get("type") == "product" and n not in allowed_products)]
        if drop:
            H.remove_nodes_from(drop)

    if H.number_of_nodes() == 0:
        v = np.zeros((len(ids_all),), dtype=np.float32)
        _PPR_UNIFIED_CACHE[key] = v
        return v

    # Normalize personalization on H
    s = sum(pers.get(n, 0.0) for n in pers if n in H)
    persH = {n: (pers[n] / s) for n in pers if n in H and pers[n] > 0.0}
    if not persH:
        v = np.zeros((len(ids_all),), dtype=np.float32)
        _PPR_UNIFIED_CACHE[key] = v
        return v

    pr = nx.pagerank(H, alpha=alpha, personalization=persH, max_iter=100, tol=1e-6, weight="weight")

    v = np.array([float(pr.get(f"product_{pid}", 0.0)) for pid in ids_all], dtype=np.float32)
    if normalize and v.size:
        vmin, vmax = float(np.min(v)), float(np.max(v))
        if vmax > vmin:
            v = (v - vmin) / (vmax - vmin)
        else:
            v = np.zeros_like(v, dtype=np.float32)

    v = v.astype(np.float32)
    _PPR_UNIFIED_CACHE[key] = v
    return v

# ---------- Candidate generation with tqdm ----------
def ppr_topk_unified(constraints_series: pd.Series, topk: int, corpus, desc: str = "PPR (unified)"):
    ids = corpus.df["id"].astype(str).tolist()
    out = []
    for C_raw in tqdm(constraints_series.tolist(), total=len(constraints_series), desc=desc):
        v = compute_ppr_prior_vector_unified(corpus, C_raw)
        order = np.argsort(-v)[:topk]
        out.append([ids[i] for i in order])
    return out

print("[OK] Brand treated as normal attribute; unified single-PPR is ready (with tqdm).")




[OK] Brand treated as normal attribute; unified single-PPR is ready (with tqdm).


In [None]:
# Cell 7 — 60/20/20 split + unified PPR evaluation on VAL and TEST (full-catalog scoring, with tqdm)

from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import random

# -----------------------------
# Fallback metric helpers (only if missing)
# -----------------------------
try:
    accuracy_at_k
except NameError:
    def accuracy_at_k(ranks, K: int) -> float:
        n = max(1, len(ranks))
        return sum(1 for r in ranks if r <= K) / n

try:
    ndcg_at_k_single
except NameError:
    import math
    def ndcg_at_k_single(ranks, K: int) -> float:
        val = 0.0
        for r in ranks:
            if r <= K:
                val += 1.0 / math.log2(r + 1.0)
        n = max(1, len(ranks))
        return val / n

try:
    recall_at_k_presence
except NameError:
    def recall_at_k_presence(pid_lists, truths, K: int):
        hits = 0
        ranks = []
        non_empty = 0
        for pids, tpid in zip(pid_lists, truths):
            if pids: non_empty += 1
            try:
                r = pids.index(tpid) + 1
                if r <= K:
                    hits += 1
                    ranks.append(r)
            except ValueError:
                pass
        n = len(truths)
        rec = hits / max(1, n)
        med_rank = float(np.median(ranks)) if ranks else float('inf')
        mean_rank = float(np.mean(ranks)) if ranks else float('inf')
        coverage = non_empty / max(1, n)
        return rec, hits, med_rank, mean_rank, coverage

# -----------------------------
# Train/Val/Test split at 60/20/20 by product IDs
# -----------------------------
assert 'queries_df' in globals(), "queries_df not found. Build it from your queries JSONL first."

SEED = 42 if 'SEED' not in globals() else SEED

def split_queries_60_20_20(qdf: pd.DataFrame, seed: int = SEED):
    uniq_pids = list(sorted(set(qdf["pid"].astype(str).tolist())))
    rnd = random.Random(seed); rnd.shuffle(uniq_pids)
    n = len(uniq_pids)
    n_train = int(0.6 * n)
    n_val   = int(0.2 * n)
    train_p = set(uniq_pids[:n_train])
    val_p   = set(uniq_pids[n_train:n_train+n_val])
    test_p  = set(uniq_pids[n_train+n_val:])
    sel = lambda s: qdf[qdf["pid"].isin(s)].reset_index(drop=True)
    return sel(train_p), sel(val_p), sel(test_p)

train_q, val_q, test_q = split_queries_60_20_20(queries_df, SEED)
print(f"[Split] Products — train:{len(set(train_q.pid))}  val:{len(set(val_q.pid))}  test:{len(set(test_q.pid))}")
print(f"[Split] Queries  — train:{len(train_q)}  val:{len(val_q)}  test:{len(test_q)}")

# -----------------------------
# Unified PPR evaluator (VAL/TEST)
# -----------------------------
TOPK_PPR = 200
KS = (10, 50, 200)

def evaluate_unified_ppr_on_split(corpus, qdf: pd.DataFrame, topk: int = TOPK_PPR, ks=KS, tag: str = "SPLIT"):
    ids = corpus.df["id"].astype(str).tolist()
    id2row = {pid: i for i, pid in enumerate(ids)}
    constraints_list = qdf["constraints"].tolist()
    truths = qdf["pid"].astype(str).tolist()

    ranks = []
    pid_topk = []

    for i in tqdm(range(len(truths)), desc=f"Scoring & ranking ({tag})"):
        C_raw = constraints_list[i]
        tpid = truths[i]

        # Single PPR over full catalog (vector aligns to ids order)
        v = compute_ppr_prior_vector_unified(corpus, C_raw)
        order = np.argsort(-v)
        pid_topk.append([ids[j] for j in order[:topk]])

        idx_true = id2row.get(tpid, None)
        if idx_true is None:
            ranks.append(10**9)
        else:
            pos = np.where(order == idx_true)[0]
            ranks.append(int(pos[0]) + 1 if len(pos) else 10**9)

    # Presence-based Recall@K on candidate lists
    print(f"\n[{tag}] Candidate presence (top-{topk}) over {len(truths)} queries")
    for K in ks:
        rec, hits, med_r, mean_r, coverage = recall_at_k_presence(pid_topk, truths, K)
        print(f"  Recall@{K}: {rec:.4f}  hits={hits}  coverage={coverage:.3f}  "
              f"median_rank={med_r:.1f}  mean_rank={mean_r:.1f}")

    # Rank-based metrics from full order
    print(f"\n[{tag}] Rank-based metrics from full order")
    for K in ks:
        acc = accuracy_at_k(ranks, K)
        ndcg = ndcg_at_k_single(ranks, K)
        print(f"  Acc@{K}: {acc:.4f}  NDCG@{K}: {ndcg:.4f}")

    return {"ranks": ranks, "pid_topk": pid_topk}



[Split] Products — train:438  val:146  test:147
[Split] Queries  — train:2190  val:730  test:735


In [None]:
# === Single cell (FIX #4 + Subset: 50 VAL / 50 TEST; progress bars + HARD FILTERS; incremental reporting) ===
# Evaluates OpenAI Retrieval (File Search over Vector Store) on small subsets of VAL/TEST.
# Metrics: Accuracy@K, Recall@K, NDCG@K for K in {1,3,5,10,50}
# Versions:
#   V1 = intent only (no text constraints; no hard filter)
#   V2 = intent + constraints text (and HARD FILTER on price_min/price_max/rating_min)
#
# REQS:
#   pip install --upgrade openai tqdm
# PRESETS from earlier cells:
#   - items_df with columns: ["id","title","text","price","average_rating"]
#   - val_q, test_q with ["qid","pid","intent_text","constraints"]

import os, json, re, math, time, sys
from typing import List, Dict, Any, Optional, Set
from tqdm.auto import tqdm
import numpy as np
import pandas as pd

# ---------------------------
# OpenAI client
# ---------------------------
from openai import OpenAI
OPENAI_API_KEY = "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"
client = OpenAI(api_key=OPENAI_API_KEY)

# ---------------------------
# Paths & constants
# ---------------------------
OUT_DIR = "/content/mm_index"
os.makedirs(OUT_DIR, exist_ok=True)
CATALOG_TXT = os.path.join(OUT_DIR, "catalog_for_openai_retrieval.txt")
VS_META_PATH = os.path.join(OUT_DIR, "openai_vector_store_meta.json")

MODEL = "gpt-4o-mini"
K_LIST = [1, 3, 5, 10, 50]
TOPK_RETRIEVE = max(K_LIST)
EXTRA_FACTOR = 3  # over-retrieve before hard-filtering in V2 (kept small for speed)
SEED = 42
VAL_N = 50
TEST_N = 50

# ---------------------------
# Metrics (binary relevance, 1 GT per query)
# ---------------------------
def compute_ranks(pred_lists: List[List[str]], truths: List[str]) -> List[int]:
    ranks = []
    for preds, true_pid in zip(pred_lists, truths):
        ranks.append(preds.index(true_pid) + 1 if true_pid in preds else 10**9)
    return ranks

def accuracy_at_k(ranks: List[int], K: int) -> float:
    n = max(1, len(ranks))
    return sum(1 for r in ranks if r <= K) / n

def recall_at_k(ranks: List[int], K: int) -> float:
    return accuracy_at_k(ranks, K)  # single relevant item

def ndcg_at_k(ranks: List[int], K: int) -> float:
    dcg = sum(1.0 / math.log2(r + 1.0) for r in ranks if r <= K)
    return dcg / max(1, len(ranks))  # IDCG=1 per query

def pretty_print_scores(label: str, ranks: List[int]):
    print(f"\n[{label}] Metrics (n={len(ranks)})")
    for K in K_LIST:
        print(f"  @ {K:>2d}  Acc: {accuracy_at_k(ranks,K):.4f}   "
              f"Recall: {recall_at_k(ranks,K):.4f}   NDCG: {ndcg_at_k(ranks,K):.4f}")
    sys.stdout.flush()

# ---------------------------
# Catalog writer used by Retrieval (one big text file; includes PID tags)
# ---------------------------
def write_catalog_txt(corpus_df: pd.DataFrame, path: str):
    with open(path, "w", encoding="utf-8") as f, tqdm(total=len(corpus_df), desc="Writing catalog.txt", unit="item") as pbar:
        for _, row in corpus_df.iterrows():
            pid = str(row["id"])
            title = str(row.get("title", "") or "")
            txt = str(row.get("text", "") or "")
            f.write(f"PID: {pid}\nTITLE: {title}\nCONTENT:\n{txt}\n---\n")
            pbar.update(1)

# ---------------------------
# Vector Store (create once; upload the catalog)
# ---------------------------
def get_or_create_vector_store(catalog_txt_path: str) -> str:
    if os.path.exists(VS_META_PATH):
        try:
            meta = json.load(open(VS_META_PATH, "r"))
            if meta.get("vector_store_id"):
                tqdm.write(f"[Reusing] Vector Store: {meta['vector_store_id']}")
                return meta["vector_store_id"]
        except Exception:
            pass

    if not os.path.exists(catalog_txt_path):
        write_catalog_txt(items_df[["id","title","text"]], catalog_txt_path)

    tqdm.write("[Create] Creating vector store…")
    vs = client.vector_stores.create(name=f"beauty-catalog-{int(time.time())}")

    tqdm.write("[Upload] Uploading catalog (OpenAI will index)…")
    with open(catalog_txt_path, "rb") as fh:
        client.vector_stores.file_batches.upload_and_poll(vector_store_id=vs.id, files=[fh])
    tqdm.write("[Upload] Done.")

    json.dump({"vector_store_id": vs.id, "catalog_file": catalog_txt_path}, open(VS_META_PATH, "w"))
    return vs.id

VECTOR_STORE_ID = get_or_create_vector_store(CATALOG_TXT)
tqdm.write(f"[OK] Vector Store ready: {VECTOR_STORE_ID}")

# ---------------------------
# Query builders
# ---------------------------
def serialize_constraints(C: Dict[str, Any]) -> str:
    if not isinstance(C, dict):
        return ""
    flat = []
    for k in ["brand","price_min","price_max","rating_min"]:
        if k in C and C[k] not in (None, "", "None"):
            flat.append(f"{k}={C[k]}")
    for k, v in C.items():
        if k in {"brand","price_min","price_max","rating_min"}: continue
        if v in (None, "", [], {}): continue
        flat.append(f"{k}={v}")
    return "; ".join(flat)

def mk_query_v1(intent_text: str, C: Optional[Dict[str, Any]] = None) -> str:
    return intent_text  # ignores constraints

def mk_query_v2(intent_text: str, C: Optional[Dict[str, Any]] = None) -> str:
    sc = serialize_constraints(C or {})
    return f"{intent_text}\nConstraints: {sc}" if sc else intent_text

# ---------------------------
# HARD FILTER (price_min/max, rating_min) for V2
# Missing numeric fields FAIL the constraint if that bound is provided.
# ---------------------------
_ALLOWED_CACHE: Dict[tuple, List[str]] = {}

def allowed_pid_set_for_constraints(C: Dict[str, Any]) -> List[str]:
    pmn, pmx, rmin = C.get("price_min"), C.get("price_max"), C.get("rating_min")
    key = (pmn, pmx, rmin)
    if key in _ALLOWED_CACHE:
        return _ALLOWED_CACHE[key]

    df = items_df.copy()
    prices  = pd.to_numeric(df["price"], errors="coerce")
    ratings = pd.to_numeric(df["average_rating"], errors="coerce")

    def _f(v):
        try: return float(v)
        except: return None
    pmn, pmx, rmin = _f(pmn), _f(pmx), _f(rmin)

    mask = pd.Series(True, index=df.index)
    if pmn is not None: mask &= prices.notna() & (prices >= pmn)
    if pmx is not None: mask &= prices.notna() & (prices <= pmx)
    if rmin is not None: mask &= ratings.notna() & (ratings >= rmin)

    allowed = df.loc[mask, "id"].astype(str).tolist()
    _ALLOWED_CACHE[key] = allowed
    return allowed

# ---------------------------
# Retrieval via Responses + File Search
# IMPORTANT: Put vector_store_ids directly on the tool AND set tool_choice='required'
# so the model must call file_search.
# ---------------------------
JSON_INSTRUCTIONS = (
    "You are a strict retrieval assistant. Use the FILE SEARCH tool to fetch relevant passages. "
    'EXTRACT product IDs ONLY by matching "PID: <ID>" present in retrieved text. '
    "Do NOT guess or fabricate IDs. If unsure, return an empty list. "
    'Return JSON only: {"pids": ["..."]}.'
)

def _parse_response_text(resp) -> str:
    txt = getattr(resp, "output_text", None)
    if txt: return txt
    try:
        parts = []
        for item in getattr(resp, "output", []) or []:
            for c in getattr(item, "content", []) or []:
                if getattr(c, "type", "") == "output_text" and getattr(c, "text", None):
                    parts.append(c.text)
        return "\n".join(parts)
    except Exception:
        return ""

def retrieve_pids(query: str, topk_final: int, extraction_cap: int, allowed_set: Optional[Set[str]]) -> List[str]:
    if allowed_set is not None and len(allowed_set) == 0:
        return []

    limit = extraction_cap if allowed_set else topk_final

    try:
        resp = client.responses.create(
            model=MODEL,
            input=[
                {"role": "system", "content": JSON_INSTRUCTIONS + f" Limit to at most {limit} unique IDs."},
                {"role": "user", "content": query},
            ],
            tools=[{
                "type": "file_search",
                "vector_store_ids": [VECTOR_STORE_ID],
            }],
            tool_choice="required",  # valid values: 'none', 'auto', 'required'
            max_output_tokens=800,
        )
    except Exception as e:
        # If the request fails for a query, return empty list so the loop can continue
        tqdm.write(f"[WARN] Retrieval error: {e}")
        return []

    txt = _parse_response_text(resp)

    # Try strict JSON parse first
    preds: List[str] = []
    try:
        data = json.loads(txt)
        if isinstance(data, dict) and isinstance(data.get("pids"), list):
            preds = [str(x) for x in data["pids"]][:limit]
    except Exception:
        # Fallback: regex scan as a last resort
        preds = re.findall(r"PID:\s*([A-Za-z0-9_-]+)", txt)[:limit]

    # Deduplicate (preserve order)
    seen, uniq = set(), []
    for p in preds:
        if p not in seen:
            seen.add(p); uniq.append(p)

    # HARD FILTER (V2 only)
    if allowed_set is not None:
        uniq = [p for p in uniq if p in allowed_set]

    return uniq[:topk_final]

# ---------------------------
# Batch evaluation with per-query progress bar
# ---------------------------
def evaluate_split_with_retrieval(split_df: pd.DataFrame, make_query_fn, label: str, use_constraints: bool) -> Dict[str, Any]:
    truths = split_df["pid"].astype(str).tolist()
    pred_lists: List[List[str]] = []

    with tqdm(total=len(split_df), desc=f"Retrieving [{label}]", unit="q") as pbar:
        for _, row in split_df.iterrows():
            intent = str(row["intent_text"])
            C = row.get("constraints", {}) or {}

            q = make_query_fn(intent, C)  # mk_query_v1 ignores C; mk_query_v2 uses C
            allowed_ids = set(allowed_pid_set_for_constraints(C)) if use_constraints else None

            preds = retrieve_pids(
                query=q,
                topk_final=TOPK_RETRIEVE,
                extraction_cap=TOPK_RETRIEVE * EXTRA_FACTOR,
                allowed_set=allowed_ids
            )
            pred_lists.append(preds)
            pbar.update(1)

    ranks = compute_ranks(pred_lists, truths)
    return {"ranks": ranks, "preds": pred_lists, "truths": truths}

# ---------------------------
# Prep catalog (if needed)
# ---------------------------
if not os.path.exists(CATALOG_TXT):
    write_catalog_txt(items_df[["id","title","text"]], CATALOG_TXT)

# ---------------------------
# Build 50/50 subsets (deterministic)
# ---------------------------
assert "val_q" in globals() and "test_q" in globals(), "val_q and test_q must exist."
val_50  = (val_q.sample(n=min(VAL_N, len(val_q)), random_state=SEED).reset_index(drop=True))
test_50 = (test_q.sample(n=min(TEST_N, len(test_q)), random_state=SEED).reset_index(drop=True))
print(f"[Subset sizes] VAL: {len(val_50)}  TEST: {len(test_50)}")
sys.stdout.flush()

# ---------------------------
# Run evaluations (on subsets only) with CONSECUTIVE REPORTING
# ---------------------------
# V1: intent only (no hard filters)
res_v1_val50 = evaluate_split_with_retrieval(val_50,  make_query_fn=mk_query_v1, label="VAL50 · V1 (intent)", use_constraints=False)
pretty_print_scores("VAL50 — V1 (intent only)", res_v1_val50["ranks"])  # report immediately

res_v1_tst50 = evaluate_split_with_retrieval(test_50, make_query_fn=mk_query_v1, label="TEST50 · V1 (intent)", use_constraints=False)
pretty_print_scores("TEST50 — V1 (intent only)", res_v1_tst50["ranks"])  # report immediately

# V2: intent + constraints (with hard filters)
res_v2_val50 = evaluate_split_with_retrieval(val_50,  make_query_fn=mk_query_v2, label="VAL50 · V2 (intent+constraints)", use_constraints=True)
pretty_print_scores("VAL50 — V2 (intent + constraints, hard-filtered)", res_v2_val50["ranks"])  # report immediately

res_v2_tst50 = evaluate_split_with_retrieval(test_50, make_query_fn=mk_query_v2, label="TEST50 · V2 (intent+constraints)", use_constraints=True)
pretty_print_scores("TEST50 — V2 (intent + constraints, hard-filtered)", res_v2_tst50["ranks"])  # report immediately

# Optional helper for inspecting subset predictions
def show_example_subset(i: int, split="val", version=1):
    if split.lower().startswith("v"):
        df, res = val_50, (res_v1_val50 if version==1 else res_v2_val50)
    else:
        df, res = test_50, (res_v1_tst50 if version==1 else res_v2_tst50)
    row = df.iloc[i]
    print("\n--- Example ---")
    print("QID:", row["qid"])
    print("Intent:", row["intent_text"])
    print("Constraints:", row["constraints"])
    print("Truth PID:", row["pid"])
    print("Top-10 preds:", res["preds"][i][:10])

print("\nTip: show_example_subset(0, split='val', version=2)  # inspect a single query from the VAL subset.")
sys.stdout.flush()


[Reusing] Vector Store: vs_6914815a36fc8191b48fc2741d533180
[OK] Vector Store ready: vs_6914815a36fc8191b48fc2741d533180
[Subset sizes] VAL: 50  TEST: 50


Retrieving [VAL50 · V1 (intent)]:   0%|          | 0/50 [00:00<?, ?q/s]


[VAL50 — V1 (intent only)] Metrics (n=50)
  @  1  Acc: 0.0200   Recall: 0.0200   NDCG: 0.0200
  @  3  Acc: 0.0600   Recall: 0.0600   NDCG: 0.0426
  @  5  Acc: 0.0600   Recall: 0.0600   NDCG: 0.0426
  @ 10  Acc: 0.0800   Recall: 0.0800   NDCG: 0.0489
  @ 50  Acc: 0.0800   Recall: 0.0800   NDCG: 0.0489


Retrieving [TEST50 · V1 (intent)]:   0%|          | 0/50 [00:00<?, ?q/s]


[TEST50 — V1 (intent only)] Metrics (n=50)
  @  1  Acc: 0.0600   Recall: 0.0600   NDCG: 0.0600
  @  3  Acc: 0.0800   Recall: 0.0800   NDCG: 0.0700
  @  5  Acc: 0.1000   Recall: 0.1000   NDCG: 0.0786
  @ 10  Acc: 0.1200   Recall: 0.1200   NDCG: 0.0857
  @ 50  Acc: 0.1200   Recall: 0.1200   NDCG: 0.0857


Retrieving [VAL50 · V2 (intent+constraints)]:   0%|          | 0/50 [00:00<?, ?q/s]


[VAL50 — V2 (intent + constraints, hard-filtered)] Metrics (n=50)
  @  1  Acc: 0.2800   Recall: 0.2800   NDCG: 0.2800
  @  3  Acc: 0.3200   Recall: 0.3200   NDCG: 0.3052
  @  5  Acc: 0.3400   Recall: 0.3400   NDCG: 0.3130
  @ 10  Acc: 0.3400   Recall: 0.3400   NDCG: 0.3130
  @ 50  Acc: 0.3400   Recall: 0.3400   NDCG: 0.3130


Retrieving [TEST50 · V2 (intent+constraints)]:   0%|          | 0/50 [00:00<?, ?q/s]


[TEST50 — V2 (intent + constraints, hard-filtered)] Metrics (n=50)
  @  1  Acc: 0.3600   Recall: 0.3600   NDCG: 0.3600
  @  3  Acc: 0.4400   Recall: 0.4400   NDCG: 0.4105
  @  5  Acc: 0.4400   Recall: 0.4400   NDCG: 0.4105
  @ 10  Acc: 0.4400   Recall: 0.4400   NDCG: 0.4105
  @ 50  Acc: 0.4400   Recall: 0.4400   NDCG: 0.4105

Tip: show_example_subset(0, split='val', version=2)  # inspect a single query from the VAL subset.


In [None]:
# -----------------------------
# Run VAL and TEST evaluations on the big catalog
# -----------------------------
val_results  = evaluate_unified_ppr_on_split(corpus, val_q,  topk=TOPK_PPR, ks=KS, tag="VAL")
test_results = evaluate_unified_ppr_on_split(corpus, test_q, topk=TOPK_PPR, ks=KS, tag="TEST")


Scoring & ranking (VAL):   0%|          | 0/730 [00:00<?, ?it/s]


[VAL] Candidate presence (top-200) over 730 queries
  Recall@10: 0.8288  hits=605  coverage=1.000  median_rank=1.0  mean_rank=2.1
  Recall@50: 0.9110  hits=665  coverage=1.000  median_rank=1.0  mean_rank=4.1
  Recall@200: 0.9315  hits=680  coverage=1.000  median_rank=1.0  mean_rank=5.7

[VAL] Rank-based metrics from full order
  Acc@10: 0.8288  NDCG@10: 0.6610
  Acc@50: 0.9110  NDCG@50: 0.6800
  Acc@200: 0.9315  NDCG@200: 0.6834


Scoring & ranking (TEST):   0%|          | 0/735 [00:00<?, ?it/s]


[TEST] Candidate presence (top-200) over 735 queries
  Recall@10: 0.8571  hits=630  coverage=1.000  median_rank=1.0  mean_rank=2.4
  Recall@50: 0.9184  hits=675  coverage=1.000  median_rank=2.0  mean_rank=3.6
  Recall@200: 0.9456  hits=695  coverage=1.000  median_rank=2.0  mean_rank=5.8

[TEST] Rank-based metrics from full order
  Acc@10: 0.8571  NDCG@10: 0.6495
  Acc@50: 0.9184  NDCG@50: 0.6642
  Acc@200: 0.9456  NDCG@200: 0.6685


In [None]:
# Cell — Budgeted sweeps with requested metrics only (no PPR Acc@1/10)

from tqdm.auto import tqdm
import numpy as np, math, pandas as pd

# ------------ Config (small sweeps) ------------
ALPHAS_FOR_SWEEP = [0.75, 0.8, 0.85]   # Phase A
ATTR_FRACS_FOR_SWEEP = [0.50, 0.60, 0.70, 0.80]  # Phase B

FIXED_LAMBDA_ATTR_FOR_ALPHA = 0.60
FIXED_ALPHA_FOR_RATIO = 0.80

TOPK_PPR = 200
SPLIT = "VAL"

# ------------ Preconditions ------------
assert 'corpus' in globals() and 'compute_ppr_prior_vector_unified' in globals()
assert 'val_q' in globals() and 'test_q' in globals()
assert 'LAMBDA_ATTR' in globals() and 'LAMBDA_SIM' in globals()

# ------------ Helpers ------------
def _clear_ppr_caches():
    if '_PPR_UNIFIED_CACHE' in globals():
        try: _PPR_UNIFIED_CACHE.clear()
        except: pass
    if '_PPR_COMP_CACHE' in globals():
        try: _PPR_COMP_CACHE.clear()
        except: pass

def _rank_of_truth(v: np.ndarray, idx_true: int) -> int:
    if idx_true is None or idx_true < 0:
        return 10**9
    s_true = v[idx_true]
    return int(np.sum(v > s_true)) + 1  # 1 + #strictly-greater

def _ndcg_at_k_from_ranks(ranks, K: int) -> float:
    vals = [(1.0 / math.log2(1 + r)) if r <= K else 0.0 for r in ranks]
    return float(np.mean(vals)) if ranks else 0.0

def recall_at_k_presence(pid_lists, truths, K: int):
    hits = 0
    ranks, non_empty = [], 0
    for pids, tpid in zip(pid_lists, truths):
        if pids: non_empty += 1
        try:
            r = pids.index(tpid) + 1
            if r <= K:
                hits += 1
                ranks.append(r)
        except ValueError:
            pass
    n = len(truths)
    rec = hits / max(1, n)
    med_rank = float(np.median(ranks)) if ranks else float('inf')
    mean_rank = float(np.mean(ranks)) if ranks else float('inf')
    coverage = non_empty / max(1, n)
    return rec, hits, coverage, med_rank, mean_rank

def _eval_split_metrics(alpha: float):
    qdf = val_q if SPLIT.upper() == "VAL" else test_q
    ids = corpus.df["id"].astype(str).tolist()
    id2row = {pid: i for i, pid in enumerate(ids)}
    constraints = qdf["constraints"].tolist()
    truths = qdf["pid"].astype(str).tolist()

    # Build PPR-only candidate lists (Top-200) and full-order ranks
    pid_topk = []
    ranks_full = []

    for C_raw, tpid in tqdm(zip(constraints, truths), total=len(truths),
                            desc=f"[{SPLIT}] Scoring α={alpha:.2f}"):
        v = compute_ppr_prior_vector_unified(corpus, C_raw, alpha=alpha)  # normalized [0,1]
        order = np.argsort(-v)
        pid_topk.append([ids[j] for j in order[:TOPK_PPR]])

        idx_true = id2row.get(tpid, None)
        ranks_full.append(_rank_of_truth(v, idx_true))

    # Presence-based Recall@{10,50,200}
    for K in (10, 50, 200):
        rec, hits, coverage, med_r, mean_r = recall_at_k_presence(pid_topk, truths, K)
        print(f"  Recall@{K}: {rec:.4f}  hits={hits}  coverage={coverage:.3f}  "
              f"median_rank={med_r:.1f}  mean_rank={mean_r:.1f}")

    # Rank-based (full order) Acc@{10,50,200} and NDCG@{10,50,200}
    print(f"\n[{SPLIT}] Rank-based metrics from full order")
    for K in (10, 50, 200):
        acc = float(sum(1 for r in ranks_full if r <= K)) / max(1, len(ranks_full))
        ndcg = _ndcg_at_k_from_ranks(ranks_full, K)
        print(f"  Acc@{K}: {acc:.4f}  NDCG@{K}: {ndcg:.4f}")

# ------------ Phase A: fix lambda_attr=0.6, sweep alpha ------------
print("=== Phase A: Fix λ_attr=0.60 (λ_sim=0.40), sweep α ===")
LAMBDA_ATTR = float(FIXED_LAMBDA_ATTR_FOR_ALPHA)
LAMBDA_SIM  = float(1.0 - LAMBDA_ATTR)
corpus.kg.finalize_edge_weights()
_clear_ppr_caches()
for a in ALPHAS_FOR_SWEEP:
    print(f"\n[α={a:.2f}]")
    _eval_split_metrics(alpha=a)

# ------------ Phase B: fix alpha=0.80, sweep lambda_attr ------------
print("\n=== Phase B: Fix α=0.80, sweep λ_attr ∈ {0.50, 0.60, 0.70, 0.80} ===")
for lam_attr in ATTR_FRACS_FOR_SWEEP:
    LAMBDA_ATTR = float(lam_attr)
    LAMBDA_SIM  = float(1.0 - lam_attr)
    corpus.kg.finalize_edge_weights()
    _clear_ppr_caches()
    print(f"\n[λ_attr={LAMBDA_ATTR:.2f}, λ_sim={LAMBDA_SIM:.2f}]")
    _eval_split_metrics(alpha=FIXED_ALPHA_FOR_RATIO)


=== Phase A: Fix λ_attr=0.60 (λ_sim=0.40), sweep α ===

[α=0.75]


[VAL] Scoring α=0.75:   0%|          | 0/730 [00:00<?, ?it/s]

  Recall@10: 0.8356  hits=610  coverage=1.000  median_rank=1.0  mean_rank=2.2
  Recall@50: 0.9110  hits=665  coverage=1.000  median_rank=1.0  mean_rank=4.0
  Recall@200: 0.9315  hits=680  coverage=1.000  median_rank=1.0  mean_rank=5.6

[VAL] Rank-based metrics from full order
  Acc@10: 0.8973  NDCG@10: 0.7259
  Acc@50: 0.9726  NDCG@50: 0.7431
  Acc@200: 0.9932  NDCG@200: 0.7465

[α=0.80]


[VAL] Scoring α=0.80:   0%|          | 0/730 [00:00<?, ?it/s]

  Recall@10: 0.8288  hits=605  coverage=1.000  median_rank=1.0  mean_rank=2.1
  Recall@50: 0.9110  hits=665  coverage=1.000  median_rank=1.0  mean_rank=4.1
  Recall@200: 0.9315  hits=680  coverage=1.000  median_rank=1.0  mean_rank=5.7

[VAL] Rank-based metrics from full order
  Acc@10: 0.8904  NDCG@10: 0.7236
  Acc@50: 0.9726  NDCG@50: 0.7426
  Acc@200: 0.9932  NDCG@200: 0.7459

[α=0.85]


[VAL] Scoring α=0.85:   0%|          | 0/730 [00:00<?, ?it/s]

  Recall@10: 0.8288  hits=605  coverage=1.000  median_rank=1.0  mean_rank=2.1
  Recall@50: 0.9110  hits=665  coverage=1.000  median_rank=1.0  mean_rank=4.1
  Recall@200: 0.9315  hits=680  coverage=1.000  median_rank=1.0  mean_rank=5.8

[VAL] Rank-based metrics from full order
  Acc@10: 0.8904  NDCG@10: 0.7235
  Acc@50: 0.9726  NDCG@50: 0.7423
  Acc@200: 0.9932  NDCG@200: 0.7456

=== Phase B: Fix α=0.80, sweep λ_attr ∈ {0.50, 0.60, 0.70, 0.80} ===

[λ_attr=0.50, λ_sim=0.50]


[VAL] Scoring α=0.80:   0%|          | 0/730 [00:00<?, ?it/s]

  Recall@10: 0.8288  hits=605  coverage=1.000  median_rank=1.0  mean_rank=2.1
  Recall@50: 0.9110  hits=665  coverage=1.000  median_rank=1.0  mean_rank=4.1
  Recall@200: 0.9315  hits=680  coverage=1.000  median_rank=1.0  mean_rank=5.7

[VAL] Rank-based metrics from full order
  Acc@10: 0.8904  NDCG@10: 0.7237
  Acc@50: 0.9726  NDCG@50: 0.7427
  Acc@200: 0.9932  NDCG@200: 0.7460

[λ_attr=0.60, λ_sim=0.40]


[VAL] Scoring α=0.80:   0%|          | 0/730 [00:00<?, ?it/s]

  Recall@10: 0.8288  hits=605  coverage=1.000  median_rank=1.0  mean_rank=2.1
  Recall@50: 0.9110  hits=665  coverage=1.000  median_rank=1.0  mean_rank=4.1
  Recall@200: 0.9315  hits=680  coverage=1.000  median_rank=1.0  mean_rank=5.7

[VAL] Rank-based metrics from full order
  Acc@10: 0.8904  NDCG@10: 0.7236
  Acc@50: 0.9726  NDCG@50: 0.7426
  Acc@200: 0.9932  NDCG@200: 0.7459

[λ_attr=0.70, λ_sim=0.30]


[VAL] Scoring α=0.80:   0%|          | 0/730 [00:00<?, ?it/s]

  Recall@10: 0.8288  hits=605  coverage=1.000  median_rank=1.0  mean_rank=2.1
  Recall@50: 0.9110  hits=665  coverage=1.000  median_rank=1.0  mean_rank=4.1
  Recall@200: 0.9315  hits=680  coverage=1.000  median_rank=1.0  mean_rank=5.8

[VAL] Rank-based metrics from full order
  Acc@10: 0.8904  NDCG@10: 0.7236
  Acc@50: 0.9726  NDCG@50: 0.7425
  Acc@200: 0.9932  NDCG@200: 0.7458

[λ_attr=0.80, λ_sim=0.20]


[VAL] Scoring α=0.80:   0%|          | 0/730 [00:00<?, ?it/s]

  Recall@10: 0.8288  hits=605  coverage=1.000  median_rank=1.0  mean_rank=2.1
  Recall@50: 0.9041  hits=660  coverage=1.000  median_rank=1.0  mean_rank=3.7
  Recall@200: 0.9315  hits=680  coverage=1.000  median_rank=1.0  mean_rank=5.8

[VAL] Rank-based metrics from full order
  Acc@10: 0.8904  NDCG@10: 0.7394
  Acc@50: 0.9658  NDCG@50: 0.7572
  Acc@200: 0.9932  NDCG@200: 0.7617


In [None]:
# Cell 8 — Logistic fusion (CE + IMG) on PPR@50; train on TRAIN_SUB, eval on VAL/TEST

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sentence_transformers import CrossEncoder
from tqdm.auto import tqdm
import numpy as np, torch, pandas as pd, math, random

# =========================
# CONFIG — EDIT HERE
# =========================
TOPK_PPR   = 50             # candidates per query
CE_BATCH   = 256            # CE batch size (shrink if GPU OOM with larger models)
SUB_FRAC   = 1.0           # train blender on 1/8 of TRAIN (set 1.0 for full TRAIN)

# Tuned PPR / KG weights
BEST_ALPHA  = 0.80          # teleport/prob. of following edges
LAMBDA_ATTR = 0.80          # global budget for product<->(brand/attr) edges
LAMBDA_SIM  = 0.20          # global budget for product<->product sim edges

# Optional: fine control of brand seed emphasis (kept equal by default)
# ATTR_KEYS["brand"]["beta"] = 1.0    # e.g., 0.7 to mute, 1.3 to boost

# Cross-encoder choice
RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
# RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
# RERANK_MODEL = "mixedbread-ai/mxbai-rerank-large-v1"

# Logistic regression hyperparams
LR_SOLVER = "liblinear"
LR_C = 1.0
LR_CLASS_WEIGHT = "balanced"
LR_MAX_ITER = 1000

# =========================
# Preconditions
# =========================
assert 'compute_ppr_prior_vector_unified' in globals(), "Need unified PPR helper."
assert 'train_q' in globals() and 'val_q' in globals() and 'test_q' in globals(), "Missing TRAIN/VAL/TEST splits."
assert 'corpus' in globals(), "Need built corpus with encoders and embeddings."

DEVICE = globals().get("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")

# =========================
# Apply KG weights & clear PPR caches
# =========================
corpus.kg.finalize_edge_weights()
for _cache_name in ("_PPR_UNIFIED_CACHE", "_PPR_COMP_CACHE"):
    if _cache_name in globals():
        try:
            globals()[_cache_name].clear()
        except Exception:
            pass

# =========================
# Helper: PPR Top-K using the tuned alpha
# =========================
def ppr_topk_unified(constraints_series: pd.Series, topk: int, corpus, desc: str = "PPR (unified)", alpha: float = BEST_ALPHA):
    ids = corpus.df["id"].astype(str).tolist()
    out = []
    for C_raw in tqdm(constraints_series.tolist(), total=len(constraints_series), desc=f"{desc} @PPR{topk} (α={alpha:.2f})"):
        v = compute_ppr_prior_vector_unified(corpus, C_raw, alpha=alpha)
        order = np.argsort(-v)[:topk]
        out.append([ids[i] for i in order])
    return out

# =========================
# CE/IMG helpers
# =========================
def pack_pairs(q_texts: list[str], pid_lists: list[list[str]], use_text: bool = True):
    """Build CE (query, doc) pairs for candidate pids; returns (pairs, offsets)."""
    id2text = dict(zip(corpus.df["id"].astype(str), corpus.df["text"].astype(str)))
    pairs, offsets = [], []
    for qi, pids in enumerate(pid_lists):
        offsets.append(len(pairs))
        qt = q_texts[qi]
        for pid in pids:
            dt = id2text.get(pid, "")
            pairs.append([qt, dt] if use_text else [qt, ""])
    return pairs, offsets

@torch.no_grad()
def _gather_img_sims_from_qemb(q_clip: torch.Tensor, pid_lists: list[list[str]]) -> np.ndarray:
    """Return flattened IMG similarities aligned with pid_lists; missing imgs -> 0."""
    img_mat = torch.from_numpy(corpus.img_emb).to(q_clip.device, dtype=torch.float32)
    norms = img_mat.norm(dim=1, keepdim=True) + 1e-12
    img_mat = img_mat / norms
    id2row = {pid: i for i, pid in enumerate(corpus.df["id"].astype(str).tolist())}

    flat = []
    for qi, pids in enumerate(pid_lists):
        qv = q_clip[qi].unsqueeze(0)  # 1 x D
        for pid in pids:
            j = id2row.get(pid, None)
            if j is None:
                flat.append(0.0); continue
            v = float((qv @ img_mat[j:j+1].T).squeeze().cpu().item())
            flat.append(v)
    return np.asarray(flat, dtype=np.float32)

def _build_Xy(ce_scores: np.ndarray, img_sims: np.ndarray,
              pid_lists: list[list[str]], offsets: list[int], true_pids: list[str]):
    """Construct feature matrix (CE, IMG) and labels aligned to flattened candidates."""
    assert ce_scores.shape[0] == img_sims.shape[0], "Feature length mismatch."
    X = np.stack([ce_scores, img_sims], axis=1).astype(np.float32)
    y = []
    cur = 0
    for qi, pids in enumerate(pid_lists):
        K = len(pids); tpid = true_pids[qi]
        for k in range(K):
            y.append(1 if pids[k] == tpid else 0)
        cur += K
    y = np.asarray(y, dtype=np.int32)
    return X, y

def _ranks_from_flat_scores(pid_lists, offsets, flat_scores, true_pids):
    ranks = []
    cur = 0
    for qi, pids in enumerate(pid_lists):
        K = len(pids)
        if K == 0:
            ranks.append(10**9); continue
        sl = flat_scores[cur:cur+K]
        order = np.argsort(-sl)
        ranked = [pids[i] for i in order]
        tpid = true_pids[qi]
        r = next((i+1 for i,p in enumerate(ranked) if p==tpid), len(ranked)+1)
        ranks.append(r)
        cur += K
    return ranks

# Metric fallbacks (use earlier ones if already defined)
try:
    accuracy_at_k
except NameError:
    def accuracy_at_k(ranks, K: int) -> float:
        n = max(1, len(ranks))
        return sum(1 for r in ranks if r <= K) / n

try:
    ndcg_at_k_single
except NameError:
    def ndcg_at_k_single(ranks, K: int) -> float:
        val = 0.0
        for r in ranks:
            if r <= K:
                val += 1.0 / math.log2(r + 1.0)
        n = max(1, len(ranks))
        return val / n

# =========================
# TRAIN_SUB (fit blender)
# =========================
train_sub = train_q.sample(frac=SUB_FRAC, random_state=42).reset_index(drop=True)
print(f"[TRAIN_SUB] Using {len(train_sub)}/{len(train_q)} queries (~{SUB_FRAC:.3f} of TRAIN)")

# PPR candidates with tuned alpha
pids_train = ppr_topk_unified(train_sub["constraints"], TOPK_PPR, corpus=corpus, desc="PPR (TRAIN_SUB)")
pids_val   = ppr_topk_unified(val_q["constraints"],   TOPK_PPR, corpus=corpus, desc="PPR (VAL)")
pids_test  = ppr_topk_unified(test_q["constraints"],  TOPK_PPR, corpus=corpus, desc="PPR (TEST)")

# Cross-Encoder
reranker = CrossEncoder(RERANK_MODEL, device=DEVICE)

# CE pairs
pairs_train, offsets_train = pack_pairs(train_sub["intent_text"].astype(str).tolist(), pids_train, use_text=True)
pairs_val,   offsets_val   = pack_pairs(val_q["intent_text"].astype(str).tolist(),   pids_val,   use_text=True)
pairs_test,  offsets_test  = pack_pairs(test_q["intent_text"].astype(str).tolist(),  pids_test,  use_text=True)

# CE scores (prob of relevance)
ce_scores_train = (np.asarray(reranker.predict(pairs_train, batch_size=CE_BATCH, show_progress_bar=True), dtype=np.float32)
                   if len(pairs_train) else np.zeros((0,), dtype=np.float32))
ce_scores_val   = (np.asarray(reranker.predict(pairs_val,   batch_size=CE_BATCH, show_progress_bar=True), dtype=np.float32)
                   if len(pairs_val) else np.zeros((0,), dtype=np.float32))
ce_scores_test  = (np.asarray(reranker.predict(pairs_test,  batch_size=CE_BATCH, show_progress_bar=True), dtype=np.float32)
                   if len(pairs_test) else np.zeros((0,), dtype=np.float32))

# CLIP text emb for queries (normalize)
q_clip_train = torch.tensor(corpus.enc.encode_text(train_sub["intent_text"].astype(str).tolist()), dtype=torch.float32, device=DEVICE)
q_clip_val   = torch.tensor(corpus.enc.encode_text(val_q["intent_text"].astype(str).tolist()),   dtype=torch.float32, device=DEVICE)
q_clip_test  = torch.tensor(corpus.enc.encode_text(test_q["intent_text"].astype(str).tolist()),  dtype=torch.float32, device=DEVICE)
q_clip_train = q_clip_train / (q_clip_train.norm(dim=1, keepdim=True) + 1e-12)
q_clip_val   = q_clip_val   / (q_clip_val.norm(dim=1, keepdim=True) + 1e-12)
q_clip_test  = q_clip_test  / (q_clip_test.norm(dim=1, keepdim=True) + 1e-12)

# IMG sims aligned to candidate lists
img_sims_train = _gather_img_sims_from_qemb(q_clip_train, pids_train)
img_sims_val   = _gather_img_sims_from_qemb(q_clip_val,   pids_val)
img_sims_test  = _gather_img_sims_from_qemb(q_clip_test,  pids_test)

# Labels
y_true_train = train_sub["pid"].astype(str).tolist()
y_true_val   = val_q["pid"].astype(str).tolist()
y_true_test  = test_q["pid"].astype(str).tolist()

# Design matrices
X_tr, y_tr = _build_Xy(ce_scores_train, img_sims_train, pids_train, offsets_train, y_true_train)
X_va, y_va = _build_Xy(ce_scores_val,   img_sims_val,   pids_val,   offsets_val,   y_true_val)
X_te, y_te = _build_Xy(ce_scores_test,  img_sims_test,  pids_test,  offsets_test,  y_true_test)

print(f"[TRAIN_SUB] blender set: X={X_tr.shape}, positives={int(y_tr.sum())}/{len(y_tr)}")
print(f"[VAL]       blender set: X={X_va.shape}, positives={int(y_va.sum())}/{len(y_va)}")
print(f"[TEST]      candidates : X={X_te.shape}")

# =========================
# Fit blender
# =========================
blender = make_pipeline(
    StandardScaler(with_mean=True, with_std=True),
    LogisticRegression(
        solver=LR_SOLVER,
        class_weight=LR_CLASS_WEIGHT,
        max_iter=LR_MAX_ITER,
        C=LR_C
    )
)
print("Fitting logistic blender on TRAIN_SUB…")
if len(X_tr) and len(np.unique(y_tr)) > 1:
    blender.fit(X_tr, y_tr)
else:
    print("WARNING: TRAIN_SUB has insufficient positives/negatives for LR; skipping fit.")

# Inspect learned weights (after scaling)
if hasattr(blender, "named_steps") and "logisticregression" in blender.named_steps:
    lr = blender.named_steps['logisticregression']
    coef = lr.coef_.ravel(); b = lr.intercept_[0]
    print(f"Learned fusion (after scaling): w_CE={coef[0]:+.3f}, w_IMG={coef[1]:+.3f}, b={b:+.3f}")

# =========================
# Rank & report
# =========================
def _report(tag, pid_lists, offsets, X, y_true):
    scores = blender.predict_proba(X)[:,1].astype(np.float32) if len(X) else np.zeros((0,), dtype=np.float32)
    ranks  = _ranks_from_flat_scores(pid_lists, offsets, scores, y_true)
    print(f"\n[{tag}] Logistic fusion (CE + IMG) @ PPR{TOPK_PPR} (α={BEST_ALPHA:.2f}, λ_attr={LAMBDA_ATTR:.2f}/λ_sim={LAMBDA_SIM:.2f})")
    for K in (1,3,5,10):  print(f"  Accuracy@{K}: {accuracy_at_k(ranks, K):.4f}")
    for K in (1,5,10,50): print(f"  NDCG@{K}:     {ndcg_at_k_single(ranks, K):.4f}")
    for K in (1,5,10,50): print(f"  Recall@{K}:   {accuracy_at_k(ranks, K):.4f}")  # presence proxy from ranks

# If already ran baseline PPR metrics, they live in val_results/test_results:
if 'val_results' in globals():
    print("\n[Baseline] PPR-only metrics already computed in previous cell (see val_results/test_results).")

_report("TRAIN_SUB", pids_train, offsets_train, X_tr, y_true_train)
_report("VAL",       pids_val,   offsets_val,   X_va, y_true_val)
_report("TEST",      pids_test,  offsets_test,  X_te, y_true_test)


[TRAIN_SUB] Using 2190/2190 queries (~1.000 of TRAIN)


PPR (TRAIN_SUB) @PPR50 (α=0.80):   0%|          | 0/2190 [00:00<?, ?it/s]

PPR (VAL) @PPR50 (α=0.80):   0%|          | 0/730 [00:00<?, ?it/s]

PPR (TEST) @PPR50 (α=0.80):   0%|          | 0/735 [00:00<?, ?it/s]

Batches:   0%|          | 0/428 [00:00<?, ?it/s]

Batches:   0%|          | 0/143 [00:00<?, ?it/s]

Batches:   0%|          | 0/144 [00:00<?, ?it/s]

[TRAIN_SUB] blender set: X=(109500, 2), positives=1920/109500
[VAL]       blender set: X=(36500, 2), positives=660/36500
[TEST]      candidates : X=(36750, 2)
Fitting logistic blender on TRAIN_SUB…
Learned fusion (after scaling): w_CE=+0.942, w_IMG=+0.859, b=-1.246

[Baseline] PPR-only metrics already computed in previous cell (see val_results/test_results).

[TRAIN_SUB] Logistic fusion (CE + IMG) @ PPR50 (α=0.80, λ_attr=0.80/λ_sim=0.20)
  Accuracy@1: 0.3712
  Accuracy@3: 0.5502
  Accuracy@5: 0.6306
  Accuracy@10: 0.7347
  NDCG@1:     0.3712
  NDCG@5:     0.5085
  NDCG@10:     0.5423
  NDCG@50:     0.5759
  Recall@1:   0.3712
  Recall@5:   0.6306
  Recall@10:   0.7347
  Recall@50:   0.8767

[VAL] Logistic fusion (CE + IMG) @ PPR50 (α=0.80, λ_attr=0.80/λ_sim=0.20)
  Accuracy@1: 0.3945
  Accuracy@3: 0.5877
  Accuracy@5: 0.6658
  Accuracy@10: 0.7795
  NDCG@1:     0.3945
  NDCG@5:     0.5410
  NDCG@10:     0.5778
  NDCG@50:     0.6073
  Recall@1:   0.3945
  Recall@5:   0.6658
  Recall@10: 

In [20]:
#copy
# Cell 8 — Logistic fusion (CE + IMG) on PPR@50; train on TRAIN_SUB, eval on VAL/TEST

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sentence_transformers import CrossEncoder
from tqdm.auto import tqdm
import numpy as np, torch, pandas as pd, math, random

# =========================
# CONFIG — EDIT HERE
# =========================
TOPK_PPR   = 50             # candidates per query
CE_BATCH   = 256            # CE batch size (shrink if GPU OOM with larger models)
SUB_FRAC   = 1.0           # train blender on 1/8 of TRAIN (set 1.0 for full TRAIN)

# Tuned PPR / KG weights
BEST_ALPHA  = 0.80          # teleport/prob. of following edges
LAMBDA_ATTR = 0.80          # global budget for product<->(brand/attr) edges
LAMBDA_SIM  = 0.20          # global budget for product<->product sim edges

# Optional: fine control of brand seed emphasis (kept equal by default)
# ATTR_KEYS["brand"]["beta"] = 1.0    # e.g., 0.7 to mute, 1.3 to boost

# Cross-encoder choice
RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
# RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
# RERANK_MODEL = "mixedbread-ai/mxbai-rerank-large-v1"

# Logistic regression hyperparams
LR_SOLVER = "liblinear"
LR_C = 1.0
LR_CLASS_WEIGHT = "balanced"
LR_MAX_ITER = 1000

# =========================
# Preconditions
# =========================
assert 'compute_ppr_prior_vector_unified' in globals(), "Need unified PPR helper."
assert 'train_q' in globals() and 'val_q' in globals() and 'test_q' in globals(), "Missing TRAIN/VAL/TEST splits."
assert 'corpus' in globals(), "Need built corpus with encoders and embeddings."

DEVICE = globals().get("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")

# =========================
# Apply KG weights & clear PPR caches
# =========================
corpus.kg.finalize_edge_weights()
for _cache_name in ("_PPR_UNIFIED_CACHE", "_PPR_COMP_CACHE"):
    if _cache_name in globals():
        try:
            globals()[_cache_name].clear()
        except Exception:
            pass

# =========================
# Helper: PPR Top-K using the tuned alpha
# =========================
def ppr_topk_unified(constraints_series: pd.Series, topk: int, corpus, desc: str = "PPR (unified)", alpha: float = BEST_ALPHA):
    ids = corpus.df["id"].astype(str).tolist()
    out = []
    for C_raw in tqdm(constraints_series.tolist(), total=len(constraints_series), desc=f"{desc} @PPR{topk} (α={alpha:.2f})"):
        v = compute_ppr_prior_vector_unified(corpus, C_raw, alpha=alpha)
        order = np.argsort(-v)[:topk]
        out.append([ids[i] for i in order])
    return out

# =========================
# CE/IMG helpers
# =========================
def pack_pairs(q_texts: list[str], pid_lists: list[list[str]], use_text: bool = True):
    """Build CE (query, doc) pairs for candidate pids; returns (pairs, offsets)."""
    id2text = dict(zip(corpus.df["id"].astype(str), corpus.df["text"].astype(str)))
    pairs, offsets = [], []
    for qi, pids in enumerate(pid_lists):
        offsets.append(len(pairs))
        qt = q_texts[qi]
        for pid in pids:
            dt = id2text.get(pid, "")
            pairs.append([qt, dt] if use_text else [qt, ""])
    return pairs, offsets

@torch.no_grad()
def _gather_img_sims_from_qemb(q_clip: torch.Tensor, pid_lists: list[list[str]]) -> np.ndarray:
    """Return flattened IMG similarities aligned with pid_lists; missing imgs -> 0."""
    img_mat = torch.from_numpy(corpus.img_emb).to(q_clip.device, dtype=torch.float32)
    norms = img_mat.norm(dim=1, keepdim=True) + 1e-12
    img_mat = img_mat / norms
    id2row = {pid: i for i, pid in enumerate(corpus.df["id"].astype(str).tolist())}

    flat = []
    for qi, pids in enumerate(pid_lists):
        qv = q_clip[qi].unsqueeze(0)  # 1 x D
        for pid in pids:
            j = id2row.get(pid, None)
            if j is None:
                flat.append(0.0); continue
            v = float((qv @ img_mat[j:j+1].T).squeeze().cpu().item())
            flat.append(v)
    return np.asarray(flat, dtype=np.float32)

def _build_Xy(ce_scores: np.ndarray, img_sims: np.ndarray,
              pid_lists: list[list[str]], offsets: list[int], true_pids: list[str]):
    """Construct feature matrix (CE, IMG) and labels aligned to flattened candidates."""
    assert ce_scores.shape[0] == img_sims.shape[0], "Feature length mismatch."
    X = np.stack([ce_scores, img_sims], axis=1).astype(np.float32)
    y = []
    cur = 0
    for qi, pids in enumerate(pid_lists):
        K = len(pids); tpid = true_pids[qi]
        for k in range(K):
            y.append(1 if pids[k] == tpid else 0)
        cur += K
    y = np.asarray(y, dtype=np.int32)
    return X, y

def _ranks_from_flat_scores(pid_lists, offsets, flat_scores, true_pids):
    ranks = []
    cur = 0
    for qi, pids in enumerate(pid_lists):
        K = len(pids)
        if K == 0:
            ranks.append(10**9); continue
        sl = flat_scores[cur:cur+K]
        order = np.argsort(-sl)
        ranked = [pids[i] for i in order]
        tpid = true_pids[qi]
        r = next((i+1 for i,p in enumerate(ranked) if p==tpid), len(ranked)+1)
        ranks.append(r)
        cur += K
    return ranks

# Metric fallbacks (use earlier ones if already defined)
try:
    accuracy_at_k
except NameError:
    def accuracy_at_k(ranks, K: int) -> float:
        n = max(1, len(ranks))
        return sum(1 for r in ranks if r <= K) / n

try:
    ndcg_at_k_single
except NameError:
    def ndcg_at_k_single(ranks, K: int) -> float:
        val = 0.0
        for r in ranks:
            if r <= K:
                val += 1.0 / math.log2(r + 1.0)
        n = max(1, len(ranks))
        return val / n

# =========================
# TRAIN_SUB (fit blender)
# =========================
train_sub = train_q.sample(frac=SUB_FRAC, random_state=42).reset_index(drop=True)
print(f"[TRAIN_SUB] Using {len(train_sub)}/{len(train_q)} queries (~{SUB_FRAC:.3f} of TRAIN)")

# PPR candidates with tuned alpha
pids_train = ppr_topk_unified(train_sub["constraints"], TOPK_PPR, corpus=corpus, desc="PPR (TRAIN_SUB)")
pids_val   = ppr_topk_unified(val_q["constraints"],   TOPK_PPR, corpus=corpus, desc="PPR (VAL)")
pids_test  = ppr_topk_unified(test_q["constraints"],  TOPK_PPR, corpus=corpus, desc="PPR (TEST)")

# Cross-Encoder
reranker = CrossEncoder(RERANK_MODEL, device=DEVICE)

# CE pairs
pairs_train, offsets_train = pack_pairs(train_sub["intent_text"].astype(str).tolist(), pids_train, use_text=True)
pairs_val,   offsets_val   = pack_pairs(val_q["intent_text"].astype(str).tolist(),   pids_val,   use_text=True)
pairs_test,  offsets_test  = pack_pairs(test_q["intent_text"].astype(str).tolist(),  pids_test,  use_text=True)

# CE scores (prob of relevance)
ce_scores_train = (np.asarray(reranker.predict(pairs_train, batch_size=CE_BATCH, show_progress_bar=True), dtype=np.float32)
                   if len(pairs_train) else np.zeros((0,), dtype=np.float32))
ce_scores_val   = (np.asarray(reranker.predict(pairs_val,   batch_size=CE_BATCH, show_progress_bar=True), dtype=np.float32)
                   if len(pairs_val) else np.zeros((0,), dtype=np.float32))
ce_scores_test  = (np.asarray(reranker.predict(pairs_test,  batch_size=CE_BATCH, show_progress_bar=True), dtype=np.float32)
                   if len(pairs_test) else np.zeros((0,), dtype=np.float32))

# CLIP text emb for queries (normalize)
q_clip_train = torch.tensor(corpus.enc.encode_text(train_sub["intent_text"].astype(str).tolist()), dtype=torch.float32, device=DEVICE)
q_clip_val   = torch.tensor(corpus.enc.encode_text(val_q["intent_text"].astype(str).tolist()),   dtype=torch.float32, device=DEVICE)
q_clip_test  = torch.tensor(corpus.enc.encode_text(test_q["intent_text"].astype(str).tolist()),  dtype=torch.float32, device=DEVICE)
q_clip_train = q_clip_train / (q_clip_train.norm(dim=1, keepdim=True) + 1e-12)
q_clip_val   = q_clip_val   / (q_clip_val.norm(dim=1, keepdim=True) + 1e-12)
q_clip_test  = q_clip_test  / (q_clip_test.norm(dim=1, keepdim=True) + 1e-12)

# IMG sims aligned to candidate lists
img_sims_train = _gather_img_sims_from_qemb(q_clip_train, pids_train)
img_sims_val   = _gather_img_sims_from_qemb(q_clip_val,   pids_val)
img_sims_test  = _gather_img_sims_from_qemb(q_clip_test,  pids_test)

# Labels
y_true_train = train_sub["pid"].astype(str).tolist()
y_true_val   = val_q["pid"].astype(str).tolist()
y_true_test  = test_q["pid"].astype(str).tolist()

# Design matrices
X_tr, y_tr = _build_Xy(ce_scores_train, img_sims_train, pids_train, offsets_train, y_true_train)
X_va, y_va = _build_Xy(ce_scores_val,   img_sims_val,   pids_val,   offsets_val,   y_true_val)
X_te, y_te = _build_Xy(ce_scores_test,  img_sims_test,  pids_test,  offsets_test,  y_true_test)

print(f"[TRAIN_SUB] blender set: X={X_tr.shape}, positives={int(y_tr.sum())}/{len(y_tr)}")
print(f"[VAL]       blender set: X={X_va.shape}, positives={int(y_va.sum())}/{len(y_va)}")
print(f"[TEST]      candidates : X={X_te.shape}")

# =========================
# Fit blender
# =========================
blender = make_pipeline(
    StandardScaler(with_mean=True, with_std=True),
    LogisticRegression(
        solver=LR_SOLVER,
        class_weight=LR_CLASS_WEIGHT,
        max_iter=LR_MAX_ITER,
        C=LR_C
    )
)
print("Fitting logistic blender on TRAIN_SUB…")
if len(X_tr) and len(np.unique(y_tr)) > 1:
    blender.fit(X_tr, y_tr)
else:
    print("WARNING: TRAIN_SUB has insufficient positives/negatives for LR; skipping fit.")

# Inspect learned weights (after scaling)
if hasattr(blender, "named_steps") and "logisticregression" in blender.named_steps:
    lr = blender.named_steps['logisticregression']
    coef = lr.coef_.ravel(); b = lr.intercept_[0]
    print(f"Learned fusion (after scaling): w_CE={coef[0]:+.3f}, w_IMG={coef[1]:+.3f}, b={b:+.3f}")

# =========================
# Rank & report
# =========================
def _report(tag, pid_lists, offsets, X, y_true):
    scores = blender.predict_proba(X)[:,1].astype(np.float32) if len(X) else np.zeros((0,), dtype=np.float32)
    ranks  = _ranks_from_flat_scores(pid_lists, offsets, scores, y_true)
    print(f"\n[{tag}] Logistic fusion (CE + IMG) @ PPR{TOPK_PPR} (α={BEST_ALPHA:.2f}, λ_attr={LAMBDA_ATTR:.2f}/λ_sim={LAMBDA_SIM:.2f})")
    for K in (1,3,5,10):  print(f"  Accuracy@{K}: {accuracy_at_k(ranks, K):.4f}")
    for K in (1,5,10,50): print(f"  NDCG@{K}:     {ndcg_at_k_single(ranks, K):.4f}")
    for K in (1,5,10,50): print(f"  Recall@{K}:   {accuracy_at_k(ranks, K):.4f}")  # presence proxy from ranks

# If already ran baseline PPR metrics, they live in val_results/test_results:
if 'val_results' in globals():
    print("\n[Baseline] PPR-only metrics already computed in previous cell (see val_results/test_results).")

_report("TRAIN_SUB", pids_train, offsets_train, X_tr, y_true_train)
_report("VAL",       pids_val,   offsets_val,   X_va, y_true_val)
_report("TEST",      pids_test,  offsets_test,  X_te, y_true_test)


[TRAIN_SUB] Using 2190/2190 queries (~1.000 of TRAIN)


PPR (TRAIN_SUB) @PPR50 (α=0.80):   0%|          | 0/2190 [00:00<?, ?it/s]

PPR (VAL) @PPR50 (α=0.80):   0%|          | 0/730 [00:00<?, ?it/s]

PPR (TEST) @PPR50 (α=0.80):   0%|          | 0/735 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

Batches:   0%|          | 0/428 [00:00<?, ?it/s]

Batches:   0%|          | 0/143 [00:00<?, ?it/s]

Batches:   0%|          | 0/144 [00:00<?, ?it/s]

[TRAIN_SUB] blender set: X=(109500, 2), positives=1920/109500
[VAL]       blender set: X=(36500, 2), positives=660/36500
[TEST]      candidates : X=(36750, 2)
Fitting logistic blender on TRAIN_SUB…
Learned fusion (after scaling): w_CE=+0.942, w_IMG=+0.859, b=-1.246

[TRAIN_SUB] Logistic fusion (CE + IMG) @ PPR50 (α=0.80, λ_attr=0.80/λ_sim=0.20)
  Accuracy@1: 0.3712
  Accuracy@3: 0.5502
  Accuracy@5: 0.6306
  Accuracy@10: 0.7347
  NDCG@1:     0.3712
  NDCG@5:     0.5085
  NDCG@10:     0.5423
  NDCG@50:     0.5759
  Recall@1:   0.3712
  Recall@5:   0.6306
  Recall@10:   0.7347
  Recall@50:   0.8767

[VAL] Logistic fusion (CE + IMG) @ PPR50 (α=0.80, λ_attr=0.80/λ_sim=0.20)
  Accuracy@1: 0.3945
  Accuracy@3: 0.5877
  Accuracy@5: 0.6658
  Accuracy@10: 0.7795
  NDCG@1:     0.3945
  NDCG@5:     0.5410
  NDCG@10:     0.5778
  NDCG@50:     0.6073
  Recall@1:   0.3945
  Recall@5:   0.6658
  Recall@10:   0.7795
  Recall@50:   0.9041

[TEST] Logistic fusion (CE + IMG) @ PPR50 (α=0.80, λ_attr=0.80/

In [None]:
# Cell 8b — CE-only comparison (no IMG blend), reusing existing legacies

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
import numpy as np, math

# --------- Preconditions (must be defined by the previous cell) ----------
req = [
    'pids_train','offsets_train','ce_scores_train','y_true_train','train_sub',
    'pids_val','offsets_val','ce_scores_val','y_true_val',
    'pids_test','offsets_test','ce_scores_test','y_true_test',
    '_ranks_from_flat_scores'
]
missing = [r for r in req if r not in globals()]
assert not missing, f"Missing from previous cell: {missing}"

# Metric fallbacks (use already-loaded ones if present)
try:
    accuracy_at_k
except NameError:
    def accuracy_at_k(ranks, K: int) -> float:
        n = max(1, len(ranks))
        return sum(1 for r in ranks if r <= K) / n

try:
    ndcg_at_k_single
except NameError:
    def ndcg_at_k_single(ranks, K: int) -> float:
        vals = [(1.0 / math.log2(1 + r)) if r <= K else 0.0 for r in ranks]
        return float(np.mean(vals)) if ranks else 0.0

# -------------------- RAW-CE (no learning; rank by CE score) --------------------
print("[RAW-CE] Ranking with raw Cross-Encoder scores only (no image, no LR)")

ranks_tr_raw = _ranks_from_flat_scores(pids_train, offsets_train, ce_scores_train, y_true_train)
ranks_va_raw = _ranks_from_flat_scores(pids_val,   offsets_val,   ce_scores_val,   y_true_val)
ranks_te_raw = _ranks_from_flat_scores(pids_test,  offsets_test,  ce_scores_test,  y_true_test)

def _report(tag, ranks):
    print(f"\n[{tag} • RAW-CE] CE-only")
    for K in (1,3,5,10):  print(f"  Accuracy@{K}: {accuracy_at_k(ranks, K):.4f}")
    for K in (1,5,10,50): print(f"  NDCG@{K}:     {ndcg_at_k_single(ranks, K):.4f}")
    for K in (1,5,10,50): print(f"  Recall@{K}:   {accuracy_at_k(ranks, K):.4f}")  # presence proxy via ranks

_report("TRAIN_SUB", ranks_tr_raw)
_report("VAL",       ranks_va_raw)
_report("TEST",      ranks_te_raw)

# -------------------- LR(CE) — logistic on CE feature only --------------------
print("\n[LR(CE)] Logistic regression on CE-only feature")

# Build X,y using CE-only (reshape to Nx1)
X_tr_ce = ce_scores_train.reshape(-1, 1).astype(np.float32)
X_va_ce = ce_scores_val.reshape(-1, 1).astype(np.float32)
X_te_ce = ce_scores_test.reshape(-1, 1).astype(np.float32)

# Labels aligned to flattened candidates
def _build_y(pid_lists, true_pids):
    y = []
    for qi, pids in enumerate(pid_lists):
        tpid = true_pids[qi]
        y.extend([1 if p==tpid else 0 for p in pids])
    return np.asarray(y, dtype=np.int32)

y_tr_ce = _build_y(pids_train, y_true_train)
y_va_ce = _build_y(pids_val,   y_true_val)
y_te_ce = _build_y(pids_test,  y_true_test)

print(f"[CE-ONLY] TRAIN_SUB: X={X_tr_ce.shape}, positives={int(y_tr_ce.sum())}/{len(y_tr_ce)}")
print(f"[CE-ONLY] VAL      : X={X_va_ce.shape}, positives={int(y_va_ce.sum())}/{len(y_va_ce)}")
print(f"[CE-ONLY] TEST     : X={X_te_ce.shape}")

blender_ce = make_pipeline(
    StandardScaler(with_mean=True, with_std=True),
    LogisticRegression(
        solver="liblinear",
        class_weight="balanced",
        max_iter=1000,
        C=1.0
    )
)
print("Fitting LR(CE) on TRAIN_SUB…")
if len(X_tr_ce) and len(np.unique(y_tr_ce)) > 1:
    blender_ce.fit(X_tr_ce, y_tr_ce)
    lr = blender_ce.named_steps['logisticregression']
    print(f"LR(CE) weights (after scaling): w_CE={lr.coef_.ravel()[0]:+.3f}, b={lr.intercept_[0]:+.3f}")
else:
    print("WARNING: TRAIN_SUB has insufficient positives/negatives for LR; skipping fit.")

# Rank with LR(CE) scores
scores_tr_ce = (blender_ce.predict_proba(X_tr_ce)[:,1].astype(np.float32)
                if len(X_tr_ce) and len(np.unique(y_tr_ce))>1 else ce_scores_train.astype(np.float32))
scores_va_ce = (blender_ce.predict_proba(X_va_ce)[:,1].astype(np.float32)
                if len(X_va_ce) and len(np.unique(y_tr_ce))>1 else ce_scores_val.astype(np.float32))
scores_te_ce = (blender_ce.predict_proba(X_te_ce)[:,1].astype(np.float32)
                if len(X_te_ce) and len(np.unique(y_tr_ce))>1 else ce_scores_test.astype(np.float32))

ranks_tr_lrce = _ranks_from_flat_scores(pids_train, offsets_train, scores_tr_ce, y_true_train)
ranks_va_lrce = _ranks_from_flat_scores(pids_val,   offsets_val,   scores_va_ce, y_true_val)
ranks_te_lrce = _ranks_from_flat_scores(pids_test,  offsets_test,  scores_te_ce, y_true_test)

def _report_lr(tag, ranks):
    print(f"\n[{tag} • LR(CE)] CE-only")
    for K in (1,3,5,10):  print(f"  Accuracy@{K}: {accuracy_at_k(ranks, K):.4f}")
    for K in (1,5,10,50): print(f"  NDCG@{K}:     {ndcg_at_k_single(ranks, K):.4f}")
    for K in (1,5,10,50): print(f"  Recall@{K}:   {accuracy_at_k(ranks, K):.4f}")

_report_lr("TRAIN_SUB", ranks_tr_lrce)
_report_lr("VAL",       ranks_va_lrce)
_report_lr("TEST",      ranks_te_lrce)


[RAW-CE] Ranking with raw Cross-Encoder scores only (no image, no LR)

[TRAIN_SUB • RAW-CE] CE-only
  Accuracy@1: 0.3942
  Accuracy@3: 0.5584
  Accuracy@5: 0.6423
  Accuracy@10: 0.7409
  NDCG@1:     0.3942
  NDCG@5:     0.5253
  NDCG@10:     0.5569
  NDCG@50:     0.5930
  Recall@1:   0.3942
  Recall@5:   0.6423
  Recall@10:   0.7409
  Recall@50:   0.8942

[VAL • RAW-CE] CE-only
  Accuracy@1: 0.4151
  Accuracy@3: 0.6000
  Accuracy@5: 0.6795
  Accuracy@10: 0.7658
  NDCG@1:     0.4151
  NDCG@5:     0.5574
  NDCG@10:     0.5853
  NDCG@50:     0.6172
  Recall@1:   0.4151
  Recall@5:   0.6795
  Recall@10:   0.7658
  Recall@50:   0.9041

[TEST • RAW-CE] CE-only
  Accuracy@1: 0.4163
  Accuracy@3: 0.5728
  Accuracy@5: 0.6680
  Accuracy@10: 0.7769
  NDCG@1:     0.4163
  NDCG@5:     0.5490
  NDCG@10:     0.5843
  NDCG@50:     0.6179
  Recall@1:   0.4163
  Recall@5:   0.6680
  Recall@10:   0.7769
  Recall@50:   0.9184

[LR(CE)] Logistic regression on CE-only feature
[CE-ONLY] TRAIN_SUB: X=(13700, 

In [None]:
# Cell 8c — CE vs Fusion per-query comparison (VAL & TEST) + Recall deltas

import numpy as np
import math

# ---- Preconditions (produced by prior cells) ----
need = [
    "pids_val","offsets_val","ce_scores_val","X_va","y_true_val",
    "pids_test","offsets_test","ce_scores_test","X_te","y_true_test",
    "blender"
]
missing = [n for n in need if n not in globals()]
assert not missing, f"Missing: {missing}. Run the CE+IMG fusion cell first."

# ---- Helper fallbacks (reuse if already defined) ----
try:
    _ranks_from_flat_scores
except NameError:
    def _ranks_from_flat_scores(pid_lists, offsets, flat_scores, true_pids):
        ranks, cur = [], 0
        for qi, pids in enumerate(pid_lists):
            K = len(pids)
            if K == 0:
                ranks.append(10**9); continue
            sl = flat_scores[cur:cur+K]
            order = np.argsort(-sl)
            ranked = [pids[i] for i in order]
            tpid = true_pids[qi]
            r = next((i+1 for i,p in enumerate(ranked) if p==tpid), len(ranked)+1)
            ranks.append(r)
            cur += K
        return ranks

try:
    accuracy_at_k
except NameError:
    def accuracy_at_k(ranks, K: int) -> float:
        n = max(1, len(ranks))
        return sum(1 for r in ranks if r <= K) / n

try:
    ndcg_at_k_single
except NameError:
    def ndcg_at_k_single(ranks, K: int) -> float:
        n = max(1, len(ranks))
        vals = [(1.0/np.log2(1+r)) if r <= K else 0.0 for r in ranks]
        return float(np.mean(vals)) if n else 0.0

# Recall@K == Acc@K in single-relevant-item setting, but we keep a named helper for clarity
def recall_at_k_from_ranks(ranks, K: int) -> float:
    return accuracy_at_k(ranks, K)

# ---- Local helpers ----
def _finite(x):
    return [r for r in x if r < 10**9]

def _rank_summary(name, ranks):
    fr = _finite(ranks)
    if not fr:
        print(f"  {name}: no finite ranks")
        return
    arr = np.asarray(fr, dtype=np.float32)
    print(f"  {name}: mean={arr.mean():.2f}, median={np.median(arr):.1f}, p25={np.percentile(arr,25):.1f}, p75={np.percentile(arr,75):.1f}")

def _rescued_lost_counts(ranks_ce, ranks_fuse, Ks=(1,5,10,50)):
    out = []
    for K in Ks:
        rescued = sum(int((rc > K) and (rf <= K)) for rc, rf in zip(ranks_ce, ranks_fuse))
        lost    = sum(int((rc <= K) and (rf > K)) for rc, rf in zip(ranks_ce, ranks_fuse))
        out.append((K, rescued, lost))
    return out

def _compare_one(split_tag, pids, offsets, ce_scores, X, y_true,
                 ks_acc=(1,3,5,10), ks_ndcg=(5,10,50), ks_recall=(1,5,10,50)):
    # CE-only ranks
    ranks_ce = _ranks_from_flat_scores(pids, offsets, ce_scores, y_true)

    # Fusion ranks (LR on [CE, IMG])
    fuse_scores = blender.predict_proba(X)[:,1].astype(np.float32) if len(X) else np.zeros((0,), np.float32)
    ranks_fuse = _ranks_from_flat_scores(pids, offsets, fuse_scores, y_true)

    n = len(ranks_ce)
    improved = sum(int(rf < rc) for rc, rf in zip(ranks_ce, ranks_fuse))
    hurt     = sum(int(rf > rc) for rc, rf in zip(ranks_ce, ranks_fuse))
    tie      = n - improved - hurt

    print(f"\n[{split_tag}] CE vs Fusion (per-query deltas)")
    print(f"  Improved: {improved} ({improved/n:.1%})   Hurt: {hurt} ({hurt/n:.1%})   Tie: {tie} ({tie/n:.1%})")

    # Rescue/Loss counts at multiple K
    for K, rescued, lost in _rescued_lost_counts(ranks_ce, ranks_fuse, Ks=ks_recall):
        print(f"  Rescued@{K}: {rescued:4d}   Lost@{K}: {lost:4d}")

    # Accuracy deltas
    for K in ks_acc:
        acc_ce, acc_fu = accuracy_at_k(ranks_ce, K), accuracy_at_k(ranks_fuse, K)
        print(f"  Acc@{K}: CE={acc_ce:.4f} → Fuse={acc_fu:.4f} (Δ={acc_fu-acc_ce:+.4f})")

    # Recall deltas (same as Acc@K in this single-label setting, printed explicitly)
    for K in ks_recall:
        rec_ce, rec_fu = recall_at_k_from_ranks(ranks_ce, K), recall_at_k_from_ranks(ranks_fuse, K)
        hits_ce = sum(int(r <= K) for r in ranks_ce)
        hits_fu = sum(int(r <= K) for r in ranks_fuse)
        print(f"  Recall@{K}: CE={rec_ce:.4f} ({hits_ce}/{n}) → Fuse={rec_fu:.4f} ({hits_fu}/{n}) (Δ={rec_fu-rec_ce:+.4f})")

    # NDCG deltas
    for K in ks_ndcg:
        nd_ce, nd_fu = ndcg_at_k_single(ranks_ce, K), ndcg_at_k_single(ranks_fuse, K)
        print(f"  NDCG@{K}: CE={nd_ce:.4f} → Fuse={nd_fu:.4f} (Δ={nd_fu-nd_ce:+.4f})")

    # Rank summaries
    _rank_summary("Ranks (CE)", ranks_ce)
    _rank_summary("Ranks (Fusion)", ranks_fuse)

# ---- Run for VAL and TEST ----
_compare_one("VAL",  pids_val,  offsets_val,  ce_scores_val,  X_va, y_true_val)
_compare_one("TEST", pids_test, offsets_test, ce_scores_test, X_te, y_true_test)



[VAL] CE vs Fusion (per-query deltas)
  Improved: 187 (25.6%)   Hurt: 174 (23.8%)   Tie: 369 (50.5%)
  Rescued@1:   49   Lost@1:   64
  Rescued@5:   36   Lost@5:   46
  Rescued@10:   40   Lost@10:   30
  Rescued@50:    0   Lost@50:    0
  Acc@1: CE=0.4151 → Fuse=0.3945 (Δ=-0.0205)
  Acc@3: CE=0.6000 → Fuse=0.5877 (Δ=-0.0123)
  Acc@5: CE=0.6795 → Fuse=0.6658 (Δ=-0.0137)
  Acc@10: CE=0.7658 → Fuse=0.7795 (Δ=+0.0137)
  Recall@1: CE=0.4151 (303/730) → Fuse=0.3945 (288/730) (Δ=-0.0205)
  Recall@5: CE=0.6795 (496/730) → Fuse=0.6658 (486/730) (Δ=-0.0137)
  Recall@10: CE=0.7658 (559/730) → Fuse=0.7795 (569/730) (Δ=+0.0137)
  Recall@50: CE=0.9041 (660/730) → Fuse=0.9041 (660/730) (Δ=+0.0000)
  NDCG@5: CE=0.5574 → Fuse=0.5410 (Δ=-0.0163)
  NDCG@10: CE=0.5853 → Fuse=0.5778 (Δ=-0.0074)
  NDCG@50: CE=0.6172 → Fuse=0.6073 (Δ=-0.0099)
  Ranks (CE): mean=9.77, median=2.0, p25=1.0, p75=9.0
  Ranks (Fusion): mean=9.41, median=2.0, p25=1.0, p75=9.0

[TEST] CE vs Fusion (per-query deltas)
  Improved: 184

In [None]:
# Cell — E5-only and IMG-only rerankers on PPR@K candidates (keep other params same)

import numpy as np, math, torch

# ---------- Preconditions ----------
need = [
    "corpus", "val_q", "test_q",
    "pids_val", "offsets_val", "y_true_val",
    "pids_test","offsets_test","y_true_test"
]
missing = [n for n in need if n not in globals()]
assert not missing, f"Missing: {missing}. Run your PPR candidate generation & splits first."

# ---------- Metric helpers (reuse if already defined) ----------
try:
    _ranks_from_flat_scores
except NameError:
    def _ranks_from_flat_scores(pid_lists, offsets, flat_scores, true_pids):
        ranks, cur = [], 0
        for qi, pids in enumerate(pid_lists):
            K = len(pids)
            if K == 0:
                ranks.append(10**9); continue
            sl = flat_scores[cur:cur+K]
            order = np.argsort(-sl)
            ranked = [pids[i] for i in order]
            tpid = true_pids[qi]
            r = next((i+1 for i,p in enumerate(ranked) if p==tpid), len(ranked)+1)
            ranks.append(r)
            cur += K
        return ranks

try:
    accuracy_at_k
except NameError:
    def accuracy_at_k(ranks, K: int) -> float:
        n = max(1, len(ranks))
        return sum(1 for r in ranks if r <= K) / n

try:
    ndcg_at_k_single
except NameError:
    def ndcg_at_k_single(ranks, K: int) -> float:
        vals = [(1.0/np.log2(1+r)) if r <= K else 0.0 for r in ranks]
        return float(np.mean(vals)) if ranks else 0.0

def _report(tag, ranks):
    print(f"\n[{tag}] Reranker-only (no CE, no fusion)")
    for K in (1,3,5,10):  print(f"  Accuracy@{K}: {accuracy_at_k(ranks, K):.4f}")
    for K in (1,5,10,50): print(f"  NDCG@{K}:     {ndcg_at_k_single(ranks, K):.4f}")
    for K in (1,5,10,50):
        rec = accuracy_at_k(ranks, K)  # single-label → Recall@K == Acc@K
        print(f"  Recall@{K}:   {rec:.4f}")

# ---------- E5-only reranker ----------
# Use the *same* item E5 embedding style as built in the corpus (no special prefixes)
E5_BATCH = 256
id2row = {pid: i for i, pid in enumerate(corpus.df["id"].astype(str).tolist())}
e5_items = corpus.e5_emb.astype(np.float32)  # already L2-normalized at build time

def _encode_e5(texts):
    return corpus.e5.encode(texts, batch_size=E5_BATCH, normalize_embeddings=True, convert_to_numpy=True).astype(np.float32)

def _gather_e5_sims_from_qemb(q_e5: np.ndarray, pid_lists: list[list[str]]) -> np.ndarray:
    flat = []
    for qi, pids in enumerate(pid_lists):
        qv = q_e5[qi]
        for pid in pids:
            j = id2row.get(pid, None)
            if j is None:
                flat.append(0.0); continue
            flat.append(float(np.dot(qv, e5_items[j])))
    return np.asarray(flat, dtype=np.float32)

# Encode queries (VAL/TEST)
q_e5_val  = _encode_e5(val_q["intent_text"].astype(str).tolist())
q_e5_test = _encode_e5(test_q["intent_text"].astype(str).tolist())

# Flattened E5 scores aligned to candidate lists
e5_scores_val  = _gather_e5_sims_from_qemb(q_e5_val,  pids_val)
e5_scores_test = _gather_e5_sims_from_qemb(q_e5_test, pids_test)

# Ranks
ranks_val_e5  = _ranks_from_flat_scores(pids_val,  offsets_val,  e5_scores_val,  y_true_val)
ranks_test_e5 = _ranks_from_flat_scores(pids_test, offsets_test, e5_scores_test, y_true_test)

print("=== E5-only (cosine on SentenceTransformer embeddings) ===")
_report("VAL",  ranks_val_e5)
_report("TEST", ranks_test_e5)

# ---------- IMG-only reranker (CLIP text–image cosine) ----------
# Use existing helper & encodings if available; otherwise compute q_clip_* now.
try:
    q_clip_val
    q_clip_test
except NameError:
    q_clip_val  = torch.tensor(corpus.enc.encode_text(val_q["intent_text"].astype(str).tolist()),
                               dtype=torch.float32, device=corpus.enc.device)
    q_clip_test = torch.tensor(corpus.enc.encode_text(test_q["intent_text"].astype(str).tolist()),
                               dtype=torch.float32, device=corpus.enc.device)
    q_clip_val  = q_clip_val / (q_clip_val.norm(dim=1, keepdim=True) + 1e-12)
    q_clip_test = q_clip_test / (q_clip_test.norm(dim=1, keepdim=True) + 1e-12)

# If helper exists, reuse; else define it.
try:
    _gather_img_sims_from_qemb
except NameError:
    @torch.no_grad()
    def _gather_img_sims_from_qemb(q_clip: torch.Tensor, pid_lists: list[list[str]]) -> np.ndarray:
        img_mat = torch.from_numpy(corpus.img_emb).to(q_clip.device, dtype=torch.float32)
        img_mat = img_mat / (img_mat.norm(dim=1, keepdim=True) + 1e-12)
        id2row = {pid: i for i, pid in enumerate(corpus.df["id"].astype(str).tolist())}
        flat = []
        for qi, pids in enumerate(pid_lists):
            qv = q_clip[qi].unsqueeze(0)  # 1 x D
            for pid in pids:
                j = id2row.get(pid, None)
                if j is None:
                    flat.append(0.0); continue
                v = float((qv @ img_mat[j:j+1].T).squeeze().cpu().item())
                flat.append(v)
        return np.asarray(flat, dtype=np.float32)

img_scores_val  = _gather_img_sims_from_qemb(q_clip_val,  pids_val)
img_scores_test = _gather_img_sims_from_qemb(q_clip_test, pids_test)

ranks_val_img  = _ranks_from_flat_scores(pids_val,  offsets_val,  img_scores_val,  y_true_val)
ranks_test_img = _ranks_from_flat_scores(pids_test, offsets_test, img_scores_test, y_true_test)

print("\n=== IMG-only (cosine between CLIP text query and catalog image embeddings) ===")
_report("VAL",  ranks_val_img)
_report("TEST", ranks_test_img)


=== E5-only (cosine on SentenceTransformer embeddings) ===

[VAL] Reranker-only (no CE, no fusion)
  Accuracy@1: 0.3589
  Accuracy@3: 0.5438
  Accuracy@5: 0.6342
  Accuracy@10: 0.7507
  NDCG@1:     0.3589
  NDCG@5:     0.5056
  NDCG@10:     0.5431
  NDCG@50:     0.5787
  Recall@1:   0.3589
  Recall@5:   0.6342
  Recall@10:   0.7507
  Recall@50:   0.9041

[TEST] Reranker-only (no CE, no fusion)
  Accuracy@1: 0.3633
  Accuracy@3: 0.5741
  Accuracy@5: 0.6748
  Accuracy@10: 0.7782
  NDCG@1:     0.3633
  NDCG@5:     0.5280
  NDCG@10:     0.5613
  NDCG@50:     0.5945
  Recall@1:   0.3633
  Recall@5:   0.6748
  Recall@10:   0.7782
  Recall@50:   0.9184

=== IMG-only (cosine between CLIP text query and catalog image embeddings) ===

[VAL] Reranker-only (no CE, no fusion)
  Accuracy@1: 0.2000
  Accuracy@3: 0.3589
  Accuracy@5: 0.4658
  Accuracy@10: 0.6082
  NDCG@1:     0.2000
  NDCG@5:     0.3357
  NDCG@10:     0.3822
  NDCG@50:     0.4494
  Recall@1:   0.2000
  Recall@5:   0.4658
  Recall@10: 

In [None]:
# Cell — CLIP-Text-only and CLIP-IMG-only rerankers on existing PPR@K candidates

import numpy as np, math, torch

# ---------- Preconditions ----------
need = [
    "corpus", "val_q", "test_q",
    "pids_val", "offsets_val", "y_true_val",
    "pids_test","offsets_test","y_true_test"
]
missing = [n for n in need if n not in globals()]
assert not missing, f"Missing: {missing}. Run your PPR candidate generation & splits first."

DEVICE_ = corpus.enc.device if hasattr(corpus, "enc") else ("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Metric helpers (reuse if already defined) ----------
try:
    _ranks_from_flat_scores
except NameError:
    def _ranks_from_flat_scores(pid_lists, offsets, flat_scores, true_pids):
        ranks, cur = [], 0
        for qi, pids in enumerate(pid_lists):
            K = len(pids)
            if K == 0:
                ranks.append(10**9); continue
            sl = flat_scores[cur:cur+K]
            order = np.argsort(-sl)
            ranked = [pids[i] for i in order]
            tpid = true_pids[qi]
            r = next((i+1 for i,p in enumerate(ranked) if p==tpid), len(ranked)+1)
            ranks.append(r)
            cur += K
        return ranks

try:
    accuracy_at_k
except NameError:
    def accuracy_at_k(ranks, K: int) -> float:
        n = max(1, len(ranks))
        return sum(1 for r in ranks if r <= K) / n

try:
    ndcg_at_k_single
except NameError:
    def ndcg_at_k_single(ranks, K: int) -> float:
        vals = [(1.0/np.log2(1+r)) if r <= K else 0.0 for r in ranks]
        return float(np.mean(vals)) if ranks else 0.0

def _report(tag, ranks):
    print(f"\n[{tag}] Reranker-only (no CE, no fusion)")
    for K in (1,3,5,10):  print(f"  Accuracy@{K}: {accuracy_at_k(ranks, K):.4f}")
    for K in (1,5,10,50): print(f"  NDCG@{K}:     {ndcg_at_k_single(ranks, K):.4f}")
    for K in (1,5,10,50):
        rec = accuracy_at_k(ranks, K)  # single-label → Recall@K == Acc@K
        print(f"  Recall@{K}:   {rec:.4f}")

# ---------- Shared lookups ----------
id2row = {pid: i for i, pid in enumerate(corpus.df["id"].astype(str).tolist())}

# ===================== CLIP TEXT-ONLY RERANKER =====================
# Use CLIP text encoder for BOTH the query and the product "text blob".
# Catalog text embeddings already exist: corpus.text_emb (L2-normalized in build step).
clip_text_items = torch.from_numpy(corpus.text_emb).to(DEVICE_, dtype=torch.float32)
clip_text_items = clip_text_items / (clip_text_items.norm(dim=1, keepdim=True) + 1e-12)

# Encode queries (VAL/TEST) with CLIP text encoder; reuse if available
try:
    q_clip_text_val
    q_clip_text_test
except NameError:
    q_clip_text_val  = torch.tensor(
        corpus.enc.encode_text(val_q["intent_text"].astype(str).tolist()),
        dtype=torch.float32, device=DEVICE_
    )
    q_clip_text_test = torch.tensor(
        corpus.enc.encode_text(test_q["intent_text"].astype(str).tolist()),
        dtype=torch.float32, device=DEVICE_
    )
    q_clip_text_val  = q_clip_text_val / (q_clip_text_val.norm(dim=1, keepdim=True) + 1e-12)
    q_clip_text_test = q_clip_text_test / (q_clip_text_test.norm(dim=1, keepdim=True) + 1e-12)

@torch.no_grad()
def _gather_clip_text_sims_from_qemb(q_clip_text: torch.Tensor, pid_lists: list[list[str]]) -> np.ndarray:
    flat = []
    for qi, pids in enumerate(pid_lists):
        qv = q_clip_text[qi].unsqueeze(0)  # 1 x D
        for pid in pids:
            j = id2row.get(pid, None)
            if j is None:
                flat.append(0.0); continue
            sim = float((qv @ clip_text_items[j:j+1].T).squeeze().detach().cpu().item())
            flat.append(sim)
    return np.asarray(flat, dtype=np.float32)

clip_text_scores_val  = _gather_clip_text_sims_from_qemb(q_clip_text_val,  pids_val)
clip_text_scores_test = _gather_clip_text_sims_from_qemb(q_clip_text_test, pids_test)

ranks_val_clip_text  = _ranks_from_flat_scores(pids_val,  offsets_val,  clip_text_scores_val,  y_true_val)
ranks_test_clip_text = _ranks_from_flat_scores(pids_test, offsets_test, clip_text_scores_test, y_true_test)

print("=== CLIP-Text-only (cosine: CLIP text query ↔ CLIP text product) ===")
_report("VAL",  ranks_val_clip_text)
_report("TEST", ranks_test_clip_text)

# ===================== CLIP IMAGE-ONLY RERANKER =====================
# Cosine between CLIP text query and catalog image embeddings.
# Reuse q_clip_text_* above (same CLIP text encoder), and corpus.img_emb (already L2-normalized in build).
clip_img_items = torch.from_numpy(corpus.img_emb).to(DEVICE_, dtype=torch.float32)
clip_img_items = clip_img_items / (clip_img_items.norm(dim=1, keepdim=True) + 1e-12)

@torch.no_grad()
def _gather_clip_img_sims_from_qemb(q_clip_text: torch.Tensor, pid_lists: list[list[str]]) -> np.ndarray:
    flat = []
    for qi, pids in enumerate(pid_lists):
        qv = q_clip_text[qi].unsqueeze(0)  # 1 x D
        for pid in pids:
            j = id2row.get(pid, None)
            if j is None:
                flat.append(0.0); continue
            sim = float((qv @ clip_img_items[j:j+1].T).squeeze().detach().cpu().item())
            flat.append(sim)
    return np.asarray(flat, dtype=np.float32)

clip_img_scores_val  = _gather_clip_img_sims_from_qemb(q_clip_text_val,  pids_val)
clip_img_scores_test = _gather_clip_img_sims_from_qemb(q_clip_text_test, pids_test)

ranks_val_clip_img  = _ranks_from_flat_scores(pids_val,  offsets_val,  clip_img_scores_val,  y_true_val)
ranks_test_clip_img = _ranks_from_flat_scores(pids_test, offsets_test, clip_img_scores_test, y_true_test)

print("\n=== CLIP-IMG-only (cosine: CLIP text query ↔ CLIP image product) ===")
_report("VAL",  ranks_val_clip_img)
_report("TEST", ranks_test_clip_img)


=== CLIP-Text-only (cosine: CLIP text query ↔ CLIP text product) ===

[VAL] Reranker-only (no CE, no fusion)
  Accuracy@1: 0.1466
  Accuracy@3: 0.2986
  Accuracy@5: 0.3959
  Accuracy@10: 0.5425
  NDCG@1:     0.1466
  NDCG@5:     0.2734
  NDCG@10:     0.3200
  NDCG@50:     0.4024
  Recall@1:   0.1466
  Recall@5:   0.3959
  Recall@10:   0.5425
  Recall@50:   0.9041

[TEST] Reranker-only (no CE, no fusion)
  Accuracy@1: 0.1714
  Accuracy@3: 0.2898
  Accuracy@5: 0.3741
  Accuracy@10: 0.5306
  NDCG@1:     0.1714
  NDCG@5:     0.2755
  NDCG@10:     0.3259
  NDCG@50:     0.4103
  Recall@1:   0.1714
  Recall@5:   0.3741
  Recall@10:   0.5306
  Recall@50:   0.9184

=== CLIP-IMG-only (cosine: CLIP text query ↔ CLIP image product) ===

[VAL] Reranker-only (no CE, no fusion)
  Accuracy@1: 0.2000
  Accuracy@3: 0.3589
  Accuracy@5: 0.4658
  Accuracy@10: 0.6082
  NDCG@1:     0.2000
  NDCG@5:     0.3357
  NDCG@10:     0.3822
  NDCG@50:     0.4494
  Recall@1:   0.2000
  Recall@5:   0.4658
  Recall@10: 

## Analysis


In [None]:
# Fixed: CE vs Fusion per-query comparison (no illegal variable names)

import numpy as np

# ---- Preconditions ----
need = ["pids_val","offsets_val","ce_scores_val","X_va","y_true_val",
        "pids_test","offsets_test","ce_scores_test","X_te","y_true_test","blender"]
for n in need:
    assert n in globals(), f"Missing {n}. Run previous cells first."

# Fallbacks if helpers weren't defined
try:
    _ranks_from_flat_scores
except NameError:
    def _ranks_from_flat_scores(pid_lists, offsets, flat_scores, true_pids):
        ranks, cur = [], 0
        for qi, pids in enumerate(pid_lists):
            K = len(pids)
            if K == 0:
                ranks.append(10**9); continue
            sl = flat_scores[cur:cur+K]
            order = np.argsort(-sl)
            ranked = [pids[i] for i in order]
            tpid = true_pids[qi]
            r = next((i+1 for i,p in enumerate(ranked) if p==tpid), len(ranked)+1)
            ranks.append(r)
            cur += K
        return ranks

try:
    accuracy_at_k
except NameError:
    def accuracy_at_k(ranks, K: int) -> float:
        n = max(1, len(ranks))
        return sum(1 for r in ranks if r <= K) / n

try:
    ndcg_at_k_single
except NameError:
    def ndcg_at_k_single(ranks, K: int) -> float:
        n = max(1, len(ranks))
        vals = [(1.0/np.log2(1+r)) if r <= K else 0.0 for r in ranks]
        return float(np.mean(vals)) if n else 0.0

def _compare_one(split_tag, pids, offsets, ce_scores, X, y_true, ks=(1,3,5,10)):
    # CE-only ranks
    ranks_ce = _ranks_from_flat_scores(pids, offsets, ce_scores, y_true)
    # Fusion ranks
    fuse_scores = blender.predict_proba(X)[:,1].astype(np.float32) if len(X) else np.zeros((0,), np.float32)
    ranks_fuse = _ranks_from_flat_scores(pids, offsets, fuse_scores, y_true)

    n = len(ranks_ce)
    improved = sum(int(rf < rc) for rc, rf in zip(ranks_ce, ranks_fuse))
    hurt     = sum(int(rf > rc) for rc, rf in zip(ranks_ce, ranks_fuse))
    tie      = n - improved - hurt

    rescued_at_10 = sum(int((rc > 10) and (rf <= 10)) for rc, rf in zip(ranks_ce, ranks_fuse))
    lost_at_10    = sum(int((rc <= 10) and (rf > 10)) for rc, rf in zip(ranks_ce, ranks_fuse))

    print(f"\n[{split_tag}] CE vs Fusion (per-query deltas)")
    print(f"  Improved: {improved}   Hurt: {hurt}   Tie: {tie}")
    print(f"  Rescued@10: {rescued_at_10}   Lost@10: {lost_at_10}")

    for K in ks:
        acc_ce, acc_fu = accuracy_at_k(ranks_ce, K), accuracy_at_k(ranks_fuse, K)
        print(f"  Acc@{K}: CE={acc_ce:.4f} → Fuse={acc_fu:.4f} (Δ={acc_fu-acc_ce:+.4f})")
    for K in (5,10,50):
        nd_ce, nd_fu = ndcg_at_k_single(ranks_ce, K), ndcg_at_k_single(ranks_fuse, K)
        print(f"  NDCG@{K}: CE={nd_ce:.4f} → Fuse={nd_fu:.4f} (Δ={nd_fu-nd_ce:+.4f})")

# ---- Run for VAL and TEST ----
_compare_one("VAL",  pids_val,  offsets_val,  ce_scores_val,  X_va, y_true_val)
_compare_one("TEST", pids_test, offsets_test, ce_scores_test, X_te, y_true_test)



[VAL] CE vs Fusion (per-query deltas)
  Improved: 237   Hurt: 198   Tie: 295
  Rescued@10: 35   Lost@10: 38
  Acc@1: CE=0.3671 → Fuse=0.3329 (Δ=-0.0342)
  Acc@3: CE=0.5205 → Fuse=0.5247 (Δ=+0.0041)
  Acc@5: CE=0.6068 → Fuse=0.6000 (Δ=-0.0068)
  Acc@10: CE=0.7068 → Fuse=0.7027 (Δ=-0.0041)
  NDCG@5: CE=0.4920 → Fuse=0.4760 (Δ=-0.0160)
  NDCG@10: CE=0.5243 → Fuse=0.5094 (Δ=-0.0149)
  NDCG@50: CE=0.5562 → Fuse=0.5513 (Δ=-0.0049)

[TEST] CE vs Fusion (per-query deltas)
  Improved: 236   Hurt: 197   Tie: 302
  Rescued@10: 49   Lost@10: 41
  Acc@1: CE=0.3537 → Fuse=0.3537 (Δ=+0.0000)
  Acc@3: CE=0.5088 → Fuse=0.5238 (Δ=+0.0150)
  Acc@5: CE=0.5986 → Fuse=0.6027 (Δ=+0.0041)
  Acc@10: CE=0.6857 → Fuse=0.6966 (Δ=+0.0109)
  NDCG@5: CE=0.4808 → Fuse=0.4861 (Δ=+0.0053)
  NDCG@10: CE=0.5089 → Fuse=0.5166 (Δ=+0.0078)
  NDCG@50: CE=0.5533 → Fuse=0.5600 (Δ=+0.0066)


In [None]:
# Cell 8c — Error attribution @10 across stages: PPR vs Rerank (CE) vs Fusion (CE+IMG)

import numpy as np

# --------- Preconditions ----------
need = ["pids_val","pids_test","offsets_val","offsets_test",
        "ce_scores_val","ce_scores_test","X_va","X_te",
        "blender","y_true_val","y_true_test","TOPK_PPR"]
for n in need:
    assert n in globals(), f"Missing {n}. Run previous cells first."

def _slice_scores(flat, offsets, pid_lists, qi):
    s = offsets[qi]; e = s + len(pid_lists[qi])
    return flat[s:e]

def _attr_report(tag, pid_lists, offsets, ce_scores, X, truths, topk_display=10):
    # Fusion scores (CE+IMG)
    fuse_scores = (blender.predict_proba(X)[:,1].astype(np.float32) if len(X) else np.zeros((0,), dtype=np.float32))

    nQ = len(truths)
    dropped_final = 0
    drop_ppr = 0
    drop_rerank = 0
    drop_fusion = 0

    rescued_by_fusion = 0   # CE miss @10 -> Fusion hit @10
    neutral_hit = 0         # CE hit @10 and Fusion hit @10
    both_miss = 0           # both miss @10

    for qi in range(nQ):
        pids = pid_lists[qi]
        tpid = truths[qi]

        # Stage flags
        in_ppr = tpid in pids

        # CE-only ranking for this query
        ce_slice = _slice_scores(ce_scores, offsets, pid_lists, qi)
        ce_order_idx = np.argsort(-ce_slice)
        ce_topk = [pids[i] for i in ce_order_idx[:topk_display]]
        ce_hit = tpid in ce_topk

        # Fusion ranking for this query
        fu_slice = _slice_scores(fuse_scores, offsets, pid_lists, qi)
        fu_order_idx = np.argsort(-fu_slice)
        fu_topk = [pids[i] for i in fu_order_idx[:topk_display]]
        fu_hit = tpid in fu_topk

        # Counters for improvement/neutral/miss comparison
        if (not ce_hit) and fu_hit:
            rescued_by_fusion += 1
        elif ce_hit and fu_hit:
            neutral_hit += 1
        elif (not ce_hit) and (not fu_hit):
            both_miss += 1

        # Attribute final errors (Acc@10 miss after fusion)
        if not fu_hit:
            dropped_final += 1
            if not in_ppr:
                drop_ppr += 1
            elif not ce_hit:
                drop_rerank += 1
            else:
                drop_fusion += 1

    def pct(x):
        return (100.0 * x / max(1, dropped_final))

    print(f"\n[{tag}] Acc@10 error attribution over {nQ} queries (TOPK_PPR={TOPK_PPR})")
    print(f"  Final misses (Fusion Acc@10=0): {dropped_final}")
    print(f"    • Dropped at PPR (truth not in candidates): {drop_ppr}  ({pct(drop_ppr):.1f}%)")
    print(f"    • Dropped at Rerank (in PPR, CE@10 miss):  {drop_rerank}  ({pct(drop_rerank):.1f}%)")
    print(f"    • Dropped at Fusion (CE@10 hit → out@10):  {drop_fusion}  ({pct(drop_fusion):.1f}%)")

    print(f"\n[{tag}] CE ↔ Fusion comparison @10 (all queries)")
    print(f"  Rescued by Fusion (CE miss → Fusion hit): {rescued_by_fusion}")
    print(f"  Neutral hits       (CE hit  ∧ Fusion hit): {neutral_hit}")
    print(f"  Both miss          (CE miss ∧ Fusion miss): {both_miss}")

# ---- Run for VAL and TEST ----
_attr_report("VAL",  pids_val,  offsets_val,  ce_scores_val,  X_va, y_true_val,  topk_display=10)
_attr_report("TEST", pids_test, offsets_test, ce_scores_test, X_te, y_true_test, topk_display=10)



[VAL] Acc@10 error attribution over 730 queries (TOPK_PPR=200)
  Final misses (Fusion Acc@10=0): 217
    • Dropped at PPR (truth not in candidates): 50  (23.0%)
    • Dropped at Rerank (in PPR, CE@10 miss):  129  (59.4%)
    • Dropped at Fusion (CE@10 hit → out@10):  38  (17.5%)

[VAL] CE ↔ Fusion comparison @10 (all queries)
  Rescued by Fusion (CE miss → Fusion hit): 35
  Neutral hits       (CE hit  ∧ Fusion hit): 478
  Both miss          (CE miss ∧ Fusion miss): 179

[TEST] Acc@10 error attribution over 735 queries (TOPK_PPR=200)
  Final misses (Fusion Acc@10=0): 223
    • Dropped at PPR (truth not in candidates): 40  (17.9%)
    • Dropped at Rerank (in PPR, CE@10 miss):  142  (63.7%)
    • Dropped at Fusion (CE@10 hit → out@10):  41  (18.4%)

[TEST] CE ↔ Fusion comparison @10 (all queries)
  Rescued by Fusion (CE miss → Fusion hit): 49
  Neutral hits       (CE hit  ∧ Fusion hit): 463
  Both miss          (CE miss ∧ Fusion miss): 182


In [22]:
# Cell 8d — Deeper error analysis for CE+IMG fusion (VAL & TEST)
import numpy as np

# ---- Preconditions ----
need = [
    "pids_val", "offsets_val", "ce_scores_val", "img_sims_val", "X_va", "y_true_val",
    "pids_test", "offsets_test", "ce_scores_test", "img_sims_test", "X_te", "y_true_test",
    "blender"
]
missing = [n for n in need if n not in globals()]
assert not missing, f"Missing: {missing}. Run the CE+IMG fusion cell first."

# Reuse ranks helper if available; otherwise define a local one
try:
    _ranks_from_flat_scores
except NameError:
    def _ranks_from_flat_scores(pid_lists, offsets, flat_scores, true_pids):
        ranks, cur = [], 0
        for qi, pids in enumerate(pid_lists):
            K = len(pids)
            if K == 0:
                ranks.append(10**9); continue
            sl = flat_scores[cur:cur+K]
            order = np.argsort(-sl)
            ranked = [pids[i] for i in order]
            tpid = true_pids[qi]
            r = next((i+1 for i,p in enumerate(ranked) if p==tpid), len(ranked)+1)
            ranks.append(r)
            cur += K
        return ranks

def _group_stats(split_tag, pids, offsets, ce_scores, img_sims, X, y_true, K=10):
    """Break down where fusion helps/hurts and inspect CE/IMG signals on the TRUE item."""
    print(f"\n=== {split_tag}: CE+IMG fusion error analysis @K={K} ===")

    # CE-only ranks
    ranks_ce = _ranks_from_flat_scores(pids, offsets, ce_scores, y_true)
    # Fusion ranks (LR(CE, IMG))
    fuse_scores = blender.predict_proba(X)[:, 1].astype(np.float32) if len(X) else np.zeros((0,), np.float32)
    ranks_fu = _ranks_from_flat_scores(pids, offsets, fuse_scores, y_true)

    n = len(ranks_ce)
    groups = {
        "rescued": {"mask": [], "ce_rank": [], "fu_rank": [], "ce_true": [], "img_true": [], "fu_true": []},
        "lost":    {"mask": [], "ce_rank": [], "fu_rank": [], "ce_true": [], "img_true": [], "fu_true": []},
        "both_hit":  {"mask": [], "ce_rank": [], "fu_rank": [], "ce_true": [], "img_true": [], "fu_true": []},
        "both_miss": {"mask": [], "ce_rank": [], "fu_rank": [], "ce_true": [], "img_true": [], "fu_true": []},
    }

    cur = 0
    for qi, pids_q in enumerate(pids):
        Kq = len(pids_q)
        ce_r = ranks_ce[qi]
        fu_r = ranks_fu[qi]

        # Determine group at top-K
        if ce_r > K and fu_r <= K:
            g = "rescued"
        elif ce_r <= K and fu_r > K:
            g = "lost"
        elif ce_r <= K and fu_r <= K:
            g = "both_hit"
        else:
            g = "both_miss"

        # Locate true item within this query's slice
        tpid = y_true[qi]
        try:
            local_idx = pids_q.index(tpid)
        except ValueError:
            cur += Kq
            continue  # should not happen, but be safe

        sl_ce  = ce_scores[cur:cur+Kq]
        sl_img = img_sims[cur:cur+Kq]
        sl_fu  = fuse_scores[cur:cur+Kq]

        ce_true  = sl_ce[local_idx]
        img_true = sl_img[local_idx]
        fu_true  = sl_fu[local_idx]

        groups[g]["mask"].append(qi)
        groups[g]["ce_rank"].append(ce_r)
        groups[g]["fu_rank"].append(fu_r)
        groups[g]["ce_true"].append(ce_true)
        groups[g]["img_true"].append(img_true)
        groups[g]["fu_true"].append(fu_true)

        cur += Kq

    for name, stats in groups.items():
        m = len(stats["mask"])
        if m == 0:
            print(f"\n[{name}] no queries")
            continue
        frac = m / max(1, n)
        ce_r  = np.array(stats["ce_rank"], dtype=np.float32)
        fu_r  = np.array(stats["fu_rank"], dtype=np.float32)
        ce_t  = np.array(stats["ce_true"], dtype=np.float32)
        img_t = np.array(stats["img_true"], dtype=np.float32)
        fu_t  = np.array(stats["fu_true"], dtype=np.float32)

        print(f"\n[{name}] {m}/{n} queries ({frac:.1%})")
        print(f"  CE rank (true):   mean={ce_r.mean():.2f}, median={np.median(ce_r):.1f}")
        print(f"  Fusion rank:      mean={fu_r.mean():.2f}, median={np.median(fu_r):.1f}")
        print(f"  CE score (true):  mean={ce_t.mean():.3f}, std={ce_t.std():.3f}")
        print(f"  IMG score (true): mean={img_t.mean():.3f}, std={img_t.std():.3f}")
        print(f"  Fusion score:     mean={fu_t.mean():.3f}, std={fu_t.std():.3f}")

    # Optional: simple correlation between CE and IMG features on positives
    all_ce_true  = np.concatenate([np.array(v["ce_true"])  for v in groups.values() if v["ce_true"]])
    all_img_true = np.concatenate([np.array(v["img_true"]) for v in groups.values() if v["img_true"]])
    if all_ce_true.size and all_img_true.size:
        corr = np.corrcoef(all_ce_true, all_img_true)[0,1]
        print(f"\n[positives] Corr(CE_true, IMG_true) = {corr:.3f}")

# Run for VAL and TEST
_group_stats("VAL",  pids_val,  offsets_val,  ce_scores_val,  img_sims_val,  X_va, y_true_val,  K=10)
_group_stats("TEST", pids_test, offsets_test, ce_scores_test, img_sims_test, X_te, y_true_test, K=10)



=== VAL: CE+IMG fusion error analysis @K=10 ===

[rescued] 40/730 queries (5.5%)
  CE rank (true):   mean=17.02, median=14.0
  Fusion rank:      mean=6.25, median=7.0
  CE score (true):  mean=-6.226, std=4.913
  IMG score (true): mean=0.249, std=0.023
  Fusion score:     mean=0.511, std=0.245

[lost] 30/730 queries (4.1%)
  CE rank (true):   mean=5.67, median=6.0
  Fusion rank:      mean=16.10, median=13.5
  CE score (true):  mean=0.677, std=5.094
  IMG score (true): mean=0.210, std=0.025
  Fusion score:     mean=0.659, std=0.216

[both_hit] 529/730 queries (72.5%)
  CE rank (true):   mean=2.24, median=1.0
  Fusion rank:      mean=2.38, median=1.0
  CE score (true):  mean=1.598, std=4.419
  IMG score (true): mean=0.244, std=0.039
  Fusion score:     mean=0.771, std=0.189

[both_miss] 61/730 queries (8.4%)
  CE rank (true):   mean=25.02, median=21.0
  Fusion rank:      mean=21.44, median=20.0
  CE score (true):  mean=-5.128, std=5.082
  IMG score (true): mean=0.234, std=0.034
  Fusion 