In [None]:
import os
import re
import json
import time
import requests
import pandas as pd
from tqdm.auto import tqdm

import google.generativeai as genai
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login

user_secrets = UserSecretsClient()
login(token=user_secrets.get_secret("HF_TOKEN"))

SERPER_API_KEY = user_secrets.get_secret("SERPER_API_KEY")
GEMINI_API_KEY = user_secrets.get_secret("GEMINI_API_KEY")

path = "/kaggle/input/arc-challenge/ai2_arc_100.csv"

model_name = "..."
DEVICE_MAP = "auto"
llm_name = "gemini-2.5-flash-lite"
web_results = 5
web_timeout = 10

total_eval_datapoints = 100
SLEEP_BETWEEN_GEMINI_CALLS = 5
GEMINI_MAX_RETRIES = 3

genai.configure(api_key=GEMINI_API_KEY)
genai_model = genai.GenerativeModel(llm_name)
print("Gemini model initialised.")

print("Loading local SLM (this can take a bit)...")
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
slm_tokenizer = AutoTokenizer.from_pretrained(model_name)
slm_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=DEVICE_MAP,
    quantization_config=bnb_config,
)
print("Local SLM loaded.")


def slm_generate(prompt: str, max_new_tokens: int = 64) -> str:
    inputs = slm_tokenizer(prompt, return_tensors="pt").to(slm_model.device)
    out_ids = slm_model.generate(**inputs, max_new_tokens=max_new_tokens)
    text = slm_tokenizer.decode(out_ids[0], skip_special_tokens=True)
    return text.strip()


def web_search(query: str, num_results: int = 2):
    num_results = min(num_results, web_results)
    results = []

    if SERPER_API_KEY:
        url = "https://google.serper.dev/search"
        headers = {"X-API-KEY": SERPER_API_KEY, "Content-Type": "application/json"}
        payload = {"q": query}
        try:
            r = requests.post(url, headers=headers, json=payload, timeout=web_timeout)
            data = r.json()
            if data.get("organic"):
                for item in data["organic"][:num_results]:
                    snippet = (item.get("snippet") or "").replace("\n", " ")
                    link = item.get("link") or ""
                    if snippet:
                        results.append({"snippet": snippet, "url": link})
        except Exception as e:
            print("Serper error:", e)

    if not results:
        results.append({"snippet": f"General information about: {query}", "url": ""})

    return results[:num_results]


def call_gemini_with_cooldown(prompt: str, max_retries: int = GEMINI_MAX_RETRIES):
    if genai_model is None:
        raise RuntimeError("Gemini model not initialised.")

    last_exc = None
    for attempt in range(1, max_retries + 1):
        try:
            resp = genai_model.generate_content(prompt)
            time.sleep(SLEEP_BETWEEN_GEMINI_CALLS)
            return resp
        except Exception as e:
            last_exc = e
            msg = str(e)
            wait_s = None
            m1 = re.search(r"retry in ([0-9.]+)s", msg, flags=re.I)
            if m1:
                wait_s = float(m1.group(1))
            else:
                m2 = re.search(r"seconds:\s*([0-9]+)", msg, flags=re.I)
                if m2:
                    wait_s = float(m2.group(1))
            if wait_s is None:
                wait_s = 60.0
            print(f"[Gemini rate/HTTP error] Attempt {attempt}/{max_retries} â†’ sleeping {wait_s:.1f}s...")
            time.sleep(wait_s)

    raise RuntimeError(f"Gemini generate_content failed after {max_retries} retries.") from last_exc


def slm_decompose_queries(question: str):
    prompt = (
        "You are a query decomposition assistant for science exam questions.\n"
        "Given a multiple-choice question, rewrite it into exactly three concise\n"
        "sub-questions:\n"
        "1) a 'WHY' question\n"
        "2) a 'WHAT' question\n"
        "3) a 'HOW' question\n\n"
        "Return them in the following format exactly:\n"
        "WHY: <why_question>\n"
        "WHAT: <what_question>\n"
        "HOW: <how_question>\n\n"
        f"Original question: {question}\n\n"
        "Decomposed queries:"
    )

    text = slm_generate(prompt, max_new_tokens=128)

    why = what = how = question
    m_why = re.search(r"WHY:\s*(.+)", text, flags=re.IGNORECASE)
    m_what = re.search(r"WHAT:\s*(.+)", text, flags=re.IGNORECASE)
    m_how = re.search(r"HOW:\s*(.+)", text, flags=re.IGNORECASE)

    if m_why:
        why = m_why.group(1).strip()
    if m_what:
        what = m_what.group(1).strip()
    if m_how:
        how = m_how.group(1).strip()

    return {"WHY": why, "WHAT": what, "HOW": how}


def slm_build_hints(question: str, tagged_snippets):
    prompt = (
        "You are given a science question and some web snippets that were retrieved\n"
        "using different sub-queries (WHY / WHAT / HOW).\n"
        "Your job is to extract 4-8 short factual points that are helpful for answering the question.\n"
        "Write them as an enumerated list using i), ii), iii), iv), ...\n"
        "Each point must be on its own line starting with i), ii), etc.\n"
        "IMPORTANT: Do NOT answer the question yourself. Only write the factual points.\n\n"
        f"Question: {question}\n\n"
        "Web snippets (with tags):\n"
    )
    for i, s in enumerate(tagged_snippets, 1):
        prompt += f"[{i}] {s['snippet']}\n"

    prompt += (
        "\nNow write 4-8 factual points, each on a new line, starting with i), ii), iii), etc.\n"
        "Do not include anything else.\n"
    )

    text = slm_generate(prompt, max_new_tokens=192)

    facts = []
    for line in text.splitlines():
        line = line.strip()
        if not line:
            continue
        m = re.match(r"(?i)^([ivx]+)\)\s*(.+)", line)
        if m:
            fact = m.group(2).strip()
        elif line.startswith("-"):
            fact = line.lstrip("-").strip()
        else:
            continue
        if fact:
            facts.append(fact)

    if not facts:
        facts = [s["snippet"] for s in tagged_snippets]

    return facts


def build_instruction_prompt(question, options_dict, facts):
    text = (
        "You are a highly capable science exam solver.\n"
        "Your goal is to choose the single best answer to the multiple-choice question.\n\n"
        "You are given:\n"
        "  i) the question\n"
        "  ii) four answer options (A, B, C, D)\n"
        "  iii) OPTIONAL factual hints produced by a smaller helper model from web search\n\n"
        "Use the following strategy:\n"
        "  1. Carefully read the question and options.\n"
        "  2. Use your own scientific knowledge and reasoning to understand what is being asked.\n"
        "  3. Look at the hints only as supplemental context: they may be helpful, but they can also\n"
        "     be incomplete, noisy, or partially irrelevant.\n"
        "  4. If the hints are consistent with the question and your own reasoning, you may use them\n"
        "     to support or refine your choice. If the hints conflict with the question or seem\n"
        "     unhelpful, ignore them and answer independently.\n"
        "  5. Silently eliminate clearly wrong options and select the single best remaining option.\n\n"
        "Important formatting instruction:\n"
        "  i) Think through the problem internally, but DO NOT show your reasoning.\n"
        "  ii) Return ONLY a single capital letter: A, B, C, or D, with no explanation.\n\n"
        f"Question: {question}\n\n"
        "Options:\n"
        f"A. {options_dict['A']}\n"
        f"B. {options_dict['B']}\n"
        f"C. {options_dict['C']}\n"
        f"D. {options_dict['D']}\n\n"
    )

    if facts:
        text += "Optional factual hints (from a smaller helper model):\n"
        for f in facts:
            text += f"i) {f}\n"
    else:
        text += (
            "Optional factual hints:\n"
            "i) (none provided; rely on your own knowledge)\n"
        )

    text += "\nAnswer (ONLY one letter A, B, C, or D):"
    return text


def gemini_mcq_answer(question, options_dict, facts):
    if genai_model is None:
        return "A"

    text = build_instruction_prompt(question, options_dict, facts)

    try:
        resp = call_gemini_with_cooldown(text)
        out_text = resp.text.strip()
    except Exception as e:
        print("Gemini error while answering MCQ after retries:", e)
        out_text = ""

    m = re.search(r"[ABCD]", out_text)
    if m:
        return m.group(0)
    return "A"


def predict_mcq_for_row(row, q_col, opt_cols, verbose: bool = False):
    question = str(row[q_col])
    options = {
        "A": str(row[opt_cols[0]]),
        "B": str(row[opt_cols[1]]),
        "C": str(row[opt_cols[2]]),
        "D": str(row[opt_cols[3]]),
    }

    sub_queries = slm_decompose_queries(question)
    tagged_snippets = []

    for tag in ["WHY", "WHAT", "HOW"]:
        subq = sub_queries[tag]
        sub_snips = web_search(subq, num_results=2)
        for s in sub_snips:
            tagged_snippets.append(
                {
                    "snippet": f"[{tag}] {s['snippet']}",
                    "url": s["url"],
                }
            )

    facts = slm_build_hints(question, tagged_snippets)
    ans_llm = gemini_mcq_answer(question, options, facts)
    final_pred = ans_llm

    if verbose:
        print("\n" + "=" * 80)
        print("Variant: our_method | Prompt: instruction")
        print(f"Question: {question}\n")
        print("Options:")
        for lab in ["A", "B", "C", "D"]:
            print(f"{lab}. {options[lab]}")
        print("\nDecomposed sub-queries:")
        for tag in ["WHY", "WHAT", "HOW"]:
            print(f"  {tag}: {sub_queries[tag]}")
        print("\nTagged web snippets:")
        for i, s in enumerate(tagged_snippets, 1):
            print(f"  [{i}] {s['snippet'][:160]}")
        print("\nSLM factual points (i), ii), ... style):")
        for f in facts:
            print(f"  - {f}")
        print("\nLLM (Gemini) final answer:", final_pred)

    return final_pred, ans_llm, sub_queries, facts


df = pd.read_csv(path)
cols = list(df.columns)
q_col = cols[0]
opt_cols = cols[1:5]
ans_col = cols[5]

print("Columns:", cols)
print("Question column:", q_col)
print("Option columns:", opt_cols)
print("Answer column:", ans_col)

total = min(total_eval_datapoints, len(df))

predictions = []
gold_labels = []
correct_flags = []
llm_answers = []
all_sub_queries = []
all_facts = []

for idx in tqdm(range(total), desc="ARC-100 (our_method, instruction)"):
    row = df.iloc[idx]
    gold = str(row[ans_col]).strip().upper()
    gold_labels.append(gold)

    verbose = idx < PRINT_FIRST_N_DEBUG

    pred, ans_llm, sub_queries, facts = predict_mcq_for_row(
        row,
        q_col,
        opt_cols,
        verbose=verbose,
    )

    predictions.append(pred)
    llm_answers.append(ans_llm)
    all_sub_queries.append(sub_queries)
    all_facts.append(facts)

    is_corr = int(pred == gold)
    correct_flags.append(is_corr)

    if not verbose:
        print(f"Q{idx + 1}: gold={gold}, pred={pred}, LLM={ans_llm}")

correct = sum(correct_flags)
accuracy = correct / total
print(f"\nFinished {total} questions.")
print(f"Pipeline accuracy (our_method, instruction): {accuracy:.3f}")

eval_df = df.iloc[:total].copy()
eval_df["gold"] = gold_labels
eval_df["pred"] = predictions
eval_df["correct"] = correct_flags
eval_df["llm_answer"] = llm_answers
eval_df["sub_queries"] = [json.dumps(q) for q in all_sub_queries]
eval_df["facts"] = [json.dumps(f) for f in all_facts]

save_path = "arc_instruction.csv"
eval_df.to_csv(save_path, index=False)
print(f"Saved detailed results to: {save_path}")