# AIMO3 Baseline Notebook
## AI Mathematical Olympiad – Progress Prize 3

### Quick Start Guide
1. **RUN_MODE Selection**:
   - `"local_ref"`: Debug mode - runs evaluation on `reference.csv` (10 problems)
   - `"submit_auto"`: Kaggle submission mode - uses `kaggle_evaluation` API

2. **Model Setup (Kaggle)**:
   - Add your model as a Kaggle Dataset/Model input
   - Update `MODEL_PATH` in CONFIG to point to `/kaggle/input/your-model-name`
   - Or use Kaggle's built-in models

3. **Telemetry**:
   - Logs saved to `/kaggle/working/aimo3_telemetry.jsonl`
   - Download after submission for analysis

## CELL A — CONFIG (Constants)

In [None]:
# ================================
# CELL A — CONFIG (Constants)
# ================================

import os

# ----- RUN MODE -----
# "local_ref": Run evaluation on reference.csv (10 problems) - for debugging
# "submit_auto": Kaggle submission mode - uses kaggle_evaluation API
RUN_MODE = "local_ref"  # Change to "submit_auto" for Kaggle submission

# ----- TIME BUDGET -----
TIME_BUDGET_SEC_PER_PROBLEM = 120  # seconds per problem

# ----- GENERATION PARAMS -----
K_BASE = 4              # Base number of candidates for easy problems
K_MAX_HARD = 8          # Max candidates for hard problems
TEMPERATURE_BASE = 0.3  # Temperature for stable generation
TEMPERATURE_HARD = 0.7  # Temperature for exploration
MAX_NEW_TOKENS = 2048   # Max tokens per generation

# ----- VOTING & EARLY STOP -----
VOTING_THRESHOLD = 0.6  # Stop early if top answer >= this fraction

# ----- PATHS -----
LOG_PATH = "/kaggle/working/aimo3_telemetry.jsonl"
CACHE_DIR = "/kaggle/working/cache"

# ----- MODEL CONFIG -----
MODEL_PATH = None  # Set to model path when available
MODEL_ID = "Qwen/Qwen2.5-Math-1.5B-Instruct"  # Fallback model ID

# ----- PROMPT STYLE -----
PROMPT_STYLE = "tir"  # "tir", "concise", or "explore"

# ----- MODE POLICY -----
MODE_POLICY = "stable"  # "stable" or "diverse"

# ----- DATA PATHS -----
REFERENCE_CSV_PATH = None
TEST_CSV_PATH = None

# Detect environment and set paths
if os.path.exists("/kaggle/input"):
    REFERENCE_CSV_PATH = "/kaggle/input/ai-mathematical-olympiad-progress-prize-3/reference.csv"
    TEST_CSV_PATH = "/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv"
    if not os.path.exists(REFERENCE_CSV_PATH):
        REFERENCE_CSV_PATH = "reference.csv"
        TEST_CSV_PATH = "test.csv"
else:
    REFERENCE_CSV_PATH = "reference.csv"
    TEST_CSV_PATH = "test.csv"

# Create working directories
os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True) if os.path.dirname(LOG_PATH) else None
os.makedirs(CACHE_DIR, exist_ok=True) if CACHE_DIR else None

print(f"RUN_MODE: {RUN_MODE}")
print(f"REFERENCE_CSV_PATH: {REFERENCE_CSV_PATH}")
print(f"LOG_PATH: {LOG_PATH}")

## CELL B — IMPORTS + SEED CONTROL

In [None]:
# ================================
# CELL B — IMPORTS + SEED CONTROL
# ================================

import os
import re
import sys
import time
import json
import math
import random
import hashlib
import warnings
from typing import Optional, Tuple, List, Dict, Any
from collections import Counter
from contextlib import redirect_stdout, redirect_stderr
import io

import pandas as pd

warnings.filterwarnings("ignore")

# ----- SEED CONTROL -----
BASE_SEED = 42
IS_KAGGLE_RERUN = os.getenv("KAGGLE_IS_COMPETITION_RERUN") is not None

def get_seed_for_run():
    if IS_KAGGLE_RERUN:
        session_seed = int(time.time()) % 100000
        return BASE_SEED + session_seed
    else:
        return BASE_SEED

def set_seed(seed: int):
    random.seed(seed)
    try:
        import numpy as np
        np.random.seed(seed)
    except ImportError:
        pass
    try:
        import torch
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    except ImportError:
        pass

CURRENT_SEED = get_seed_for_run()
set_seed(CURRENT_SEED)

print(f"IS_KAGGLE_RERUN: {IS_KAGGLE_RERUN}")
print(f"CURRENT_SEED: {CURRENT_SEED}")

## CELL C — LAZY MODEL LOADER

In [None]:
# ================================
# CELL C — LAZY MODEL LOADER
# ================================

_model_cache = {"model": None, "tokenizer": None, "device": None, "loaded": False}

def get_device():
    try:
        import torch
        if torch.cuda.is_available():
            return "cuda"
        elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return "mps"
    except ImportError:
        pass
    return "cpu"

def load_model():
    global _model_cache
    
    if _model_cache["loaded"]:
        return _model_cache["model"], _model_cache["tokenizer"], _model_cache["device"]
    
    print("Loading model...")
    start_time = time.time()
    
    try:
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        device = get_device()
        _model_cache["device"] = device
        print(f"Using device: {device}")
        
        model_source = MODEL_PATH if MODEL_PATH and os.path.exists(MODEL_PATH) else MODEL_ID
        print(f"Loading from: {model_source}")
        
        _model_cache["tokenizer"] = AutoTokenizer.from_pretrained(
            model_source, trust_remote_code=True,
            local_files_only=(MODEL_PATH is not None and os.path.exists(MODEL_PATH))
        )
        
        dtype = torch.float16 if device == "cuda" else torch.float32
        _model_cache["model"] = AutoModelForCausalLM.from_pretrained(
            model_source, torch_dtype=dtype,
            device_map="auto" if device == "cuda" else None,
            trust_remote_code=True,
            local_files_only=(MODEL_PATH is not None and os.path.exists(MODEL_PATH))
        )
        
        if device != "cuda":
            _model_cache["model"] = _model_cache["model"].to(device)
        
        _model_cache["loaded"] = True
        print(f"Model loaded in {time.time() - start_time:.2f}s")
        
    except Exception as e:
        print(f"Warning: Could not load model: {e}")
        print("Using rule-based solver...")
        _model_cache["model"] = None
        _model_cache["tokenizer"] = None
        _model_cache["device"] = "cpu"
        _model_cache["loaded"] = True
    
    return _model_cache["model"], _model_cache["tokenizer"], _model_cache["device"]

def is_model_available():
    if _model_cache["loaded"]:
        return _model_cache["model"] is not None
    if MODEL_PATH and os.path.exists(MODEL_PATH):
        return True
    return False

print("Model loader initialized (lazy loading)")

## CELL D — TOOL-INTEGRATED REASONING + SAFE PYTHON EXECUTOR

In [None]:
# ================================
# CELL D — TIR-lite + SAFE PYTHON EXECUTOR
# ================================

ALLOWED_MODULES = {"math", "fractions", "itertools", "functools", "collections", "decimal", "numbers", "cmath", "random", "statistics"}

try:
    import sympy
    ALLOWED_MODULES.add("sympy")
except ImportError:
    pass

def create_safe_globals():
    """Create a safe globals dict with pre-loaded allowed modules."""
    import math, fractions, itertools, functools, collections, decimal, random, statistics
    
    safe_globals = {
        "__builtins__": {
            "abs": abs, "all": all, "any": any, "bin": bin, "bool": bool, "chr": chr,
            "dict": dict, "divmod": divmod, "enumerate": enumerate, "filter": filter,
            "float": float, "frozenset": frozenset, "hex": hex, "int": int,
            "isinstance": isinstance, "len": len, "list": list, "map": map, "max": max,
            "min": min, "oct": oct, "ord": ord, "pow": pow, "print": print, "range": range,
            "repr": repr, "reversed": reversed, "round": round, "set": set, "slice": slice,
            "sorted": sorted, "str": str, "sum": sum, "tuple": tuple, "type": type, "zip": zip,
            "True": True, "False": False, "None": None, "complex": complex,
        },
        # Pre-loaded modules (no import needed)
        "math": math,
        "fractions": fractions,
        "Fraction": fractions.Fraction,
        "itertools": itertools,
        "functools": functools,
        "collections": collections,
        "decimal": decimal,
        "Decimal": decimal.Decimal,
        "random": random,
        "statistics": statistics,
    }
    
    try:
        import sympy
        safe_globals["sympy"] = sympy
    except ImportError:
        pass
    
    return safe_globals

def strip_imports(code: str) -> str:
    """Remove import statements since modules are pre-loaded."""
    lines = code.split('\n')
    filtered = []
    for line in lines:
        stripped = line.strip()
        # Skip import lines for allowed modules
        if stripped.startswith('import ') or stripped.startswith('from '):
            # Check if it's importing an allowed module
            skip = False
            for mod in ALLOWED_MODULES:
                if mod in stripped:
                    skip = True
                    break
            if skip:
                continue
        filtered.append(line)
    return '\n'.join(filtered)

def run_python(code: str, timeout_sec: float = 10.0) -> Tuple[bool, str]:
    """Execute Python code in a sandboxed environment."""
    import signal
    
    # Strip import statements for pre-loaded modules
    code = strip_imports(code)
    
    output_capture = io.StringIO()
    safe_globals = create_safe_globals()
    safe_locals = {}
    
    def timeout_handler(signum, frame):
        raise TimeoutError("Code execution timed out")
    
    old_handler = signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(int(timeout_sec))
    
    try:
        with redirect_stdout(output_capture), redirect_stderr(output_capture):
            exec(code, safe_globals, safe_locals)
        
        output = output_capture.getvalue()
        
        # Capture result variables
        for var_name in ["result", "answer", "ans", "final", "output"]:
            if var_name in safe_locals:
                val = safe_locals[var_name]
                if output:
                    output += f"\n{var_name} = {val}"
                else:
                    output = f"{var_name} = {val}"
                break
        
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old_handler)
        return True, output if output else "Execution completed (no output)"
        
    except TimeoutError as e:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old_handler)
        return False, f"Timeout: {str(e)}"
    except Exception as e:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old_handler)
        return False, f"Error: {type(e).__name__}: {str(e)}"

def parse_python_block(text: str) -> Optional[str]:
    """Extract Python code block from text."""
    patterns = [r"```python\s*\n(.*?)```", r"```py\s*\n(.*?)```", r"```\s*\n(.*?)```"]
    for pattern in patterns:
        match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
        if match:
            return match.group(1).strip()
    return None

def get_tir_prompt(problem: str) -> str:
    """Generate Tool-Integrated Reasoning prompt."""
    return f"""You are a mathematical problem solver. Solve the following problem step by step.

RULES:
1. Think through the problem carefully.
2. If you need to compute something, write Python code in a ```python ... ``` block.
3. After your reasoning, provide your final answer on a new line as: ANSWER: <integer>
4. The answer must be an integer between 0 and 99999.
5. If the problem asks for a remainder when divided by some number, compute that remainder.

PROBLEM:
{problem}

SOLUTION:"""

def get_concise_prompt(problem: str) -> str:
    """Generate concise direct-answer prompt."""
    return f"""Solve this math problem. Give only the final integer answer (0-99999).

Problem: {problem}

ANSWER:"""

def get_explore_prompt(problem: str) -> str:
    """Generate exploration prompt with more reasoning."""
    return f"""You are an expert mathematician. Carefully analyze this problem and explore multiple approaches.

Problem:
{problem}

Instructions:
1. Identify the key mathematical concepts involved.
2. Consider multiple solution approaches.
3. Use Python code (```python ... ```) for complex calculations.
4. Verify your answer if possible.
5. End with: ANSWER: <integer> (must be 0-99999)

Let's solve this step by step:"""

print("Safe Python executor initialized")
print(f"Available modules: {ALLOWED_MODULES}")

## CELL E — ANSWER EXTRACTION + VALIDATION

In [None]:
# ================================
# CELL E — ANSWER EXTRACTION + VALIDATION
# ================================

def extract_answer(text: str) -> Tuple[Optional[int], str]:
    if not text:
        return None, "empty"
    
    text = text.strip()
    
    answer_patterns = [
        r"ANSWER\s*:\s*(\d+)", r"answer\s*:\s*(\d+)", r"Answer\s*:\s*(\d+)",
        r"The answer is\s*:\s*(\d+)", r"The answer is\s+(\d+)",
        r"final answer\s*:\s*(\d+)", r"Final answer\s*:\s*(\d+)",
    ]
    
    for pattern in answer_patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            try:
                return int(match.group(1)), "ANSWER"
            except ValueError:
                continue
    
    boxed_patterns = [r"\\\\boxed\{(\d+)\}", r"\\boxed\{(\d+)\}", r"\$\\\\boxed\{(\d+)\}\$", r"\$\\boxed\{(\d+)\}\$"]
    
    for pattern in boxed_patterns:
        match = re.search(pattern, text)
        if match:
            try:
                return int(match.group(1)), "BOXED"
            except ValueError:
                continue
    
    integers = re.findall(r"\b(\d+)\b", text)
    if integers:
        try:
            return int(integers[-1]), "LASTINT"
        except ValueError:
            pass
    
    return None, "none"

def validate_answer(answer: Optional[int]) -> Tuple[bool, int]:
    if answer is None:
        return False, 0
    
    if not isinstance(answer, int):
        try:
            answer = int(answer)
        except (ValueError, TypeError):
            return False, 0
    
    if 0 <= answer <= 99999:
        return True, answer
    
    return False, max(0, min(99999, answer))

def safe_extract_answer(text: str) -> Tuple[int, Dict[str, Any]]:
    raw_answer, method = extract_answer(text)
    is_valid, final_answer = validate_answer(raw_answer)
    
    metadata = {
        "raw_answer": raw_answer, "method": method,
        "is_valid": is_valid, "fallback_used": not is_valid,
    }
    
    if not is_valid:
        final_answer = 0
        metadata["fallback_value"] = 0
    
    return final_answer, metadata

print("Answer extraction functions initialized")

## CELL F — CANDIDATE GENERATION + SELF-CONSISTENCY

In [None]:
# ================================
# CELL F — CANDIDATE GENERATION + SELF-CONSISTENCY
# ================================

def generate_one(problem: str, temperature: float = 0.3, max_new_tokens: int = MAX_NEW_TOKENS, prompt_style: str = "tir") -> Tuple[str, Dict[str, Any]]:
    model, tokenizer, device = load_model()
    
    if prompt_style == "tir":
        prompt = get_tir_prompt(problem)
    elif prompt_style == "concise":
        prompt = get_concise_prompt(problem)
    else:
        prompt = get_explore_prompt(problem)
    
    meta = {"prompt_style": prompt_style, "temperature": temperature, "max_new_tokens": max_new_tokens}
    
    if model is None:
        return "", meta
    
    try:
        import torch
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs, max_new_tokens=max_new_tokens,
                temperature=temperature if temperature > 0 else 1.0,
                do_sample=temperature > 0, top_p=0.95,
                pad_token_id=tokenizer.eos_token_id,
            )
        
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = generated[len(prompt):].strip() if prompt in generated else generated.strip()
        return response, meta
        
    except Exception as e:
        meta["error"] = str(e)
        return "", meta

def execute_code_in_response(response: str) -> str:
    code = parse_python_block(response)
    if code:
        ok, output = run_python(code)
        if ok:
            response += f"\n\n[Code Output]\n{output}"
    return response

def generate_candidates(problem: str, k: int = K_BASE, temperature_schedule: List[float] = None) -> List[Dict[str, Any]]:
    if temperature_schedule is None:
        temperature_schedule = [TEMPERATURE_BASE] * k
    
    candidates = []
    
    for i in range(k):
        temp = temperature_schedule[i] if i < len(temperature_schedule) else TEMPERATURE_BASE
        raw_text, meta = generate_one(problem, temperature=temp, prompt_style=PROMPT_STYLE)
        processed_text = execute_code_in_response(raw_text)
        answer, answer_meta = safe_extract_answer(processed_text)
        
        candidates.append({
            "answer": answer, "raw_text": raw_text,
            "processed_text": processed_text, "metadata": {**meta, **answer_meta},
        })
    
    return candidates

def vote_candidates(candidates: List[Dict[str, Any]]) -> Dict[str, Any]:
    if not candidates:
        return {"top_answer": 0, "top_count": 0, "total": 0, "vote_margin": 0.0, "entropy": 0.0, "answer_counts": {}}
    
    answers = [c["answer"] for c in candidates]
    counter = Counter(answers)
    total = len(answers)
    
    most_common = counter.most_common()
    top_answer, top_count = most_common[0]
    
    second_count = most_common[1][1] if len(most_common) > 1 else 0
    vote_margin = (top_count - second_count) / total
    
    entropy = 0.0
    for count in counter.values():
        p = count / total
        if p > 0:
            entropy -= p * math.log2(p)
    
    return {
        "top_answer": top_answer, "top_count": top_count, "total": total,
        "vote_margin": vote_margin, "entropy": entropy,
        "answer_counts": dict(counter.most_common(5)),
    }

def should_early_stop(vote_result: Dict[str, Any], threshold: float = VOTING_THRESHOLD) -> bool:
    if vote_result["total"] == 0:
        return False
    return vote_result["top_count"] / vote_result["total"] >= threshold

print("Candidate generation functions initialized")

## CELL G — VERIFIER

In [None]:
# ================================
# CELL G — VERIFIER (Rule-based + Optional LLM)
# ================================

def rule_verifier(problem: str, answer: int) -> Dict[str, Any]:
    checks = []
    passed = True
    reason = "OK"
    
    if not (0 <= answer <= 99999):
        passed = False
        reason = f"Answer {answer} out of valid range [0, 99999]"
        checks.append(("range_check", False, reason))
    else:
        checks.append(("range_check", True, "In valid range"))
    
    problem_lower = problem.lower()
    if "remainder" in problem_lower or "modulo" in problem_lower or "mod " in problem_lower:
        mod_patterns = [r"divided by\s+(\d+)", r"modulo\s+(\d+)", r"mod\s+(\d+)", r"\(mod\s*(\d+)\)"]
        for pattern in mod_patterns:
            match = re.search(pattern, problem_lower)
            if match:
                mod_val = int(match.group(1))
                if answer >= mod_val and mod_val < 100000:
                    checks.append(("mod_check", False, f"Answer {answer} >= modulo {mod_val}"))
                else:
                    checks.append(("mod_check", True, f"Answer {answer} < modulo {mod_val}"))
                break
    
    return {"passed": passed, "reason": reason, "checks": checks}

def llm_verifier(problem: str, answer: int) -> Dict[str, Any]:
    model, tokenizer, device = load_model()
    
    if model is None:
        return {"passed": True, "reason": "LLM not available", "response": ""}
    
    prompt = f"""Given this math problem and proposed answer, quickly check if the answer could be correct.
If you find a clear error or contradiction, say INVALID. Otherwise say VALID.

Problem: {problem}

Proposed Answer: {answer}

Verification (VALID or INVALID):"""
    
    try:
        import torch
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.1, do_sample=False, pad_token_id=tokenizer.eos_token_id)
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip()
        passed = "INVALID" not in response.upper()
        
        return {"passed": passed, "reason": "VALID" if passed else "INVALID found", "response": response[:200]}
        
    except Exception as e:
        return {"passed": True, "reason": f"Error: {str(e)}", "response": ""}

def verify_answer(problem: str, answer: int, use_llm: bool = False) -> Dict[str, Any]:
    result = {"rule_verifier": rule_verifier(problem, answer), "llm_verifier": None, "final_passed": True}
    
    if not result["rule_verifier"]["passed"]:
        result["final_passed"] = False
    
    if use_llm and result["rule_verifier"]["passed"]:
        result["llm_verifier"] = llm_verifier(problem, answer)
        if not result["llm_verifier"]["passed"]:
            result["final_passed"] = False
    
    return result

print("Verifier functions initialized")

## CELL H — SOLVER ORCHESTRATOR

In [None]:
# ================================
# CELL H — SOLVER ORCHESTRATOR
# ================================

def solve_problem(problem_id: str, problem_text: str) -> Tuple[int, Dict[str, Any]]:
    start_time = time.time()
    
    telemetry = {
        "id": problem_id, "elapsed_sec": 0, "k_used": 0, "candidates": [],
        "chosen_answer": 0, "vote_margin": 0.0, "verifier_used": False,
        "verifier_pass": True, "tool_calls_count": 0, "parse_method_used": "none",
        "mode": "EASY", "temperature_schedule": [], "seed": CURRENT_SEED,
        "run_policy": MODE_POLICY, "is_rerun": IS_KAGGLE_RERUN,
    }
    
    try:
        if MODE_POLICY == "diverse" and IS_KAGGLE_RERUN:
            mode, k = "HARD", K_MAX_HARD
            temp_schedule = [TEMPERATURE_HARD] * k
        else:
            mode, k = "EASY", K_BASE
            temp_schedule = [TEMPERATURE_BASE] * k
        
        telemetry["mode"] = mode
        telemetry["temperature_schedule"] = temp_schedule
        
        candidates = []
        for i in range(k):
            if time.time() - start_time > TIME_BUDGET_SEC_PER_PROBLEM * 0.8:
                break
            
            raw_text, meta = generate_one(problem_text, temperature=temp_schedule[i])
            processed_text = execute_code_in_response(raw_text)
            answer, answer_meta = safe_extract_answer(processed_text)
            
            candidates.append({"answer": answer, "raw_text": raw_text[:500], "metadata": {**meta, **answer_meta}})
            
            if parse_python_block(raw_text):
                telemetry["tool_calls_count"] += 1
        
        telemetry["k_used"] = len(candidates)
        
        vote_result = vote_candidates(candidates)
        telemetry["vote_margin"] = vote_result["vote_margin"]
        telemetry["candidates"] = [(a, c) for a, c in vote_result["answer_counts"].items()]
        
        chosen_answer = vote_result["top_answer"]
        
        for c in candidates:
            if c["answer"] == chosen_answer:
                telemetry["parse_method_used"] = c["metadata"].get("method", "none")
                break
        
        use_llm_verifier = vote_result["vote_margin"] < 0.3 and mode == "HARD"
        verification = verify_answer(problem_text, chosen_answer, use_llm=use_llm_verifier)
        
        telemetry["verifier_used"] = True
        telemetry["verifier_pass"] = verification["final_passed"]
        
        if not verification["final_passed"]:
            for answer, count in vote_result["answer_counts"].items():
                if answer != chosen_answer:
                    alt_verify = verify_answer(problem_text, answer, use_llm=False)
                    if alt_verify["final_passed"]:
                        chosen_answer = answer
                        break
        
        is_valid, final_answer = validate_answer(chosen_answer)
        if not is_valid:
            telemetry["fallback_used"] = True
            final_answer = 0
        
        telemetry["chosen_answer"] = final_answer
        
    except Exception as e:
        telemetry["error"] = str(e)
        final_answer = 0
        telemetry["chosen_answer"] = final_answer
        telemetry["fallback_used"] = True
    
    telemetry["elapsed_sec"] = time.time() - start_time
    
    return final_answer, telemetry

print("Solver orchestrator initialized")

## CELL I — TELEMETRY LOGGER

In [None]:
# ================================
# CELL I — TELEMETRY LOGGER
# ================================

def append_jsonl(filepath: str, data: Dict[str, Any]):
    os.makedirs(os.path.dirname(filepath), exist_ok=True) if os.path.dirname(filepath) else None
    with open(filepath, "a") as f:
        f.write(json.dumps(data, default=str) + "\n")
        f.flush()

def read_telemetry(filepath: str) -> List[Dict[str, Any]]:
    entries = []
    if os.path.exists(filepath):
        with open(filepath, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    try:
                        entries.append(json.loads(line))
                    except json.JSONDecodeError:
                        pass
    return entries

def log_telemetry(telemetry: Dict[str, Any]):
    append_jsonl(LOG_PATH, telemetry)

def print_telemetry_summary(telemetry_list: List[Dict[str, Any]]):
    if not telemetry_list:
        print("No telemetry data available")
        return
    
    n = len(telemetry_list)
    total_time = sum(t.get("elapsed_sec", 0) for t in telemetry_list)
    avg_time = total_time / n if n > 0 else 0
    
    parse_methods = Counter(t.get("parse_method_used", "none") for t in telemetry_list)
    parse_fail_rate = parse_methods.get("none", 0) / n if n > 0 else 0
    
    k_values = [t.get("k_used", 0) for t in telemetry_list]
    avg_k = sum(k_values) / n if n > 0 else 0
    
    verifier_used = sum(1 for t in telemetry_list if t.get("verifier_used", False))
    verifier_pass = sum(1 for t in telemetry_list if t.get("verifier_pass", True))
    
    fallback_used = sum(1 for t in telemetry_list if t.get("fallback_used", False))
    
    print("\n" + "="*50)
    print("TELEMETRY SUMMARY")
    print("="*50)
    print(f"Total problems: {n}")
    print(f"Total time: {total_time:.2f}s")
    print(f"Avg time per problem: {avg_time:.2f}s")
    print(f"Parse fail rate: {parse_fail_rate:.2%}")
    print(f"Avg k_used: {avg_k:.1f}")
    print(f"K distribution: {Counter(k_values)}")
    print(f"Parse methods: {dict(parse_methods)}")
    print(f"Verifier usage: {verifier_used}/{n}")
    print(f"Verifier pass rate: {verifier_pass}/{verifier_used if verifier_used > 0 else 1}")
    print(f"Fallback used: {fallback_used}/{n}")
    print("="*50)

print("Telemetry logger initialized")
print(f"Log path: {LOG_PATH}")

## CELL J — LOCAL HARNESS (Reference CSV Regression)

In [None]:
# ================================
# CELL J — LOCAL HARNESS (Reference CSV Regression)
# ================================

def run_reference_eval(csv_path: str = None, limit: int = None) -> Dict[str, Any]:
    csv_path = csv_path or REFERENCE_CSV_PATH
    
    if not os.path.exists(csv_path):
        print(f"Reference CSV not found: {csv_path}")
        return {"error": "File not found", "accuracy": 0.0}
    
    print(f"\nRunning reference evaluation on: {csv_path}")
    print("="*60)
    
    df = pd.read_csv(csv_path)
    
    if limit:
        df = df.head(limit)
    
    n_problems = len(df)
    print(f"Evaluating {n_problems} problems...\n")
    
    results = []
    correct = 0
    telemetry_list = []
    
    for idx, row in df.iterrows():
        problem_id = str(row["id"])
        problem_text = row["problem"]
        expected_answer = int(row["answer"])
        
        print(f"[{idx+1}/{n_problems}] Problem {problem_id}...")
        
        predicted_answer, telemetry = solve_problem(problem_id, problem_text)
        
        telemetry["expected_answer"] = expected_answer
        telemetry["is_correct"] = (predicted_answer == expected_answer)
        log_telemetry(telemetry)
        telemetry_list.append(telemetry)
        
        is_correct = (predicted_answer == expected_answer)
        if is_correct:
            correct += 1
            status = "Y"
        else:
            status = "X"
        
        print(f"  {status} Predicted: {predicted_answer}, Expected: {expected_answer} ({telemetry['elapsed_sec']:.2f}s)")
        
        results.append({
            "id": problem_id, "predicted": predicted_answer,
            "expected": expected_answer, "correct": is_correct,
            "elapsed_sec": telemetry["elapsed_sec"],
        })
    
    accuracy = correct / n_problems if n_problems > 0 else 0.0
    
    print("\n" + "="*60)
    print(f"ACCURACY: {correct}/{n_problems} = {accuracy:.2%}")
    print("="*60)
    
    print_telemetry_summary(telemetry_list)
    
    return {"accuracy": accuracy, "correct": correct, "total": n_problems, "results": results}

print("Local harness initialized")

## CELL K — SUBMISSION GLUE (Kaggle Evaluation API)

In [None]:
# ================================
# CELL K — SUBMISSION GLUE (Kaggle Evaluation API)
# ================================

import sys
import os

kaggle_eval_paths = ["/kaggle/input/kaggle-evaluation", "/kaggle/input", ".", ".."]

for path in kaggle_eval_paths:
    if os.path.exists(os.path.join(path, "kaggle_evaluation")):
        sys.path.insert(0, path)
        break

def predict(test_df: pd.DataFrame) -> pd.DataFrame:
    results = []
    
    for idx, row in test_df.iterrows():
        problem_id = str(row["id"])
        problem_text = str(row["problem"])
        
        answer, telemetry = solve_problem(problem_id, problem_text)
        log_telemetry(telemetry)
        
        results.append({"id": problem_id, "answer": int(answer)})
    
    return pd.DataFrame(results)

def setup_and_serve():
    try:
        from kaggle_evaluation.aimo_3_inference_server import AIMO3InferenceServer
        
        server = AIMO3InferenceServer(predict)
        
        if os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
            print("Starting inference server (competition mode)...")
            server.serve()
        else:
            print("Running local gateway test...")
            server.run_local_gateway()
            
    except ImportError as e:
        print(f"kaggle_evaluation not available: {e}")
        print("Running in local-only mode")

print("Submission glue initialized")
print("Use setup_and_serve() to start the server")

## SELF TEST — Unit Tests

In [None]:
# ================================
# SELF TEST 1: Schema/Predict Unit Test
# ================================

def test_predict_schema():
    print("\nTEST 1: Schema/Predict Unit Test")
    print("-"*40)
    
    test_df = pd.DataFrame({"id": ["test001"], "problem": ["What is $1+1$?"]})
    result_df = predict(test_df)
    
    assert "id" in result_df.columns, "Missing 'id' column"
    assert "answer" in result_df.columns, "Missing 'answer' column"
    assert len(result_df) == 1, f"Expected 1 row, got {len(result_df)}"
    
    answer = result_df["answer"].iloc[0]
    assert isinstance(answer, (int, type(1))), f"Answer should be int, got {type(answer)}"
    assert 0 <= answer <= 99999, f"Answer {answer} out of range [0, 99999]"
    
    print(f"OK Output schema correct")
    print(f"OK Answer: {answer} (valid int in [0, 99999])")
    print("TEST 1 PASSED\n")
    return True

if RUN_MODE == "local_ref":
    try:
        test_predict_schema()
    except AssertionError as e:
        print(f"TEST 1 FAILED: {e}")
    except Exception as e:
        print(f"TEST 1 ERROR: {e}")

In [None]:
# ================================
# SELF TEST 2: Reference Evaluation
# ================================

def test_reference_eval():
    print("\nTEST 2: Reference Evaluation")
    print("-"*40)
    
    result = run_reference_eval(limit=2)
    
    assert "accuracy" in result, "Missing 'accuracy' in result"
    assert "total" in result, "Missing 'total' in result"
    assert result["total"] == 2, f"Expected 2 problems, got {result['total']}"
    
    print(f"OK Accuracy: {result['accuracy']:.2%}")
    print(f"OK No crashes during evaluation")
    print("TEST 2 PASSED\n")
    return True

if RUN_MODE == "local_ref":
    try:
        if os.path.exists(REFERENCE_CSV_PATH):
            test_reference_eval()
        else:
            print("Skipping TEST 2: reference.csv not found")
    except AssertionError as e:
        print(f"TEST 2 FAILED: {e}")
    except Exception as e:
        print(f"TEST 2 ERROR: {e}")

In [None]:
# ================================
# SELF TEST 3: Tool Executor
# ================================

def test_tool_executor():
    print("\nTEST 3: Tool Executor")
    print("-"*40)
    
    # Test 1: Simple arithmetic
    ok, output = run_python("result = 2 + 3")
    assert ok, f"Execution failed: {output}"
    assert "5" in output, f"Expected '5' in output, got: {output}"
    print(f"OK Simple arithmetic: 2+3 = 5")
    
    # Test 2: Math module (already loaded, no import needed)
    ok, output = run_python("result = math.factorial(5)")
    assert ok, f"Execution failed: {output}"
    assert "120" in output, f"Expected '120' in output, got: {output}"
    print(f"OK Math module: factorial(5) = 120")
    
    # Test 3: Code with import statement (should be stripped)
    ok, output = run_python("import math\nresult = math.sqrt(16)")
    assert ok, f"Execution failed: {output}"
    assert "4" in output, f"Expected '4' in output, got: {output}"
    print(f"OK Import stripping works: sqrt(16) = 4")
    
    # Test 4: Timeout handling
    ok, output = run_python("x = 1", timeout_sec=1)
    assert ok, f"Simple code should not timeout"
    print(f"OK Timeout handling works")
    
    print("TEST 3 PASSED\n")
    return True

if RUN_MODE == "local_ref":
    try:
        test_tool_executor()
    except AssertionError as e:
        print(f"TEST 3 FAILED: {e}")
    except Exception as e:
        print(f"TEST 3 ERROR: {e}")

In [None]:
# ================================
# SELF TEST 4: Answer Extraction
# ================================

def test_answer_extraction():
    print("\nTEST 4: Answer Extraction")
    print("-"*40)
    
    test_cases = [
        ("ANSWER: 42", 42, "ANSWER"),
        ("The answer is 123", 123, "LASTINT"),
        ("After calculation, we get 456. ANSWER: 456", 456, "ANSWER"),
        ("Result: 789", 789, "LASTINT"),
    ]
    
    for text, expected_answer, expected_method in test_cases:
        answer, method = extract_answer(text)
        assert answer == expected_answer, f"Expected {expected_answer}, got {answer} for '{text}'"
        assert method == expected_method, f"Expected method {expected_method}, got {method}"
        print(f"OK '{text[:30]}...' -> {answer} ({method})")
    
    print("TEST 4 PASSED\n")
    return True

if RUN_MODE == "local_ref":
    try:
        test_answer_extraction()
    except AssertionError as e:
        print(f"TEST 4 FAILED: {e}")
    except Exception as e:
        print(f"TEST 4 ERROR: {e}")

## MAIN EXECUTION

In [None]:
# ================================
# MAIN EXECUTION
# ================================

if __name__ == "__main__":
    print("\n" + "="*60)
    print("AIMO3 BASELINE NOTEBOOK")
    print("="*60)
    print(f"RUN_MODE: {RUN_MODE}")
    print(f"CURRENT_SEED: {CURRENT_SEED}")
    print(f"IS_KAGGLE_RERUN: {IS_KAGGLE_RERUN}")
    print("="*60 + "\n")
    
    if RUN_MODE == "local_ref":
        print("Running in LOCAL/DEBUG mode...")
        print("Evaluating reference.csv...\n")
        
        if os.path.exists(LOG_PATH):
            os.remove(LOG_PATH)
        
        if os.path.exists(REFERENCE_CSV_PATH):
            result = run_reference_eval()
            print(f"\nFinal Accuracy: {result['accuracy']:.2%}")
        else:
            print(f"Reference CSV not found: {REFERENCE_CSV_PATH}")
            print("Running self-tests only...")
        
    elif RUN_MODE == "submit_auto":
        print("Running in SUBMISSION mode...")
        print("Starting inference server...\n")
        
        setup_and_serve()
    
    else:
        print(f"Unknown RUN_MODE: {RUN_MODE}")
        print("Valid options: 'local_ref', 'submit_auto'")

print("\nNotebook execution complete.")