# Food / Diet / Nutrition RAG (Multi‚ÄëSource)
### TXT + PDF) ‚Üí Ingestion ‚Üí Chunking ‚Üí Index (Chroma) ‚Üí Query Improve ‚Üí Hybrid Retrieval ‚Üí Cross‚ÄëEncoder Rerank ‚Üí Self‚ÄëCheck ‚Üí Evaluation ‚Üí UI
.
**Supported source types**
- **TXT** (`.txt`)
- **PDF** (`.pdf`) ‚Äî extracted with PyMuPDF (fallback: pdfplumber)

**Folder convention** (recommended)
- `data/txt/` for TXT
- `data/pdf/` for PDFs

You can also put files directly under `data/` and the loader will still find them.

## STEP 0 ‚Äî Setup (PyCharm-friendly)
Ensures paths work even if the notebook is stored under `notebooks/`.

In [None]:

import os
from pathlib import Path
import importlib.util
import sys

cwd = Path.cwd()
if cwd.name.lower() == "notebooks":
    os.chdir(cwd.parent)

print("CWD:", Path.cwd())

## STEP 1 ‚Äî Install & Imports
Install dependencies (run once). The notebook has fallbacks when optional libs are missing.

In [None]:
# Imports (+ optional install only-if-missing)
# Works in Jupyter and is safe for a git repo: won't reinstall packages if already present.

import sys, importlib.util, subprocess

# --- list your packages here (pip name -> import name) ---
REQUIRED = {
    "sentence-transformers": "sentence_transformers",
    "rank-bm25": "rank_bm25",
    "chromadb": "chromadb",
    "pyspellchecker": "spellchecker",
    "rapidfuzz": "rapidfuzz",
    "ipywidgets": "ipywidgets",
    "nltk": "nltk",
    "pymupdf": "fitz",          # PyMuPDF import is fitz
    "pdfplumber": "pdfplumber",
    "odfpy": "odf",
}

def is_installed(import_name: str) -> bool:
    return importlib.util.find_spec(import_name) is not None

missing = [pip_name for pip_name, imp in REQUIRED.items() if not is_installed(imp)]

print("Python executable:", sys.executable)
print("Missing packages:", missing if missing else "None ")

# Set this to True only if you want the notebook to auto-install missing deps.
AUTO_INSTALL = False

if missing and AUTO_INSTALL:
    print("\nInstalling missing packages (same interpreter as notebook)...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", *missing])
    print(" Install done.\n")

# ---- Now import everything (safe: will either work or clearly show what's missing) ----
import re, json, math, time, hashlib
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
from pathlib import Path
import numpy as np

# Embeddings & reranker
try:
    from sentence_transformers import SentenceTransformer, CrossEncoder
except Exception:
    SentenceTransformer = None
    CrossEncoder = None

# Keyword retrieval
try:
    from rank_bm25 import BM25Okapi
except Exception:
    BM25Okapi = None

# Query improve utilities
try:
    from spellchecker import SpellChecker
except Exception:
    SpellChecker = None

try:
    from rapidfuzz import fuzz
except Exception:
    fuzz = None

# Vector DB persistence
try:
    import chromadb
    from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
except Exception:
    chromadb = None
    SentenceTransformerEmbeddingFunction = None

# UI
try:
    import ipywidgets as widgets
    from IPython.display import display, Markdown, clear_output
except Exception:
    widgets = None

# PDF extraction
try:
    import fitz  # PyMuPDF
except Exception:
    fitz = None

try:
    import pdfplumber
except Exception:
    pdfplumber = None

# ODF extraction
try:
    from odf.opendocument import load as odf_load
    from odf import text as odf_text, teletype as odf_teletype
except Exception:
    odf_load = None
    odf_text = None
    odf_teletype = None

print("\n=== Availability ===")
print("SentenceTransformer:", bool(SentenceTransformer))
print("CrossEncoder:", bool(CrossEncoder))
print("BM25:", bool(BM25Okapi))
print("Chroma:", bool(chromadb))
print("ipywidgets:", bool(widgets))
print("PyMuPDF:", bool(fitz))
print("pdfplumber:", bool(pdfplumber))
print("odfpy:", bool(odf_load))

if missing and not AUTO_INSTALL:
    print("\n To install dependencies, either:")
    print("1) set AUTO_INSTALL=True and re-run this cell, OR")
    print("2) run in terminal (recommended):")
    print("   python -m pip install -U " + " ".join(missing))


## STEP 2 ‚Äî Data Sources (separated by source type)
We keep files separate by source type (TXT/PDF/ODF) but also support searching under `data/`.

In [None]:
from dataclasses import dataclass
from typing import List

DATA_DIR = Path("data")
TXT_DIR = DATA_DIR / "txt"
PDF_DIR = DATA_DIR / "pdf"
ODF_DIR = DATA_DIR / "odf"

SUPPORTED = {
    "txt": [".txt"],
    "pdf": [".pdf"],
    "odf": [".odt", ".ods", ".odp"],
}

@dataclass
class DocFile:
    source_type: str
    path: str
    name: str

def discover_files() -> List[DocFile]:
    out: List[DocFile] = []
    roots = [
        ("txt", TXT_DIR),
        ("pdf", PDF_DIR),
        ("odf", ODF_DIR),
        ("_data", DATA_DIR),  # fallback: scan all under data
    ]

    seen = set()
    for _, root in roots:
        if not root.exists():
            continue

        for fp in root.rglob("*"):
            if not fp.is_file():
                continue

            ext = fp.suffix.lower()
            if ext in SUPPORTED["txt"]:
                typ = "txt"
            elif ext in SUPPORTED["pdf"]:
                typ = "pdf"
            elif ext in SUPPORTED["odf"]:
                typ = "odf"
            else:
                continue

            key = str(fp.resolve())
            if key in seen:
                continue
            seen.add(key)

            out.append(DocFile(source_type=typ, path=str(fp), name=fp.name))

    return sorted(out, key=lambda x: (x.source_type, x.name))

doc_files = discover_files()
print("Total files:", len(doc_files))
for f in doc_files[:20]:
    print(f"- [{f.source_type}] {f.path}")


## STEP 3 ‚Äî Extraction (TXT / PDF / ODF)
Each source type is extracted into raw text.

- **TXT:** UTF‚Äë8 decode
- **PDF:** PyMuPDF (best-effort), fallback to pdfplumber

In [None]:

def read_txt(fp: Path) -> str:
    return fp.read_text(encoding="utf-8", errors="replace")

def read_pdf_pymupdf(fp: Path) -> str:
    # Best-effort extraction (no OCR). If your PDFs are scanned images, you need OCR (not included here).
    if fitz is None:
        raise RuntimeError("PyMuPDF not available")
    doc = fitz.open(fp)
    pages = []
    for page in doc:
        pages.append(page.get_text("text"))
    doc.close()
    return "\n".join(pages)

def read_pdf_pdfplumber(fp: Path) -> str:
    if pdfplumber is None:
        raise RuntimeError("pdfplumber not available")
    pages = []
    with pdfplumber.open(fp) as pdf:
        for p in pdf.pages:
            pages.append(p.extract_text() or "")
    return "\n".join(pages)

def read_odf(fp: Path) -> str:
    if odf_load is None:
        raise RuntimeError("odfpy not available")
    doc = odf_load(str(fp))
    # Extract text from <text:p> elements (best-effort).
    ps = doc.getElementsByType(odf_text.P)
    out = []
    for p in ps:
        out.append(odf_teletype.extractText(p))
    return "\n".join(out)

def extract_text(doc: DocFile) -> str:
    fp = Path(doc.path)
    if doc.source_type == "txt":
        return read_txt(fp)
    if doc.source_type == "pdf":
        # prefer PyMuPDF, fallback to pdfplumber
        if fitz is not None:
            return read_pdf_pymupdf(fp)
        if pdfplumber is not None:
            return read_pdf_pdfplumber(fp)
        raise RuntimeError("No PDF extractor available (install pymupdf or pdfplumber).")
    if doc.source_type == "odf":
        return read_odf(fp)
    raise ValueError("Unknown source_type: " + doc.source_type)

# quick smoke test
if doc_files:
    t = extract_text(doc_files[0])
    print("First file:", doc_files[0].name, "| type:", doc_files[0].source_type)
    print("Chars:", len(t))
    print(t[:300], "...")
else:
    print("No files found under data/. Add files under data/txt, data/pdf, data/odf.")

## STEP 4 ‚Äî Cleaning + Metadata + Chunking
We normalize text, detect rough structure (paragraphs/sentences), then create overlapping chunks.

Metadata stored per chunk:
- `source_type`, `source_file`, `doc_id`
- `topic` (food/nutrition/unknown), `language`
- optional: `reliability`, `date`, `section`

In [None]:

def detect_language_light(text: str) -> str:
    return "he" if re.search(r"[\u0590-\u05FF]", text) else "en"

def clean_text(s: str) -> str:
    s = s.replace("\r\n", "\n").replace("\r", "\n")
    s = re.sub(r"[ \t]+", " ", s)
    s = re.sub(r"\n{3,}", "\n\n", s)
    s = re.sub(r"[‚Äú‚Äù]", '"', s)
    s = re.sub(r"[‚Äò‚Äô]", "'", s)
    return s.strip()

def split_paragraphs(s: str) -> List[str]:
    return [p.strip() for p in s.split("\n\n") if p.strip()]

def split_sentences_best_effort(s: str) -> List[str]:
    parts = re.split(r"(?<=[.!?])\s+", s.strip())
    return [p.strip() for p in parts if p.strip()]

def guess_topic(text: str, filename: str) -> str:
    t = text.lower()
    fn = filename.lower()
    nutrition_terms = ["vitamin","calorie","protein","fiber","sodium","cholesterol","diet","nutrition","omega","macro","micronutrient"]
    food_terms = ["ingredients","recipe","bake","cook","boil","meal","serving","mix","stir"]
    if any(x in fn for x in ["nutrition","diet","health"]) or any(x in t for x in nutrition_terms):
        return "nutrition"
    if any(x in fn for x in ["recipe","cook","meal"]) or any(x in t for x in food_terms):
        return "food"
    return "unknown"

CHUNK_MAX_CHARS = 1200
CHUNK_OVERLAP_CHARS = 200

@dataclass
class Chunk:
    chunk_id: str
    doc_id: str
    source_type: str
    source_file: str
    text: str
    meta: Dict

def stable_id(s: str) -> str:
    return hashlib.sha1(s.encode("utf-8", errors="ignore")).hexdigest()[:10]

def recursive_chunk(paragraph: str, max_chars: int, overlap: int) -> List[str]:
    if len(paragraph) <= max_chars:
        return [paragraph]
    sents = split_sentences_best_effort(paragraph)
    chunks, cur = [], ""
    for sent in sents:
        if len(cur) + len(sent) + 1 <= max_chars:
            cur = (cur + " " + sent).strip()
        else:
            if cur:
                chunks.append(cur)
            cur = sent
    if cur:
        chunks.append(cur)
    if overlap > 0 and len(chunks) > 1:
        out, prev = [], ""
        for c in chunks:
            out.append((prev[-overlap:] + " " + c).strip() if prev else c)
            prev = c
        return out
    return chunks

def ingest_all(doc_files: List[DocFile]) -> List[Chunk]:
    out: List[Chunk] = []
    for doc in doc_files:
        raw = extract_text(doc)
        cleaned = clean_text(raw)
        doc_id = Path(doc.name).stem
        lang = detect_language_light(cleaned)
        topic = guess_topic(cleaned, doc.name)

        paras = split_paragraphs(cleaned)
        for pi, p in enumerate(paras):
            pieces = recursive_chunk(p, CHUNK_MAX_CHARS, CHUNK_OVERLAP_CHARS)
            for ci, piece in enumerate(pieces):
                cid = f"{doc_id}::{doc.source_type}::p{pi}::c{ci}::{stable_id(piece)}"
                out.append(Chunk(
                    chunk_id=cid,
                    doc_id=doc_id,
                    source_type=doc.source_type,
                    source_file=doc.path,
                    text=piece,
                    meta={
                        "source_type": doc.source_type,
                        "source_file": doc.path,
                        "topic": topic,
                        "language": lang,
                        "reliability": "unknown",
                        "date": None,
                        "section": None
                    }
                ))
    return out

chunks = ingest_all(doc_files)
print("Chunks:", len(chunks))
if chunks:
    print("Example:", chunks[0].chunk_id)
    print(chunks[0].text[:220], "...")

## STEP 5 ‚Äî Indexing (Chroma persistence + fallback)
We index chunks in a persistent vector DB (Chroma). If Chroma is unavailable, we compute an in-memory embedding matrix.

In [None]:
EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
CHROMA_DIR = "chroma_food_rag"

if SentenceTransformer is None:
    raise RuntimeError("Install sentence-transformers to build embeddings.")

def chroma_safe_metadata(meta: dict) -> dict:
    """
    Chroma metadata values must be only: bool | int | float | str
    """
    safe = {}
    for k, v in meta.items():
        if v is None:
            continue  # drop Nones
        if isinstance(v, (bool, int, float, str)):
            safe[k] = v
        else:
            # convert anything else to string
            safe[k] = str(v)
    return safe
emb_model = SentenceTransformer(EMB_MODEL_NAME) if chunks else None

chroma_col = None
if chromadb is not None and SentenceTransformerEmbeddingFunction is not None and chunks:
    client = chromadb.PersistentClient(path=CHROMA_DIR)
    emb_fn = SentenceTransformerEmbeddingFunction(model_name=EMB_MODEL_NAME)
    chroma_col = client.get_or_create_collection(
        name="food_rag_chunks_multisource",
        embedding_function=emb_fn,
        metadata={"hnsw:space": "cosine"}
    )

    existing = set()
    try:
        existing = set(chroma_col.get(include=[]).get("ids", []))
    except Exception:
        existing = set()

    new_ids, new_docs, new_metas = [], [], []
for c in chunks:
    if c.chunk_id not in existing:
        new_ids.append(c.chunk_id)
        new_docs.append(c.text)
        new_metas.append(chroma_safe_metadata(c.meta))

if new_ids:
    chroma_col.add(ids=new_ids, documents=new_docs, metadatas=new_metas)
    print(f"Added {len(new_ids)} new chunks to Chroma")
else:
    print("Chroma already up to date")

vecs = None
if chroma_col is None and emb_model is not None and chunks:
    vecs = emb_model.encode([c.text for c in chunks], normalize_embeddings=True, show_progress_bar=True).astype(np.float32)
    print("vecs:", vecs.shape)

In [None]:
#Chroma sync (make DB exactly match current chunks)

from pathlib import Path
import math

def chroma_sync_to_current(chroma_col, chunks, batch_size: int = 5000):
    if chroma_col is None:
        print("chroma_col is None")
        return

    # current ids from Python ingestion
    current_ids = [c.chunk_id for c in chunks]
    current_set = set(current_ids)

    # stored ids in Chroma
    stored = chroma_col.get(include=[])  # ids only
    stored_ids = stored.get("ids", [])
    stored_set = set(stored_ids)

    to_delete = sorted(list(stored_set - current_set))
    to_add_ids = []
    to_add_docs = []
    to_add_metas = []

    # index chunks by id for quick lookup
    by_id = {c.chunk_id: c for c in chunks}
    missing = sorted(list(current_set - stored_set))
    for cid in missing:
        c = by_id[cid]
        to_add_ids.append(c.chunk_id)
        to_add_docs.append(c.text)
        # IMPORTANT: Chroma metadata cannot contain None ‚Üí convert to strings or remove Nones
        meta = {k: v for k, v in c.meta.items() if v is not None}
        to_add_metas.append(meta)

    print("Current chunks:", len(current_ids))
    print("Stored in Chroma:", len(stored_ids))
    print("Will delete:", len(to_delete))
    print("Will add:", len(to_add_ids))

    # delete extra ids
    if to_delete:
        for i in range(0, len(to_delete), batch_size):
            chroma_col.delete(ids=to_delete[i:i+batch_size])
        print("Deleted extras")

    # add missing
    if to_add_ids:
        for i in range(0, len(to_add_ids), batch_size):
            chroma_col.add(
                ids=to_add_ids[i:i+batch_size],
                documents=to_add_docs[i:i+batch_size],
                metadatas=to_add_metas[i:i+batch_size],
            )
        print("Added missing")

    # verify
    new_count = len(chroma_col.get(include=[]).get("ids", []))
    print("Final Chroma count:", new_count)
    print("Should match current:", len(current_ids))

# RUN IT
chroma_sync_to_current(chroma_col, chunks)


In [None]:
#Verify latest files: If Chunks in Python > Chunks in Chroma, then it‚Äôs not actually up to date (usually because chunk_ids changed).
print("Chunks in Python:", len(chunks))

if chroma_col is not None:
    count = len(chroma_col.get(include=[]).get("ids", []))
    print("Chunks in Chroma:", count)

## STEP 6 ‚Äî Hybrid Retrieval (BM25 + embeddings) + Query Improvement
We:
1) create query variants
2) retrieve using embeddings + BM25
3) fuse/deduplicate results

This step improves recall without exploding the query:
1) Normalize (lowercase, units)
2) Domain-safe spell correction (lightweight)
3) Synonym expansion using a real vocabulary (MeSH Entry Terms)
4) Produce 2‚Äì4 short query variants and fuse results

In [None]:
# --- Make sure vocab  exists (load MeSH cache if available) ---
from pathlib import Path
import json

SYN_CACHE = Path("vocab/mesh_synonyms_cache.json")

if "syn_vocab" not in globals():
    if SYN_CACHE.exists():
        syn_vocab = json.loads(SYN_CACHE.read_text("utf-8"))
        print(f"syn_vocab loaded from cache: {len(syn_vocab):,} entries")
    else:
        syn_vocab = {}   # fallback (no MeSH synonyms)
        print("syn_vocab not found (no cache). Using empty synonyms.")


In [None]:
# BM25 + Query Improve (spell + MeSH synonyms)

import re
from typing import List, Dict, Optional, Tuple

# --- BM25 ---
def tokenize(text: str) -> List[str]:
    return re.findall(r"\w+", text.lower())

bm25 = None
if BM25Okapi is not None and chunks:
    bm25 = BM25Okapi([tokenize(c.text) for c in chunks])
    print("BM25 ready")
else:
    print("BM25 not available")

# --- Query normalization ---
PROTECTED_TERMS = {
    "bmi","bp","ldl","hdl","vit","omega","omega-3","dha","epa",
    "sodium","potassium","calcium","magnesium","iron","zinc","protein","fiber",
    "carbs","kcal","mg","g","kg","ml","iu"
}

def normalize_query(q: str) -> str:
    q = q.strip().lower()
    q = re.sub(r"\s+", " ", q)
    # units
    q = q.replace("kilograms", "kg").replace("kilogram", "kg")
    q = q.replace("milligrams", "mg").replace("milligram", "mg")
    q = q.replace("grams", "g").replace("gram", "g")
    return q

def meaningful_tokens(q: str) -> List[str]:
    toks = [t for t in tokenize(q) if len(t) > 1]
    return toks

# --- Spell correction (safe) ---
def spell_fix_domain_safe(q: str) -> str:
    q = normalize_query(q)
    if SpellChecker is None:
        return q

    sp = SpellChecker()
    toks = q.split()
    out = []

    for t in toks:
        # keep tokens with digits/punct (omega-3, 2g, etc.)
        if re.search(r"[^a-z]", t):
            out.append(t); continue
        if len(t) <= 2 or t in PROTECTED_TERMS:
            out.append(t); continue

        corr = sp.correction(t) or t

        # accept only if close enough (avoid meaning changes)
        if fuzz is not None and fuzz.ratio(t, corr) < 80:
            corr = t

        out.append(corr)

    return " ".join(out)

# --- MeSH matching: generate n-gram keys from query to look up in syn_vocab ---
def candidate_keys_from_query(q: str, max_ngram: int = 3) -> List[str]:
    q = normalize_query(q)
    toks = q.split()
    keys = []
    for n in range(max_ngram, 0, -1):
        for i in range(0, len(toks) - n + 1):
            keys.append(" ".join(toks[i:i+n]))
    # unique, keep longest first
    uniq = []
    seen = set()
    for k in keys:
        if k not in seen:
            uniq.append(k)
            seen.add(k)
    return uniq

# --- Filter bad variants---
BAD_PATTERNS = [r"\bpressure,\s*blood\b", r"\bhigh pressure,\s*blood\b"]
def is_good_variant(v: str) -> bool:
    v = v.strip().lower()
    if any(re.search(p, v) for p in BAD_PATTERNS):
        return False
    if len(v.split()) > 16:
        return False
    return True

# --- ONE expand_query function ---
def expand_query(q: str, max_alts: int = 4, syns_per_match: int = 2) -> List[str]:
    base = spell_fix_domain_safe(q)
    alts = [base]

    # 1) Replace matched phrases with MeSH synonyms
    keys = candidate_keys_from_query(base, max_ngram=3)
    for key in keys:
        syns = syn_vocab.get(key, [])[:syns_per_match] if syn_vocab else []
        for s in syns:
            cand = normalize_query(base.replace(key, s))
            if cand != base and is_good_variant(cand) and len(meaningful_tokens(cand)) >= 2:
                alts.append(cand)
        if len(alts) >= max_alts:
            break

    # 2) Acronym expansion (generic, not hypertension-specific)
    if "bp" in base and "blood pressure" not in base and len(alts) < max_alts:
        alts.append(base.replace("bp", "blood pressure"))

    # dedupe + cap
    uniq, seen = [], set()
    for a in alts:
        a = normalize_query(a)
        if a and a not in seen and is_good_variant(a):
            uniq.append(a); seen.add(a)
    return uniq[:max_alts]

# --- Retrieval helpers ---
def minmax_norm(vals: List[float]) -> List[float]:
    if not vals:
        return []
    mn, mx = min(vals), max(vals)
    if mx - mn < 1e-9:
        return [1.0 for _ in vals]
    return [(v - mn) / (mx - mn) for v in vals]

def retrieve_chroma(q: str, top_k: int = 10, where: Optional[Dict]=None) -> List[Dict]:
    if chroma_col is None:
        return []
    res = chroma_col.query(
        query_texts=[q],
        n_results=top_k,
        where=where,
        include=["documents", "metadatas", "distances"],  # <-- DO NOT include "ids"
    )
    out = []
    for cid, doc, meta, dist in zip(res["ids"][0], res["documents"][0], res["metadatas"][0], res["distances"][0]):
        out.append({"chunk_id": cid, "text": doc, "meta": meta, "score": 1.0 - float(dist)})
    return out

def retrieve_embed_fallback(q: str, top_k: int = 10) -> List[Tuple[int, float]]:
    if vecs is None or emb_model is None:
        return []
    qv = emb_model.encode([q], normalize_embeddings=True, show_progress_bar=False).astype(np.float32)[0]
    sims = vecs @ qv
    idx = np.argsort(-sims)[:top_k]
    return [(int(i), float(sims[i])) for i in idx]

def retrieve_bm25(q: str, top_k: int = 10) -> List[Tuple[int, float]]:
    if bm25 is None:
        return []
    scores = bm25.get_scores(tokenize(q))
    idx = np.argsort(-scores)[:top_k]
    return [(int(i), float(scores[i])) for i in idx]

by_id = {c.chunk_id: c for c in chunks}

def fusion_retrieve(question: str, fused_topk: int = 25, per_query_k: int = 12, where: Optional[Dict]=None):
    variants = expand_query(question)
    fused: Dict[str, float] = {}

    for q in variants:
        # embeddings
        if chroma_col is not None:
            for r in retrieve_chroma(q, top_k=per_query_k, where=where):
                fused[r["chunk_id"]] = max(fused.get(r["chunk_id"], 0.0), 0.6 * float(r["score"]))
        else:
            emb = retrieve_embed_fallback(q, top_k=per_query_k)
            emb_scores = minmax_norm([s for _, s in emb])
            for (i, _), sc in zip(emb, emb_scores):
                fused[chunks[i].chunk_id] = max(fused.get(chunks[i].chunk_id, 0.0), 0.6 * sc)

        # bm25
        bm = retrieve_bm25(q, top_k=per_query_k)
        bm_scores = minmax_norm([s for _, s in bm])
        for (i, _), sc in zip(bm, bm_scores):
            fused[chunks[i].chunk_id] = max(fused.get(chunks[i].chunk_id, 0.0), 0.4 * sc)

    ranked = sorted(fused.items(), key=lambda x: x[1], reverse=True)[:fused_topk]
    results = []
    for cid, sc in ranked:
        c = by_id.get(cid)
        if not c:
            continue
        results.append({
            "chunk_id": c.chunk_id,
            "text": c.text,
            "meta": c.meta,
            "source_file": c.source_file,
            "source_type": c.source_type,
            "score": float(sc),
        })
    return results, variants

# Demo
q_demo = "foods for high blood pressure"
retrieved, variants = fusion_retrieve(q_demo, fused_topk=25)
print("Variants:", variants)
print("Top result:", retrieved[0]["chunk_id"] if retrieved else "‚Äî")


## STEP 7 ‚Äî Cross‚ÄëEncoder Reranking
Reranking usually gives the biggest precision boost by scoring (question, chunk_text) pairs directly.

In [None]:

RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
reranker = CrossEncoder(RERANK_MODEL) if CrossEncoder is not None else None

def cross_encoder_rerank(question: str, retrieved: List[Dict], top_n: int = 10) -> List[Dict]:
    if reranker is None:
        return retrieved[:top_n]
    pairs = [(question, r["text"]) for r in retrieved]
    scores = reranker.predict(pairs)
    out = []
    for r, s in zip(retrieved, scores):
        r2 = dict(r)
        r2["rerank_score"] = float(s)
        out.append(r2)
    out.sort(key=lambda x: x.get("rerank_score", -1e9), reverse=True)
    return out[:top_n]

retrieved, variants = fusion_retrieve(q_demo, fused_topk=40)   # more candidates
reranked = cross_encoder_rerank(q_demo, retrieved, top_n=10)

for r in reranked[:3]:
    print(r["chunk_id"], "| type:", r["source_type"], "| rerank:", round(r.get("rerank_score", r["score"]), 3))

## STEP 8 ‚Äî Augmentation (prompt with citations)
We build a strict prompt:
- answer only from context
- if missing: "I don't know"
- end with citations (chunk IDs)

In [None]:

def build_context(chunks_list: List[Dict], max_chunks: int = 6, max_chars_each: int = 900) -> str:
    ctx = []
    for r in chunks_list[:max_chunks]:
        ctx.append(f"[{r['chunk_id']}] (type={r['source_type']}, source={Path(r['source_file']).name})\n{r['text'][:max_chars_each]}")
    return "\n\n".join(ctx)

def build_prompt(question: str, chunks_list: List[Dict]) -> str:
    context = build_context(chunks_list)
    return f"""You are a helpful assistant.
Answer the question ONLY using the provided context.
If the context does not contain the answer, say: "I don't know from the provided documents."

Question:
{question}

Context:
{context}

Rules:
- Use ONLY context facts.
- End with: Citations: [chunk_id1, chunk_id2, ...]
"""

prompt = build_prompt(q_demo, reranked)
print(prompt[:800], "...")

## STEP 9 ‚Äî Generation + Self‚ÄëCorrection (verifier)
What this step does:
 1) Generates a DRAFT answer from retrieved context using Ollama
 2) Runs a VERIFIER pass that removes unsupported claims and enforces citations
 3) If Ollama is not running, it DOES NOT crash ‚Äî it switches to retrieval-only mode

### Requirements: - Ollama installed and running : ollama pull llama3.2:3b
- A model pulled, e.g.:  ollama pull llama3.1:8b # or smaller if your laptop is weak: ollama pull llama3.2:3b
- Python package:pip install ollama

Notes:
- The verifier is essential for RAG: it reduces unsupported claims.
- If you have weak hardware, switch to a smaller model:    OLLAMA_MODEL = "llama3.2:3b"

In [None]:

import requests, re
from pathlib import Path
from typing import List, Dict, Optional

OLLAMA_HOST = "http://127.0.0.1:11434"
OLLAMA_MODEL = "llama3.2:3b"

def ensure_ollama_ready_http(model: str) -> bool:
    try:
        r = requests.get(f"{OLLAMA_HOST}/api/tags", timeout=5)
        r.raise_for_status()
        models = [m.get("name") for m in r.json().get("models", []) if m.get("name")]
        if model not in models:
            print(f"Model '{model}' not found.")
            print(f"Run once in terminal:  ollama pull {model}")
            print("Available models (first 15):", models[:15])
            return False
        print(" Ollama reachable | Model available:", model)
        return True
    except Exception as e:
        print(" Ollama not reachable:", e)
        print("Install + run Ollama: https://ollama.com/download")
        return False

def ollama_chat(prompt: str, temperature: float = 0.2) -> str:
    payload = {
        "model": OLLAMA_MODEL,
        "messages": [
            {"role": "system", "content": "You strictly follow instructions and cite sources."},
            {"role": "user", "content": prompt},
        ],
        "stream": False,
        "options": {"temperature": float(temperature)},
    }
    r = requests.post(f"{OLLAMA_HOST}/api/chat", json=payload, timeout=180)
    r.raise_for_status()
    return (r.json().get("message", {}).get("content") or "").strip()


In [None]:

# ---- Context building ----
def build_context(chunks_list: List[Dict], max_chunks: int = 6, max_chars_each: int = 900) -> str:
    ctx = []
    for r in chunks_list[:max_chunks]:
        src = Path(r.get("source_file", "")).name if r.get("source_file") else "unknown"
        ctx.append(
            f"[{r['chunk_id']}] (type={r.get('source_type','?')}, source={src})\n"
            f"{(r.get('text','') or '')[:max_chars_each]}"
        )
    return "\n\n".join(ctx)

def print_top_chunks(chunks_list: List[Dict], n: int = 3, chars: int = 220):
    print("\n--- TOP CHUNKS (preview) ---")
    for r in chunks_list[:n]:
        txt = (r.get("text") or "").replace("\n", " ")
        print(f"* {r.get('chunk_id')} | score={r.get('rerank_score', r.get('score', 0)):.3f}")
        print(" ", txt[:chars], "...\n")

# ---- Prompt (forces citations per bullet) ----
def build_prompt(question: str, chunks_list: List[Dict]) -> str:
    context = build_context(chunks_list, max_chunks=6, max_chars_each=1100)
    allowed_ids = [r["chunk_id"] for r in chunks_list[:6]]

    return f"""You are a careful nutrition assistant in a RAG system.

Question:
{question}

Context (ONLY source of truth):
{context}

STRICT RULES:
- Use ONLY facts explicitly stated in the context.
- If the context does not contain the answer, say exactly:
I don't know from the provided documents.
- If you write bullets, EACH bullet MUST end with ONE citation like: [FULL_CHUNK_ID]
- You may cite ONLY these chunk IDs:
{allowed_ids}
- Final line must be: Citations: [id1, id2, ...]

Output format:
1) 1 short answer sentence
2) 0‚Äì5 bullets (only if supported)
3) Final line: Citations: [...]
"""

# ---- Draft generation ----
def generate_answer(question: str, chunks_list: List[Dict]) -> str:
    if not chunks_list:
        return "I don't know from the provided documents.\nCitations: []"
    return ollama_chat(build_prompt(question, chunks_list), temperature=0.2)

# ---- Verifier (keeps only supported + cited bullets) ----
def build_verifier_prompt(question: str, draft: str, chunks_list: List[Dict]) -> str:
    context = build_context(chunks_list, max_chunks=8, max_chars_each=1100)
    allowed = [r["chunk_id"] for r in chunks_list[:8]]

    return f"""You are a verifier for a RAG system.

Question:
{question}

Context:
{context}

Draft answer:
{draft}

Rules:
1) Remove ANY claim not explicitly supported by the context.
2) Remove any bullet that does not end with: [FULL_CHUNK_ID]
3) Only allow citations from this list:
{allowed}
4) If not enough info, output exactly:
I don't know from the provided documents.
Citations: []

Return ONLY the final corrected answer (must end with Citations: [...]).
"""

def verifier_pass(question: str, draft: str, chunks_list: List[Dict]) -> str:
    if not chunks_list:
        return "I don't know from the provided documents.\nCitations: []"
    return ollama_chat(build_verifier_prompt(question, draft, chunks_list), temperature=0.0)

# ---- End-to-end Step 9 ----
def rag_answer(
    question: str,
    fused_topk: int = 60,
    per_query_k: int = 25,
    rerank_top_n: int = 10,
    verify: bool = True,
    where: Optional[Dict] = None,
    debug: bool = True,
):
    if not ensure_ollama_ready_http(OLLAMA_MODEL):
        raise RuntimeError("Ollama/model not ready. See messages above.")

    retrieved, variants = fusion_retrieve(question, fused_topk=fused_topk, per_query_k=per_query_k, where=where)

    try:
        reranked = cross_encoder_rerank(question, retrieved, top_n=rerank_top_n)
    except Exception:
        reranked = retrieved[:rerank_top_n]

    if debug:
        print("Variants:", variants)
        print_top_chunks(reranked, n=3)

    draft = generate_answer(question, reranked)
    final = verifier_pass(question, draft, reranked) if verify else draft

    print("\n--- DRAFT ---\n", draft)
    print("\n--- FINAL ---\n", final)

    return {"question": question, "variants": variants, "top_chunks": [r["chunk_id"] for r in reranked], "draft": draft, "final": final}

# Demo:
# _ = rag_answer("foods for high blood pressure", verify=True, debug=True)

In [None]:
q_demo = "foods for high blood pressure"

retrieved, variants = fusion_retrieve(q_demo, fused_topk=60, per_query_k=25)
reranked = cross_encoder_rerank(q_demo, retrieved, top_n=10)

draft = ollama_chat(build_prompt(q_demo, reranked), temperature=0.2)
final = verifier_pass(q_demo, draft, reranked)

print("Variants:", variants)
print("\n--- FINAL ---\n", final)

In [None]:
# ---- Demo (optional) ----
q_demo = "benefits of fiber"
_ = rag_answer(q_demo, verify=True)

In [None]:
# Build a real synonym vocabulary from MeSH (one-time) + cache to JSON
# Uses lxml(recover=True) so big XML parsing won't crash.

from pathlib import Path
from collections import defaultdict
import json

# If this import fails, install into the SAME env as your Jupyter kernel:
# In terminal:  python -m pip install lxml
from lxml import etree

MESH_XML = Path("vocab/mesh_desc.xml")
CACHE = Path("vocab/mesh_synonyms_cache.json")
CACHE.parent.mkdir(parents=True, exist_ok=True)

def norm(t: str) -> str:
    return " ".join((t or "").lower().split())

def build_mesh_synonyms_cache(mesh_path: Path, cache_path: Path, max_terms_per_head: int = 25) -> dict:
    if not mesh_path.exists():
        raise FileNotFoundError(f"Missing MeSH XML at: {mesh_path.resolve()}")

    print("‚è≥ Parsing MeSH XML (this can take a few minutes once)...")
    parser = etree.XMLParser(recover=True, huge_tree=True)
    root = etree.parse(str(mesh_path), parser).getroot()

    syn_map = defaultdict(set)
    n_records = 0

    for rec in root.findall(".//DescriptorRecord"):
        n_records += 1
        head_el = rec.find("./DescriptorName/String")
        if head_el is None:
            continue
        head = norm(head_el.text)
        if not head:
            continue

        terms = []
        for term_el in rec.findall(".//ConceptList/Concept/TermList/Term/String"):
            s = norm(term_el.text)
            if s and s != head:
                terms.append(s)

        for s in terms[:max_terms_per_head]:
            syn_map[head].add(s)
            syn_map[s].add(head)

        if n_records % 5000 == 0:
            print(f"  processed {n_records} records...")

    syn_vocab = {k: sorted(v) for k, v in syn_map.items()}

    cache_path.write_text(json.dumps(syn_vocab, ensure_ascii=False), encoding="utf-8")
    print("Cached to:", cache_path.resolve())
    return syn_vocab

# Load cache if exists, otherwise build it
if CACHE.exists():
    syn_vocab = json.loads(CACHE.read_text("utf-8"))
    print("Loaded cached MeSH synonyms:", len(syn_vocab))
else:
    syn_vocab = build_mesh_synonyms_cache(MESH_XML, CACHE, max_terms_per_head=25)
    print("Built MeSH synonyms:", len(syn_vocab))

print("Example hypertension:", syn_vocab.get("hypertension", [])[:10])
print("Example vitamin c:", syn_vocab.get("vitamin c", [])[:10])

## STEP 10 ‚Äî UI
This creates a proper web interface reachable at http://127.0.0.1:7860

In [None]:
!pip install gradio

In [None]:
import gradio as gr
from pathlib import Path
import time
from typing import Dict, List
import traceback
import base64
import os

# ---------- 0. IMAGE LOADER ----------
def load_local_gif():
    # Update this path if needed
    paths_to_check = [
        r"C:\Users\reychel\Documents\GitHub\food-rag-web\gif\giphy.gif",
        "gif/giphy.gif",
        "giphy.gif"
    ]
    for p in paths_to_check:
        if os.path.exists(p):
            try:
                with open(p, "rb") as f:
                    data = base64.b64encode(f.read()).decode('utf-8')
                    return f"data:image/gif;base64,{data}"
            except: pass
    return None

GIF_SOURCE = load_local_gif()

# ---------- 1. CSS  ----------
green_mist_css = """
/* APP BACKGROUND */
body, .gradio-container {
    background: radial-gradient(circle at 50% 0%, #064e3b 0%, #020617 60%) !important;
    color: #ffffff !important;
    font-family: 'Inter', sans-serif !important;
}

/* üü¢ TEXTBOXES & INPUTS (Force White) */
textarea, input {
    background-color: #020617 !important;
    border: 1px solid #1e293b !important;
    color: #ffffff !important;
    font-weight: 500 !important;
    border-radius: 10px !important;
    opacity: 1 !important;
}

/* Force Read-Only to be WHITE */
textarea[readonly], textarea:disabled {
    color: #ffffff !important;
    -webkit-text-fill-color: #ffffff !important;
    background-color: #0f172a !important;
    opacity: 1 !important;
    border-color: #334155 !important;
}

/*  HIDE ALL INDIVIDUAL LOADING ANIMATIONS  */
.loading { display: none !important; }
.meta-text-container { display: none !important; }
.pending { border-color: transparent !important; }
.generating { border-color: transparent !important; }

/*  HEADERS (QUERY, SETTINGS -> NEON GREEN) */
.prose h1, .prose h2, .prose h3 {
    color: #4ade80 !important;
    font-weight: 800 !important;
    letter-spacing: 1px;
    margin-bottom: 5px;
    margin-top: 0px;
    opacity: 1 !important;
    text-transform: uppercase;
    line-height: 1.5;
}

/* MAIN TITLE */
.header-title {
    font-size: 3.5rem;
    font-weight: 800;
    background: linear-gradient(180deg, #ffffff 0%, #4ade80 100%);
    -webkit-background-clip: text;
    -webkit-text-fill-color: transparent;
    text-shadow: 0 0 40px rgba(74, 222, 128, 0.4);
    text-align: center;
}
.header-subtitle {
    color: #86efac; text-align: center; font-size: 1.1rem; letter-spacing: 2px;
    text-transform: uppercase; margin-bottom: 30px; opacity: 0.9;
}

/* LABELS */
.block-label, label span {
    color: #4ade80 !important;
    font-weight: bold !important;
    font-size: 0.9rem !important;
    text-transform: uppercase;
}

/* üü¢ STATUS BAR CONTAINER */
.status-container {
    padding: 12px 20px;
    border-radius: 12px;
    margin-bottom: 20px;
    font-family: 'Courier New', monospace; font-weight: 700;
    display: flex; align-items: center; justify-content: space-between;
    background: rgba(15, 23, 42, 0.95); /* More opaque to fix grey issue */
    backdrop-filter: blur(5px);
    margin-top: 0px;
    height: 50px;
}

/* üü¢ NUCLEAR WHITE TEXT FIX FOR STATUS BAR */
.status-container span, .status-container div {
    color: #ffffff !important;
    -webkit-text-fill-color: #ffffff !important;
    opacity: 1 !important;
}

/* Status Phases */
.phase-search { border: 1px solid #3b82f6; box-shadow: 0 0 15px rgba(59, 130, 246, 0.2); }
.phase-rerank { border: 1px solid #f59e0b; box-shadow: 0 0 15px rgba(245, 158, 11, 0.2); }
.phase-gen    { border: 1px solid #a855f7; box-shadow: 0 0 15px rgba(168, 85, 247, 0.2); }
.phase-done   { border: 1px solid #22c55e; box-shadow: 0 0 20px rgba(34, 197, 94, 0.3); }

.smart-spinner {
    width: 20px; height: 20px; border: 3px solid #ffffff;
    border-top-color: transparent; border-radius: 50%;
    animation: spin 0.8s linear infinite;
}
@keyframes spin { 100% { transform: rotate(360deg); } }

/* EVIDENCE CARDS */
.evidence-card {
    background: #111827; border: 1px solid #1f2937;
    border-radius: 8px; padding: 12px; margin-bottom: 10px;
}
.evidence-card:hover { border-color: #4ade80; }

/* BUTTONS */
#search-btn {
    background: linear-gradient(90deg, #22c55e 0%, #16a34a 100%);
    border: 1px solid #4ade80; color: white; font-weight: 800; font-size: 1.1rem;
    height: 50px; border-radius: 10px;
}
#search-btn:hover { box-shadow: 0 0 20px rgba(34, 197, 94, 0.5); }

#clear-btn {
    background: #1e293b; color: #cbd5e1; border: 1px solid #334155;
    height: 50px; border-radius: 10px; font-weight: 600;
}
#clear-btn:hover { color: white; border-color: white; }
"""

# ---------- 2. LOGIC ----------

def _safe_filename(p):
    try: return Path(p).name if p else "unknown"
    except: return "unknown"

def _format_evidence_html(chunks: List[Dict]) -> str:
    if not chunks: return "<div style='color:#cbd5e1; text-align:center;'>No evidence found.</div>"
    out = ["<div style='max-height: 450px; overflow-y: auto; padding-right: 5px;'>"]
    for i, r in enumerate(chunks, 1):
        sf = _safe_filename(r.get("source_file"))
        sc = r.get("rerank_score", r.get("score", 0.0))
        txt = (r.get("text") or "").strip().replace("\n", " ")[:400] + "..."
        out.append(f"""
        <div class="evidence-card">
            <div style="display:flex; justify-content:space-between; margin-bottom:5px; color:#94a3b8; font-size:0.85em;">
                <span style="color:#ffffff; font-weight:bold;">{i}. {sf}</span>
                <span style="color:#4ade80;">Score: {sc:.2f}</span>
            </div>
            <div style="color:#e2e8f0; font-size:0.9em; line-height:1.5;">{txt}</div>
        </div>""")
    out.append("</div>")
    return "\n".join(out)

# üü¢ JAVASCRIPT TIMER SCRIPT
# This tiny script finds the timer element and updates it every 100ms
js_timer_script = """
<script>
    function startLiveTimer() {
        let startTime = Date.now();
        let timerElement = document.getElementById('live-timer');
        if(window.ragInterval) clearInterval(window.ragInterval);

        window.ragInterval = setInterval(() => {
            if(timerElement) {
                let elapsed = ((Date.now() - startTime) / 1000).toFixed(1);
                timerElement.innerText = elapsed + "s";
            } else {
                // Try to find element again if DOM updated
                timerElement = document.getElementById('live-timer');
            }
        }, 100);
    }
    // Auto-start when this HTML is loaded
    startLiveTimer();
</script>
"""

def ui_answer(question: str, top_k: int, progress=gr.Progress()):
    fused_topk = 60
    per_query_k = 25
    verify = True
    start_time = time.time()

    if not question.strip():
        yield gr.update(visible=False), gr.update(visible=False), "", "", "", ""
        return

    # PHASE 1
    progress(0.1, desc="Scanning...")
    status_html = f"""
    <div class="status-container phase-search">
        <div style="display:flex; align-items:center;">
            <div class="smart-spinner" style="margin-right:15px;"></div>
            <span>SCANNING DATABASE...</span>
        </div>
        <span style="font-size:0.9em;">PHASE 1/4 | <span id="live-timer">0.0s</span></span>
    </div>
    {js_timer_script}
    """
    yield (gr.update(visible=True), gr.update(value="..."), "", "", "", status_html)

    try:
        retrieved, variants = fusion_retrieve(question, fused_topk=fused_topk, per_query_k=per_query_k)
    except Exception as e:
        print(f"Search Error: {e}")
        retrieved = []
        variants = [question]

    if not retrieved:
        yield (gr.update(visible=True), "No docs found.", "", "", "", "<div class='status-container phase-done' style='border-color:red; color:red;'>‚ùå NO RESULTS</div>")
        return

    # PHASE 2
    progress(0.4, desc="Reranking...")
    # Note: We re-inject the timer span (id="live-timer") but NOT the script again,
    # because the script is already running and looking for that ID.
    elapsed = time.time() - start_time
    status_html = f"""
    <div class="status-container phase-rerank">
        <div style="display:flex; align-items:center;">
            <div class="smart-spinner" style="margin-right:15px;"></div>
            <span>FOUND {len(retrieved)} DOCS -> RERANKING...</span>
        </div>
        <span style="font-size:0.9em;">PHASE 2/4 | <span id="live-timer">{elapsed:.1f}s</span></span>
    </div>
    """
    yield (gr.update(visible=True), "...", "", ", ".join(variants), "", status_html)

    try:
        reranked = cross_encoder_rerank(question, retrieved, top_n=int(top_k))
    except:
        reranked = retrieved[:int(top_k)]

    # PHASE 3
    progress(0.7, desc="Thinking...")
    elapsed = time.time() - start_time
    status_html = f"""
    <div class="status-container phase-gen">
        <div style="display:flex; align-items:center;">
            <div class="smart-spinner" style="margin-right:15px;"></div>
            <span>ANALYZING & WRITING...</span>
        </div>
        <span style="font-size:0.9em;">PHASE 3/4 | <span id="live-timer">{elapsed:.1f}s</span></span>
    </div>
    """
    yield (gr.update(visible=True), "Generating...", "", ", ".join(variants), "", status_html)

    draft = generate_answer(question, reranked)
    final = verifier_pass(question, draft, reranked) if verify else draft

    # PHASE 4
    progress(1.0, desc="Done!")
    total_time = time.time() - start_time

    # Stop timer script hack (clearing interval)
    stop_script = "<script>if(window.ragInterval) clearInterval(window.ragInterval);</script>"

    status_html = f"""
    <div class="status-container phase-done">
        <div style="display:flex; align-items:center;">
            <span style="font-size:1.5em; margin-right:15px;">‚úÖ</span>
            <span>COMPLETE</span>
        </div>
        <span style="font-weight:bold;">{total_time:.2f}s</span>
    </div>
    {stop_script}
    """

    yield (
        gr.update(visible=True),
        final,
        _format_evidence_html(reranked),
        ", ".join(variants),
        draft,
        status_html
    )

# ---------- 3. LAYOUT ----------

theme = gr.themes.Base(primary_hue="green", neutral_hue="slate").set(
    body_background_fill="#020617", block_background_fill="#0f172a", block_border_width="0px"
)

with gr.Blocks(theme=theme, css=green_mist_css, title="FoodRAG Pro") as demo:

    # TITLE
    with gr.Row():
        with gr.Column():
            gr.HTML("""
            <div style="padding: 30px 0;">
                <div class="header-title">FoodRAG Pro</div>
                <div class="header-subtitle">Intelligent Nutritional Analysis</div>
            </div>
            """)

    with gr.Row():
        # LEFT: INPUT
        with gr.Column(scale=4):
            gr.Markdown("### QUERY")
            q = gr.Textbox(placeholder="E.g., Can I take magnesium with antibiotics?", lines=5, show_label=False)

            gr.Markdown("### SETTINGS")
            top_k = gr.Slider(1, 10, value=5, step=1, label="Citations (K)")

            with gr.Row():
                btn = gr.Button("Analyze", variant="primary", elem_id="search-btn")
                clear = gr.Button("Reset", variant="secondary", elem_id="clear-btn")

        # RIGHT: OUTPUT
        with gr.Column(scale=6, visible=False) as output_col:
            # Status Bar
            status_display = gr.HTML()

            gr.Markdown("### üß† ANSWER")
            answer = gr.Textbox(lines=8, show_label=False, interactive=False)

            gr.Markdown("### üìÇ SOURCES")
            evidence = gr.HTML()

            with gr.Accordion("üìù  Logs", open=False):
                variants_out = gr.Textbox(label="Expansion", interactive=False)
                draft_out = gr.Textbox(label="Draft", lines=4, interactive=False)

    # ACTIONS
    btn.click(
        ui_answer,
        inputs=[q, top_k],
        outputs=[output_col, answer, evidence, variants_out, draft_out, status_display],
        show_progress="hidden"
    )
    q.submit(
        ui_answer,
        inputs=[q, top_k],
        outputs=[output_col, answer, evidence, variants_out, draft_out, status_display],
        show_progress="hidden"
    )

    def reset_ui():
        return ("", 5, gr.update(visible=False), "", "", "", "", "")

    clear.click(reset_ui, outputs=[q, top_k, output_col, answer, evidence, variants_out, draft_out, status_display])

print(" Launching Final UI (Live Timer + White Text)...")
demo.launch(share=False, debug=True)

## STEP 11 ‚Äî Evaluation (retrieval)
The goal of the evaluation is to measure the quality of the retrieval and reranking components of the RAG pipeline.

Why MRR@10?
 - Works great when each query has 1+ relevant chunks
 - Rewards ranking the first relevant chunk as high as possible
 - Simple + very standard for retrieval/reranking evaluation

Requirements:
 eval_questions.jsonl with:  {"query": "...", "relevant_chunk_ids": ["chunk_id1", "chunk_id2", ...]}

In [None]:
# =========================
# STEP 10 ‚Äî Retrieval Evaluation (MRR@10)
# =========================

import json
from pathlib import Path
import numpy as np

# ---- config ----
EVAL_PATH = Path("eval_questions.jsonl")   # <-- make sure this file exists
K = 10

# ---- metric ----
def mrr_at_k(ranked_ids, relevant_set, k: int) -> float:
    """
    Mean Reciprocal Rank at K for a single query.
    """
    for i, cid in enumerate(ranked_ids[:k], start=1):
        if cid in relevant_set:
            return 1.0 / i
    return 0.0


# ---- sanity checks ----
print("EVAL FILE PATH:", EVAL_PATH.resolve())
print("EVAL FILE EXISTS:", EVAL_PATH.exists())
print("EVAL FILE SIZE:", EVAL_PATH.stat().st_size if EVAL_PATH.exists() else "N/A")


def evaluate_mrr10(
    eval_path: Path,
    k: int = 10,
    fused_topk: int = 60,
    per_query_k: int = 25,
    show_per_query: bool = True,
):
    print(">>> START evaluate_mrr10")

    # load evaluation set
    rows = [
        json.loads(l)
        for l in eval_path.read_text("utf-8").splitlines()
        if l.strip()
    ]

    scores = []
    per_query_results = []

    for r in rows:
        q = r["query"].strip()
        relevant = set(r["relevant_chunk_ids"])

        # retrieval
        retrieved, variants = fusion_retrieve(
            q, fused_topk=fused_topk, per_query_k=per_query_k
        )

        # rerank (cross-encoder)
        try:
            reranked = cross_encoder_rerank(q, retrieved, top_n=k)
        except Exception:
            reranked = retrieved[:k]

        ranked_ids = [x["chunk_id"] for x in reranked]

        score = mrr_at_k(ranked_ids, relevant, k)
        scores.append(score)

        first_rank = None
        for i, cid in enumerate(ranked_ids[:k], start=1):
            if cid in relevant:
                first_rank = i
                break

        per_query_results.append({
            "query": q,
            "variants": variants,
            "mrr10": score,
            "first_relevant_rank": first_rank,
            "top10": ranked_ids,
        })

    mean_mrr = float(np.mean(scores)) if scores else 0.0

    print("\n=== Retrieval Evaluation: MRR@10 ===")
    print("Queries:", len(scores))
    print(f"MRR@10: {mean_mrr:.3f}")

    if show_per_query:
        print("\nPer-query preview (first 5):")
        for item in per_query_results[:5]:
            rank_str = (
                item["first_relevant_rank"]
                if item["first_relevant_rank"] is not None
                else "NOT IN TOP10"
            )
            print("-" * 60)
            print("Query:", item["query"])
            print("Variants:", item["variants"])
            print(f"First relevant rank: {rank_str}")
            print(f"MRR@10: {item['mrr10']:.3f}")
            print("Top 3:", item["top10"][:3])

    return mean_mrr, per_query_results


# ---- run evaluation ----
print(">>> CALLING evaluation now...")
mean_mrr10, details = evaluate_mrr10(EVAL_PATH, k=10)
print(">>> DONE. mean_mrr10 =", mean_mrr10)
