# AIMO3 Baseline Notebook
## AI Mathematical Olympiad – Progress Prize 3

### Quick Start Guide
1. **RUN_MODE Selection**:
   - `"local_ref"`: Debug mode - runs evaluation on `reference.csv` (10 problems)
   - `"submit_auto"`: Kaggle submission mode - uses `kaggle_evaluation` API

2. **Model Setup (Kaggle)**:
   - Add your model as a Kaggle Dataset/Model input
   - Update `MODEL_PATH` in CONFIG to point to `/kaggle/input/your-model-name`
   - Or use Kaggle's built-in models

3. **Telemetry**:
   - Logs saved to `/kaggle/working/aimo3_telemetry.jsonl`
   - Download after submission for analysis

## CELL A — CONFIG (Constants)

In [11]:
# ============================================================
# AIMO3 NOTEBOOK — NEW SOLVER (from scratch) — PART 1/3
# Cells: A → D
# ============================================================

# ============================================
# CELL A — CONFIG
# ============================================

import os
K_BASE = 6
TEMPERATURE_BASE = 0.6
MAX_NEW_TOKENS = 2048
TOP_P = 0.95
TOP_K = 50
AGENT_MAX_RETRIES = 3
EXEC_TIMEOUT_SEC = 10

# Modes:
# - "submit": run Kaggle inference server
# - "local_ref": run small local debug on reference.csv
RUN_MODE = os.getenv("RUN_MODE", "submit")

# Model (you said: start fresh with 8B on H100)
MODEL_PATH = os.getenv("MODEL_PATH", "/kaggle/input/qwen-3/transformers/8b-fp8/1")
MODEL_ID   = os.getenv("MODEL_ID",   MODEL_PATH)

# Generation budget (start conservative; dynamic escalation happens later)
K_BASE = int(os.getenv("K_BASE", "4"))                 # base samples per stage (batch)
TEMPERATURE_BASE = float(os.getenv("TEMPERATURE", "0.6"))
TOP_P = float(os.getenv("TOP_P", "0.95"))
TOP_K = int(os.getenv("TOP_K", "50"))

# Token budgets by stage (short → longer)
MAX_NEW_TOKENS_DIRECT = int(os.getenv("MAX_NEW_TOKENS_DIRECT", "256"))
MAX_NEW_TOKENS_CODE   = int(os.getenv("MAX_NEW_TOKENS_CODE",   "768"))
MAX_NEW_TOKENS_HEAVY  = int(os.getenv("MAX_NEW_TOKENS_HEAVY",  "1400"))

# Agent retries (per stage)
AGENT_MAX_RETRIES = int(os.getenv("AGENT_MAX_RETRIES", "2"))

# Python execution timeout (seconds) — will be dynamically adjusted later too
EXEC_TIMEOUT_SEC_BASE = int(os.getenv("EXEC_TIMEOUT_SEC", "10"))

# Overall logging
VERBOSE = int(os.getenv("VERBOSE", "1")) == 1

# Paths
REF_CSV_PATH = os.getenv(
    "REF_CSV_PATH",
    "/kaggle/input/ai-mathematical-olympiad-progress-prize-3/reference.csv",
)

print("============================================================")
print("AIMO3 NOTEBOOK — NEW SOLVER (Transformers) — PART 1/3")
print("============================================================")
print(f"RUN_MODE: {RUN_MODE}")
print(f"MODEL_PATH: {MODEL_PATH}")
print(f"K_BASE={K_BASE}, TEMP={TEMPERATURE_BASE}, TOP_P={TOP_P}, TOP_K={TOP_K}")
print(f"MAX_NEW_TOKENS: direct={MAX_NEW_TOKENS_DIRECT}, code={MAX_NEW_TOKENS_CODE}, heavy={MAX_NEW_TOKENS_HEAVY}")
print(f"AGENT_MAX_RETRIES={AGENT_MAX_RETRIES}, EXEC_TIMEOUT_SEC_BASE={EXEC_TIMEOUT_SEC_BASE}")
print("============================================================")




AIMO3 NOTEBOOK — NEW SOLVER (Transformers) — PART 1/3
RUN_MODE: submit
MODEL_PATH: /kaggle/input/qwen-3/transformers/8b-fp8/1
K_BASE=4, TEMP=0.6, TOP_P=0.95, TOP_K=50
MAX_NEW_TOKENS: direct=256, code=768, heavy=1400
AGENT_MAX_RETRIES=2, EXEC_TIMEOUT_SEC_BASE=10


## CELL B — IMPORTS + SEED CONTROL

In [12]:
# ============================================
# CELL B — IMPORTS + UTILS
# ============================================

import re
import time
import json
import math
import random
import tempfile
import subprocess
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Speed / determinism knobs
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

def vprint(*args, **kwargs):
    if VERBOSE:
        print(*args, **kwargs)

def now_s() -> float:
    return time.time()

def safe_int(x: Any) -> Optional[int]:
    try:
        if isinstance(x, bool):
            return None
        if isinstance(x, (int, np.integer)):
            return int(x)
        s = str(x).strip()
        if re.fullmatch(r"[+-]?\d+", s):
            return int(s)
        return None
    except Exception:
        return None

def clamp_abs_int(x: int, limit: int = 10**18) -> int:
    if x > limit: return limit
    if x < -limit: return -limit
    return x




## CELL C — LLM BACKEND (Transformers)

In [13]:
# ============================================
# CELL C — LLM BACKEND (Transformers)
# ============================================

_TOKENIZER: Optional[AutoTokenizer] = None
_MODEL: Optional[AutoModelForCausalLM] = None
_DEVICE: Optional[torch.device] = None

def load_llm() -> Tuple[AutoTokenizer, AutoModelForCausalLM, torch.device]:
    """
    Loads tokenizer + model once.
    This cell is designed to be stable for Kaggle / Transformers.
    """
    global _TOKENIZER, _MODEL, _DEVICE
    if _TOKENIZER is not None and _MODEL is not None and _DEVICE is not None:
        return _TOKENIZER, _MODEL, _DEVICE

    vprint("[LLM] Loading tokenizer...")
    tok = AutoTokenizer.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        use_fast=True,
    )

    vprint("[LLM] Loading model...")
    # NOTE: use torch_dtype="auto" because your Kaggle weights may be FP8/FP16/BF16 packaged.
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype="auto",
        device_map="auto",
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _TOKENIZER, _MODEL, _DEVICE = tok, model, device

    vprint("[LLM] Loaded. device:", device, "| dtype:", getattr(next(model.parameters(), None), "dtype", None))
    return tok, model, device

def _apply_chat_template(tokenizer: AutoTokenizer, messages: List[Dict[str, str]]) -> str:
    """
    Robust chat formatting: prefers tokenizer.apply_chat_template if available.
    Falls back to simple role-prefixed transcript.
    """
    if hasattr(tokenizer, "apply_chat_template"):
        try:
            return tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            pass

    # Fallback (works with most causal LMs, less optimal)
    parts = []
    for m in messages:
        role = m.get("role", "user")
        content = m.get("content", "")
        parts.append(f"{role.upper()}:\n{content}\n")
    parts.append("ASSISTANT:\n")
    return "\n".join(parts)

@torch.inference_mode()
def llm_generate_texts(
    batch_messages: List[List[Dict[str, str]]],
    max_new_tokens: int,
    temperature: float,
    top_p: float,
    top_k: int,
) -> List[str]:
    """
    Batch generation for higher GPU utilization.
    batch_messages: list of chat message lists (len = batch size).
    Returns: list of decoded assistant continuations (len = batch size).
    """
    tokenizer, model, _ = load_llm()

    prompts = [_apply_chat_template(tokenizer, msgs) for msgs in batch_messages]
    enc = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
    )
    enc = {k: v.to(model.device) for k, v in enc.items()}

    do_sample = temperature > 1e-6
    gen = model.generate(
        **enc,
        do_sample=do_sample,
        temperature=temperature if do_sample else None,
        top_p=top_p if do_sample else None,
        top_k=top_k if do_sample else None,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        use_cache=True,
    )

    # Slice off the prompt part (per-row prompt lengths differ because of padding)
    # We decode full then remove prompt prefix via lengths.
    input_lens = enc["input_ids"].shape[1]
    # NOTE: this is safe because we padded to same length; model output includes that padded prompt.
    outs = tokenizer.batch_decode(gen[:, input_lens:], skip_special_tokens=True)
    return [o.strip() for o in outs]




## CELL D — SAFE PYTHON EXECUTOR + PROMPTS

In [14]:
# ============================================
# CELL D — PROMPT FACTORY (diverse archetypes)
# ============================================

SYSTEM_SOLVER = (
    "You are a competition mathematician.\n"
    "Return FINAL_ANSWER: <integer> when you are done.\n"
    "If you use Python, put code inside a single ```python ... ``` block and ensure it prints ONE integer.\n"
    "Do NOT include multiple candidate answers.\n"
)

# 4 “archetypes” (diversity by design, not just random seeds)
ARCHETYPES = {
    "direct": (
        "Solve succinctly. If the problem is computational, derive a clean formula and compute.\n"
        "End with FINAL_ANSWER: <integer>.\n"
    ),
    "sympy_first": (
        "Use Python + sympy early. Convert the problem to equations / constraints, solve via sympy.\n"
        "Print only the final integer.\n"
    ),
    "bruteforce_constraints": (
        "Prefer brute force / search under constraints if feasible.\n"
        "Write Python to search intelligently (pruning).\n"
        "Print only the final integer.\n"
    ),
    "number_theory_style": (
        "Use modular arithmetic / invariants / parity / divisibility reasoning.\n"
        "If any computation is needed, write Python.\n"
        "Print only the final integer.\n"
    ),
}

def build_messages(problem: str, archetype: str, extra_hint: str = "") -> List[Dict[str, str]]:
    if archetype not in ARCHETYPES:
        archetype = "direct"
    user = (
        f"PROBLEM:\n{problem}\n\n"
        f"STYLE:\n{ARCHETYPES[archetype]}\n"
    )
    if extra_hint.strip():
        user += f"\nEXTRA_HINT:\n{extra_hint.strip()}\n"
    return [
        {"role": "system", "content": SYSTEM_SOLVER},
        {"role": "user", "content": user},
    ]


## CELL E — ANSWER EXTRACTION (Strict)

In [15]:
# ============================================================
# CELL E — PARSING / EXTRACTION / VOTING UTILS
# ============================================================

import re
import math
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

FINAL_PATTERNS = [
    re.compile(r"FINAL_ANSWER\s*[:=]\s*(-?\d+)", re.IGNORECASE),
    re.compile(r"\\boxed\{\s*(-?\d+)\s*\}"),
    re.compile(r"(?<!\w)Answer\s*[:=]\s*(-?\d+)", re.IGNORECASE),
]

INT_ONLY_LINE = re.compile(r"^\s*(-?\d+)\s*$")

CODE_BLOCK_PATTERNS = [
    re.compile(r"```python\s*(.*?)```", re.DOTALL | re.IGNORECASE),
    re.compile(r"```py\s*(.*?)```", re.DOTALL | re.IGNORECASE),
    re.compile(r"```\s*(.*?)```", re.DOTALL),  # fallback (dangerous; use last)
]

def strip_think(text: str) -> str:
    # remove common thinking tags without breaking content
    text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL | re.IGNORECASE)
    text = re.sub(r"<analysis>.*?</analysis>", "", text, flags=re.DOTALL | re.IGNORECASE)
    return text

def extract_final_answer(text: str) -> Optional[int]:
    if not text:
        return None
    t = strip_think(text)

    # Priority 1: explicit patterns (take last occurrence)
    for pat in FINAL_PATTERNS:
        matches = list(pat.finditer(t))
        if matches:
            return int(matches[-1].group(1))

    # Priority 2: last line is integer-only
    lines = [ln.strip() for ln in t.strip().splitlines() if ln.strip()]
    if lines:
        m = INT_ONLY_LINE.match(lines[-1])
        if m:
            return int(m.group(1))

    # Priority 3: "best-effort" — find integers near the end (last ~300 chars)
    tail = t[-300:]
    ints = re.findall(r"(-?\d+)", tail)
    if ints:
        try:
            return int(ints[-1])
        except Exception:
            return None
    return None

def extract_python_code(text: str) -> Optional[str]:
    if not text:
        return None
    t = strip_think(text)

    best = None
    for pat in CODE_BLOCK_PATTERNS:
        blocks = list(pat.finditer(t))
        if not blocks:
            continue
        # prefer the last block (often the final corrected version)
        best = blocks[-1].group(1).strip()
        if best:
            break

    if not best:
        return None

    # remove leading language hints / stray backticks inside
    best = best.strip().strip("`").strip()

    # If the model includes "FINAL_ANSWER:" inside code as a comment, keep it fine.
    return best if len(best) >= 8 else None

def parse_int_from_stdout(stdout: str) -> Optional[int]:
    if not stdout:
        return None
    lines = [ln.strip() for ln in stdout.strip().splitlines() if ln.strip()]
    if not lines:
        return None

    # Prefer last integer-only line
    for ln in reversed(lines):
        m = INT_ONLY_LINE.match(ln)
        if m:
            return int(m.group(1))

    # Otherwise, parse last FINAL_ANSWER / boxed / Answer patterns in stdout
    return extract_final_answer(stdout)

@dataclass
class Candidate:
    answer: int
    source: str  # "direct", "boxed", "code", "verifier", etc.
    weight: float
    meta: Dict[str, Any] = field(default_factory=dict)

def weighted_vote(cands: List[Candidate]) -> Tuple[Optional[int], Dict[int, float]]:
    if not cands:
        return None, {}
    score: Dict[int, float] = {}
    best_single: Dict[int, float] = {}
    for c in cands:
        score[c.answer] = score.get(c.answer, 0.0) + float(c.weight)
        best_single[c.answer] = max(best_single.get(c.answer, 0.0), float(c.weight))
    # rank by total weight, then best single, then answer (stable tie-break)
    best = sorted(score.items(), key=lambda kv: (kv[1], best_single.get(kv[0], 0.0), -abs(kv[0])), reverse=True)[0][0]
    return best, score

def consensus_strength(score_map: Dict[int, float]) -> float:
    if not score_map:
        return 0.0
    vals = sorted(score_map.values(), reverse=True)
    if len(vals) == 1:
        return 1.0
    # ratio top / (top + second)
    top, second = vals[0], vals[1]
    denom = top + second
    return float(top / denom) if denom > 0 else 0.0


## CELL F — AGENTIC SOLVER (PoT + Retry + Voting)

In [16]:
# ============================================================
# CELL F — PYTHON EXECUTOR (SAFE-ISH) + TIME BUDGET
# ============================================================

import os
import sys
import tempfile
import subprocess
from pathlib import Path

def _executor_env() -> Dict[str, str]:
    env = dict(os.environ)
    # reduce noisy behaviors
    env["PYTHONHASHSEED"] = "0"
    env["PYTHONWARNINGS"] = "ignore"
    env["OMP_NUM_THREADS"] = "1"
    env["OPENBLAS_NUM_THREADS"] = "1"
    env["MKL_NUM_THREADS"] = "1"
    env["NUMEXPR_NUM_THREADS"] = "1"
    return env

def run_python_sandbox(code: str, timeout_sec: int = 10) -> Tuple[bool, str, str]:
    """
    Runs code in a temp dir using a temp .py file.
    Returns: (success, stdout, stderr)
    """
    if not code or not code.strip():
        return False, "", "Empty code"

    # Hard guard: prevent obvious shell / file deletion attempts (not perfect)
    dangerous = ["os.system(", "subprocess.", "shutil.rmtree", "rm -rf", "pip install", "wget ", "curl "]
    lowered = code.lower()
    for d in dangerous:
        if d in lowered:
            return False, "", f"Blocked dangerous token: {d}"

    with tempfile.TemporaryDirectory() as td:
        td_path = Path(td)
        script_path = td_path / "main.py"

        prelude = (
            "import sys\n"
            "import math\n"
            "import itertools\n"
            "import functools\n"
            "import fractions\n"
            "from math import *\n"
            "try:\n"
            "    import sympy as sp\n"
            "except Exception:\n"
            "    sp = None\n"
            "try:\n"
            "    import numpy as np\n"
            "except Exception:\n"
            "    np = None\n"
            "sys.setrecursionlimit(10**6)\n"
        )

        # If code doesn't print, try to encourage a single print by wrapping? No: keep pure.
        script_path.write_text(prelude + "\n\n" + code + "\n", encoding="utf-8")

        try:
            proc = subprocess.run(
                [sys.executable, str(script_path)],
                cwd=str(td_path),
                capture_output=True,
                text=True,
                timeout=int(timeout_sec),
                env=_executor_env(),
            )
            ok = (proc.returncode == 0)
            return ok, (proc.stdout or ""), (proc.stderr or "")
        except subprocess.TimeoutExpired:
            return False, "", "Execution timed out"
        except Exception as e:
            return False, "", f"Executor error: {e}"

class DeadlineManager:
    def __init__(self, total_budget_sec: int = 5 * 60 * 60 - 60, hard_stop_margin_sec: int = 20):
        self.t0 = time.time()
        self.total_budget_sec = int(total_budget_sec)
        self.hard_stop_margin_sec = int(hard_stop_margin_sec)

    def elapsed(self) -> float:
        return time.time() - self.t0

    def remaining(self) -> float:
        return max(0.0, self.total_budget_sec - self.elapsed())

    def can_continue(self) -> bool:
        return self.remaining() > self.hard_stop_margin_sec

    def per_problem_budget(self, problems_left: int, min_sec: int = 10, max_sec: int = 180) -> int:
        if problems_left <= 0:
            return min_sec
        rem = self.remaining()
        # fair-share with cap
        fair = rem / float(problems_left)
        return int(max(min_sec, min(max_sec, fair)))


## CELL G — TELEMETRY LOGGER

In [17]:
# ============================================================
# CELL G — PROMPTS (DIVERSE ARCHETYPES) + GENERATION HELPERS
# ============================================================

# Requires llm_generate_texts from CELL C.

SYSTEM_SOLVER = (
    "You are a math contest solver. "
    "You MUST output the final integer answer in the format: FINAL_ANSWER: <integer> "
    "Do not include any other numbers on the last line."
)

SYSTEM_CODE = (
    "You are a math contest solver that uses Python to compute exactly. "
    "Return ONLY a Python code block ```python ... ```.\n"
    "Rules:\n"
    "- The code must print ONLY one line: the integer answer.\n"
    "- No extra prints.\n"
    "- Use sympy/numpy if helpful.\n"
)

SYSTEM_VERIFY = (
    "You are a strict verifier. "
    "Given a problem and a proposed integer answer, check for constraint/boundary traps. "
    "Return either:\n"
    "VERDICT: OK\n"
    "or\n"
    "VERDICT: FAIL\n"
    "If FAIL and you are confident, also provide: FINAL_ANSWER: <integer> on a new line."
)

def user_prompt_direct(problem: str, style_hint: str = "") -> str:
    hint = f"\nStyle hint: {style_hint}\n" if style_hint else "\n"
    return (
        f"Problem:\n{problem}\n"
        f"{hint}"
        "Solve carefully. Put the final result as the last line exactly like:\n"
        "FINAL_ANSWER: 123\n"
    )

def user_prompt_code(problem: str, approach_hint: str = "") -> str:
    hint = f"\nApproach hint: {approach_hint}\n" if approach_hint else "\n"
    return (
        f"Problem:\n{problem}\n"
        f"{hint}"
        "Write Python that computes the exact integer answer and prints ONLY that integer.\n"
    )

def user_prompt_verify(problem: str, answer: int) -> str:
    return (
        f"Problem:\n{problem}\n\n"
        f"Proposed answer: {answer}\n\n"
        "Check quickly for missed constraints (integer, positivity, bounds, divisibility, off-by-one, etc.)."
    )

ARCHETYPES_DIRECT = [
    "Algebra / simplify first",
    "Number theory / modular constraints",
    "Brute force reasoning (small search if implied)",
    "Invariant / parity / extremal thinking",
]

ARCHETYPES_CODE = [
    "Use sympy solve/simplify; then print integer",
    "Try brute force/search over integer constraints",
    "Translate to equations and compute exactly",
    "If geometry: coordinate bash / compute exactly",
]

def _batch_messages(system: str, user: str, n: int) -> List[List[Dict[str, str]]]:
    return [[{"role": "system", "content": system}, {"role": "user", "content": user}] for _ in range(int(n))]

def generate_n_direct(problem: str, n: int, temperature: float, max_new_tokens: int) -> List[str]:
    # diversify by cycling archetype hints
    msgs = []
    for i in range(n):
        hint = ARCHETYPES_DIRECT[i % len(ARCHETYPES_DIRECT)]
        msgs.append([
            {"role": "system", "content": SYSTEM_SOLVER},
            {"role": "user", "content": user_prompt_direct(problem, hint)},
        ])
    return llm_generate_texts(
        msgs,
        temperature=float(temperature),
        top_p=float(TOP_P),
        top_k=int(TOP_K),
        max_new_tokens=int(max_new_tokens),
    )

def generate_n_code(problem: str, n: int, temperature: float, max_new_tokens: int) -> List[str]:
    msgs = []
    for i in range(n):
        hint = ARCHETYPES_CODE[i % len(ARCHETYPES_CODE)]
        msgs.append([
            {"role": "system", "content": SYSTEM_CODE},
            {"role": "user", "content": user_prompt_code(problem, hint)},
        ])
    return llm_generate_texts(
        msgs,
        temperature=float(temperature),
        top_p=float(TOP_P),
        top_k=int(TOP_K),
        max_new_tokens=int(max_new_tokens),
    )

def generate_verify(problem: str, answer: int, temperature: float = 0.0, max_new_tokens: int = 256) -> str:
    msgs = [[
        {"role": "system", "content": SYSTEM_VERIFY},
        {"role": "user", "content": user_prompt_verify(problem, answer)},
    ]]
    outs = llm_generate_texts(
        msgs,
        temperature=float(temperature),
        top_p=1.0,
        top_k=0,
        max_new_tokens=int(max_new_tokens),
    )
    return outs[0] if outs else ""


## CELL H — LOCAL HARNESS (Reference CSV Regression)

In [18]:
# ============================================================
# CELL H — SOLVER (TWO-STAGE + WEIGHTED VOTING + RETRY)
# ============================================================

def build_candidates_from_direct(outputs: List[str]) -> List[Candidate]:
    cands: List[Candidate] = []
    for t in outputs:
        ans = extract_final_answer(t)
        if ans is None:
            continue
        # Weight heuristics: explicit FINAL_ANSWER / boxed gets higher
        w = 0.6
        if re.search(r"FINAL_ANSWER\s*[:=]\s*-?\d+", t, re.IGNORECASE):
            w = 0.9
        if re.search(r"\\boxed\{\s*-?\d+\s*\}", t):
            w = max(w, 0.85)
        cands.append(Candidate(answer=int(ans), source="direct", weight=float(w), meta={"raw": t[-500:]}))
    return cands

def build_candidates_from_code(outputs: List[str], exec_timeout_sec: int, max_retries: int) -> List[Candidate]:
    cands: List[Candidate] = []
    for raw in outputs:
        code = extract_python_code(raw)
        if not code:
            # fallback: if it still gave FINAL_ANSWER
            ans = extract_final_answer(raw)
            if ans is not None:
                cands.append(Candidate(answer=int(ans), source="direct_fallback", weight=0.55, meta={"raw": raw[-500:]}))
            continue

        cur_code = code
        last_err = None
        for attempt in range(max_retries):
            ok, out, err = run_python_sandbox(cur_code, timeout_sec=exec_timeout_sec)
            if ok:
                ans = parse_int_from_stdout(out)
                if ans is not None:
                    cands.append(Candidate(
                        answer=int(ans),
                        source="code_exec",
                        weight=2.0,
                        meta={"stdout_tail": out[-300:], "attempt": attempt, "raw": raw[-400:]},
                    ))
                else:
                    # ran but output not parseable => low weight (likely noisy prints)
                    a2 = extract_final_answer(out)
                    if a2 is not None:
                        cands.append(Candidate(answer=int(a2), source="code_exec_loose", weight=1.0, meta={"stdout_tail": out[-300:]}))
                break

            last_err = (err or "").strip()
            # If retry available, ask model to fix code using error feedback (single-shot in-place)
            if attempt < max_retries - 1:
                fix_msgs = [[
                    {"role": "system", "content": SYSTEM_CODE},
                    {"role": "user", "content": (
                        "Fix the Python code below. It failed with an error.\n"
                        "Return ONLY a corrected ```python``` code block that prints ONLY the integer answer.\n\n"
                        f"ERROR:\n{last_err}\n\n"
                        f"CODE:\n```python\n{cur_code}\n```"
                    )},
                ]]
                fixed = llm_generate_texts(
                    fix_msgs,
                    temperature=0.2,
                    top_p=float(TOP_P),
                    top_k=int(TOP_K),
                    max_new_tokens=1024,
                )
                if fixed:
                    new_code = extract_python_code(fixed[0])
                    if new_code:
                        cur_code = new_code
                        continue
            # no more retries or couldn't fix
            break

        if last_err and (not cands or cands[-1].meta.get("raw") != raw[-400:]):
            # Keep a trace in meta? skip (no candidate)
            pass

    return cands

def maybe_verify(problem: str, top_answer: int, budget_ok: bool) -> List[Candidate]:
    if not budget_ok:
        return []
    txt = generate_verify(problem, top_answer, temperature=0.0, max_new_tokens=256)
    t = strip_think(txt)
    if "VERDICT: OK" in t:
        return [Candidate(answer=int(top_answer), source="verifier_ok", weight=1.2, meta={"raw": t[-400:]})]
    if "VERDICT: FAIL" in t:
        alt = extract_final_answer(t)
        if alt is not None and int(alt) != int(top_answer):
            return [Candidate(answer=int(alt), source="verifier_fix", weight=1.0, meta={"raw": t[-400:]})]
    return []

def solve_one(problem: str, dm: Optional[DeadlineManager] = None, problems_left: int = 50) -> int:
    """
    Fresh solver: Two-stage gating + weighted voting + optional verification.
    """
    if dm is None:
        dm = DeadlineManager()

    # Dynamic per-problem budget (soft; we still rely on token + executor limits)
    per_budget = dm.per_problem_budget(problems_left=problems_left, min_sec=12, max_sec=180)

    t_start = time.time()
    cands: List[Candidate] = []

    # -----------------------
    # STAGE A: cheap direct
    # -----------------------
    K1 = max(2, min(K_BASE, 4))
    direct_outs = generate_n_direct(problem, n=K1, temperature=float(TEMPERATURE_BASE), max_new_tokens=min(768, int(MAX_NEW_TOKENS)))
    cands += build_candidates_from_direct(direct_outs)

    best, score_map = weighted_vote(cands)
    strength = consensus_strength(score_map)

    # If strong enough, verify quickly then return
    if best is not None and strength >= 0.72:
        cands += maybe_verify(problem, best, budget_ok=(dm.can_continue() and (time.time() - t_start) < per_budget))
        best2, _ = weighted_vote(cands)
        return int(best2 if best2 is not None else best)

    # -----------------------
    # STAGE B: code (TIR)
    # -----------------------
    # Spend time only if within per_budget
    if (time.time() - t_start) < per_budget and dm.can_continue():
        # Fewer but heavier attempts; code stage weights dominate when successful.
        K2 = max(1, min(K_BASE, 3))
        code_outs = generate_n_code(problem, n=K2, temperature=0.2, max_new_tokens=int(MAX_NEW_TOKENS))
        cands += build_candidates_from_code(code_outs, exec_timeout_sec=int(EXEC_TIMEOUT_SEC), max_retries=int(AGENT_MAX_RETRIES))

    best, score_map = weighted_vote(cands)
    strength = consensus_strength(score_map)

    # If still weak, expand direct a bit (diverse) within budget
    if best is not None and strength < 0.68 and (time.time() - t_start) < per_budget and dm.can_continue():
        K3 = max(2, min(K_BASE, 6))
        direct_outs2 = generate_n_direct(problem, n=K3, temperature=min(0.9, float(TEMPERATURE_BASE) + 0.2), max_new_tokens=min(1024, int(MAX_NEW_TOKENS)))
        cands += build_candidates_from_direct(direct_outs2)
        best, score_map = weighted_vote(cands)
        strength = consensus_strength(score_map)

    # Optional verify if we have *any* answer and budget allows
    if best is not None and (time.time() - t_start) < per_budget and dm.can_continue():
        cands += maybe_verify(problem, best, budget_ok=True)
        best2, _ = weighted_vote(cands)
        return int(best2 if best2 is not None else best)

    # Fallback: if nothing parseable, return 0
    if best is None:
        return 0
    return int(best)


## CELL I — SUBMISSION GLUE (Kaggle Evaluation API)

In [19]:
# ============================================================
# CELL I — TOOLING: PARSING / SAFE EXEC / VOTING / BUDGET
# ============================================================

import re, sys, os, time, math, tempfile, subprocess
from collections import Counter, defaultdict

FINAL_ANSWER_RE = re.compile(r"FINAL_ANSWER\s*:\s*([-+]?\d+)", re.IGNORECASE)
BOXED_RE = re.compile(r"\\boxed\{([^}]*)\}")
INT_RE = re.compile(r"[-+]?\d+")

def _strip_think(text: str) -> str:
    # remove <think>...</think> if present
    return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL | re.IGNORECASE)

def extract_final_answer_int(text: str):
    if not text:
        return None
    m = FINAL_ANSWER_RE.search(text)
    if m:
        try:
            return int(m.group(1))
        except:
            return None
    return None

def extract_boxed_int(text: str):
    if not text:
        return None
    m = BOXED_RE.search(text)
    if not m:
        return None
    inside = m.group(1)
    ints = INT_RE.findall(inside)
    if not ints:
        return None
    try:
        return int(ints[-1])
    except:
        return None

def extract_last_int(text: str):
    if not text:
        return None
    ints = INT_RE.findall(text)
    if not ints:
        return None
    try:
        return int(ints[-1])
    except:
        return None

def extract_python_blocks(text: str):
    """
    Return list of code blocks in order. Handles ```python ... ``` and ``` ... ```.
    """
    if not text:
        return []
    blocks = []
    # Prefer python fenced blocks
    for m in re.finditer(r"```python\s*(.*?)```", text, flags=re.DOTALL | re.IGNORECASE):
        blocks.append(m.group(1).strip())
    # Fallback generic fenced blocks if no python blocks
    if not blocks:
        for m in re.finditer(r"```\s*(.*?)```", text, flags=re.DOTALL):
            blocks.append(m.group(1).strip())
    return [b for b in blocks if b]

_FORBIDDEN_PATTERNS = [
    r"\bimport\s+os\b", r"\bimport\s+sys\b", r"\bimport\s+subprocess\b",
    r"\bfrom\s+os\b", r"\bfrom\s+sys\b", r"\bfrom\s+subprocess\b",
    r"\bopen\s*\(", r"\beval\s*\(", r"\bexec\s*\(",
    r"\b__import__\b", r"\bpathlib\b", r"\bshutil\b",
    r"\bsocket\b", r"\brequests\b", r"\burllib\b"
]

def sanitize_python(code: str) -> str:
    """
    Soft-sanitize: comment out forbidden lines, keep the rest.
    (We don't hard reject because LLM hay lỡ import os.)
    """
    if not code:
        return code
    out_lines = []
    for line in code.splitlines():
        s = line.strip()
        bad = any(re.search(p, s) for p in _FORBIDDEN_PATTERNS)
        if bad:
            out_lines.append("# [blocked] " + line)
        else:
            out_lines.append(line)
    return "\n".join(out_lines)

def execute_python_sandbox(code: str, timeout_sec: int = 10):
    """
    Run code in a subprocess, capture stdout/stderr, enforce timeout.
    Return: (ok: bool, stdout: str, stderr: str)
    """
    if not code or not code.strip():
        return (False, "", "empty_code")

    code = sanitize_python(code)

    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
        f.write(code)
        tmp_path = f.name

    try:
        r = subprocess.run(
            [sys.executable, tmp_path],
            capture_output=True,
            text=True,
            timeout=timeout_sec
        )
        ok = (r.returncode == 0)
        return (ok, (r.stdout or "").strip(), (r.stderr or "").strip())
    except subprocess.TimeoutExpired:
        return (False, "", "timeout")
    except Exception as e:
        return (False, "", f"exec_error: {e}")
    finally:
        try:
            os.remove(tmp_path)
        except:
            pass

def parse_int_from_stdout(stdout: str):
    """
    Contract: we take LAST non-empty line, then extract last int on that line.
    If still none, extract last int anywhere in stdout.
    """
    if not stdout:
        return None
    lines = [ln.strip() for ln in stdout.splitlines() if ln.strip()]
    if lines:
        v = extract_last_int(lines[-1])
        if v is not None:
            return v
    return extract_last_int(stdout)

def weighted_vote(cands):
    """
    cands: list of dicts:
      { "answer": int|None, "weight": float, "meta": ... }
    Return best_answer or None.
    """
    score = defaultdict(float)
    for c in cands:
        a = c.get("answer", None)
        w = float(c.get("weight", 0.0))
        if a is None:
            continue
        score[a] += w
    if not score:
        return None, {}
    best = max(score.items(), key=lambda kv: kv[1])[0]
    return best, dict(score)

def make_time_budget(num_problems: int, total_seconds: int = 5*60*60, safety_margin: float = 0.10):
    """
    Very simple budget per problem for competition run.
    """
    usable = int(total_seconds * (1.0 - safety_margin))
    per = max(15, usable // max(1, num_problems))
    return per


## CELL J — SELF TESTS

In [20]:
# ============================================================
# CELL J — AGENTIC SOLVER (8B-friendly) — NEW PIPELINE
# ============================================================

def _truncate_problem(text: str, max_chars: int = 12000) -> str:
    text = text or ""
    text = text.strip()
    if len(text) <= max_chars:
        return text
    # keep end part too (many problems place constraints late)
    head = text[: int(max_chars*0.70)]
    tail = text[-int(max_chars*0.30):]
    return head + "\n...\n" + tail

def build_prompt_fast(problem: str) -> str:
    # SINGLE LINE contract => parser won't catch random "2, 3, 20..."
    return (
        "You are solving a math contest problem.\n"
        "Return ONLY one line in this exact format:\n"
        "FINAL_ANSWER: <integer>\n"
        "No other text.\n\n"
        f"PROBLEM:\n{problem}\n"
    )

def build_prompt_tir(problem: str, feedback: str = "") -> str:
    return (
        "You are an expert mathematician using Python as a calculator.\n"
        "You MUST solve by writing Python code.\n"
        "Rules:\n"
        "1) Provide exactly one python code block fenced as ```python ... ```.\n"
        "2) The code MUST print ONLY the integer answer (one line, no debug).\n"
        "3) Avoid file/network/system operations.\n"
        "4) After the code block, output exactly one line: FINAL_ANSWER: <integer>\n"
        "   (the integer must match what your code prints)\n"
        f"{('FEEDBACK_FROM_PREVIOUS_ATTEMPT:\\n' + feedback + '\\n') if feedback else ''}\n"
        f"PROBLEM:\n{problem}\n"
    )

class AgenticSolver8B:
    def __init__(
        self,
        k_base: int = 6,
        temp: float = 0.6,
        top_p: float = 0.95,
        top_k: int = 50,
        max_new_tokens: int = 2048,
        max_retries: int = 3,
        exec_timeout_sec: int = 10,
    ):
        self.k_base = int(k_base)
        self.temp = float(temp)
        self.top_p = float(top_p)
        self.top_k = int(top_k)
        self.max_new_tokens = int(max_new_tokens)
        self.max_retries = int(max_retries)
        self.exec_timeout_sec = int(exec_timeout_sec)

    def solve_one(self, problem_text: str, budget_sec: int = 120):
        problem = _truncate_problem(problem_text)

        candidates = []

        # ---------- Stage A: FAST DIRECT (cheap) ----------
        # Use small K early; if consensus -> done
        k_fast = min(2, self.k_base)
        fast_prompt = build_prompt_fast(problem)
        fast_prompts = [fast_prompt] * max(1, k_fast)

        fast_texts = llm_generate_texts(
            fast_prompts,
            max_new_tokens=min(256, self.max_new_tokens),
            temperature=max(0.2, min(self.temp, 0.7)),
            top_p=self.top_p,
            top_k=self.top_k,
        )

        for t in fast_texts:
            t = _strip_think(t)
            a = extract_final_answer_int(t)
            if a is None:
                a = extract_boxed_int(t)
            if a is None:
                a = extract_last_int(t)
            candidates.append({"answer": a, "weight": 0.8, "meta": {"stage": "fast"}})

        best_a, scoremap = weighted_vote(candidates)
        if best_a is not None:
            # if strong consensus among fast answers, accept
            counts = Counter([c["answer"] for c in candidates if c["answer"] is not None])
            if counts and counts[best_a] >= 2:
                return int(best_a), {"stage": "fast_consensus", "scoremap": scoremap}

        # ---------- Stage B: TIR (python) with retries ----------
        feedback = ""
        start = time.time()

        for attempt in range(self.max_retries):
            # Budget-aware: if running out, reduce K
            elapsed = time.time() - start
            remaining = max(5.0, budget_sec - elapsed)
            if remaining < 25:
                k = 2
            else:
                k = self.k_base

            tir_prompt = build_prompt_tir(problem, feedback=feedback)
            tir_prompts = [tir_prompt] * max(1, k)

            texts = llm_generate_texts(
                tir_prompts,
                max_new_tokens=self.max_new_tokens,
                temperature=self.temp,
                top_p=self.top_p,
                top_k=self.top_k,
            )

            new_errs = []
            for t in texts:
                raw = _strip_think(t)

                # 1) Try python execution first (highest weight)
                blocks = extract_python_blocks(raw)
                exec_ok = False
                exec_ans = None
                if blocks:
                    ok, out, err = execute_python_sandbox(blocks[0], timeout_sec=self.exec_timeout_sec)
                    exec_ok = ok
                    if ok:
                        exec_ans = parse_int_from_stdout(out)
                        if exec_ans is not None:
                            candidates.append({"answer": int(exec_ans), "weight": 3.5, "meta": {"stage": "tir_exec_ok"}})
                        else:
                            candidates.append({"answer": None, "weight": 0.0, "meta": {"stage": "tir_exec_no_int", "out": out[:200]}})
                            new_errs.append("python_ran_but_no_integer_output")
                    else:
                        candidates.append({"answer": None, "weight": 0.0, "meta": {"stage": "tir_exec_fail", "err": err[:200]}})
                        new_errs.append(err[:200] if err else "python_exec_failed")

                # 2) Fallback parse from contract line
                if exec_ans is None:
                    a = extract_final_answer_int(raw)
                    if a is None:
                        a = extract_boxed_int(raw)
                    if a is None:
                        a = extract_last_int(raw)
                    candidates.append({"answer": a, "weight": 0.6, "meta": {"stage": "tir_text_fallback"}})

            best, scoremap = weighted_vote(candidates)
            if best is not None:
                # if top score is comfortably above others, accept
                sorted_scores = sorted(scoremap.items(), key=lambda kv: kv[1], reverse=True)
                if len(sorted_scores) == 1 or sorted_scores[0][1] >= 1.5 * sorted_scores[1][1]:
                    return int(best), {"stage": f"tir_vote_attempt_{attempt}", "scoremap": scoremap}

            feedback = " ; ".join(new_errs[:3])  # keep short

        # Last fallback
        best, scoremap = weighted_vote(candidates)
        if best is not None:
            return int(best), {"stage": "final_vote_fallback", "scoremap": scoremap}
        return 0, {"stage": "fail_all"}

# Instantiate solver (uses globals from CELL A: K_BASE, TEMPERATURE_BASE, TOP_P, TOP_K, MAX_NEW_TOKENS, AGENT_MAX_RETRIES, EXEC_TIMEOUT_SEC)
solver = AgenticSolver8B(
    k_base=K_BASE,
    temp=TEMPERATURE_BASE,
    top_p=TOP_P,
    top_k=TOP_K,
    max_new_tokens=MAX_NEW_TOKENS,
    max_retries=AGENT_MAX_RETRIES,
    exec_timeout_sec=EXEC_TIMEOUT_SEC,
)


## CELL K — MAIN EXECUTION

In [21]:
# ================================
# CELL K — SUBMISSION GLUE (Kaggle Evaluation API)
# ================================

import os
import sys
import pandas as pd
from typing import Union

# Make kaggle_evaluation importable (robust across notebook/competition)
for p in ["/kaggle/input/kaggle-evaluation", "/kaggle/input", ".", ".."]:
    if os.path.exists(os.path.join(p, "kaggle_evaluation")):
        sys.path.insert(0, p)
        break

def predict(test_input: Union[pd.DataFrame, dict, pd.Series]) -> pd.DataFrame:
    """
    Kaggle prediction endpoint.
    Input must contain: id, problem
    Output must contain: id, answer
    """
    if isinstance(test_input, dict):
        test_df = pd.DataFrame([test_input])
    elif isinstance(test_input, pd.Series):
        test_df = pd.DataFrame([test_input.to_dict()])
    elif isinstance(test_input, pd.DataFrame):
        test_df = test_input
    else:
        raise ValueError(f"predict() expects DataFrame/dict/Series, got {type(test_input)}")

    required = {"id", "problem"}
    missing = required - set(test_df.columns)
    if missing:
        raise ValueError(f"Input missing required columns: {missing}")

    cfg = globals().get("GLOBAL_CONFIG", {}) or {}

    rows = []
    for _, r in test_df.iterrows():
        pid = str(r["id"])
        prob = str(r["problem"])

        # solve_problem + log_telemetry must be defined in earlier cells (A–J)
        ans, telemetry = solve_problem(pid, prob, config=cfg)
        try:
            log_telemetry(telemetry)
        except Exception:
            pass

        rows.append({"id": pid, "answer": int(ans)})

    return pd.DataFrame(rows)

def setup_and_serve():
    """
    Start Kaggle inference server.
    - Competition rerun: serve()
    - Local run: try run_local_gateway() if available
    """
    try:
        from kaggle_evaluation.aimo_3_inference_server import AIMO3InferenceServer

        # Warm load model once (must exist in earlier cells)
        load_model()

        server = AIMO3InferenceServer(predict)

        if os.getenv("KAGGLE_IS_COMPETITION_RERUN") == "1":
            print("Starting inference server (competition mode)...")
            server.serve()
        else:
            print("Attempting local gateway (debug mode)...")
            if hasattr(server, "run_local_gateway"):
                try:
                    server.run_local_gateway()
                except Exception as e:
                    print(f"run_local_gateway failed: {e}")
                    print("Falling back to serve()...")
                    server.serve()
            else:
                print("run_local_gateway not available, using serve()...")
                server.serve()

    except Exception as e:
        print(f"Could not start kaggle_evaluation server: {e}")
        print("If you're in local_ref/debug mode, this is OK.")

print("Submission glue initialized.")
print("Call setup_and_serve() to start the Kaggle evaluation server.")


Submission glue initialized.
Call setup_and_serve() to start the Kaggle evaluation server.
