In [None]:
#!/usr/bin/env python3
"""
FINAL TASK-2(b) HYBRID QUERY FEDERATION SYSTEM - FULL FEATURE MAX
- Fully merged: hybrid similarity reranker, embedding-based expansion,
  safer numeric extraction, top-K BNS candidates, offence->section mapping,
  structured JSON LLM synthesis (jurisdiction-locked), and fallback JSON generation.
- Requirements:
    pip install sentence-transformers rapidfuzz numpy requests google-generativeai
  (google-generativeai optional if you want real Gemini; fallback will work without it)
- Usage:
    python full_pipeline.py
"""

import sqlite3
import json
import requests
import re
import os
import time
from pprint import pprint
import traceback

# optional Gemini wrapper
try:
    import google.generativeai as genai
except Exception:
    genai = None

# hybrid similarity imports
import numpy as np
from sentence_transformers import SentenceTransformer
from rapidfuzz import distance, fuzz

# ------------------------------------------------------------
# CONFIGURATION
# ------------------------------------------------------------
CRPC_SERVER = "http://192.168.226.115:5000"
BNS_DB_PATH = "bns.db"
CASES_JSON = "legal_cases.json"
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")  # optional: set in environment

# Hybrid similarity weights (alpha: Jaccard, beta: Edit, gamma: Embedding)
ALPHA, BETA, GAMMA = 0.3, 0.3, 0.4
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
HYBRID_TOP_K = 5  # number of top candidates to show

# Candidate expansion settings
EMBED_EXPAND_THRESHOLD = 5   # if <= this many bns_hits, expand with embedding search
EMBED_EXPAND_LIMIT = 80      # how many top similar to fetch when expanding

# misc
MAX_CAND = 300

# ------------------------------------------------------------
# DB / Cases load
# ------------------------------------------------------------
if not os.path.exists(BNS_DB_PATH):
    print("WARNING: BNS DB not found at", BNS_DB_PATH)
BNS = sqlite3.connect(BNS_DB_PATH)
BNS.row_factory = sqlite3.Row

try:
    LEGAL_CASES = json.load(open(CASES_JSON, "r"))
except Exception:
    LEGAL_CASES = []

# ------------------------------------------------------------
# Embedding model cache
# ------------------------------------------------------------
_embedding_model = None

def get_embedding_model():
    global _embedding_model
    if _embedding_model is None:
        _embedding_model = SentenceTransformer(EMBEDDING_MODEL)
    return _embedding_model

# ------------------------------------------------------------
# IPC → BNS OFFICIAL MAPPING (kept)
# ------------------------------------------------------------
IPC_TO_BNS = {
    "420": [316,317,318,319],
    "417": [316],
    "415": [315],
    "376": [63,64,65],
    "302": [103],
    "304": [104],
    "307": [109],
    "354": [73]
}

# ------------------------------------------------------------
# Helper: Gemini wrapper (optional) + robust fallback (including JSON)
# ------------------------------------------------------------
def call_gemini(prompt, max_tokens=512, timeout=10.0):
    """
    Use real Gemini API if configured. Otherwise use a deterministic fallback.
    Fallback also attempts to return structured JSON mappings when prompt requests JSON.
    """
    prompt_l = (prompt or "").lower()

    # real Gemini if available
    if GEMINI_API_KEY and genai is not None:
        try:
            genai.configure(api_key=GEMINI_API_KEY)
            model = genai.GenerativeModel("models/gemini-2.5-pro")
            # Basic generate_content call; production code should set safety params
            response = model.generate_content(prompt, max_output_tokens=max_tokens, timeout=timeout)
            if hasattr(response, "text"):
                return response.text
            return str(response)
        except Exception as e:
            # fall through to fallback, but include error note
            return f"[Gemini Error] {str(e)}\n\n" + _fallback_gemini(prompt)

    # fallback
    return _fallback_gemini(prompt)

def _fallback_gemini(prompt):
    """
    Deterministic fallback generator:
    - If prompt asks for JSON mapping (contains 'return a json' or 'mappings'), emit JSON mapping heuristically.
    - Else if prompt asks to extract crimes, attempt structured JSON of crimes.
    - Else short summary fallback.
    """
    prompt_l = (prompt or "").lower()

    # JSON mapping request (structured LLM)
    if "return a json object" in prompt_l or '"mappings"' in prompt_l or "output only valid json" in prompt_l:
        # naive mapping table - expand as per needs
        naive = {
            "murder": {"bns":[103], "ipc":[302], "reason":"Intentional killing → murder provisions"},
            "kidnap": {"bns":[142], "ipc":[364], "reason":"Abduction/kidnapping provisions"},
            "kidnapping": {"bns":[142], "ipc":[364], "reason":"Abduction/kidnapping provisions"},
            "rape": {"bns":[65], "ipc":[],"reason":"Sexual assault; POCSO if victim minor"},
            "extortion": {"bns":[232], "ipc":[384], "reason":"Threats/demands of money"},
            "threat": {"bns":[232], "ipc":[503], "reason":"Criminal intimidation / threats"},
            "dismember": {"bns":[201], "ipc":[201], "reason":"Destruction of evidence / dismemberment"},
            "robbery": {"bns":[313], "ipc":[392], "reason":"Robbery provisions"},
            "cheat": {"bns":[173], "ipc":[420], "reason":"Cheating / fraud"}
        }
        mappings = []
        for k,v in naive.items():
            if k in prompt_l:
                mappings.append({
                    "offence": k,
                    "suggested_bns": v["bns"],
                    "suggested_ipc": v.get("ipc", []),
                    "reason": v.get("reason","")
                })
        # If none matched heuristically, include a generic mapping hint
        if not mappings:
            mappings.append({
                "offence": "unknown",
                "suggested_bns": [],
                "suggested_ipc": [],
                "reason": "No direct keyword matched; please rely on retrieval candidates."
            })
        return json.dumps({"mappings": mappings}, indent=2)

    # extraction prompt asked for crimes JSON
    if "extract all crimes" in prompt_l or "identify all crimes" in prompt_l or '"crimes"' in prompt_l:
        detected = []
        for k, v in [
            ("kidnap", "Kidnapping"),
            ("abduct", "Kidnapping"),
            ("rape", "Rape"),
            ("sexual assault", "Rape"),
            ("murder", "Murder"),
            ("kill", "Murder"),
            ("strangle", "Murder"),
            ("threat", "Criminal Intimidation"),
            ("extort", "Extortion"),
            ("money", "Extortion"),
            ("cut", "Destruction of Evidence"),
            ("dismember", "Destruction of Evidence"),
            ("rob", "Robbery"),
            ("steal", "Theft")
        ]:
            if k in prompt_l and v not in detected:
                detected.append(v)
        return json.dumps({"crimes":[{"crime":c,"confidence":0.8} for c in detected]}, indent=2)

    # long narrative fallback
    if len(prompt_l) > 300:
        return "LLM summary (fallback): Narrative contains multiple offences; refer to relevant sections."

    # default short fallback
    return "LLM answer (fallback): Refer to the fetched sections from BNS/CRPC and case-law for details."

# ------------------------------------------------------------
# 1. QUERY REWRITE (with safer numeric extraction)
# ------------------------------------------------------------
def rewrite_query(q):
    q0 = q or ""
    ql = q0.lower()
    bns_kw, crpc_kw = [], []

    # crime detection rules
    if any(w in ql for w in ["rape","sexual"]):
        bns_kw += ["rape","sexual","woman"]
        crpc_kw += ["rape"]
    if any(w in ql for w in ["cheat","fraud","dishonest","deceiv"]):
        bns_kw += ["cheat","fraud","dishonest"]
        crpc_kw += ["cheating","420"]
    if any(w in ql for w in ["murder","kill","homicide","strangle","stab"]):
        bns_kw += ["murder","kill","death","homicide"]
        crpc_kw += ["murder","302"]
    if any(w in ql for w in ["kidnap","abduct","abduction"]):
        bns_kw += ["kidnap","abduct"]
        crpc_kw += ["kidnapping"]
    if any(w in ql for w in ["threat","intimidat","blackmail","extort","demand money"]):
        bns_kw += ["threat","intimidation","extortion"]
        crpc_kw += ["intimidation","extortion"]
    if any(w in ql for w in ["rob","robbery","steal","theft"]):
        bns_kw += ["robbery","theft"]
        crpc_kw += ["robbery","theft"]

    # explicit sections only (safer)
    explicit_sections = re.findall(r"\b(?:section|sec|s\.)\s+(\d{2,4})\b", ql)
    explicit_sections = [s for s in explicit_sections if len(s) >= 2]

    # optional numeric detection (filtered) - we prefer explicit only
    raw_nums = re.findall(r"\b(\d{1,4})\b", ql)
    candidates = []
    for n in raw_nums:
        if n in explicit_sections:
            candidates.append(n); continue
        if re.search(rf"\b{n}\s*(?:years?|yrs?|year|old|y/o|age)\b", ql) or re.search(rf"\b(?:years?|yrs?|year|old|y/o|age)\s*{n}\b", ql):
            continue
        if re.search(rf"(?:₹|rs\.?|rupees|lakhs|lakh)\s*{n}\b", ql) or re.search(rf"\b{n}\s*(?:lakhs|lakh|rupees|rs\.?)\b", ql):
            continue
        if int(n) < 18 and re.search(rf"\b(?:girl|boy|child|minor|infant|kid)\b", ql):
            continue
        candidates.append(n)

    # default to explicit sections only
    sections = list(dict.fromkeys(explicit_sections))

    intent = {
        "is_crime_story": any(k in ql for k in ["kidnap","murder","rape","threat","extort","rob","stalk","abduct"]),
        "is_section_lookup": bool(re.search(r"\b(?:section|sec|s\.)\s+\d{2,4}\b", ql)),
        "is_definition": any(w in ql for w in ["what is","explain","define","meaning of"]),
        "is_comparison": any(w in ql for w in ["compare","difference","vs","versus"])
    }

    return {
        "bns_keywords": list(dict.fromkeys(bns_kw)),
        "crpc_keywords": list(dict.fromkeys(crpc_kw)),
        "sections": sections,
        "raw_query": q0,
        "intent": intent
    }

# ------------------------------------------------------------
# 2. BNS SEARCH (SECTION & KEYWORD)
# ------------------------------------------------------------
def bns_by_section(s):
    cur = BNS.cursor()
    try:
        s_int = int(s)
    except Exception:
        return []
    cur.execute("""
        SELECT section AS section_id, section__name AS title, description AS text
        FROM bns_sections
        WHERE section = ?
    """, (s_int,))
    return [dict(r) for r in cur.fetchall()]

def bns_by_keyword(kw):
    cur = BNS.cursor()
    k = f"%{kw.lower()}%"
    cur.execute("""
        SELECT section AS section_id, section__name AS title, description AS text
        FROM bns_sections
        WHERE LOWER(section__name) LIKE ?
           OR LOWER(description) LIKE ?
    """, (k, k))
    return [dict(r) for r in cur.fetchall()]

# ------------------------------------------------------------
# 3. CRPC SEARCH (REMOTE)
# ------------------------------------------------------------
def crpc_search(kw):
    if not kw:
        return []
    try:
        r = requests.get(f"{CRPC_SERVER}/search_crpc",
                         params={"q": kw, "limit": 10},
                         timeout=1.5)
        if r.ok:
            return r.json()
    except Exception:
        try:
            r = requests.get(f"{CRPC_SERVER}/search_crpc",
                             params={"q": kw, "limit": 10},
                             timeout=3.0)
            if r.ok:
                return r.json()
        except Exception:
            pass
    return []

# ------------------------------------------------------------
# 4. CASE SEARCH
# ------------------------------------------------------------
def case_search(keywords):
    out = []
    if not LEGAL_CASES:
        return []
    for c in LEGAL_CASES:
        t = json.dumps(c).lower()
        if any(k.lower() in t for k in keywords[:6]):
            out.append(c)
    return out if out else LEGAL_CASES[:2]

# ------------------------------------------------------------
# EMBEDDING-BASED CANDIDATE EXPANSION
# ------------------------------------------------------------
def expand_candidates_with_embeddings(query, current_candidates, all_limit=EMBED_EXPAND_LIMIT):
    """
    If current_candidates is small, fetch top-N similar sections from entire BNS using embeddings.
    """
    conn = sqlite3.connect(BNS_DB_PATH)
    conn.row_factory = sqlite3.Row
    cur = conn.cursor()
    cur.execute("SELECT section AS section_id, section__name AS title, description AS text FROM bns_sections")
    rows = [dict(r) for r in cur.fetchall()]
    conn.close()

    combined_texts = [ (str(r.get("title","")) + " " + str(r.get("text",""))).strip() for r in rows ]
    model = get_embedding_model()
    all_emb = model.encode([query] + combined_texts, convert_to_numpy=True, show_progress_bar=False)
    q_emb = all_emb[0]
    rows_embs = all_emb[1:]

    sims = []
    for i, r in enumerate(rows):
        emb = rows_embs[i]
        denom = (np.linalg.norm(q_emb) * np.linalg.norm(emb))
        sim = 0.0
        if denom != 0:
            sim = float(np.dot(q_emb, emb) / denom)
        sims.append((i, sim))
    sims_sorted = sorted(sims, key=lambda x: x[1], reverse=True)[:all_limit]
    extended = [ rows[i] for i, _ in sims_sorted ]

    # include current_candidates at front, dedupe
    seen = set()
    out = []
    for s in (current_candidates + extended):
        sid = s.get("section_id")
        if sid not in seen:
            out.append(s); seen.add(sid)
    return out

# ------------------------------------------------------------
# HYBRID SIMILARITY UTILITIES
# ------------------------------------------------------------
def tokenize_set(s: str):
    if not s:
        return set()
    s = re.sub(r"[^a-z0-9\s]", " ", s.lower())
    toks = [t for t in s.split() if len(t) > 1]
    return set(toks)

def jaccard_similarity(a: str, b: str) -> float:
    ta = tokenize_set(a); tb = tokenize_set(b)
    if not ta and not tb: return 1.0
    if not ta or not tb: return 0.0
    inter = ta.intersection(tb); union = ta.union(tb)
    return len(inter) / len(union)

def edit_similarity(a: str, b: str) -> float:
    if (not a) and (not b): return 1.0
    if (not a) or (not b): return 0.0
    try:
        sim = distance.Levenshtein.normalized_similarity(a, b)
        return float(sim)
    except Exception:
        return float(fuzz.token_sort_ratio(a, b) / 100.0)

def embedding_cosine_sim(a_emb: np.ndarray, b_emb: np.ndarray) -> float:
    if a_emb is None or b_emb is None: return 0.0
    num = float(np.dot(a_emb, b_emb))
    denom = float(np.linalg.norm(a_emb) * np.linalg.norm(b_emb))
    if denom == 0: return 0.0
    sim = num / denom
    return (sim + 1.0) / 2.0

def hybrid_score(text_a: str, text_b: str,
                 alpha=ALPHA, beta=BETA, gamma=GAMMA,
                 a_emb=None, b_emb=None) -> dict:
    j = jaccard_similarity(text_a, text_b)
    e = edit_similarity(text_a, text_b)
    if a_emb is None or b_emb is None:
        model = get_embedding_model()
        emb = model.encode([text_a, text_b], convert_to_numpy=True, show_progress_bar=False)
        a_emb, b_emb = emb[0], emb[1]
    c = embedding_cosine_sim(a_emb, b_emb)
    score = alpha * j + beta * e + gamma * c
    return {"score": score, "jaccard": j, "edit": e, "embed": c}

# ------------------------------------------------------------
# 5. BEST BNS SECTION SELECTION (returns top-N ranked list)
# ------------------------------------------------------------
def choose_best_bns(query, bns_hits, rerank_top_n=HYBRID_TOP_K):
    q = (query or "").strip()
    if not bns_hits:
        return []

    # Expand if small
    if len(bns_hits) <= EMBED_EXPAND_THRESHOLD:
        try:
            bns_hits = expand_candidates_with_embeddings(q, bns_hits, all_limit=EMBED_EXPAND_LIMIT)
        except Exception as e:
            print("embedding expansion failed:", e)

    # Build candidate combined texts
    candidates = []
    for s in bns_hits:
        combined = ((str(s.get("title","")) or "") + " " + (str(s.get("text","")) or "")).strip()
        candidates.append({"section_id": s.get("section_id"), "title": s.get("title"), "text": s.get("text"), "combined": combined, "_raw": s})

    if len(candidates) > MAX_CAND:
        candidates = candidates[:MAX_CAND]

    model = get_embedding_model()
    texts_to_embed = [q] + [c["combined"] for c in candidates]
    all_emb = model.encode(texts_to_embed, convert_to_numpy=True, show_progress_bar=False)
    q_emb = all_emb[0]
    cand_embs = all_emb[1:]

    scored = []
    for i, c in enumerate(candidates):
        s_emb = cand_embs[i]
        sc = hybrid_score(q, c["combined"], a_emb=q_emb, b_emb=s_emb)
        entry = dict(c["_raw"])  # original row dict
        entry["_score"] = float(sc["score"])
        entry["_jaccard"] = float(sc["jaccard"])
        entry["_edit"] = float(sc["edit"])
        entry["_embed"] = float(sc["embed"])
        scored.append(entry)

    scored_sorted = sorted(scored, key=lambda x: x["_score"], reverse=True)
    return scored_sorted[:rerank_top_n]

# ------------------------------------------------------------
# 6. CRIME EXTRACTION (rules + LLM) - unchanged
# ------------------------------------------------------------
def rule_based_extract_crimes(text):
    t = (text or "").lower()
    detected = []
    if any(w in t for w in ["kidnap","abduct","abduction"]):
        detected.append("Kidnapping")
    if any(w in t for w in ["rape","sexual assault","sexual"]):
        detected.append("Rape")
    if any(w in t for w in ["murder","kill","homicide","strangle","stab"]):
        detected.append("Murder")
    if any(w in t for w in ["threat","intimidation","blackmail","extort"]):
        detected.append("Criminal Intimidation / Extortion")
    if any(w in t for w in ["dismember","cut body","chop","body parts"]):
        detected.append("Destruction of Evidence / Dismemberment")
    if any(w in t for w in ["rob","steal","theft"]):
        detected.append("Robbery / Theft")
    return list(dict.fromkeys(detected))

def llm_extract_crimes_full(text):
    prompt = (
        "Extract all crimes from the following narrative and return JSON: "
        '{"crimes":[{"crime":"...","details":"..."}]} '
        "Narrative:\n\n" + (text or "")
    )
    res = call_gemini(prompt)
    try:
        if isinstance(res, str):
            parsed = json.loads(res)
            crimes = [c.get("crime") for c in parsed.get("crimes", []) if c.get("crime")]
            return list(dict.fromkeys(crimes))
    except Exception:
        pass
    return rule_based_extract_crimes(text)

def hybrid_extract_crimes(text):
    rules = rule_based_extract_crimes(text)
    llm = llm_extract_crimes_full(text)
    merged = list(dict.fromkeys(rules + llm))
    provenance = {"rules": rules, "llm": llm, "merged": merged}
    return provenance

# ------------------------------------------------------------
# 7. Offence -> Section Mapping Helper
# ------------------------------------------------------------
def map_offences_to_sections(query, offences, bns_candidates):
    """
    For each offence string, produce a ranked list of candidate sections (top-K).
    Uses IPC->BNS seeds, keyword lookup, and reranking on the candidate pool.
    """
    mapping = {}
    # small IPC heuristic table (extend as needed)
    IPC_MAP = {
        "murder": "302",
        "kidnapping": "364",
        "kidnap": "364",
        "rape": None,
        "extortion": "384",
        "criminal intimidation": None,
        "criminal intimidation / extortion": None
    }

    for off in (offences or []):
        off_l = off.lower()
        seeds = []

        # use IPC_MAP if available
        ipc = IPC_MAP.get(off_l)
        if ipc and ipc in IPC_TO_BNS:
            for b in IPC_TO_BNS[ipc]:
                seeds += bns_by_section(b)

        # keyword fallback from the offence tokens
        tokens = re.findall(r"\w+", off_l)
        for tok in tokens:
            try:
                seeds += bns_by_keyword(tok)
            except Exception:
                pass

        # always include the current bns_candidates if provided
        seeds += bns_candidates

        # dedupe seeds
        flat = []
        seen = set()
        for s in seeds:
            if not s: continue
            sid = s.get("section_id")
            if sid not in seen:
                flat.append(s); seen.add(sid)

        if not flat:
            flat = bns_candidates

        # rerank these for this offence
        ranked = choose_best_bns(query, flat, rerank_top_n=HYBRID_TOP_K)
        mapping[off] = ranked

    return mapping

# ------------------------------------------------------------
# 8. FINAL SYNTHESIS (top-K BNS output + structured LLM JSON mapping)
# ------------------------------------------------------------
def final_answer(query, bns_hits, crpc_hits, cases, crime_provenance=None):
    # get top-K BNS candidates
    top_list = choose_best_bns(query, bns_hits, rerank_top_n=HYBRID_TOP_K)

    answer_lines = []
    answer_lines.append("Final Legal Summary:")

    if top_list:
        answer_lines.append("\n• Top BNS candidate sections (ranked):")
        for t in top_list:
            answer_lines.append(f"  - BNS {t.get('section_id')}: {t.get('title')}  (score={t.get('_score'):.3f})")
            snippet = (t.get("text","") or "")[:300].replace("\n", " ")
            if snippet:
                answer_lines.append(f"     Snippet: {snippet}")
    else:
        answer_lines.append("\n• No relevant BNS section found.")

    if crpc_hits:
        c = crpc_hits[0]
        if isinstance(c, dict):
            secid = c.get("section_id", c.get("section", "N/A"))
            title = c.get("title", c.get("section__name", "CRPC match"))
            answer_lines.append(f"\n• Relevant CRPC Section {secid}: {title}")
        else:
            answer_lines.append(f"\n• Relevant CRPC hit (raw): {str(c)[:120]}")

    if cases:
        answer_lines.append(f"\n• {len(cases)} related case(s) identified.")

    if crime_provenance:
        pr = crime_provenance
        answer_lines.append("\nDetected Offences (provenance):")
        answer_lines.append(f" - rules: {pr.get('rules', [])}")
        answer_lines.append(f" - llm: {pr.get('llm', [])}")
        answer_lines.append(f" - merged: {pr.get('merged', [])}")

    answer_lines.append("\n(End of deterministic summary.)")

    # Create an offence -> section mapping using deterministic seeds + rerank
    offence_mapping = {}
    if crime_provenance and top_list:
        try:
            offence_mapping = map_offences_to_sections(query, crime_provenance.get("merged", []), bns_hits)
        except Exception:
            offence_mapping = {}

    # Jurisdiction-locked LLM prompt requesting JSON mapping
    llm_prompt = (
        "You are a legal assistant that MUST answer ONLY using Indian law (BNS, BNSS, BSA, POCSO). "
        "Do NOT use U.S., UK, or any foreign legal frameworks. "
        "Return a JSON object with this schema:\n"
        '{ "mappings": ['
        '  {"offence":"<offence name>", "suggested_bns":[<section_numbers>], '
        '   "suggested_ipc":[<ipc_numbers_optional>], "reason":"<one-line reason>"}'
        '  , ... ] }\n\n'
        "Use the query below and the deterministic candidate list when possible. Output ONLY valid JSON.\n\n"
        + (query or "")
    )

    llm_block = call_gemini(llm_prompt)
    # try to parse JSON from LLM
    llm_json = None
    try:
        m = re.search(r"(\{.*\})", llm_block, flags=re.S)
        if m:
            llm_json = json.loads(m.group(1))
        else:
            llm_json = json.loads(llm_block)
    except Exception:
        llm_json = None

    answer_lines.append("\nLLM Synthesis:")

    if llm_json and isinstance(llm_json, dict):
        for entry in llm_json.get("mappings", []):
            off = entry.get("offence")
            bns_list = entry.get("suggested_bns", [])
            ipc_list = entry.get("suggested_ipc", [])
            reason = entry.get("reason", "")
            answer_lines.append(f" - {off}: BNS {bns_list} IPC {ipc_list} ; {reason}")
    else:
        # fallback: show deterministic offence mapping (if any) and raw LLM text
        if offence_mapping:
            answer_lines.append("Deterministic offence -> candidate sections (from re-ranker):")
            for off, secs in offence_mapping.items():
                answer_lines.append(f" - {off}: " + ", ".join([str(s.get("section_id")) + f"(score={s.get('_score'):.3f})" for s in secs]))
        answer_lines.append("\nLLM Raw Output (fallback):")
        answer_lines.append(str(llm_block))

    return "\n".join(answer_lines)

# ------------------------------------------------------------
# 9. MAIN PIPELINE - unified ask() entrypoint
# ------------------------------------------------------------
def run_pipeline(query):
    return ask(query)

def ask(query):
    # Step 1: rewrite + intent detection
    print("\n=== TASK-2(b) HYBRID ARCHITECTURE DEMO ===")
    print("\n>> STEP 1: QUERY REWRITE + INTENT")
    rewrite = rewrite_query(query)
    pprint(rewrite)

    # Step 1.5: hybrid crime extraction for long narratives
    crime_prov = None
    if rewrite["intent"].get("is_crime_story") or len(query) > 120:
        crime_prov = hybrid_extract_crimes(query)
        print("\nDetected offences (hybrid):")
        pprint(crime_prov)

    # ------------------------------------
    # BNS Federation (blocking)
    # ------------------------------------
    bns_hits = []

    # 1. Direct BNS section lookup (explicit numbers or IPC->BNS)
    for sec in rewrite["sections"]:
        bns_hits += bns_by_section(sec)
        if sec in IPC_TO_BNS:
            for mapped in IPC_TO_BNS[sec]:
                bns_hits += bns_by_section(mapped)

    # 2. Keyword search using rewrite keywords and also include rules/llm extracted crime keywords
    if not bns_hits:
        search_keys = rewrite["bns_keywords"][:10]
        if crime_prov:
            search_keys += crime_prov.get("merged", [])[:6]
        for kw in search_keys:
            try:
                bns_hits += bns_by_keyword(kw)
            except Exception:
                pass

    # deduplicate BNS hits by section_id (preserve order)
    seen = set()
    uniq_bns = []
    for s in bns_hits:
        sid = s.get("section_id")
        if sid not in seen:
            uniq_bns.append(s)
            seen.add(sid)
    bns_hits = uniq_bns

    print(f"\nBNS hits: {len(bns_hits)}")

    # If too few candidates, expand using embeddings
    if len(bns_hits) <= EMBED_EXPAND_THRESHOLD:
        try:
            bns_hits = expand_candidates_with_embeddings(query, bns_hits, all_limit=EMBED_EXPAND_LIMIT)
            # dedupe after expansion
            seen = set(); uniq_bns = []
            for s in bns_hits:
                sid = s.get("section_id")
                if sid not in seen:
                    uniq_bns.append(s); seen.add(sid)
            bns_hits = uniq_bns
            print(f"BNS hits after embedding expansion: {len(bns_hits)}")
        except Exception as e:
            print("embedding expansion error:", e)

    # ------------------------------------
    # CRPC Federation (remote)
    # ------------------------------------
    crpc_hits = []
    crpc_search_keys = rewrite["crpc_keywords"][:6]
    if crime_prov:
        crpc_search_keys += crime_prov.get("merged", [])[:6]
    for kw in crpc_search_keys:
        try:
            crpc_hits += crpc_search(kw)
        except Exception:
            pass
    print(f"CRPC hits: {len(crpc_hits)}")

    # ------------------------------------
    # Case federation
    # ------------------------------------
    cases = case_search(rewrite["bns_keywords"] + (crime_prov.get("merged", []) if crime_prov else []))
    print(f"Case hits: {len(cases)}")

    # ------------------------------------
    # Final Integration & Synthesis
    # ------------------------------------
    print("\n>> STEP 3: FINAL INTEGRATION & SYNTHESIS")
    final = final_answer(query, bns_hits, crpc_hits, cases, crime_provenance=crime_prov)
    print("\n=== FINAL ANSWER ===")
    print(final)
    print("=====================================\n")

    return {
        "bns": bns_hits,
        "crpc": crpc_hits,
        "cases": cases,
        "crime_provenance": crime_prov,
        "final": final
    }

# ------------------------------------------------------------
# If invoked directly, run a quick self-test
# ------------------------------------------------------------
if __name__ == "__main__":
    if not os.path.exists(BNS_DB_PATH):
        print("ERROR: cannot find BNS DB at", BNS_DB_PATH)
        raise SystemExit(1)

    try:
        q = input("Enter query: ").strip()
    except Exception:
        q = "punishment for rape and murder"
    out = ask(q)
