In [None]:
import os, re, json, time, ast, requests
from typing import List, Dict, Tuple

import pandas as pd
from tqdm.auto import tqdm
from functools import lru_cache

import google.generativeai as genai
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import logging as hf_logging
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
from datasets import load_dataset

user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
SERPER_API_KEY = user_secrets.get_secret("SERPER_API_KEY")
GEMINI_API_KEY = user_secrets.get_secret("GEMINI_API_KEY")

login(token=HF_TOKEN)

model_name = "meta-llama/Llama-3.2-3B-Instruct"
DEVICE_MAP = "auto"
llm_name = "gemini-2.5-flash-lite"

SPLIT_NAME = "murder_mysteries"
NUM_EXAMPLES = 100
PRINT_FIRST_N_DEBUG = 3
SLEEP_BETWEEN_GEMINI_CALLS = 5.0
WEB_TIMEOUT = 10
web_results = 5

genai.configure(api_key=GEMINI_API_KEY)
gemini_model = genai.GenerativeModel(llm_name)
print("Gemini model initialised.")

print("Loading local SLM (this can take a bit)...")
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
slm_tokenizer = AutoTokenizer.from_pretrained(model_name)
slm_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=DEVICE_MAP,
    quantization_config=bnb_config,
)
hf_logging.set_verbosity_error()
slm_tokenizer.pad_token = slm_tokenizer.eos_token
slm_model.config.pad_token_id = slm_tokenizer.pad_token_id
print("Local SLM loaded.")


def slm_generate(prompt: str, max_new_tokens: int = 200) -> str:
    inputs = slm_tokenizer(prompt, return_tensors="pt").to(slm_model.device)
    out_ids = slm_model.generate(**inputs, max_new_tokens=max_new_tokens)
    text = slm_tokenizer.decode(out_ids[0], skip_special_tokens=True)
    return text.strip()


def web_search(query: str, num_results: int = 2):
    num_results = min(num_results, web_results)
    results = []
    if SERPER_API_KEY:
        url = "https://google.serper.dev/search"
        headers = {"X-API-KEY": SERPER_API_KEY, "Content-Type": "application/json"}
        payload = {"q": query}
        try:
            r = requests.post(url, headers=headers, json=payload, timeout=WEB_TIMEOUT)
            data = r.json()
            if data.get("organic"):
                for item in data["organic"][:num_results]:
                    snippet = (item.get("snippet") or "").replace("\n", " ")
                    link = item.get("link") or ""
                    if snippet:
                        results.append({"snippet": snippet, "url": link})
        except Exception as e:
            print("Serper error:", e)
    if not results:
        results.append({"snippet": f"General information about: {query}", "url": ""})
    return results[:num_results]


@lru_cache(maxsize=4096)
def web_search_cached(q: str, num_results: int = 2):
    res = web_search(q, num_results=num_results)
    return tuple((d["snippet"], d["url"]) for d in res)


def call_gemini_with_cooldown(prompt: str, max_retries: int = 3):
    last_exc = None
    for attempt in range(1, max_retries + 1):
        try:
            resp = gemini_model.generate_content(prompt)
            time.sleep(SLEEP_BETWEEN_GEMINI_CALLS)
            return resp
        except Exception as e:
            last_exc = e
            msg = str(e)
            wait_s = None
            m1 = re.search(r"retry in ([0-9.]+)s", msg, flags=re.I)
            if m1:
                wait_s = float(m1.group(1))
            else:
                m2 = re.search(r"seconds:\s*([0-9]+)", msg, flags=re.I)
                if m2:
                    wait_s = float(m2.group(1))
            if wait_s is None:
                wait_s = 60.0
            print(f"[Gemini rate/HTTP error] Attempt {attempt}/{max_retries} -> sleeping {wait_s:.1f}s...")
            time.sleep(wait_s)
    raise RuntimeError("Gemini generate_content failed after retries.") from last_exc


def _looks_bad(q: str) -> bool:
    q = (q or "").strip().lower()
    if not q:
        return True
    if q in {"<why>", "<what>", "<how>"}:
        return True
    if len(re.sub(r"[^a-z]", "", q)) < 5:
        return True
    return False


def _simple_fallback_decomp(full_q: str) -> Dict[str, str]:
    base = re.sub(r"\s*\?$", "", full_q).strip()
    if not base:
        base = full_q.strip()
    return {
        "WHY": f"Why is the following question challenging to answer correctly: {base}?",
        "WHAT": f"What key facts from the narrative are needed to answer: {base}?",
        "HOW": f"How can reasoning over the narrative help answer: {base}?",
    }


def slm_decompose_queries(full_question: str) -> Dict[str, str]:
    prompt = (
        "You are a query decomposition assistant for MuSR-style multi-step reasoning tasks.\n"
        "Given a narrative and a question, rewrite it into exactly three concise sub-queries:\n"
        "WHY: reasoning or explanation to focus on\n"
        "WHAT: key entities, events, or definitions needed\n"
        "HOW: steps or logical process to connect facts\n\n"
        "Return EXACTLY:\nWHY: <text>\nWHAT: <text>\nHOW: <text>\n\n"
        f"Input (narrative + question): {full_question}\n\nDecomposed queries:"
    )
    text = slm_generate(prompt, max_new_tokens=192)
    why = what = how = ""
    m_why = re.search(r"WHY:\s*(.+)", text, flags=re.I)
    m_what = re.search(r"WHAT:\s*(.+)", text, flags=re.I)
    m_how = re.search(r"HOW:\s*(.+)", text, flags=re.I)
    if m_why:
        why = m_why.group(1).strip()
    if m_what:
        what = m_what.group(1).strip()
    if m_how:
        how = m_how.group(1).strip()
    if _looks_bad(why) or _looks_bad(what) or _looks_bad(how):
        return _simple_fallback_decomp(full_question)
    return {"WHY": why, "WHAT": what, "HOW": how}


def slm_build_hints(full_question: str, tagged_snippets: List[Dict[str, str]]) -> List[str]:
    prompt = (
        "You are helping with a MuSR-style multi-step reasoning benchmark.\n"
        "Given the narrative and question, and web snippets tagged WHY/WHAT/HOW,\n"
        "extract 4-8 short, verifiable bullet points that help decide the answer.\n"
        "Do NOT answer the question, and do NOT mention any option letter.\n\n"
        f"Narrative + question:\n{full_question}\n\nWeb snippets:\n"
    )
    for i, s in enumerate(tagged_snippets, 1):
        prompt += f"[{i}] {s['snippet']}\n"
    prompt += "\nWrite 4-8 bullets. Start each line with '- ' only:\n"
    text = slm_generate(prompt, max_new_tokens=256)
    facts: List[str] = []
    for line in text.splitlines():
        ln = line.strip()
        if re.match(r"^(-|\*|•|\d+\.)\s+", ln):
            ln = re.sub(r"^(\*|•|\d+\.)\s+", "- ", ln)
        if ln.startswith("- "):
            fact = ln[2:].strip()
            if fact:
                facts.append(fact)
    if not facts:
        facts = [s["snippet"] for s in tagged_snippets]
    return facts[:8]


def build_musr_instruction_prompt(
    narrative: str,
    question: str,
    options_dict: Dict[str, str],
    hints: List[str] = None,
) -> str:
    text = (
        "You are solving a multi-step reasoning multiple-choice question from the MuSR benchmark.\n"
        "You are given a narrative, a question about it, and several answer options.\n"
        "Your task is to choose the single best answer option.\n\n"
        "Instructions:\n"
        "  i) Carefully read the narrative and understand the key events and relationships.\n"
        "  ii) Read the question and determine exactly what is being asked.\n"
        "  iii) Consider the logical consequences of the narrative to evaluate each option.\n"
        "  iv) You may reason step-by-step INTERNALLY, but you MUST NOT show your reasoning.\n"
        "  v) Only output the final answer option as a single capital letter (for example A, B, C).\n\n"
        "Narrative:\n"
        f"{narrative}\n\n"
        f"Question: {question}\n\n"
        "Options:\n"
    )
    for lab in sorted(options_dict.keys()):
        text += f"{lab}. {options_dict[lab]}\n"
    if hints:
        text += "\nOptional factual hints from a smaller helper model (may be incomplete or noisy):\n"
        for h in hints:
            text += f"- {h}\n"
    text += "\nFinal answer (ONLY write a single letter corresponding to the best option, such as A or B):"
    return text


def gemini_musr_answer(narrative, question, options_dict, hints,):
    prompt = build_musr_instruction_prompt(narrative, question, options_dict, hints=hints)
    resp = call_gemini_with_cooldown(prompt)
    out_text = (getattr(resp, "text", "") or "").strip()
    pattern = r"\b[" + "".join(sorted(options_dict.keys())) + r"]\b"
    m = re.search(pattern, out_text)
    if m:
        return m.group(0), out_text
    m2 = re.search(r"\b[A-Z]\b", out_text)
    if m2 and m2.group(0) in options_dict:
        return m2.group(0), out_text
    first_key = sorted(options_dict.keys())[0]
    return first_key, out_text


def parse_choices(raw) -> List[str]:
    if isinstance(raw, (list, tuple)):
        return [str(x) for x in raw]
    if isinstance(raw, str):
        s = raw.strip()
        try:
            val = json.loads(s)
            if isinstance(val, list):
                return [str(x) for x in val]
        except Exception:
            pass
        try:
            val = ast.literal_eval(s)
            if isinstance(val, (list, tuple)):
                return [str(x) for x in val]
        except Exception:
            pass
        return [s]
    return [str(raw)]


def prepare_musr_example(row) -> Tuple[str, str, Dict[str, str], str]:
    if "narrative" not in row or "question" not in row or "choices" not in row or "answer_index" not in row:
        return None, None, None, None
    narrative = str(row["narrative"])
    question = str(row["question"])
    choices = parse_choices(row["choices"])
    k = len(choices)
    if k < 2 or k > 4:
        return None, None, None, None
    ans_idx_raw = row["answer_index"]
    try:
        ans_idx = int(ans_idx_raw)
    except Exception:
        return None, None, None, None
    if ans_idx < 0 or ans_idx >= k:
        return None, None, None, None
    labels = list("ABCD")[:k]
    options_dict = {labels[i]: choices[i] for i in range(k)}
    gold = labels[ans_idx]
    return narrative, question, options_dict, gold


def predict_musr_for_example(row, verbose: bool = False):
    narrative, question, options_dict, gold = prepare_musr_example(row)
    if options_dict is None:
        return None, None, None, [], None
    full_q = narrative + "\n\nQuestion: " + question
    sub_queries: Dict[str, str] = slm_decompose_queries(full_q)
    tagged = []
    for tag in ["WHY", "WHAT", "HOW"]:
        q_sub = sub_queries[tag]
        for snip, url in web_search_cached(q_sub, num_results=2):
            tagged.append({"snippet": f"[{tag}] {snip}", "url": url})
    hints = slm_build_hints(full_q, tagged)
    pred_label, gemini_raw = gemini_musr_answer(
        narrative,
        question,
        options_dict,
        hints=hints if hints else None,
    )
    return pred_label, gold, sub_queries, hints, gemini_raw


dataset = load_dataset("TAUR-Lab/MuSR")
ds = dataset[SPLIT_NAME]
df_all = ds.to_pandas()

print(f"MuSR split '{SPLIT_NAME}' loaded.")
print("Columns:", list(df_all.columns))

rows_out: List[Dict] = []
preds: List[str] = []
golds: List[str] = []
correct_flags: List[int] = []

print("\n=== Running MuSR — our_method, instruction ===")

used = 0
skipped = 0

for i in tqdm(range(len(df_all)), desc="musr_our_method_instruction"):
    if used >= NUM_EXAMPLES:
        break
    r = df_all.iloc[i]
    pred_label, gold, sub_queries, hints, gemini_raw = predict_musr_for_example(
        r,
        verbose=(used < PRINT_FIRST_N_DEBUG),
    )
    if pred_label is None:
        skipped += 1
        continue
    used += 1
    preds.append(pred_label)
    golds.append(gold)
    correct = int(pred_label == gold)
    correct_flags.append(correct)
    rows_out.append({
        "idx_in_split": int(i),
        "variant": "our_method",
        "split": SPLIT_NAME,
        "prompt_mode": "instruction",
        "narrative": r["narrative"],
        "question": r["question"],
        "choices_raw": json.dumps(parse_choices(r["choices"]), ensure_ascii=False),
        "gold_label": gold,
        "pred_label": pred_label,
        "correct": correct,
        "gemini_output": gemini_raw,
        "sub_queries": json.dumps(sub_queries, ensure_ascii=False) if sub_queries else "",
        "hints": json.dumps(hints, ensure_ascii=False) if hints else "",
    })
    if used > PRINT_FIRST_N_DEBUG:
        print(f"Example {used}: gold={gold} | pred={pred_label} | correct={bool(correct)}")

acc = sum(correct_flags) / max(1, len(preds))
print(f"\nPipeline accuracy (MuSR, our_method, instruction): {acc:.3f} on {len(preds)} examples.")
print(f"Skipped examples (invalid choices or answer_index): {skipped}")

df_out = pd.DataFrame(rows_out)
out_path = f"musr_{SPLIT_NAME}_instruction_our_method.csv"
df_out.to_csv(out_path, index=False)
print("Saved detailed results to:", out_path)