In [None]:
os.environ["GEMINI_API_KEY"] = "AIzaSyCDnBUtlgBkrHoyjvQT7XaONdjH65GY7Pw"


In [None]:
#!/usr/bin/env python3
"""
FINAL TASK-2(b) HYBRID QUERY FEDERATION SYSTEM - UPDATED
- Backwards compatible with your original script
- Adds Hybrid Crime/Intent extraction (Rule + LLM fallback)
- Adds unified `ask()` entrypoint
- Adds provenance for detections
- Uses Gemini-style API wrapper if configured, otherwise safe simulated LLM
- Keeps original BNS/CRPC/CASES logic and remote CRPC call
"""

import sqlite3
import json
import requests
import re
import os
import time
from pprint import pprint
import google.generativeai as genai


# ------------------------------------------------------------
# CONFIGURATION (edit GEMINI_API_KEY if available)
# ------------------------------------------------------------
CRPC_SERVER = "http://192.168.226.115:5000"
BNS_DB_PATH = "bns.db"
CASES_JSON = "legal_cases.json"
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")  # optional: set in environment

# Load BNS DB
BNS = sqlite3.connect(BNS_DB_PATH)
BNS.row_factory = sqlite3.Row

# Load JSON legal cases
try:
    LEGAL_CASES = json.load(open(CASES_JSON, "r"))
except Exception:
    LEGAL_CASES = []

# ------------------------------------------------------------
# IPC → BNS OFFICIAL MAPPING (IMPORTANT!!)
# (kept unchanged)
# ------------------------------------------------------------
IPC_TO_BNS = {
    "420": [316,317,318,319],
    "417": [316],
    "415": [315],
    "376": [63,64,65],
    "302": [103],
    "304": [104],
    "307": [109],
    "354": [73]
}

# ------------------------------------------------------------
# Helper: simple safe Gemini-like wrapper (fallback to local simulation)
# ------------------------------------------------------------
def call_gemini(prompt, max_tokens=512, timeout=5.0):
    """
    REAL Gemini API call.
    If no API key or the model fails, fallback is used.
    """
    # ---------------------- REAL GEMINI CALL ----------------------
    if GEMINI_API_KEY:
        try:
            genai.configure(api_key=GEMINI_API_KEY)
            model = genai.GenerativeModel("models/gemini-2.5-pro")

            
            # We give the exact prompt directly to the model
            response = model.generate_content(prompt)
            
            # return raw text from Gemini
            if hasattr(response, "text"):
                return response.text
            return str(response)

        except Exception as e:
            # If Gemini fails, continue to fallback below
            return f"[Gemini Error] {str(e)}"

    # ---------------------- FALLBACK MODE -------------------------
    prompt_l = (prompt or "").lower()

    # Fallback for extraction prompts
    if "extract all crimes" in prompt_l or "identify all crimes" in prompt_l or '"crimes"' in prompt_l:
        detected = []
        for k, v in [
            ("kidnap", "Kidnapping"),
            ("abduct", "Kidnapping"),
            ("rape", "Rape"),
            ("sexual assault", "Rape"),
            ("murder", "Murder"),
            ("kill", "Murder"),
            ("strangle", "Murder"),
            ("threat", "Criminal Intimidation"),
            ("extort", "Extortion"),
            ("money", "Extortion"),
            ("cut", "Destruction of Evidence"),
            ("dismember", "Destruction of Evidence"),
            ("rob", "Robbery"),
            ("steal", "Theft")
        ]:
            if k in prompt_l and v not in detected:
                detected.append(v)

        return json.dumps({"crimes":[{"crime":c,"confidence":0.8} for c in detected]})

    # Fallback generic summary
    if len(prompt_l) > 300:
        return "LLM summary (fallback): Narrative contains multiple offences; refer to relevant sections."
    return "LLM answer (fallback): Refer to the fetched sections from BNS/CRPC and case-law for details."


# ------------------------------------------------------------
# 1. QUERY REWRITE (LLM-DEMO + deterministic rules)
#    - preserved your original rewrite_query, but added intent detection & provenance
# ------------------------------------------------------------
def rewrite_query(q):
    q0 = q or ""
    ql = q0.lower()
    bns_kw, crpc_kw, sections = [], [], []

    # crime detection rules (kept and extended)
    if any(w in ql for w in ["rape","sexual"]):
        bns_kw += ["rape","sexual","woman"]
        crpc_kw += ["rape"]
    if any(w in ql for w in ["cheat","fraud","dishonest","deceiv"]):
        bns_kw += ["cheat","fraud","dishonest"]
        crpc_kw += ["cheating","420"]
    if any(w in ql for w in ["murder","kill","homicide","strangle","stab"]):
        bns_kw += ["murder","kill","death","homicide"]
        crpc_kw += ["murder","302"]
    if any(w in ql for w in ["kidnap","abduct","abduction"]):
        bns_kw += ["kidnap","abduct"]
        crpc_kw += ["kidnapping"]
    if any(w in ql for w in ["threat","intimidat","blackmail","extort","demand money"]):
        bns_kw += ["threat","intimidation","extortion"]
        crpc_kw += ["intimidation","extortion"]
    if any(w in ql for w in ["rob","robbery","steal","theft"]):
        bns_kw += ["robbery","theft"]
        crpc_kw += ["robbery","theft"]

    # extract explicit 3-digit section numbers (e.g., "section 302")
    sections = re.findall(r"\b(?:section\s+)?(\d{2,4})\b", ql)
    sections = [s for s in sections if len(s) >= 2]

    # simple intent checks
    intent = {
        "is_crime_story": any(k in ql for k in ["kidnap","murder","rape","threat","extort","rob","stalk","abduct"]),
        "is_section_lookup": bool(re.search(r"\bsection\s+\d{2,4}\b", ql)),
        "is_definition": any(w in ql for w in ["what is","explain","define","meaning of"]),
        "is_comparison": any(w in ql for w in ["compare","difference","vs","versus"])
    }

    return {
        "bns_keywords": list(dict.fromkeys(bns_kw)),
        "crpc_keywords": list(dict.fromkeys(crpc_kw)),
        "sections": list(dict.fromkeys(sections)),
        "raw_query": q0,
        "intent": intent
    }

# ------------------------------------------------------------
# 2. BNS SEARCH (SECTION & KEYWORD) - unchanged queries, defensive code added
# ------------------------------------------------------------
def bns_by_section(s):
    cur = BNS.cursor()
    try:
        s_int = int(s)
    except Exception:
        return []
    cur.execute("""
        SELECT section AS section_id, section__name AS title, description AS text
        FROM bns_sections
        WHERE section = ?
    """, (s_int,))
    return [dict(r) for r in cur.fetchall()]

def bns_by_keyword(kw):
    cur = BNS.cursor()
    k = f"%{kw.lower()}%"
    cur.execute("""
        SELECT section AS section_id, section__name AS title, description AS text
        FROM bns_sections
        WHERE LOWER(section__name) LIKE ?
           OR LOWER(description) LIKE ?
    """, (k, k))
    return [dict(r) for r in cur.fetchall()]

# ------------------------------------------------------------
# 3. CRPC SEARCH (REMOTE SERVER) - preserved but added provenance + retry
# ------------------------------------------------------------
def crpc_search(kw):
    if not kw:
        return []
    try:
        r = requests.get(f"{CRPC_SERVER}/search_crpc",
                         params={"q": kw, "limit": 10},
                         timeout=1.5)
        if r.ok:
            return r.json()
    except Exception:
        # single retry with longer timeout
        try:
            r = requests.get(f"{CRPC_SERVER}/search_crpc",
                             params={"q": kw, "limit": 10},
                             timeout=3.0)
            if r.ok:
                return r.json()
        except Exception:
            pass
    return []

# ------------------------------------------------------------
# 4. CASE SEARCH - preserved with provenance
# ------------------------------------------------------------
def case_search(keywords):
    out = []
    if not LEGAL_CASES:
        return []
    for c in LEGAL_CASES:
        t = json.dumps(c).lower()
        if any(k.lower() in t for k in keywords[:6]):
            out.append(c)
    return out if out else LEGAL_CASES[:2]

# ------------------------------------------------------------
# 5. BEST BNS SECTION SELECTION (KEPT + improved scoring + provenance)
# ------------------------------------------------------------
def choose_best_bns(query, bns_hits):
    q = (query or "").lower()
    if not bns_hits:
        return None

    # priority crime words
    if "rape" in q:
        priority = ["rape","sexual"]
    elif "cheat" in q or "fraud" in q:
        priority = ["cheat","fraud","dishonest"]
    elif "murder" in q or "kill" in q:
        priority = ["murder","kill","death"]
    else:
        # take top-most informative words from query
        priority = [w for w in re.findall(r"\w+", q) if len(w) > 3][:6]

    # remove generic sections (small list; keep existing behavior)
    generic_ids = set([1,2,3,4,5])
    filtered = [s for s in bns_hits if s.get("section_id") not in generic_ids]

    if not filtered:
        filtered = bns_hits

    # scoring: occurrences in title and text with small tf weighting
    def score(section):
        text = (section.get("title","") + " " + section.get("text","")).lower()
        s = 0
        for p in priority:
            s += text.count(p) * 2  # title/text occurrences
        # small boost for exact section id mention
        try:
            qsecs = re.findall(r"\b(\d{2,4})\b", q)
            if str(section.get("section_id")) in qsecs:
                s += 10
        except:
            pass
        return s

    filtered = sorted(filtered, key=lambda s: score(s), reverse=True)
    best = filtered[0]
    # attach score for provenance
    best["_score"] = score(best)
    return best

# ------------------------------------------------------------
# 6. LLM-BASED (or simulated) CRIME EXTRACTION + MERGE WITH RULES
# ------------------------------------------------------------
def rule_based_extract_crimes(text):
    t = (text or "").lower()
    detected = []
    if any(w in t for w in ["kidnap","abduct","abduction"]):
        detected.append("Kidnapping")
    if any(w in t for w in ["rape","sexual assault","sexual"]):
        detected.append("Rape")
    if any(w in t for w in ["murder","kill","homicide","strangle","stab"]):
        detected.append("Murder")
    if any(w in t for w in ["threat","intimidation","blackmail","extort"]):
        detected.append("Criminal Intimidation / Extortion")
    if any(w in t for w in ["dismember","cut body","chop","body parts"]):
        detected.append("Destruction of Evidence / Dismemberment")
    if any(w in t for w in ["rob","steal","theft"]):
        detected.append("Robbery / Theft")
    return list(dict.fromkeys(detected))

def llm_extract_crimes_full(text):
    # call gemini wrapper with an extraction prompt
    prompt = (
        "Extract all crimes from the following narrative and return JSON: "
        '{"crimes":[{"crime":"...","details":"..."}]} '
        "Narrative:\n\n" + (text or "")
    )
    res = call_gemini(prompt)
    # if res is JSON-like, parse; fall back to empty
    try:
        if isinstance(res, str):
            parsed = json.loads(res)
            crimes = [c.get("crime") for c in parsed.get("crimes", []) if c.get("crime")]
            return list(dict.fromkeys(crimes))
    except Exception:
        pass
    # fallback simulate by using the rule-based extraction (lower confidence)
    return rule_based_extract_crimes(text)

def hybrid_extract_crimes(text):
    rules = rule_based_extract_crimes(text)
    llm = llm_extract_crimes_full(text)
    merged = list(dict.fromkeys(rules + llm))
    provenance = {
        "rules": rules,
        "llm": llm,
        "merged": merged
    }
    return provenance

# ------------------------------------------------------------
# 7. FINAL SYNTHESIS (LLM-DEMO + deterministic punishment detection)
# ------------------------------------------------------------
def final_answer(query, bns_hits, crpc_hits, cases, crime_provenance=None):
    top = choose_best_bns(query, bns_hits)

    answer_lines = []
    answer_lines.append("Final Legal Summary:")

    if top:
        answer_lines.append(f"\n• Relevant BNS Section {top['section_id']}: {top['title']}")
        # punishment detection (simple heuristics)
        pun = []
        for line in (top.get("text","") or "").split("\n"):
            if any(x in line.lower() for x in ["punish","life","imprison","death","fine"]):
                pun.append(line.strip())
        if pun:
            answer_lines.append(f"  Punishment: {pun[0]}")
    else:
        answer_lines.append("\n• No relevant BNS section found.")

    if crpc_hits:
        # show top CRPC hit with provenance (the remote service should return section_id/title)
        c = crpc_hits[0]
        # attempt to format the remote response
        if isinstance(c, dict):
            secid = c.get("section_id", c.get("section", "N/A"))
            title = c.get("title", c.get("section__name", "CRPC match"))
            answer_lines.append(f"• Relevant CRPC Section {secid}: {title}")
        else:
            answer_lines.append(f"• Relevant CRPC hit (raw): {str(c)[:120]}")

    if cases:
        answer_lines.append(f"• {len(cases)} related case(s) identified.")

    # Add crime provenance if available
    if crime_provenance:
        pr = crime_provenance
        answer_lines.append("\nDetected Offences (provenance):")
        answer_lines.append(f" - rules: {pr.get('rules', [])}")
        answer_lines.append(f" - llm: {pr.get('llm', [])}")
        answer_lines.append(f" - merged: {pr.get('merged', [])}")

    answer_lines.append("\n(End of deterministic summary.)")

    # Append LLM synthesis as separate block (non-deterministic explanation)
    llm_prompt = (
    "You are a legal assistant that MUST answer ONLY using Indian law. "
    "Do NOT use U.S., UK, or any foreign legal frameworks. "
    "Base all reasoning strictly on Bharatiya Nyaya Sanhita (BNS 2023), "
    "Bharatiya Nagarik Suraksha Sanhita (BNSS 2023), "
    "Bharatiya Sakshya Adhiniyam (BSA 2023), the POCSO Act 2012, "
    "and relevant Indian constitutional or statutory provisions. "
    "If the crime existed under IPC before BNS, give the corresponding IPC "
    "section as historical reference ONLY when helpful. "
    "Now provide concise legal reasoning and mapping for this query: "
    + (query or "")
    )

    llm_block = call_gemini(llm_prompt)
    answer_lines.append("\nLLM Synthesis:")
    answer_lines.append(str(llm_block))

    return "\n".join(answer_lines)

# ------------------------------------------------------------
# 8. MAIN PIPELINE - unified ask() entrypoint (backwards compatible)
# ------------------------------------------------------------
def run_pipeline(query):
    return ask(query)

def ask(query):
    # Step 1: rewrite + intent detection
    print("\n=== TASK-2(b) HYBRID ARCHITECTURE DEMO ===")
    print("\n>> STEP 1: QUERY REWRITE + INTENT")
    rewrite = rewrite_query(query)
    pprint(rewrite)

    # Step 1.5: hybrid crime extraction for long narratives
    crime_prov = None
    if rewrite["intent"].get("is_crime_story") or len(query) > 120:
        crime_prov = hybrid_extract_crimes(query)
        print("\nDetected offences (hybrid):")
        pprint(crime_prov)

    # ------------------------------------
    # BNS Federation
    # ------------------------------------
    bns_hits = []

    # 1. Direct BNS section lookup (from explicit numbers or IPC->BNS)
    for sec in rewrite["sections"]:
        bns_hits += bns_by_section(sec)
        # fallback via IPC→BNS mapping (if section is IPC number like "302")
        if sec in IPC_TO_BNS:
            for mapped in IPC_TO_BNS[sec]:
                bns_hits += bns_by_section(mapped)

    # 2. If no hits yet, keyword search using rewrite keywords and also include rules/llm extracted crime keywords
    if not bns_hits:
        search_keys = rewrite["bns_keywords"][:10]  # original keywords
        # include crime_provenance merged as keywords if available
        if crime_prov:
            search_keys += crime_prov.get("merged", [])[:6]
        for kw in search_keys:
            try:
                bns_hits += bns_by_keyword(kw)
            except Exception:
                pass

    # deduplicate BNS hits by section_id
    seen = set()
    uniq_bns = []
    for s in bns_hits:
        sid = s.get("section_id")
        if sid not in seen:
            uniq_bns.append(s)
            seen.add(sid)
    bns_hits = uniq_bns

    print(f"\nBNS hits: {len(bns_hits)}")

    # ------------------------------------
    # CRPC Federation (remote)
    # ------------------------------------
    crpc_hits = []
    # prefer crpc_keywords from rewrite; also try merged crimes as keywords
    crpc_search_keys = rewrite["crpc_keywords"][:6]
    if crime_prov:
        crpc_search_keys += crime_prov.get("merged", [])[:6]
    for kw in crpc_search_keys:
        try:
            crpc_hits += crpc_search(kw)
        except Exception:
            pass
    print(f"CRPC hits: {len(crpc_hits)}")

    # ------------------------------------
    # Case federation
    # ------------------------------------
    cases = case_search(rewrite["bns_keywords"] + (crime_prov.get("merged", []) if crime_prov else []))
    print(f"Case hits: {len(cases)}")

    # ------------------------------------
    # Final Integration & Synthesis
    # ------------------------------------
    print("\n>> STEP 3: FINAL INTEGRATION & SYNTHESIS")
    final = final_answer(query, bns_hits, crpc_hits, cases, crime_provenance=crime_prov)
    print("\n=== FINAL ANSWER ===")
    print(final)
    print("=====================================\n")

    return {
        "bns": bns_hits,
        "crpc": crpc_hits,
        "cases": cases,
        "crime_provenance": crime_prov,
        "final": final
    }

# If invoked directly, run a quick self-test (keeps original example output behavior)
if __name__ == "__main__":
    #test_q = "punishment for cheating"
    out = ask(input("Enter query: ").strip())
