In [None]:
import os
import re
import json
import time
import random
from typing import List, Dict, Tuple

import pandas as pd
from tqdm.auto import tqdm

import google.generativeai as genai
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
from datasets import load_dataset

user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
GEMINI_API_KEY = user_secrets.get_secret("GEMINI_API_KEY")

assert HF_TOKEN, "Set HF_TOKEN in Kaggle secrets."
assert GEMINI_API_KEY, "Set GEMINI_API_KEY in Kaggle secrets."

login(token=HF_TOKEN)

llm_name = "gemini-2.5-flash-lite"

NUM_CLAIMS = 100
PRINT_FIRST_N_DEBUG = 3
BASE_SLEEP_BETWEEN_CALLS = 5.0
GEMINI_MAX_RETRIES = 3

PROMPT_MODES = ["instruction", "cot"]
VARIANTS = ["our_method"]

CANONICAL_LABELS = [
    "Supported",
    "Refuted",
    "Not Enough Evidence",
    "Conflicting Evidence/Cherrypicking",
]

genai.configure(api_key=GEMINI_API_KEY)
gemini_model = genai.GenerativeModel(llm_name)
print("Gemini model initialised for AVeriTeC.")


def call_gemini_with_cooldown(prompt: str, max_retries: int = GEMINI_MAX_RETRIES):
    last_exc = None
    for attempt in range(1, max_retries + 1):
        try:
            resp = gemini_model.generate_content(prompt)
            time.sleep(BASE_SLEEP_BETWEEN_CALLS)
            return resp
        except Exception as e:
            last_exc = e
            msg = str(e)
            wait_s = None
            m1 = re.search(r"retry in ([0-9.]+)s", msg, flags=re.I)
            if m1:
                wait_s = float(m1.group(1))
            else:
                m2 = re.search(r"seconds:\s*([0-9]+)", msg, flags=re.I)
                if m2:
                    wait_s = float(m2.group(1))
            if wait_s is None:
                wait_s = 60.0
            print(f"[Gemini rate/HTTP error] Attempt {attempt}/{max_retries} → sleeping {wait_s:.1f}s...")
            time.sleep(wait_s)
    raise RuntimeError(f"Gemini generate_content failed after {max_retries} retries.") from last_exc


def _normalize_label_token(raw: str) -> str:
    if not raw:
        return "Not Enough Evidence"
    raw = raw.strip()
    raw = re.sub(r"[.\s]+$", "", raw)
    low = raw.lower()
    if low in {"1", "supported"}:
        return "Supported"
    if low in {"2", "refuted"}:
        return "Refuted"
    if low in {"3"} or "not enough evidence" in low or "insufficient evidence" in low:
        return "Not Enough Evidence"
    if low in {"4"} or "conflicting" in low or "cherry" in low:
        return "Conflicting Evidence/Cherrypicking"
    for lab in CANONICAL_LABELS:
        if lab.lower() == low:
            return lab
    for lab in CANONICAL_LABELS:
        if lab.lower() in low:
            return lab
    return "Not Enough Evidence"


def extract_label_from_output(text: str) -> str:
    if not text:
        return "Not Enough Evidence"
    m = re.search(r"Final label:\s*([^\n]+)", text, flags=re.I)
    if m:
        candidate = m.group(1).strip()
        return _normalize_label_token(candidate)
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    if not lines:
        return "Not Enough Evidence"
    candidate = lines[-1]
    m2 = re.search(r"Label:\s*([^\n]+)", candidate, flags=re.I)
    if m2:
        candidate = m2.group(1).strip()
    return _normalize_label_token(candidate)


def _build_averitec_base_desc() -> str:
    return (
        "You are a careful fact-checking assistant for real-world political and news claims.\n"
        "Your task is to assign EXACTLY ONE veracity label to the claim, choosing from:\n"
        "1. Supported\n"
        "2. Refuted\n"
        "3. Not Enough Evidence\n"
        "4. Conflicting Evidence/Cherrypicking\n\n"
        "Definitions:\n"
        "i) Supported: The claim is true or well-supported by reliable evidence.\n"
        "ii) Refuted: The claim is false or clearly contradicted by reliable evidence.\n"
        "iii) Not Enough Evidence: Available information is insufficient or inconclusive to decide.\n"
        "iv) Conflicting Evidence/Cherrypicking: There is both supporting and refuting evidence,\n"
        "  or the claim selectively highlights only part of the evidence (cherry-picking).\n\n"
    )


def _build_hints_block(hints: List[str]) -> str:
    if not hints:
        return ""
    text = (
        "You are also given high-quality bullet-point hints that summarise key factual\n"
        "evidence collected about this claim (they may still be incomplete, but are usually\n"
        "very informative). Give these hints significant weight when deciding the label.\n\n"
        "Hints:\n"
    )
    for h in hints:
        text += f"- {h}\n"
    text += "\n"
    return text


def build_averitec_instruction_prompt(claim: str, hints: List[str] = None) -> str:
    desc = _build_averitec_base_desc()
    hints_block = _build_hints_block(hints or [])
    final_instr = (
        "You may reason internally, but DO NOT show your reasoning.\n"
        "Respond with exactly one line in the format:\n"
        "Final label: <one of Supported / Refuted / Not Enough Evidence / Conflicting Evidence/Cherrypicking>\n"
    )
    prompt = (
        desc
        + hints_block
        + "Claim to label:\n"
        + claim
        + "\n\n"
        + final_instr
    )
    return prompt


def build_averitec_cot_prompt(claim: str, hints: List[str] = None) -> str:
    desc = _build_averitec_base_desc()
    hints_block = _build_hints_block(hints or [])
    final_instr = (
        "Think step by step about whether the claim is supported, refuted, lacks evidence,\n"
        "or has conflicting evidence. Briefly explain your reasoning.\n"
        "Then, on a new line, write exactly:\n"
        "Final label: <one of Supported / Refuted / Not Enough Evidence / Conflicting Evidence/Cherrypicking>\n"
    )
    prompt = (
        desc
        + hints_block
        + "Claim to label:\n"
        + claim
        + "\n\n"
        + final_instr
    )
    return prompt


def gemini_veracity_answer(
    claim: str,
    hints: List[str] = None,
    mode: str = "instruction",
) -> Tuple[str, str]:
    mode = mode.lower()
    if mode == "cot":
        prompt = build_averitec_cot_prompt(claim, hints=hints)
    else:
        prompt = build_averitec_instruction_prompt(claim, hints=hints)
    resp = call_gemini_with_cooldown(prompt)
    out_text = (getattr(resp, "text", "") or "").strip()
    label_pred = extract_label_from_output(out_text)
    return label_pred, out_text


def decompositions_to_hints(decomp_field) -> List[str]:
    hints: List[str] = []
    if isinstance(decomp_field, list):
        for x in decomp_field:
            if not isinstance(x, str):
                continue
            s = x.strip()
            if not s:
                continue
            hints.append(s)
    elif isinstance(decomp_field, str):
        for part in re.split(r"[\n;]+", decomp_field):
            s = part.strip()
            if s:
                hints.append(s)
    return hints[:8]


def predict_claim_for_row(
    row,
    variant: str = "our_method",
    mode: str = "instruction",
    verbose: bool = False,
):
    claim = str(row["claim"])
    gold = str(row["label"]).strip()
    hints: List[str] = []
    if variant == "our_method":
        hints = decompositions_to_hints(row.get("decomposition", []))
    label_pred, gemini_raw = gemini_veracity_answer(
        claim,
        hints=hints if hints else None,
        mode=mode,
    )
    if verbose:
        print("\n" + "=" * 70)
        print(f"Variant      : {variant}")
        print(f"Prompt mode  : {mode}")
        print(f"Claim        : {claim}")
        print(f"Gold label   : {gold}")
        print(f"Pred label   : {label_pred}")
        if variant == "our_method":
            print("\nHints from decomposition (first few):")
            for h in hints[:5]:
                print(f"  - {h}")
    return label_pred, gemini_raw, hints


dataset = load_dataset("AlbertHatsuki/AveriTeC-decomposed")
df_all = dataset["train"].to_pandas()

assert all(col in df_all.columns for col in ["id", "claim", "label", "decomposition"]), \
    f"Unexpected columns: {list(df_all.columns)}"

sub_df = df_all.iloc[:NUM_CLAIMS].reset_index(drop=True)
print(f"Loaded AVeriTeC-decomposed train split — using first {len(sub_df)} claims.")

rows_out = []

for mode in PROMPT_MODES:
    print(f"\n================ PROMPT_MODE = {mode} ================")
    for variant in VARIANTS:
        preds = []
        golds = []
        print(f"\n=== Running AVeriTeC — variant={variant}, mode={mode} ===")
        for i in tqdm(range(len(sub_df)), desc=f"{variant}-{mode}"):
            r = sub_df.iloc[i]
            gold = str(r["label"]).strip()
            label_pred, gemini_raw, hints = predict_claim_for_row(
                r,
                variant=variant,
                mode=mode,
                verbose=(i < PRINT_FIRST_N_DEBUG),
            )
            preds.append(label_pred)
            golds.append(gold)
            correct = int(label_pred == gold)
            rows_out.append({
                "id": int(r["id"]),
                "variant": variant,
                "prompt_mode": mode,
                "claim": r["claim"],
                "label_gold": gold,
                "label_pred": label_pred,
                "correct": correct,
                "gemini_output": gemini_raw,
                "hints_used": json.dumps(hints, ensure_ascii=False) if hints else "",
            })
            if i >= PRINT_FIRST_N_DEBUG:
                print(f"Claim {i+1}: gold={gold} | pred={label_pred} | correct={bool(correct)}")
        acc = sum(int(p == g) for p, g in zip(preds, golds)) / max(1, len(preds))
        print(f"\nVariant={variant}, mode={mode} — accuracy: {acc:.3f} on {len(preds)} claims.")

df_out = pd.DataFrame(rows_out)
out_path = "averitec_cot.csv"
df_out.to_csv(out_path, index=False)
print("\nSaved detailed results to:", out_path)

print("\nSummary (accuracy by variant & mode):")
for mode in PROMPT_MODES:
    for variant in VARIANTS:
        vrows = df_out[(df_out["variant"] == variant) & (df_out["prompt_mode"] == mode)]
        if vrows.empty:
            continue
        acc = vrows["correct"].mean()
        print(f"  mode={mode:11s} | variant={variant:10s} → accuracy {acc:.3f} on {len(vrows)} claims")