In [1]:
# ============================================================
# AIMO3 TOP-ORIENTED SUBMISSION — Qwen-3 30B-A3B Thinking (H100)
# FIXES:
# - proper left-padding setup downstream
# - separate TOOL_POOL_SIZE (python workers) vs TOOL_THREAD_WORKERS (threads)
# - add GEN_BATCH_SIZE for GPU utilization
# ============================================================

from __future__ import annotations

import os, re, time, math, json, random, glob
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed

MODEL_PATH = "/kaggle/input/qwen-3/transformers/30b-a3b-thinking-2507-fp8/1"
SEED = int(os.getenv("SEED", "42"))

# Warm/cache
WARMUP_USR_LIB = os.getenv("WARMUP_USR_LIB", "0") == "1"
CACHE_MODEL_FILES = os.getenv("CACHE_MODEL_FILES", "1") == "1"
CACHE_CHUNK_MB = int(os.getenv("CACHE_CHUNK_MB", "512"))
CACHE_WORKERS = int(os.getenv("CACHE_WORKERS", "8"))
CACHE_EXTS = (".safetensors", ".bin", ".pt")

# Time (mặc định 4h55m giống style V22 notebook, tránh chết sát giờ)
HARD_WALL_SECONDS = int(os.getenv("HARD_WALL_SECONDS", str((4 * 60 + 55) * 60)))
TOTAL_QUESTIONS = int(os.getenv("TOTAL_QUESTIONS", "110"))
MIN_BUDGET_S = float(os.getenv("MIN_BUDGET_S", "10"))
MAX_BUDGET_S = float(os.getenv("MAX_BUDGET_S", "420"))

# Tool loop (CPU)
TOOL_POOL_SIZE = int(os.getenv("TOOL_POOL_SIZE", "12"))  # số python subprocess workers
TOOL_THREAD_WORKERS = int(os.getenv("TOOL_THREAD_WORKERS", str(min(12, TOOL_POOL_SIZE))))
TOOL_TIMEOUT_S = float(os.getenv("TOOL_TIMEOUT_S", "6.0"))
MAX_TURNS = int(os.getenv("MAX_TURNS", "10"))

# Sampling (accuracy vs time)
STAGE1_BATCH = int(os.getenv("STAGE1_BATCH", "2"))
STAGE2_BATCH = int(os.getenv("STAGE2_BATCH", "3"))
CONFIDENT_RATIO = float(os.getenv("CONFIDENT_RATIO", "0.78"))
VERIFY_RATIO = float(os.getenv("VERIFY_RATIO", "0.66"))
VERIFY_TOP_N = int(os.getenv("VERIFY_TOP_N", "3"))

# Generation (GPU)
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "12288"))   # 16k ok, nhưng 12k thường nhanh hơn
DTYPE = os.getenv("DTYPE", "bfloat16")
GPU_MEM_UTIL = float(os.getenv("GPU_MEM_UTIL", "0.92"))

# Batch size cho HF backend (đẩy GPU util lên)
GEN_BATCH_SIZE = int(os.getenv("GEN_BATCH_SIZE", "8"))  # H100 + 30B fp8 thường chịu 6-10

os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [2]:
# =========================
# CELL 2/13 — WARMUP HELPERS
# =========================
def warmup_usr_lib() -> None:
    import subprocess
    cmd = "find /kaggle/usr/lib -type f -print0 | xargs -0 -P 32 -n 500 cat > /dev/null"
    subprocess.run(cmd, shell=True, check=False)

def cache_model(path: str, exts=CACHE_EXTS, num_workers: int = 8, chunk_mb: int = 256) -> None:
    import multiprocessing
    from concurrent.futures import ThreadPoolExecutor, as_completed

    def warmup_file(fpath: str) -> int:
        chunk = chunk_mb * 1024 * 1024
        total = 0
        with open(fpath, "rb") as f:
            while True:
                b = f.read(chunk)
                if not b:
                    break
                total += len(b)
        return total

    if not os.path.isdir(path):
        return

    files = [
        os.path.join(root, name)
        for root, _, names in os.walk(path)
        for name in names
        if name.endswith(exts)
    ]
    if not files:
        return

    try:
        cpu = multiprocessing.cpu_count()
    except Exception:
        cpu = 4
    num_workers = max(1, min(num_workers, cpu, 16))
    files.sort(key=lambda f: os.path.getsize(f), reverse=True)

    t0 = time.time()
    total = 0
    with ThreadPoolExecutor(max_workers=num_workers) as ex:
        futs = [ex.submit(warmup_file, f) for f in files]
        for fut in as_completed(futs):
            total += fut.result()

    if not os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
        gb = total / 1024**3
        print(f"[cache_model] warmed ~{gb:.2f} GB in {time.time()-t0:.1f}s")

if WARMUP_USR_LIB:
    warmup_usr_lib()
if CACHE_MODEL_FILES:
    cache_model(MODEL_PATH, num_workers=CACHE_WORKERS, chunk_mb=CACHE_CHUNK_MB)




[cache_model] warmed ~29.03 GB in 77.7s


In [3]:
# =========================
# CELL 3/13 — IMPORTS
# =========================
import numpy as np
import pandas as pd

try:
    import polars as pl
except Exception:
    pl = None

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
set_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True




2025-12-31 13:52:54.467482: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767189174.585913     105 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767189174.619138     105 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1767189174.913344     105 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767189174.913368     105 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767189174.913371     105 computation_placer.cc:177] computation placer alr

In [4]:
# =========================
# CELL 4/13 — PARSERS + UTILS
# =========================
_THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL)
_BOXED_RE = re.compile(r"\\boxed\{([^}]*)\}")
_TOOL_RE = re.compile(r"<tool:python>\s*(.*?)\s*</tool:python>", re.DOTALL)

def strip_think(text: str) -> str:
    text = _THINK_RE.sub("", text)
    if "</think>" in text:
        text = text.split("</think>", 1)[-1]
    return text.strip()

def parse_boxed_int(text: str) -> Optional[int]:
    text = strip_think(text)
    m = _BOXED_RE.search(text)
    if not m:
        return None
    raw = m.group(1).strip()
    if not re.fullmatch(r"[+-]?\d+", raw):
        return None
    v = int(raw)
    return v if 0 <= v <= 99999 else None

def parse_tool_code(text: str) -> Optional[str]:
    text = strip_think(text)
    m = _TOOL_RE.search(text)
    return m.group(1).strip() if m else None

def fallback_last_int(text: str) -> Optional[int]:
    text = strip_think(text)
    nums = re.findall(r"[-+]?\d+", text)
    if not nums:
        return None
    try:
        v = int(nums[-1])
    except Exception:
        return None
    return v if 0 <= v <= 99999 else None

def mod100000(x: int) -> int:
    return int(x) % 100000

def clamp(x: float, lo: float, hi: float) -> float:
    return float(min(hi, max(lo, x)))

def trim_history(hist: List[Tuple[str, str]], max_items: int = 10) -> List[Tuple[str, str]]:
    return hist if len(hist) <= max_items else hist[-max_items:]




In [5]:
# =========================
# CELL 5/13 — DIFFICULTY ROUTER
# =========================
@dataclass(frozen=True)
class Plan:
    tag: str
    budget_weight: float
    stage1_max_k: int
    stage2_max_k: int
    stage1_max_tokens: int
    stage2_max_tokens: int
    temp1: float
    temp2: float
    top_p1: float
    top_p2: float

def route_problem(problem: str) -> Plan:
    p = (problem or "").lower()

    if any(k in p for k in ["triangle","circle","radius","angle","tangent","perpendicular","circum","inscribed"]):
        return Plan("GEO", 1.20, 6, 14, 950, 1800, 0.55, 0.75, 0.92, 0.90)

    if any(k in p for k in ["mod","congruen","prime","gcd","lcm","divis","remainder","coprime","valuation","phi("]):
        return Plan("NT", 1.30, 6, 16, 980, 1950, 0.55, 0.78, 0.92, 0.90)

    if any(k in p for k in ["f(","functional","for all real","for all integers","for all x","for all n","satisfies"]):
        return Plan("FUNC", 1.25, 6, 16, 980, 1950, 0.55, 0.78, 0.92, 0.90)

    if any(k in p for k in ["probability","expected","random","uniform","dice","coin","distribution"]):
        return Plan("PROB", 1.15, 6, 14, 950, 1800, 0.55, 0.75, 0.92, 0.90)

    if any(k in p for k in ["ways","choose","arrangements","permutation","combination","graph","color","pigeonhole","invariant"]):
        return Plan("COMB", 1.20, 6, 16, 950, 1900, 0.55, 0.78, 0.92, 0.90)

    return Plan("ALG", 1.00, 6, 14, 900, 1700, 0.50, 0.72, 0.92, 0.90)




In [6]:
# =========================
# CELL 6/13 — PROMPTS
# =========================
SYSTEM_TIR = (
    "You are an olympiad math solver.\n"
    "You MUST follow this protocol:\n\n"
    "If you need computation, output exactly:\n"
    "<tool:python>\n"
    "# python code\n"
    "</tool:python>\n\n"
    "If you are ready to answer, output exactly ONE line:\n"
    "\\boxed{NONNEGATIVE_INTEGER}\n\n"
    "Rules:\n"
    "- Output NOTHING else outside the tool block.\n"
    "- Final answer must be an integer in [0, 99999].\n"
    "- Prefer verifying with python when possible.\n"
)

SYSTEM_VERIFY = (
    "You are a strict verifier.\n"
    "Given a problem and proposed integer answer A, DISPROVE it quickly.\n"
    "Use python checks when possible:\n"
    "- parity constraints\n"
    "- modular constraints\n"
    "- substitution / brute force small cases / random tests\n\n"
    "Protocol:\n"
    "- You may output <tool:python>...</tool:python> blocks.\n"
    "- Then output EXACTLY one final line: PASS or FAIL or UNKNOWN\n"
    "- No extra text.\n"
)

SYSTEM_SELECT = (
    "You are a selector.\n"
    "Pick the most reliable candidate answer based on evidence.\n"
    "Output EXACTLY one line: \\boxed{NONNEGATIVE_INTEGER}\n"
    "No extra text.\n"
)

HINTS = [
    "Tool-first: explore small cases in python, infer pattern, verify, then output boxed.",
    "Proof-first: derive symbolic structure, then minimal python verification, output boxed.",
    "Number-theory: use modular constraints/parity/gcd; python to test; output boxed.",
    "Comb/Prob: use invariants or counting; python to validate small n; output boxed.",
]




In [7]:
# =========================
# CELL 7/13 — PROMPT BUILDER (chat template)
# =========================
class PromptBuilder:
    def __init__(self, tok):
        self.tok = tok

    def render(self, system: str, user: str, history: List[Tuple[str, str]]) -> str:
        msgs = [{"role": "system", "content": system}, {"role": "user", "content": user}]
        for r, c in history:
            msgs.append({"role": r, "content": c})
        try:
            return self.tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        except Exception:
            out = [f"[SYSTEM]\n{system}\n", f"[USER]\n{user}\n"]
            for r, c in history:
                out.append(f"[{r.upper()}]\n{c}\n")
            out.append("[ASSISTANT]\n")
            return "\n".join(out)




In [8]:
# =========================
# CELL 8/13 — PYTHON TOOL POOL (Subprocess, Kaggle-stable) [FIXED]
# - One ToolPool only
# - Has .run(wid, code, timeout_s) -> (out, ok) to match engine
# =========================
import os, sys, json, queue, uuid, subprocess, threading
from contextlib import contextmanager
from typing import List, Dict, Optional, Tuple

_WORKER_SRC = r"""
import sys, json, traceback, io, contextlib, math, random, itertools
try:
    import sympy as sp
except Exception:
    sp = None

G = {"math": math, "random": random, "itertools": itertools, "sp": sp}

def handle(req):
    code = req.get("code", "")
    out_io = io.StringIO()
    ok = True
    with contextlib.redirect_stdout(out_io), contextlib.redirect_stderr(out_io):
        try:
            exec(compile(code, "<tool>", "exec"), G, G)
        except Exception:
            ok = False
            traceback.print_exc(limit=3)

    txt = out_io.getvalue().strip()
    if not txt:
        for k in ("__result__", "result", "ans", "_"):
            if k in G:
                try:
                    txt = str(G[k])
                    break
                except Exception:
                    pass
    if not txt:
        txt = "[WARN] No output. Use print()."
    if len(txt) > 2000:
        txt = txt[:2000] + "\n[...TRUNCATED...]"
    return {"id": req.get("id"), "ok": ok, "out": txt}

for line in sys.stdin:
    line = line.strip()
    if not line:
        continue
    try:
        req = json.loads(line)
    except Exception:
        continue
    resp = handle(req)
    sys.stdout.write(json.dumps(resp, ensure_ascii=False) + "\n")
    sys.stdout.flush()
"""

class SubprocessToolWorker:
    def __init__(self):
        self.proc: Optional[subprocess.Popen] = None
        self._pending: Dict[str, "queue.Queue[dict]"] = {}
        self._lock = threading.Lock()
        self._reader_thread: Optional[threading.Thread] = None
        self.start()

    def start(self):
        self.stop()
        env = dict(os.environ)
        env["PYTHONUNBUFFERED"] = "1"
        self.proc = subprocess.Popen(
            [sys.executable, "-u", "-c", _WORKER_SRC],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1,
            env=env,
        )

        def _reader():
            assert self.proc is not None and self.proc.stdout is not None
            for line in self.proc.stdout:
                line = line.strip()
                if not line:
                    continue
                try:
                    msg = json.loads(line)
                except Exception:
                    continue
                rid = msg.get("id")
                if not rid:
                    continue
                q = self._pending.pop(rid, None)
                if q is not None:
                    q.put(msg)

        self._reader_thread = threading.Thread(target=_reader, daemon=True)
        self._reader_thread.start()

    def is_alive(self) -> bool:
        return self.proc is not None and (self.proc.poll() is None)

    def stop(self):
        if self.proc is not None:
            try:
                self.proc.kill()
            except Exception:
                pass
            try:
                self.proc.wait(timeout=1)
            except Exception:
                pass
        self.proc = None
        self._pending.clear()

    def execute(self, code: str, timeout_s: float) -> Tuple[str, bool]:
        if not self.is_alive():
            self.start()
        assert self.proc is not None and self.proc.stdin is not None

        rid = uuid.uuid4().hex
        q: "queue.Queue[dict]" = queue.Queue(maxsize=1)
        self._pending[rid] = q

        payload = {"id": rid, "code": code}
        with self._lock:
            try:
                self.proc.stdin.write(json.dumps(payload, ensure_ascii=False) + "\n")
                self.proc.stdin.flush()
            except Exception:
                self.start()
                return "[PYTHON_ERROR] worker write failed", False

        try:
            msg = q.get(timeout=timeout_s)
        except queue.Empty:
            self.start()
            return "[PYTHON_TIMEOUT]", False

        out = (msg.get("out", "") or "").strip() or "[WARN] No output. Use print()."
        ok = bool(msg.get("ok", False))
        return out, ok

class ToolPool:
    def __init__(self, size: int):
        self.size = max(1, int(size))
        self.workers: List[SubprocessToolWorker] = [SubprocessToolWorker() for _ in range(self.size)]

    def run(self, wid: int, code: str, timeout_s: float) -> Tuple[str, bool]:
        w = self.workers[int(wid) % self.size]
        return w.execute(code, timeout_s)

    def close(self):
        for w in self.workers:
            w.stop()


In [14]:
# =========================
# CELL 9/13 — BACKEND (HF) + tokenizer  [REPLACE THIS CELL]
# FIX:
# - left padding for decoder-only
# - remove_invalid_values + InfNanRemoveLogitsProcessor
# - renormalize_logits
# - safe sampling clamp
# =========================
import os, importlib.util
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token

class BackendBase:
    def generate(self, prompts: List[str], *, temperature: float, top_p: float, max_tokens: int) -> List[str]:
        raise NotImplementedError

def _dtype_from_str(s: str):
    s = (s or "").lower()
    if "bf16" in s or "bfloat" in s:
        return torch.bfloat16
    if "fp16" in s or "float16" in s:
        return torch.float16
    return torch.bfloat16

# logits sanitizers (robust across transformers versions)
def _get_logits_sanitizer():
    try:
        from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor
        return LogitsProcessorList([InfNanRemoveLogitsProcessor()])
    except Exception:
        return None

_SANITIZER = _get_logits_sanitizer()

class HFBackend(BackendBase):
    def __init__(self, model_path: str, max_model_len: int):
        self.tokenizer = tokenizer

        # stability-first: sdpa tends to be safer than flash for some fp8 stacks
        # if you insist speed: set USE_FLASH_ATTN=1
        use_flash = os.getenv("USE_FLASH_ATTN", "0") == "1"
        has_flash = importlib.util.find_spec("flash_attn") is not None
        attn_impl = "flash_attention_2" if (use_flash and has_flash) else "sdpa"

        base_kwargs = dict(
            dtype=_dtype_from_str(DTYPE),
            device_map="auto",
            trust_remote_code=True,
            low_cpu_mem_usage=True,
        )

        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                attn_implementation=attn_impl,
                **base_kwargs,
            )
        except Exception:
            self.model = AutoModelForCausalLM.from_pretrained(model_path, **base_kwargs)

        self.model.eval()
        self.max_model_len = int(max_model_len)

        self.eos_id = self.tokenizer.eos_token_id
        self.pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.eos_id

        self.bs = int(GEN_BATCH_SIZE)

        # generation config safety knobs
        try:
            self.model.generation_config.remove_invalid_values = True
            self.model.generation_config.renormalize_logits = True
        except Exception:
            pass

    @torch.inference_mode()
    def generate(self, prompts: List[str], *, temperature: float, top_p: float, max_tokens: int) -> List[str]:
        outs: List[str] = []

        # clamp sampling params to avoid weirdness
        t = float(temperature) if temperature is not None else 0.0
        p = float(top_p) if top_p is not None else 1.0
        t = 0.0 if t < 1e-6 else min(1.2, max(0.05, t))
        p = min(1.0, max(0.05, p))
        do_sample = t > 1e-6

        i = 0
        while i < len(prompts):
            bs = max(1, int(self.bs))
            batch = prompts[i:i+bs]

            enc = self.tokenizer(
                batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_model_len,
            )
            enc = {k: v.to(self.model.device) for k, v in enc.items()}
            input_len = int(enc["input_ids"].shape[1])

            gen_kwargs = dict(
                **enc,
                max_new_tokens=int(max_tokens),
                use_cache=True,
                pad_token_id=self.pad_id,
                eos_token_id=self.eos_id,
                do_sample=do_sample,
                temperature=(t if do_sample else None),
                top_p=(p if do_sample else None),
                remove_invalid_values=True,   # key fix
                renormalize_logits=True,     # key fix
            )
            if _SANITIZER is not None:
                gen_kwargs["logits_processor"] = _SANITIZER

            try:
                gen = self.model.generate(**gen_kwargs)
            except RuntimeError as e:
                msg = str(e).lower()
                if "out of memory" in msg or "cuda out of memory" in msg:
                    torch.cuda.empty_cache()
                    self.bs = max(1, self.bs // 2)
                    if not os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
                        print(f"[hf] OOM -> backoff GEN_BATCH_SIZE to {self.bs}")
                    continue
                raise

            for j in range(gen.shape[0]):
                tail = gen[j, input_len:]
                outs.append(self.tokenizer.decode(tail, skip_special_tokens=True))

            i += bs

        return outs

_backend = HFBackend(MODEL_PATH, max_model_len=MAX_MODEL_LEN)
_backend_name = "hf"
print(f"[backend] loaded {_backend_name} | attn={os.getenv('USE_FLASH_ATTN','0')} | GEN_BATCH_SIZE={GEN_BATCH_SIZE} | MAX_MODEL_LEN={MAX_MODEL_LEN}")


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[backend] loaded hf | attn=0 | GEN_BATCH_SIZE=8 | MAX_MODEL_LEN=12288


In [15]:
# =========================
# CELL 10/13 — TIME MANAGER
# =========================
class TimeManager:
    def __init__(self, hard_wall_s: int, total_questions: int):
        self.start = time.time()
        self.deadline = self.start + int(hard_wall_s)
        self.total = max(1, int(total_questions))
        self.done = 0

    def remaining(self) -> float:
        return max(0.0, self.deadline - time.time())

    def budget(self, weight: float) -> float:
        rem = self.remaining()
        left = max(1, self.total - self.done)
        base = rem / left
        return clamp(base * float(weight), MIN_BUDGET_S, MAX_BUDGET_S)

    def mark_done(self) -> None:
        self.done += 1




In [16]:
# =========================
# CELL 11/13 — CANDIDATES + WEIGHTED VOTE
# =========================
@dataclass
class Candidate:
    answer: Optional[int]
    raw: str
    tool_calls: int
    tool_errors: int
    elapsed: float
    stage: int
    verified: Optional[bool] = None  # PASS->True, FAIL->False, UNKNOWN->None

def cand_weight(c: Candidate) -> float:
    if c.answer is None:
        return 0.0
    w = 1.0
    if c.tool_calls > 0 and c.tool_errors == 0:
        w += 0.9
    w -= 0.75 * c.tool_errors
    w += max(0.0, 0.35 - 0.015 * c.elapsed)
    if c.verified is True:
        w += 2.25
    if c.verified is False:
        w -= 2.25
    return max(0.0, w)

def weighted_vote(cands: List[Candidate]) -> Tuple[Optional[int], float, Dict[int, float]]:
    scores: Dict[int, float] = {}
    total = 0.0
    for c in cands:
        if c.answer is None:
            continue
        w = cand_weight(c)
        total += w
        scores[c.answer] = scores.get(c.answer, 0.0) + w
    if not scores or total <= 1e-9:
        return None, 0.0, {}
    best = max(scores.items(), key=lambda kv: kv[1])[0]
    ratio = scores[best] / total
    return best, ratio, scores




In [17]:
# =========================
# CELL 12/13 — TIR ENGINE + VERIFIER + SELECTOR
# =========================
@dataclass
class TIRState:
    history: List[Tuple[str, str]]
    hint: str
    worker_id: int
    done: bool = False
    answer: Optional[int] = None
    tool_calls: int = 0
    tool_errors: int = 0
    raw_last: str = ""

class TIRBatchEngine:
    def __init__(self, backend: BackendBase, pb: PromptBuilder, tools: ToolPool):
        self.backend = backend
        self.pb = pb
        self.tools = tools

    def _turn(self, problem: str, states: List[TIRState], *, temperature: float, top_p: float, max_tokens: int, time_left_s: float) -> None:
        active = [i for i, s in enumerate(states) if not s.done]
        if not active:
            return

        prompts: List[str] = []
        for i in active:
            s = states[i]
            user = f"{problem}\n\nHint: {s.hint}"
            prompts.append(self.pb.render(SYSTEM_TIR, user, s.history))

        outs = self.backend.generate(prompts, temperature=temperature, top_p=top_p, max_tokens=max_tokens)

        tool_jobs: List[Tuple[int, int, str]] = []  # (state_idx, worker_id, code)
        for idx, out in zip(active, outs):
            out = strip_think(out)
            st = states[idx]
            st.raw_last = out

            ans = parse_boxed_int(out)
            if ans is not None:
                st.done = True
                st.answer = ans
                continue

            code = parse_tool_code(out)
            if code:
                st.tool_calls += 1
                tool_jobs.append((idx, st.worker_id, code))
                continue

            st.history.append(("assistant", out[:800]))
            st.history.append(("user", "Output ONLY either a <tool:python> block OR one line \\boxed{integer}."))
            st.history = trim_history(st.history)

        if not tool_jobs:
            return

        def run_one(job: Tuple[int, int, str]) -> Tuple[int, str, str, bool]:
            i, wid, code = job
            py_out, ok = self.tools.run(wid, code, timeout_s=min(TOOL_TIMEOUT_S, max(1.0, time_left_s)))
            return i, code, py_out, ok

        with ThreadPoolExecutor(max_workers=min(len(tool_jobs), TOOL_THREAD_WORKERS)) as ex:
            futs = [ex.submit(run_one, j) for j in tool_jobs]
            for fut in as_completed(futs):
                i, code, py_out, ok = fut.result()
                st = states[i]
                if (not ok) or ("PYTHON_TIMEOUT" in py_out) or ("Traceback" in py_out):
                    st.tool_errors += 1
                st.history.append(("assistant", f"<tool:python>\n{code}\n</tool:python>"))
                st.history.append(("user", f"Python output:\n{py_out}"))
                st.history = trim_history(st.history)

    def run_progressive(
        self,
        problem: str,
        *,
        max_k: int,
        batch_k: int,
        budget_s: float,
        stage: int,
        max_tokens: int,
        temperature: float,
        top_p: float,
        early_stop_ratio: float,
    ) -> List[Candidate]:
        t0 = time.time()
        cands: List[Candidate] = []
        created = 0

        while created < max_k and (time.time() - t0) < budget_s:
            add = min(batch_k, max_k - created)
            states = [
                TIRState(history=[], hint=HINTS[(created + i) % len(HINTS)], worker_id=(created + i) % TOOL_POOL_SIZE)
                for i in range(add)
            ]
            created += add

            for _ in range(MAX_TURNS):
                if time.time() - t0 >= budget_s:
                    break
                self._turn(
                    problem, states,
                    temperature=temperature, top_p=top_p,
                    max_tokens=max_tokens,
                    time_left_s=budget_s - (time.time() - t0),
                )
                if all(s.done for s in states):
                    break

            elapsed = time.time() - t0
            for s in states:
                ans = s.answer
                if ans is None:
                    ans = parse_boxed_int(s.raw_last) or fallback_last_int(s.raw_last)
                cands.append(Candidate(
                    answer=ans,
                    raw=s.raw_last,
                    tool_calls=s.tool_calls,
                    tool_errors=s.tool_errors,
                    elapsed=elapsed,
                    stage=stage,
                ))

            best, ratio, _ = weighted_vote(cands)
            if best is not None and ratio >= early_stop_ratio:
                break

        return cands

class Verifier:
    def __init__(self, backend: BackendBase, pb: PromptBuilder, tools: ToolPool):
        self.backend = backend
        self.pb = pb
        self.tools = tools

    def verify(self, problem: str, answer: int, budget_s: float) -> Optional[bool]:
        t0 = time.time()
        history: List[Tuple[str, str]] = []
        user = f"Problem:\n{problem}\n\nProposed answer A = {answer}\n"

        for _ in range(6):
            if time.time() - t0 >= budget_s:
                return None

            prompt = self.pb.render(SYSTEM_VERIFY, user, history)
            out = strip_think(self.backend.generate([prompt], temperature=0.0, top_p=1.0, max_tokens=650)[0])

            code = parse_tool_code(out)
            if code:
                py_out, ok = self.tools.run(0, code, timeout_s=min(TOOL_TIMEOUT_S, max(1.0, budget_s - (time.time() - t0))))
                history.append(("assistant", f"<tool:python>\n{code}\n</tool:python>"))
                history.append(("user", f"Python output:\n{py_out}"))
                history = trim_history(history, 10)
                continue

            last = out.strip().splitlines()[-1].strip().upper() if out.strip() else ""
            if last == "PASS":
                return True
            if last == "FAIL":
                return False
            if last == "UNKNOWN":
                return None

            history.append(("assistant", out[:800]))
            history.append(("user", "Return ONLY one final line: PASS or FAIL or UNKNOWN."))
            history = trim_history(history, 10)

        return None

class Selector:
    def __init__(self, backend: BackendBase, pb: PromptBuilder):
        self.backend = backend
        self.pb = pb

    def select(self, problem: str, scores: Dict[int, float]) -> Optional[int]:
        if not scores:
            return None
        items = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[:8]
        evidence = "\n".join([f"- {a}: score={s:.3f}" for a, s in items])
        prompt = self.pb.render(SYSTEM_SELECT, f"Problem:\n{problem}\n\nCandidate scores:\n{evidence}\n", [])
        out = strip_think(self.backend.generate([prompt], temperature=0.0, top_p=1.0, max_tokens=220)[0])
        return parse_boxed_int(out) or fallback_last_int(out)




In [18]:
# =========================
# CELL 13/13 — SOLVER + DEV TEST + KAGGLE HOOK
# =========================
class AIMO3Solver:
    def __init__(self):
        self.pb = PromptBuilder(tokenizer)
        self.tools = ToolPool(TOOL_POOL_SIZE)
        self.backend = _backend
        self.tm = TimeManager(HARD_WALL_SECONDS, TOTAL_QUESTIONS)
        self.engine = TIRBatchEngine(self.backend, self.pb, self.tools)
        self.verifier = Verifier(self.backend, self.pb, self.tools)
        self.selector = Selector(self.backend, self.pb)

        if not os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
            print(f"[solver] backend = {_backend_name}")

    def close(self):
        self.tools.close()

    def solve_problem(self, problem: str) -> int:
        plan = route_problem(problem)
        rem = self.tm.remaining()
        if rem < 5.0:
            return 0

        budget = min(self.tm.budget(plan.budget_weight), rem)

        # Stage 1
        c1 = self.engine.run_progressive(
            problem,
            max_k=plan.stage1_max_k,
            batch_k=STAGE1_BATCH,
            budget_s=0.38 * budget,
            stage=1,
            max_tokens=plan.stage1_max_tokens,
            temperature=plan.temp1,
            top_p=plan.top_p1,
            early_stop_ratio=CONFIDENT_RATIO,
        )
        best, ratio, scores = weighted_vote(c1)
        if best is not None and ratio >= CONFIDENT_RATIO:
            self.tm.mark_done()
            return mod100000(best)

        # Stage 2
        c2 = self.engine.run_progressive(
            problem,
            max_k=plan.stage2_max_k,
            batch_k=STAGE2_BATCH,
            budget_s=0.50 * budget,
            stage=2,
            max_tokens=plan.stage2_max_tokens,
            temperature=plan.temp2,
            top_p=plan.top_p2,
            early_stop_ratio=CONFIDENT_RATIO,
        )

        all_c = c1 + c2
        best, ratio, scores = weighted_vote(all_c)

        if best is None:
            self.tm.mark_done()
            return 0

        # Verifier-on-uncertainty
        if ratio < VERIFY_RATIO and budget >= 25.0:
            top = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[:VERIFY_TOP_N]
            per_verify_budget = 0.10 * budget / max(1, len(top))

            for ans, _ in top:
                verdict = self.verifier.verify(problem, ans, budget_s=per_verify_budget)
                for c in all_c:
                    if c.answer == ans:
                        c.verified = verdict

            best, ratio, scores = weighted_vote(all_c)

        # Deterministic selector if still not confident
        final = best
        if ratio < 0.80 and (0.07 * budget) >= 3.0:
            sel = self.selector.select(problem, scores)
            if sel is not None:
                final = sel

        self.tm.mark_done()
        return mod100000(final)

solver = AIMO3Solver()

def predict(id_: "pl.Series", problem: "pl.Series"):
    if pl is not None and isinstance(id_, pl.Series):
        pid = id_.item(0)
        prob = problem.item(0)
        ans = solver.solve_problem(prob)
        return pl.DataFrame({"id": [pid], "answer": [ans]})
    else:
        pid = id_[0] if hasattr(id_, "__len__") else id_
        prob = problem[0] if hasattr(problem, "__len__") else problem
        ans = solver.solve_problem(prob)
        return pd.DataFrame({"id": [pid], "answer": [ans]})

# ---- DEV EVAL ----
def _find_comp_file(fname: str) -> Optional[str]:
    hits = glob.glob(f"/kaggle/input/*/{fname}")
    return hits[0] if hits else None

def dev_eval(n: int = 30):
    ref_path = _find_comp_file("reference.csv")
    if not ref_path:
        print("[dev_eval] reference.csv not found in /kaggle/input/*/")
        return

    df = pd.read_csv(ref_path)
    if "problem" not in df.columns:
        print("[dev_eval] reference.csv missing 'problem' column")
        return

    has_gt = "answer" in df.columns
    gt = df.set_index("id")["answer"].to_dict() if has_gt else None

    n = min(n, len(df))
    sub = df.iloc[:n].copy()

    t0 = time.time()
    correct = 0
    done = 0

    for _, row in sub.iterrows():
        pid = row["id"]
        prob = row["problem"]
        ans = solver.solve_problem(prob)
        done += 1
        if has_gt and int(ans) == int(gt[pid]):
            correct += 1

        if done % 5 == 0:
            elapsed = time.time() - t0
            if has_gt:
                print(f"[dev_eval] {done}/{n}  elapsed={elapsed:.1f}s  acc={100*correct/done:.1f}%")
            else:
                print(f"[dev_eval] {done}/{n}  elapsed={elapsed:.1f}s")

    elapsed = time.time() - t0
    if has_gt:
        print(f"[dev_eval] FINAL: {correct}/{n} = {100*correct/n:.1f}%  | time={elapsed:.1f}s")
    else:
        print(f"[dev_eval] FINAL: done {n} problems | time={elapsed:.1f}s | (no ground truth in reference.csv)")

# ---- Kaggle Inference Server ----
import kaggle_evaluation.aimo_3_inference_server as aimo3
inference_server = aimo3.AIMO3InferenceServer(predict)

if os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
    inference_server.serve()
else:
    print("[dev] solver loaded. Running dev_eval() on reference.csv ...")
    dev_eval(n=int(os.getenv("DEV_N", "30")))

    if os.getenv("DEV_RUN_GATEWAY", "0") == "1":
        ref_path = _find_comp_file("reference.csv")
        df = pd.read_csv(ref_path)
        tmp = df[["id", "problem"]].head(int(os.getenv("DEV_GATEWAY_N", "10")))
        tmp_path = "ref_input_head.csv"
        tmp.to_csv(tmp_path, index=False)
        print(f"[dev] run_local_gateway on {tmp_path}")
        inference_server.run_local_gateway((tmp_path,))


A custom logits processor of type <class 'transformers.generation.logits_process.InfNanRemoveLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.InfNanRemoveLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.InfNanRemoveLogitsProcessor'> to see related `.generate()` flags.


[solver] backend = hf
[dev] solver loaded. Running dev_eval() on reference.csv ...


ValueError: Expected a cuda device, but got: cpu