In [1]:
# ============================================================
# AIMO3 CLEAN SUBMISSION — Qwen-3 30B-A3B Thinking (H100)
# - Setup / warm libs (optional) + cache model files into OS page cache
# - vLLM (Python API) + chat-template prompts
# - TIR with <tool:python> blocks
# - Jupyter kernel pool (stateful) + wall-clock timeout + interrupt
# - Dynamic time manager
# - Upgraded answer selection: weighted vote + verify-on-uncertainty + deterministic selector
#
# NOTE:
# - Kaggle AIMO3 predict signature uses pl.Series and returns pl.DataFrame/pd.DataFrame
# ============================================================

from __future__ import annotations

import os
import re
import time
import math
import json
import queue
import threading
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager

# =========================
# CELL 1/13 — CONFIG
# =========================
from __future__ import annotations

import os, re, time, queue, threading
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from contextlib import contextmanager
from concurrent.futures import ThreadPoolExecutor, as_completed

MODEL_PATH = "/kaggle/input/qwen-3/transformers/30b-a3b-thinking-2507-fp8/1"
SEED = int(os.getenv("SEED", "42"))

# Warm/cache
WARMUP_USR_LIB = os.getenv("WARMUP_USR_LIB", "0") == "1"
CACHE_MODEL_FILES = os.getenv("CACHE_MODEL_FILES", "1") == "1"
CACHE_CHUNK_MB = int(os.getenv("CACHE_CHUNK_MB", "512"))
CACHE_WORKERS = int(os.getenv("CACHE_WORKERS", "8"))
CACHE_EXTS = (".safetensors", ".bin", ".pt")

# Time (5h - 5m default)
HARD_WALL_SECONDS = int(os.getenv("HARD_WALL_SECONDS", str(5 * 60 * 60 - 5 * 60)))
TOTAL_QUESTIONS = int(os.getenv("TOTAL_QUESTIONS", "110"))   # override if needed
MIN_BUDGET_S = float(os.getenv("MIN_BUDGET_S", "12"))
MAX_BUDGET_S = float(os.getenv("MAX_BUDGET_S", "420"))

# Tool loop
KERNEL_POOL_SIZE = int(os.getenv("KERNEL_POOL_SIZE", "8"))
TOOL_TIMEOUT_S = float(os.getenv("TOOL_TIMEOUT_S", "4.0"))
MAX_TURNS = int(os.getenv("MAX_TURNS", "8"))

# Adaptive sampling
STAGE1_BATCH = int(os.getenv("STAGE1_BATCH", "4"))
STAGE2_BATCH = int(os.getenv("STAGE2_BATCH", "4"))
CONFIDENT_RATIO = float(os.getenv("CONFIDENT_RATIO", "0.78"))
VERIFY_RATIO = float(os.getenv("VERIFY_RATIO", "0.66"))
VERIFY_TOP_N = int(os.getenv("VERIFY_TOP_N", "3"))

# Generation
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "16384"))
DTYPE = os.getenv("DTYPE", "bfloat16")
GPU_MEM_UTIL = float(os.getenv("GPU_MEM_UTIL", "0.92"))

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
# =========================
# CELL 2/13 — WARMUP HELPERS
# =========================
def warmup_usr_lib() -> None:
    import subprocess
    cmd = "find /kaggle/usr/lib -type f -print0 | xargs -0 -P 32 -n 500 cat > /dev/null"
    subprocess.run(cmd, shell=True, check=False)

def cache_model(path: str, exts=CACHE_EXTS, num_workers: int = 8, chunk_mb: int = 256) -> None:
    import os, multiprocessing

    def warmup_file(fpath: str) -> int:
        chunk = chunk_mb * 1024 * 1024
        total = 0
        with open(fpath, "rb") as f:
            while True:
                b = f.read(chunk)
                if not b:
                    break
                total += len(b)
        return total

    if not os.path.isdir(path):
        return

    files = [
        os.path.join(root, name)
        for root, _, names in os.walk(path)
        for name in names
        if name.endswith(exts)
    ]
    if not files:
        return

    try:
        cpu = multiprocessing.cpu_count()
    except Exception:
        cpu = 4
    num_workers = max(1, min(num_workers, cpu, 16))
    files.sort(key=lambda f: os.path.getsize(f), reverse=True)

    t0 = time.time()
    total = 0
    with ThreadPoolExecutor(max_workers=num_workers) as ex:
        futs = [ex.submit(warmup_file, f) for f in files]
        for fut in as_completed(futs):
            total += fut.result()

    if not os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
        gb = total / 1024**3
        print(f"[cache_model] warmed ~{gb:.2f} GB in {time.time()-t0:.1f}s")

if WARMUP_USR_LIB:
    warmup_usr_lib()

if CACHE_MODEL_FILES:
    cache_model(MODEL_PATH, num_workers=CACHE_WORKERS, chunk_mb=CACHE_CHUNK_MB)


[cache_model] warmed ~29.03 GB in 64.9s


In [3]:
# =========================
# CELL 3/13 — IMPORTS (FIX vLLM missing: no crash)
# =========================
import numpy as np
import pandas as pd

try:
    import polars as pl
except Exception:
    pl = None

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from jupyter_client import KernelManager

set_seed(SEED)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


2025-12-31 09:53:41.647063: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767174821.662449     635 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767174821.667005     635 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1767174821.678832     635 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767174821.678850     635 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767174821.678852     635 computation_placer.cc:177] computation placer alr

In [4]:
# =========================
# CELL 4/13 — PARSERS + UTILS
# =========================
_THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL)
_BOXED_RE = re.compile(r"\\boxed\{([^}]*)\}")
_TOOL_RE = re.compile(r"<tool:python>\s*(.*?)\s*</tool:python>", re.DOTALL)

def strip_think(text: str) -> str:
    text = _THINK_RE.sub("", text)
    if "</think>" in text:
        text = text.split("</think>", 1)[-1]
    return text.strip()

def parse_boxed_int(text: str) -> Optional[int]:
    text = strip_think(text)
    m = _BOXED_RE.search(text)
    if not m:
        return None
    raw = m.group(1).strip()
    if not re.fullmatch(r"[+-]?\d+", raw):
        return None
    v = int(raw)
    return v if 0 <= v <= 99999 else None

def parse_tool_code(text: str) -> Optional[str]:
    text = strip_think(text)
    m = _TOOL_RE.search(text)
    return m.group(1).strip() if m else None

def fallback_last_int(text: str) -> Optional[int]:
    text = strip_think(text)
    nums = re.findall(r"[-+]?\d+", text)
    if not nums:
        return None
    try:
        v = int(nums[-1])
    except Exception:
        return None
    return v if 0 <= v <= 99999 else None

def mod100000(x: int) -> int:
    return int(x) % 100000

def clamp(x: float, lo: float, hi: float) -> float:
    return float(min(hi, max(lo, x)))

def trim_history(hist: List[Tuple[str, str]], max_items: int = 10) -> List[Tuple[str, str]]:
    return hist if len(hist) <= max_items else hist[-max_items:]


In [5]:
# =========================
# CELL 5/13 — DIFFICULTY ROUTER
# =========================
@dataclass(frozen=True)
class Plan:
    tag: str
    budget_weight: float
    stage1_max_k: int
    stage2_max_k: int
    stage1_max_tokens: int
    stage2_max_tokens: int
    temp1: float
    temp2: float
    top_p1: float
    top_p2: float

def route_problem(problem: str) -> Plan:
    p = (problem or "").lower()

    # Geometry
    if any(k in p for k in ["triangle", "circle", "radius", "angle", "tangent", "perpendicular", "circum", "inscribed"]):
        return Plan("GEO", 1.15, 6, 18, 900, 1700, 0.55, 0.75, 0.92, 0.90)

    # Number theory
    if any(k in p for k in ["mod", "congruen", "prime", "gcd", "lcm", "divis", "remainder", "coprime", "valuation", "phi("]):
        return Plan("NT", 1.25, 6, 22, 950, 1900, 0.55, 0.78, 0.92, 0.90)

    # Functional equations
    if any(k in p for k in ["f(", "functional", "for all real", "for all integers", "for all x", "for all n", "satisfies"]):
        return Plan("FUNC", 1.20, 6, 20, 950, 1900, 0.55, 0.78, 0.92, 0.90)

    # Probability / expected value
    if any(k in p for k in ["probability", "expected", "random", "uniform", "dice", "coin", "distribution"]):
        return Plan("PROB", 1.10, 6, 18, 900, 1700, 0.55, 0.75, 0.92, 0.90)

    # Combinatorics
    if any(k in p for k in ["ways", "choose", "arrangements", "permutation", "combination", "graph", "color", "pigeonhole", "invariant"]):
        return Plan("COMB", 1.15, 6, 20, 900, 1800, 0.55, 0.78, 0.92, 0.90)

    # Default algebra
    return Plan("ALG", 1.00, 6, 18, 850, 1600, 0.50, 0.72, 0.92, 0.90)


In [6]:
# =========================
# CELL 6/13 — PROMPTS
# =========================
SYSTEM_TIR = (
    "You are an olympiad math solver.\n"
    "You MUST follow this protocol:\n\n"
    "If you need computation, output exactly:\n"
    "<tool:python>\n"
    "# python code\n"
    "</tool:python>\n\n"
    "If you are ready to answer, output exactly ONE line:\n"
    "\\boxed{NONNEGATIVE_INTEGER}\n\n"
    "Rules:\n"
    "- Output NOTHING else outside the tool block.\n"
    "- Final answer must be an integer in [0, 99999].\n"
    "- Prefer verifying with python when possible.\n"
)

SYSTEM_VERIFY = (
    "You are a strict verifier.\n"
    "Given a problem and proposed integer answer A, DISPROVE it quickly.\n"
    "Use python checks when possible:\n"
    "- parity constraints\n"
    "- modular constraints (including mod 100000 if relevant)\n"
    "- substitution / brute force small cases / random tests\n\n"
    "Protocol:\n"
    "- You may output <tool:python>...</tool:python> blocks.\n"
    "- Then output EXACTLY one final line: PASS or FAIL or UNKNOWN\n"
    "- No extra text.\n"
)

SYSTEM_SELECT = (
    "You are a selector.\n"
    "Pick the most reliable candidate answer based on evidence.\n"
    "Output EXACTLY one line: \\boxed{NONNEGATIVE_INTEGER}\n"
    "No extra text.\n"
)

HINTS = [
    "Tool-first: explore small cases in python, infer pattern, verify, then output boxed.",
    "Proof-first: derive symbolic structure, then minimal python verification, output boxed.",
    "Number-theory: use modular constraints/parity/gcd; python to test; output boxed.",
    "Comb/Prob: use invariants or counting; python to validate small n; output boxed.",
]


In [7]:
# =========================
# CELL 7/13 — PROMPT BUILDER (chat template)
# =========================
class PromptBuilder:
    def __init__(self, model_path: str):
        self.tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    def render(self, system: str, user: str, history: List[Tuple[str, str]]) -> str:
        msgs = [{"role": "system", "content": system}, {"role": "user", "content": user}]
        for r, c in history:
            msgs.append({"role": r, "content": c})
        try:
            return self.tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        except Exception:
            out = [f"[SYSTEM]\n{system}\n", f"[USER]\n{user}\n"]
            for r, c in history:
                out.append(f"[{r.upper()}]\n{c}\n")
            out.append("[ASSISTANT]\n")
            return "\n".join(out)


In [8]:
# =========================
# CELL 8/13 — PYTHON TOOL (KernelPool) [FIXED for Kaggle]
# - Remove/avoid colab_kernel_launcher arg parsing issues
# - Force child kernels to run via ipykernel_launcher
# =========================
import sys
import queue
import threading
from contextlib import contextmanager
from typing import List, Tuple

from jupyter_client import KernelManager

class PythonKernel:
    def __init__(self) -> None:
        # Do NOT rely on the notebook's kernelspec (may route via colab_kernel_launcher)
        self.km = KernelManager()
        # Force a standard ipykernel process
        self.km.kernel_cmd = [sys.executable, "-m", "ipykernel_launcher", "-f", "{connection_file}"]

        self.kc = None
        self._lock = threading.Lock()

    def start(self) -> None:
        self.km.start_kernel()  # <- NO extra_arguments like "-q"
        self.kc = self.km.client()
        self.kc.start_channels()

        # Wait until kernel is ready (important for stability)
        try:
            self.kc.wait_for_ready(timeout=5)
        except Exception:
            pass

        # Lightweight prelude (sympy optional but useful)
        self.execute(
            "import math, itertools, random\n"
            "try:\n"
            "    import sympy as sp\n"
            "except Exception:\n"
            "    sp = None\n",
            timeout_s=2.0
        )

    def is_alive(self) -> bool:
        try:
            return self.km.is_alive()
        except Exception:
            return False

    def interrupt(self) -> None:
        try:
            self.km.interrupt_kernel()
        except Exception:
            pass

    def shutdown(self) -> None:
        try:
            if self.kc:
                self.kc.stop_channels()
        finally:
            try:
                self.km.shutdown_kernel(now=True)
            except Exception:
                pass

    def execute(self, code: str, timeout_s: float) -> str:
        assert self.kc is not None, "Kernel not started"
        with self._lock:
            msg_id = self.kc.execute(code, store_history=False, allow_stdin=False)
            start = time.time()
            out_lines: List[str] = []

            while True:
                if time.time() - start >= timeout_s:
                    self.interrupt()
                    out_lines.append("[PYTHON_TIMEOUT]")
                    break
                try:
                    msg = self.kc.get_iopub_msg(timeout=0.2)
                except Exception:
                    continue

                if msg.get("parent_header", {}).get("msg_id") != msg_id:
                    continue

                msg_type = msg.get("msg_type", "")
                content = msg.get("content", {})

                if msg_type == "stream":
                    t = content.get("text", "")
                    if t:
                        out_lines.append(t.rstrip("\n"))
                elif msg_type in ("display_data", "execute_result"):
                    data = content.get("data", {})
                    if "text/plain" in data:
                        out_lines.append(str(data["text/plain"]))
                elif msg_type == "error":
                    tb = content.get("traceback", [])
                    out_lines.append("\n".join(tb[-3:]) if tb else "[PYTHON_ERROR]")
                elif msg_type == "status" and content.get("execution_state") == "idle":
                    break

            text = "\n".join([x for x in out_lines if x])
            if len(text) > 2000:
                text = text[:2000] + "\n[...TRUNCATED...]"
            return text.strip() if text.strip() else "[WARN] No output. Use print()."

class KernelPool:
    def __init__(self, size: int):
        self.q: "queue.Queue[PythonKernel]" = queue.Queue()
        self.all: List[PythonKernel] = []
        size = max(1, int(size))

        for _ in range(size):
            k = PythonKernel()
            try:
                k.start()
            except Exception:
                # If a kernel fails to start, shutdown and retry once
                try:
                    k.shutdown()
                except Exception:
                    pass
                k = PythonKernel()
                k.start()

            self.q.put(k)
            self.all.append(k)

    @contextmanager
    def acquire(self) -> PythonKernel:
        k = self.q.get()
        try:
            if not k.is_alive():
                k.shutdown()
                k = PythonKernel()
                k.start()
            yield k
        finally:
            self.q.put(k)

    def close(self) -> None:
        while not self.q.empty():
            try:
                self.q.get_nowait()
            except Exception:
                break
        for k in self.all:
            k.shutdown()


In [9]:
# =========================
# CELL 9/13 — BACKEND (vLLM if available else HF) ✅ robust on Kaggle
# - vLLM optional
# - HF uses sdpa by default; only enable flash_attention_2 if flash_attn exists
# =========================
import importlib.util

class BackendBase:
    def generate(self, prompts: List[str], *, temperature: float, top_p: float, max_tokens: int) -> List[str]:
        raise NotImplementedError

class HFBackend(BackendBase):
    def __init__(self, model_path: str, max_model_len: int):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

        # Prefer PyTorch SDPA (native, no extra package). Only use FA2 if flash_attn is installed.
        has_flash_attn = importlib.util.find_spec("flash_attn") is not None
        attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"

        base_kwargs = dict(
            dtype=torch.bfloat16,          # use dtype (torch_dtype is deprecated)
            device_map="auto",
            trust_remote_code=True,
        )

        # Try with chosen attention impl; if anything blows up, retry safely without forcing it.
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                attn_implementation=attn_impl,
                **base_kwargs,
            )
        except Exception as e:
            # FA2 often fails when flash_attn isn't present; fallback to default attention.
            self.model = AutoModelForCausalLM.from_pretrained(model_path, **base_kwargs)

        self.model.eval()
        self.max_model_len = int(max_model_len)

    @torch.inference_mode()
    def generate(self, prompts: List[str], *, temperature: float, top_p: float, max_tokens: int) -> List[str]:
        outs: List[str] = []
        bs = 8
        do_sample = temperature > 1e-6

        for i in range(0, len(prompts), bs):
            batch = prompts[i : i + bs]
            enc = self.tokenizer(
                batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_model_len,
            ).to(self.model.device)

            gen = self.model.generate(
                **enc,
                do_sample=do_sample,
                temperature=float(max(1e-6, temperature)),
                top_p=float(top_p),
                max_new_tokens=int(max_tokens),
                use_cache=True,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

            for j in range(gen.shape[0]):
                prompt_len = int(enc["attention_mask"][j].sum().item())
                tail = gen[j, prompt_len:]
                outs.append(self.tokenizer.decode(tail, skip_special_tokens=True))

        return outs

_backend_name = "hf"
_backend: BackendBase

try:
    from vllm import LLM, SamplingParams  # type: ignore

    class VLLMBackend(BackendBase):
        def __init__(self, model_path: str):
            self.llm = LLM(
                model=model_path,
                dtype=DTYPE,
                max_model_len=MAX_MODEL_LEN,
                gpu_memory_utilization=GPU_MEM_UTIL,
                tensor_parallel_size=1,
                trust_remote_code=True,
            )

        def generate(self, prompts: List[str], *, temperature: float, top_p: float, max_tokens: int) -> List[str]:
            sp = SamplingParams(
                temperature=float(temperature),
                top_p=float(top_p),
                max_tokens=int(max_tokens),
            )
            out = self.llm.generate(prompts, sp)
            return [o.outputs[0].text for o in out]

    _backend = VLLMBackend(MODEL_PATH)
    _backend_name = "vllm"

except Exception:
    _backend = HFBackend(MODEL_PATH, max_model_len=MAX_MODEL_LEN)
    _backend_name = "hf"


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [10]:
# =========================
# CELL 10/13 — TIME MANAGER (difficulty-weighted)
# =========================
class TimeManager:
    def __init__(self, hard_wall_s: int, total_questions: int):
        self.start = time.time()
        self.deadline = self.start + int(hard_wall_s)
        self.total = max(1, int(total_questions))
        self.done = 0

    def remaining(self) -> float:
        return max(0.0, self.deadline - time.time())

    def budget(self, weight: float) -> float:
        rem = self.remaining()
        left = max(1, self.total - self.done)
        base = rem / left
        return clamp(base * float(weight), MIN_BUDGET_S, MAX_BUDGET_S)

    def mark_done(self) -> None:
        self.done += 1


In [11]:
# =========================
# CELL 11/13 — CANDIDATES + WEIGHTED VOTE
# =========================
@dataclass
class Candidate:
    answer: Optional[int]
    raw: str
    tool_calls: int
    tool_errors: int
    elapsed: float
    stage: int
    verified: Optional[bool] = None  # PASS->True, FAIL->False, UNKNOWN->None

def cand_weight(c: Candidate) -> float:
    if c.answer is None:
        return 0.0
    w = 1.0
    if c.tool_calls > 0 and c.tool_errors == 0:
        w += 0.9
    w -= 0.75 * c.tool_errors
    w += max(0.0, 0.35 - 0.015 * c.elapsed)
    if c.verified is True:
        w += 2.25
    if c.verified is False:
        w -= 2.25
    return max(0.0, w)

def weighted_vote(cands: List[Candidate]) -> Tuple[Optional[int], float, Dict[int, float]]:
    scores: Dict[int, float] = {}
    total = 0.0
    for c in cands:
        if c.answer is None:
            continue
        w = cand_weight(c)
        total += w
        scores[c.answer] = scores.get(c.answer, 0.0) + w
    if not scores or total <= 1e-9:
        return None, 0.0, {}
    best = max(scores.items(), key=lambda kv: kv[1])[0]
    ratio = scores[best] / total
    return best, ratio, scores


In [12]:
# =========================
# CELL 12/13 — TIR ENGINE (Adaptive K) + VERIFIER + SELECTOR
# =========================
@dataclass
class TIRState:
    history: List[Tuple[str, str]]
    hint: str
    done: bool = False
    answer: Optional[int] = None
    tool_calls: int = 0
    tool_errors: int = 0
    raw_last: str = ""

class TIRBatchEngine:
    def __init__(self, backend: BackendBase, pb: PromptBuilder, pool: KernelPool):
        self.backend = backend
        self.pb = pb
        self.pool = pool

    def _turn(self, problem: str, states: List[TIRState], *, temperature: float, top_p: float, max_tokens: int, time_left_s: float) -> None:
        active = [i for i, s in enumerate(states) if not s.done]
        if not active:
            return

        prompts: List[str] = []
        for i in active:
            s = states[i]
            user = f"{problem}\n\nHint: {s.hint}"
            prompts.append(self.pb.render(SYSTEM_TIR, user, s.history))

        outs = self.backend.generate(prompts, temperature=temperature, top_p=top_p, max_tokens=max_tokens)

        tool_jobs: List[Tuple[int, str]] = []
        for idx, out in zip(active, outs):
            out = strip_think(out)
            st = states[idx]
            st.raw_last = out

            ans = parse_boxed_int(out)
            if ans is not None:
                st.done = True
                st.answer = ans
                continue

            code = parse_tool_code(out)
            if code:
                st.tool_calls += 1
                tool_jobs.append((idx, code))
                continue

            st.history.append(("assistant", out[:800]))
            st.history.append(("user", "Output ONLY either a <tool:python> block OR one line \\boxed{integer}."))
            st.history = trim_history(st.history)

        if not tool_jobs:
            return

        def run_one(job: Tuple[int, str]) -> Tuple[int, str, str]:
            i, code = job
            with self.pool.acquire() as k:
                py_out = k.execute(code, timeout_s=min(TOOL_TIMEOUT_S, max(1.0, time_left_s)))
            return i, code, py_out

        with ThreadPoolExecutor(max_workers=min(len(tool_jobs), KERNEL_POOL_SIZE)) as ex:
            futs = [ex.submit(run_one, j) for j in tool_jobs]
            for fut in as_completed(futs):
                i, code, py_out = fut.result()
                st = states[i]
                if ("PYTHON_TIMEOUT" in py_out) or ("Traceback" in py_out) or ("Error" in py_out):
                    st.tool_errors += 1
                st.history.append(("assistant", f"<tool:python>\n{code}\n</tool:python>"))
                st.history.append(("user", f"Python output:\n{py_out}"))
                st.history = trim_history(st.history)

    def run_progressive(
        self,
        problem: str,
        *,
        max_k: int,
        batch_k: int,
        budget_s: float,
        stage: int,
        max_tokens: int,
        temperature: float,
        top_p: float,
        early_stop_ratio: float,
    ) -> List[Candidate]:
        t0 = time.time()
        cands: List[Candidate] = []
        created = 0

        while created < max_k and (time.time() - t0) < budget_s:
            add = min(batch_k, max_k - created)
            states = [
                TIRState(history=[], hint=HINTS[(created + i) % len(HINTS)])
                for i in range(add)
            ]
            created += add

            for _ in range(MAX_TURNS):
                if time.time() - t0 >= budget_s:
                    break
                self._turn(
                    problem, states,
                    temperature=temperature, top_p=top_p,
                    max_tokens=max_tokens,
                    time_left_s=budget_s - (time.time() - t0),
                )
                if all(s.done for s in states):
                    break

            elapsed = time.time() - t0
            for s in states:
                ans = s.answer
                if ans is None:
                    ans = parse_boxed_int(s.raw_last) or fallback_last_int(s.raw_last)
                cands.append(Candidate(
                    answer=ans,
                    raw=s.raw_last,
                    tool_calls=s.tool_calls,
                    tool_errors=s.tool_errors,
                    elapsed=elapsed,
                    stage=stage,
                ))

            best, ratio, _ = weighted_vote(cands)
            if best is not None and ratio >= early_stop_ratio:
                break

        return cands

class Verifier:
    def __init__(self, backend: BackendBase, pb: PromptBuilder, pool: KernelPool):
        self.backend = backend
        self.pb = pb
        self.pool = pool

    def verify(self, problem: str, answer: int, budget_s: float) -> Optional[bool]:
        t0 = time.time()
        history: List[Tuple[str, str]] = []
        user = f"Problem:\n{problem}\n\nProposed answer A = {answer}\n"

        for _ in range(6):
            if time.time() - t0 >= budget_s:
                return None

            prompt = self.pb.render(SYSTEM_VERIFY, user, history)
            out = strip_think(self.backend.generate([prompt], temperature=0.0, top_p=1.0, max_tokens=650)[0])

            code = parse_tool_code(out)
            if code:
                with self.pool.acquire() as k:
                    py_out = k.execute(code, timeout_s=min(TOOL_TIMEOUT_S, max(1.0, budget_s - (time.time() - t0))))
                history.append(("assistant", f"<tool:python>\n{code}\n</tool:python>"))
                history.append(("user", f"Python output:\n{py_out}"))
                history = trim_history(history, 10)
                continue

            last = out.strip().splitlines()[-1].strip().upper() if out.strip() else ""
            if last == "PASS":
                return True
            if last == "FAIL":
                return False
            if last == "UNKNOWN":
                return None

            history.append(("assistant", out[:800]))
            history.append(("user", "Return ONLY one final line: PASS or FAIL or UNKNOWN."))
            history = trim_history(history, 10)

        return None

class Selector:
    def __init__(self, backend: BackendBase, pb: PromptBuilder):
        self.backend = backend
        self.pb = pb

    def select(self, problem: str, scores: Dict[int, float]) -> Optional[int]:
        if not scores:
            return None
        items = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[:8]
        evidence = "\n".join([f"- {a}: score={s:.3f}" for a, s in items])
        prompt = self.pb.render(SYSTEM_SELECT, f"Problem:\n{problem}\n\nCandidate scores:\n{evidence}\n", [])
        out = strip_think(self.backend.generate([prompt], temperature=0.0, top_p=1.0, max_tokens=220)[0])
        return parse_boxed_int(out) or fallback_last_int(out)


In [13]:
# =========================
# CELL 13/13 — SOLVER + KAGGLE HOOK
# =========================
class AIMO3Solver:
    def __init__(self):
        self.pb = PromptBuilder(MODEL_PATH)
        self.pool = KernelPool(KERNEL_POOL_SIZE)
        self.backend = _backend
        self.tm = TimeManager(HARD_WALL_SECONDS, TOTAL_QUESTIONS)
        self.engine = TIRBatchEngine(self.backend, self.pb, self.pool)
        self.verifier = Verifier(self.backend, self.pb, self.pool)
        self.selector = Selector(self.backend, self.pb)

        if not os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
            print(f"[solver] backend = {_backend_name}")

    def close(self):
        self.pool.close()

    def solve_problem(self, problem: str) -> int:
        plan = route_problem(problem)
        rem = self.tm.remaining()
        if rem < 5.0:
            return 0

        budget = min(self.tm.budget(plan.budget_weight), rem)

        # Stage 1 (adaptive K + early stop)
        c1 = self.engine.run_progressive(
            problem,
            max_k=plan.stage1_max_k,
            batch_k=STAGE1_BATCH,
            budget_s=0.38 * budget,
            stage=1,
            max_tokens=plan.stage1_max_tokens,
            temperature=plan.temp1,
            top_p=plan.top_p1,
            early_stop_ratio=CONFIDENT_RATIO,
        )
        best, ratio, scores = weighted_vote(c1)
        if best is not None and ratio >= CONFIDENT_RATIO:
            self.tm.mark_done()
            return mod100000(best)

        # Stage 2 (adaptive K + early stop)
        c2 = self.engine.run_progressive(
            problem,
            max_k=plan.stage2_max_k,
            batch_k=STAGE2_BATCH,
            budget_s=0.50 * budget,
            stage=2,
            max_tokens=plan.stage2_max_tokens,
            temperature=plan.temp2,
            top_p=plan.top_p2,
            early_stop_ratio=CONFIDENT_RATIO,
        )

        all_c = c1 + c2
        best, ratio, scores = weighted_vote(all_c)

        # Repair if nothing parsed
        if best is None:
            c3 = self.engine.run_progressive(
                problem,
                max_k=1,
                batch_k=1,
                budget_s=0.10 * budget,
                stage=2,
                max_tokens=420,
                temperature=0.0,
                top_p=1.0,
                early_stop_ratio=1.0,
            )
            ans = c3[0].answer if c3 else None
            self.tm.mark_done()
            return mod100000(ans if ans is not None else 0)

        # Verifier-on-uncertainty
        if ratio < VERIFY_RATIO and budget >= 25.0:
            top = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[:VERIFY_TOP_N]
            per_verify_budget = 0.10 * budget / max(1, len(top))

            for ans, _ in top:
                verdict = self.verifier.verify(problem, ans, budget_s=per_verify_budget)
                for c in all_c:
                    if c.answer == ans:
                        c.verified = verdict

            best, ratio, scores = weighted_vote(all_c)

        # Deterministic selector if still not confident
        final = best
        if ratio < 0.80 and (0.07 * budget) >= 3.0:
            sel = self.selector.select(problem, scores)
            if sel is not None:
                final = sel

        self.tm.mark_done()
        return mod100000(final)

solver = AIMO3Solver()

def predict(id_: "pl.Series", problem: "pl.Series"):
    if pl is not None and isinstance(id_, pl.Series):
        pid = id_.item(0)
        prob = problem.item(0)
        ans = solver.solve_problem(prob)
        return pl.DataFrame({"id": [pid], "answer": [ans]})
    else:
        pid = id_[0] if hasattr(id_, "__len__") else id_
        prob = problem[0] if hasattr(problem, "__len__") else problem
        ans = solver.solve_problem(prob)
        return pd.DataFrame({"id": [pid], "answer": [ans]})

import kaggle_evaluation.aimo_3_inference_server as aimo3
inference_server = aimo3.AIMO3InferenceServer(predict)

if os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
    inference_server.serve()
else:
    print("Dev mode: solver loaded.")


0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to di

[solver] backend = hf
Dev mode: solver loaded.


Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/usr/local/lib/python3.12/dist-packages/colab_kernel_launcher.py", line 37, in <module>
    ColabKernelApp.launch_instance()
  File "/usr/local/lib/python3.12/dist-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelapp.py", line 712, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.12/dist-packages/tornado/platform/asyncio.py", line 211, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.12/asyncio/base_events.py", line 645, in run_forever
    self._run_once()
  File "/usr/lib/python3.12/asyncio/base_events.py", line 1984, in _run_once
    handle = self._ready.popleft()
             ^^^^^^^^^^^^^^^^^^^^^
IndexError: pop from an empty deque
