In [None]:
import os
import re
import json
import time
import requests
from typing import List, Dict, Tuple

import pandas as pd
from tqdm.auto import tqdm
from functools import lru_cache

import google.generativeai as genai
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import logging as hf_logging
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
from datasets import load_dataset

user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
SERPER_API_KEY = user_secrets.get_secret("SERPER_API_KEY")
GEMINI_API_KEY = user_secrets.get_secret("GEMINI_API_KEY")

login(token=HF_TOKEN)

model_name = "meta-llama/Llama-3.2-3B-Instruct"
DEVICE_MAP = "auto"
llm_name = "gemini-2.5-flash-lite"

DATASET_NAME = "Rowan/hellaswag"
SPLIT_NAME = "validation"
NUM_EXAMPLES = 100
PRINT_FIRST_N_DEBUG = 3
BASE_SLEEP_BETWEEN_CALLS = 5.0
WEB_TIMEOUT = 10
web_results = 5

PROMPT_MODE = "instruction"
VARIANTS = ["our_method"]

assert GEMINI_API_KEY, "Set GEMINI_API_KEY in Kaggle secrets."
genai.configure(api_key=GEMINI_API_KEY)
gemini_model = genai.GenerativeModel(llm_name)
print("Gemini model initialised.")

print("Loading local SLM (this can take a bit)...")
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
slm_tokenizer = AutoTokenizer.from_pretrained(model_name)
slm_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=DEVICE_MAP,
    quantization_config=bnb_config,
)
hf_logging.set_verbosity_error()
slm_tokenizer.pad_token = slm_tokenizer.eos_token
slm_model.config.pad_token_id = slm_tokenizer.pad_token_id
print("Local SLM loaded.")


def call_gemini_with_cooldown(prompt: str, max_retries: int = 3):
    last_exc = None
    for attempt in range(1, max_retries + 1):
        try:
            resp = gemini_model.generate_content(prompt)
            time.sleep(BASE_SLEEP_BETWEEN_CALLS)
            return resp
        except Exception as e:
            last_exc = e
            msg = str(e)
            wait_s = None
            m1 = re.search(r"retry in ([0-9.]+)s", msg, flags=re.I)
            if m1:
                wait_s = float(m1.group(1))
            else:
                m2 = re.search(r"seconds:\s*([0-9]+)", msg, flags=re.I)
                if m2:
                    wait_s = float(m2.group(1))
            if wait_s is None:
                wait_s = 60.0
            print(f"[Rate limit] Attempt {attempt}/{max_retries} → sleeping {wait_s:.1f}s...")
            time.sleep(wait_s)
    raise RuntimeError("Gemini generate_content failed after retries.") from last_exc


def slm_generate(prompt: str, max_new_tokens: int = 200) -> str:
    inputs = slm_tokenizer(prompt, return_tensors="pt").to(slm_model.device)
    out_ids = slm_model.generate(**inputs, max_new_tokens=max_new_tokens)
    text = slm_tokenizer.decode(out_ids[0], skip_special_tokens=True)
    return text.strip()


def web_search(query: str, num_results: int = 2):
    num_results = min(num_results, web_results)
    results = []
    if SERPER_API_KEY:
        url = "https://google.serper.dev/search"
        headers = {"X-API-KEY": SERPER_API_KEY, "Content-Type": "application/json"}
        payload = {"q": query}
        try:
            r = requests.post(url, headers=headers, json=payload, timeout=WEB_TIMEOUT)
            data = r.json()
            if data.get("organic"):
                for item in data["organic"][:num_results]:
                    snippet = (item.get("snippet") or "").replace("\n", " ")
                    link = item.get("link") or ""
                    if snippet:
                        results.append({"snippet": snippet, "url": link})
        except Exception as e:
            print("Serper error:", e)
    if not results:
        results.append({"snippet": f"General information about: {query}", "url": ""})
    return results[:num_results]


@lru_cache(maxsize=4096)
def web_search_cached(q: str, num_results: int = 2):
    res = web_search(q, num_results=num_results)
    return tuple((d["snippet"], d["url"]) for d in res)


def _looks_bad(q: str) -> bool:
    q = (q or "").strip().lower()
    if not q:
        return True
    if q in {"<why>", "<what>", "<how>"}:
        return True
    if len(re.sub(r"[^a-z]", "", q)) < 5:
        return True
    return False


def _simple_fallback_decomp(full_q: str) -> Dict[str, str]:
    base = re.sub(r"\s*\?$", "", full_q).strip()
    if not base:
        base = full_q.strip()
    return {
        "WHY": f"Why is the following context challenging for a commonsense completion task: {base}?",
        "WHAT": f"What key real-world facts or commonsense knowledge are needed to complete: {base}?",
        "HOW": f"How could everyday experience and physics help decide the most plausible ending for: {base}?",
    }


def slm_decompose_queries(full_question: str) -> Dict[str, str]:
    prompt = (
        "You are a query decomposition assistant for HellaSwag-style commonsense "
        "narrative completion tasks.\n"
        "Given a context and candidate endings, rewrite it into exactly three concise "
        "sub-queries:\n"
        "WHY: what high-level commonsense reasoning is needed\n"
        "WHAT: key entities, actions, or world knowledge required\n"
        "HOW: steps or logical process to compare plausibility of endings\n\n"
        "Return EXACTLY:\nWHY: <text>\nWHAT: <text>\nHOW: <text>\n\n"
        f"Input (context + instructions): {full_question}\n\nDecomposed queries:"
    )
    text = slm_generate(prompt, max_new_tokens=192)
    why = what = how = ""
    m_why = re.search(r"WHY:\s*(.+)", text, flags=re.I)
    m_what = re.search(r"WHAT:\s*(.+)", text, flags=re.I)
    m_how = re.search(r"HOW:\s*(.+)", text, flags=re.I)
    if m_why:
        why = m_why.group(1).strip()
    if m_what:
        what = m_what.group(1).strip()
    if m_how:
        how = m_how.group(1).strip()
    if _looks_bad(why) or _looks_bad(what) or _looks_bad(how):
        return _simple_fallback_decomp(full_question)
    return {"WHY": why, "WHAT": what, "HOW": how}


def slm_build_hints(full_question: str, tagged_snippets: List[Dict[str, str]]) -> List[str]:
    prompt = (
        "You are helping with a HellaSwag-style commonsense reasoning benchmark.\n"
        "Given the context and instructions, and web snippets tagged WHY/WHAT/HOW,\n"
        "extract 4–8 short, verifiable bullet points that capture relevant real-world\n"
        "knowledge (for example typical sequences of actions, physical constraints, social norms).\n"
        "Do NOT answer the question and do NOT mention any option letter.\n\n"
        f"Context and instructions:\n{full_question}\n\nWeb snippets:\n"
    )
    for i, s in enumerate(tagged_snippets, 1):
        prompt += f"[{i}] {s['snippet']}\n"
    prompt += "\nWrite 4–8 bullets. Start each line with '- ' only:\n"
    text = slm_generate(prompt, max_new_tokens=256)
    facts: List[str] = []
    for line in text.splitlines():
        ln = line.strip()
        if re.match(r"^(-|\*|•|\d+\.)\s+", ln):
            ln = re.sub(r"^(\*|•|\d+\.)\s+", "- ", ln)
        if ln.startswith("- "):
            fact = ln[2:].strip()
            if fact:
                facts.append(fact)
    if not facts:
        facts = [s["snippet"] for s in tagged_snippets]
    return facts[:8]


def build_hellaswag_instruction_prompt(
    context: str,
    options_dict: Dict[str, str],
    hints: List[str] = None,
) -> str:
    text = (
        "You are solving a commonsense narrative completion multiple-choice question "
        "from the HellaSwag benchmark.\n"
        "You are given a short context and several possible endings.\n"
        "Your task is to choose the single most plausible, coherent ending that continues "
        "the context in a natural way.\n\n"
        "Instructions:\n"
        "  • Carefully read the context and understand what is happening.\n"
        "  • Read each candidate ending and check whether it is grammatically correct,\n"
        "    logically consistent, and likely given the context.\n"
        "  • Prefer endings that match everyday physical and social commonsense.\n"
        "  • You may reason step-by-step INTERNALLY, but you MUST NOT show your reasoning.\n"
        "  • Only output the final answer option as a single capital letter: A, B, C, or D.\n\n"
        f"Context:\n{context}\n\n"
        "Candidate endings:\n"
    )
    for lab in ["A", "B", "C", "D"]:
        text += f"{lab}. {options_dict[lab]}\n"
    if hints:
        text += "\nOptional factual or world-knowledge hints from a smaller helper model (may be incomplete or noisy):\n"
        for h in hints:
            text += f"- {h}\n"
    text += "\nFinal answer (ONLY one letter A, B, C, or D):"
    return text


def gemini_hellaswag_answer(
    context: str,
    options_dict: Dict[str, str],
    hints: List[str] = None,
) -> Tuple[str, str]:
    prompt = build_hellaswag_instruction_prompt(context, options_dict, hints=hints)
    resp = call_gemini_with_cooldown(prompt)
    out_text = (getattr(resp, "text", "") or "").strip()
    m = re.search(r"\b[A-D]\b", out_text)
    if m:
        return m.group(0), out_text
    return "A", out_text


def prepare_hellaswag_example(row) -> Tuple[str, Dict[str, str], str]:
    context = str(row["ctx"])
    endings_list = list(row["endings"])
    if len(endings_list) != 4:
        endings_list = (endings_list + [""] * 4)[:4]
    options_dict = {
        "A": str(endings_list[0]),
        "B": str(endings_list[1]),
        "C": str(endings_list[2]),
        "D": str(endings_list[3]),
    }
    ans_idx_raw = row["label"]
    ans_idx = int(ans_idx_raw)
    gold = "ABCD"[ans_idx]
    return context, options_dict, gold


def predict_hellaswag_for_row(
    row,
    variant: str = "our_method",
    verbose: bool = False,
):
    context, options_dict, gold = prepare_hellaswag_example(row)
    full_q = (
        "Context: "
        + context
        + "\n\nCandidate endings:\n"
        f"A. {options_dict['A']}\n"
        f"B. {options_dict['B']}\n"
        f"C. {options_dict['C']}\n"
        f"D. {options_dict['D']}\n\n"
        "Question: Which ending is the most plausible, coherent continuation?"
    )
    sub_queries: Dict[str, str] = {}
    hints: List[str] = []
    if variant == "our_method":
        sub_queries = slm_decompose_queries(full_q)
        tagged = []
        for tag in ["WHY", "WHAT", "HOW"]:
            q_sub = sub_queries[tag]
            for snip, url in web_search_cached(q_sub, num_results=2):
                tagged.append({"snippet": f"[{tag}] {snip}", "url": url})
        hints = slm_build_hints(full_q, tagged)
    pred_label, gemini_raw = gemini_hellaswag_answer(
        context,
        options_dict,
        hints=hints if hints else None,
    )
    if verbose:
        print("\n" + "=" * 70)
        print(f"Variant   : {variant}")
        print(f"Context   :\n{context[:600]}...\n")
        print("Candidate endings:")
        for lab in ["A", "B", "C", "D"]:
            print(f"  {lab}. {options_dict[lab]}")
        print(f"\nGold label: {gold}")
        print(f"Pred label: {pred_label}")
        if variant == "our_method":
            print("\nSub-queries:")
            for k, v in sub_queries.items():
                print(f"  {k}: {v}")
            print("\nSample hints:")
            for h in hints[:5]:
                print(f"  - {h}")
    return pred_label, gold, sub_queries, hints, gemini_raw


dataset = load_dataset(DATASET_NAME)
ds_split = dataset[SPLIT_NAME]
df_all = ds_split.to_pandas()
df_split = df_all.iloc[:NUM_EXAMPLES].reset_index(drop=True)

print(f"HellaSwag split '{SPLIT_NAME}' loaded — using first {len(df_split)} examples.")
print("Columns:", list(df_split.columns))

rows_out: List[Dict] = []
preds: List[str] = []
golds: List[str] = []
correct_flags: List[int] = []

for variant in VARIANTS:
    print(f"\n=== Running HellaSwag — variant={variant}, prompt_mode={PROMPT_MODE} ===")
    for i in tqdm(range(len(df_split)), desc=variant):
        r = df_split.iloc[i]
        pred_label, gold, sub_queries, hints, gemini_raw = predict_hellaswag_for_row(
            r,
            variant=variant,
            verbose=(i < PRINT_FIRST_N_DEBUG),
        )
        preds.append(pred_label)
        golds.append(gold)
        correct = int(pred_label == gold)
        correct_flags.append(correct)
        endings_list = list(r["endings"])
        if len(endings_list) != 4:
            endings_list = (endings_list + [""] * 4)[:4]
        rows_out.append({
            "idx_in_split": int(i),
            "variant": variant,
            "split": SPLIT_NAME,
            "prompt_mode": PROMPT_MODE,
            "context": r["ctx"],
            "ending0": str(endings_list[0]),
            "ending1": str(endings_list[1]),
            "ending2": str(endings_list[2]),
            "ending3": str(endings_list[3]),
            "gold_label": gold,
            "pred_label": pred_label,
            "correct": correct,
            "gemini_output": gemini_raw,
            "sub_queries": json.dumps(sub_queries, ensure_ascii=False) if sub_queries else "",
            "hints": json.dumps(hints, ensure_ascii=False) if hints else "",
        })
        if i >= PRINT_FIRST_N_DEBUG:
            print(f"Example {i+1}: gold={gold} | pred={pred_label} | correct={bool(correct)}")

acc = sum(correct_flags) / max(1, len(preds))
print(f"\nVariant={VARIANTS[0]}, mode={PROMPT_MODE} — accuracy: {acc:.3f} on {len(preds)} examples.")

df_out = pd.DataFrame(rows_out)
out_path = f"hellaswag_{SPLIT_NAME}_instruction_our_method.csv"
df_out.to_csv(out_path, index=False)
print("\nSaved:", out_path)

print("\nSummary (accuracy by variant):")
for variant in VARIANTS:
    variant_preds = [row["pred_label"] for row in rows_out if row["variant"] == variant]
    variant_golds = [row["gold_label"] for row in rows_out if row["variant"] == variant]
    if not variant_preds:
        continue
    acc = sum(int(p == g) for p, g in zip(variant_preds, variant_golds)) / max(1, len(variant_preds))
    print(f"  {variant:10s} — accuracy {acc:.3f} on {len(variant_preds)} examples")