<a href="https://colab.research.google.com/github/quantexolution/aimo/blob/main/Experiments_with_AIMO_3_%7C_GPT_OSS_120B(with_tools).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

ai_mathematical_olympiad_progress_prize_3_path = kagglehub.competition_download('ai-mathematical-olympiad-progress-prize-3')
andreasbis_aimo_3_utils_path = kagglehub.notebook_output_download('andreasbis/aimo-3-utils')
danielhanchen_gpt_oss_120b_transformers_default_1_path = kagglehub.model_download('danielhanchen/gpt-oss-120b/Transformers/default/1')

print('Data source import complete.')


In [None]:
%pip uninstall --yes 'keras' 'matplotlib' 'scikit-learn' 'tensorflow'

In [None]:
import warnings
warnings.simplefilter('ignore')

In [None]:
import os
import sys
import subprocess

In [None]:
def set_env(input_archive, temp_dir):

    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir, exist_ok=True)

        subprocess.run(['tar', '-xzf', input_archive, '-C', temp_dir], check=True)

    subprocess.run([
        sys.executable,
        '-m',
        'pip',
        'install',
        '--no-index',
        '--find-links',
        f'{temp_dir}/wheels',
        'unsloth',
        'trl',
        'vllm',
        'openai_harmony'
    ], check=True)

In [None]:
set_env(
    input_archive='/kaggle/input/aimo-3-utils/wheels.tar.gz',
    temp_dir='/kaggle/tmp/setup'
)

In [None]:
subprocess.run(['ls', '/kaggle/tmp/setup/tiktoken_encodings'])

In [None]:
os.environ['TRANSFORMERS_NO_TF'] = '1'
os.environ['TRANSFORMERS_NO_FLAX'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['TRITON_PTXAS_PATH'] = '/usr/local/cuda/bin/ptxas'
os.environ['TIKTOKEN_ENCODINGS_BASE'] = '/kaggle/tmp/setup/tiktoken_encodings'

In [None]:
import gc
import re
import math
import time
import queue
import threading
import contextlib
from typing import Optional
from jupyter_client import KernelManager
from collections import Counter, defaultdict
from concurrent.futures import as_completed, ThreadPoolExecutor

import pandas as pd
import polars as pl

from openai import OpenAI

from openai_harmony import (
    HarmonyEncodingName,
    load_harmony_encoding,
    SystemContent,
    DeveloperContent,
    ReasoningEffort,
    ToolNamespaceConfig,
    Author,
    Message,
    Role,
    TextContent,
    Conversation
)

from transformers import set_seed
import kaggle_evaluation.aimo_3_inference_server

In [None]:
class CFG:

    system_prompt_1 = (
        'You are a world-class International Mathematical Olympiad (IMO) competitor. '
        'The final answer must be a non-negative integer between 0 and 99999. '
        'You must place the final integer answer inside \\boxed{}.'
    )

    tool_prompt = (
        'Use this tool to execute Python code. '
        'The environment is a stateful Jupyter notebook. '
        'You must use print() to output results.'
    )

        # Behavioral instructions (go in DEVELOPER role - active commands)
    developer_prompt_1 = (
        'You are a world-class International Mathematical Olympiad (IMO) competitor. '
        'The final answer must be a non-negative integer between 0 and 99999. '
        'Output the probability of your answer being correct inside \\probability{}. '
        'You must place the final integer answer inside \\boxed{}. '
    )

    # developer_prompt_2 = (
    #     'You are an expert mathematician who solves problems systematically. '
    #     'Break down the problem into steps. Verify each step with Python. '
    #     'The final answer must be an integer in [0, 99999] inside \\boxed{}.'
    # )

    # developer_prompt_3 = (
    #     'You are a computational mathematician. Use Python to verify ALL calculations. '
    #     'Never perform arithmetic mentally - always use code. '
    #     'Final answer: integer in [0, 99999] inside \\boxed{}.'
    # )

        # Array for ensemble rotation
    developer_prompts = [
        developer_prompt_1,

    ]

    # === PROMPT ENSEMBLE (Diversity for better coverage) ===
    system_prompts = [
        # Prompt 0: Standard IMO competitor
        (
            'You are a world-class International Mathematical Olympiad (IMO) competitor. '
            'The final answer must be a non-negative integer between 0 and 99999. '
            'Output the probability of your answer being correct inside \\probability{}. '
            'You must place the final integer answer inside \\boxed{}.'
        ),
        # # Prompt 1: Methodical step-by-step
        # (
        #     'You are an expert mathematician who solves problems systematically. '
        #     'Break down the problem into clear steps. Verify each step with Python code. '
        #     'The final answer must be an integer in [0, 99999] inside \\boxed{}.'
        # ),
        # # Prompt 2: Computation-focused (always use code)
        # (
        #     'You are a computational mathematician. Use Python to verify ALL calculations. '
        #     'Never perform arithmetic mentally - always use code to compute and verify. '
        #     'Final answer: integer in [0, 99999] inside \\boxed{}.'
        # ),
    ]

    preference_prompts = [
        'Use all tools possible for correct mathematical computations. You have a jupyter notebook available which keeps full memory of all intermediate variables.',
        'Use `sympy` for symbolic computation and verify results numerically.'
    ]

    served_model_name = 'gpt-oss'
    model_path = '/kaggle/input/gpt-oss-120b/transformers/default/1'

    kv_cache_dtype = 'fp8_e4m3'
    dtype = 'auto'

    high_problem_timeout = 900
    base_problem_timeout = 300

    notebook_limit = 17400
    server_timeout = 180

    session_timeout = 960
    jupyter_timeout = 10
    sandbox_timeout = 5

    stream_interval = 200
    context_tokens = 65536
    search_tokens = 128
    buffer_tokens = 512
    batch_size = 256

    # === EARLY STOPPING ===
    early_stop = 4                 # Votes needed to stop early
    min_samples_before_stop = 4     # Min samples before early stop allowed

    attempts = 8
    workers = 16
    turns = 128
    seed = 68

    gpu_memory_utilization = 0.96
    temperature = 0.5
    min_p = 0.01

    # === TEMPERATURE SCHEDULE ===
    base_temperature = 0.75
    temp_low = 0                  # Lower bound for temperature variation
    temp_high = 1.4                # Upper bound for temperature variation


In [None]:
set_seed(CFG.seed)

In [None]:
class AIMO3Template:

    def __init__(self):

        pass

    def get_system_content(self, system_prompt: str, tool_config: ToolNamespaceConfig) -> SystemContent:

        return (
            SystemContent.new()
            .with_model_identity(system_prompt)
            .with_reasoning_effort(reasoning_effort=ReasoningEffort.HIGH)
            .with_tools(tool_config)
        )

    def apply_chat_template(
        self,
        system_prompt: str,
        developer_prompt: str,  # Changed from system_prompt
        user_prompt: str,
        tool_config: ToolNamespaceConfig
    ) -> list[Message]:

        # 1. SYSTEM message: Identity + tools only
        system_content = self.get_system_content(system_prompt, tool_config)
        system_message = Message.from_role_and_content(Role.SYSTEM, system_content)

        # 2. DEVELOPER message: Behavioral instructions (ACTIVE)
        developer_content = DeveloperContent.new().with_instructions(developer_prompt)
        developer_message = Message.from_role_and_content(Role.DEVELOPER, developer_content)

        # 3. USER message: The problem
        user_message = Message.from_role_and_content(Role.USER, user_prompt)

        return [system_message, developer_message, user_message]



In [None]:
class AIMO3Sandbox:

    _port_lock = threading.Lock()
    _next_port = 50000

    @classmethod
    def _get_next_ports(cls, count: int = 5) -> list[int]:

        with cls._port_lock:
            ports = list(range(cls._next_port, cls._next_port + count))
            cls._next_port += count

            return ports

    def __init__(self, timeout: float):

        self._default_timeout = timeout
        self._owns_kernel = False
        self._client = None
        self._km = None

        ports = self._get_next_ports(5)

        env = os.environ.copy()
        env['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'
        env['PYDEVD_WARN_EVALUATION_TIMEOUT'] = '0'
        env['JUPYTER_PLATFORM_DIRS'] = '1'
        env['PYTHONWARNINGS'] = 'ignore'
        env['MPLBACKEND'] = 'Agg'

        self._km = KernelManager()
        self._km.shell_port = ports[0]
        self._km.iopub_port = ports[1]
        self._km.stdin_port = ports[2]
        self._km.hb_port = ports[3]
        self._km.control_port = ports[4]

        self._km.start_kernel(env=env, extra_arguments=['--Application.log_level=CRITICAL'])

        self._client = self._km.blocking_client()
        self._client.start_channels()
        self._client.wait_for_ready(timeout=self._default_timeout)
        self._owns_kernel = True

        # self.execute(
        #     'import math\n'
        #     'import sympy\n'
        #     'import itertools\n'
        #     'import collections\n'
        #     'import numpy as np\n'
        #     'import mpmath\n'
        #     'mpmath.mp.dps = 64\n'
        #     'from fractions import Fraction\n'

        # )
        self.execute(
            'import math, cmath, decimal, fractions, itertools, functools, collections, random, sys\n'
            'from fractions import Fraction\n'
            'from decimal import Decimal\n'
            'from collections import Counter, defaultdict, deque\n'
            'import numpy as np\n'
            'import mpmath\n'
            'mpmath.mp.dps = 64\n'
            'import sympy\n'
            'from sympy import *\n'
            'from sympy.ntheory import *\n'
            'from sympy.ntheory.modular import crt\n'
            'decimal.getcontext().prec = 50\n'
            'sys.setrecursionlimit(10000)\n'
        )


    def _format_error(self, traceback: list[str]) -> str:

        clean_lines = []

        for frame in traceback:
            clean_frame = re.sub(r'\x1b\[[0-9;]*m', '', frame)

            if 'File "' in clean_frame and 'ipython-input' not in clean_frame:
                continue

            clean_lines.append(clean_frame)

        return ''.join(clean_lines)

    def execute(self, code: str, timeout: float | None = None) -> str:

        client = self._client
        effective_timeout = timeout or self._default_timeout

        msg_id = client.execute(
            code,
            store_history=True,
            allow_stdin=False,
            stop_on_error=False
        )

        stdout_parts = []
        stderr_parts = []

        start_time = time.time()

        while True:
            elapsed = time.time() - start_time

            if elapsed > effective_timeout:
                self._km.interrupt_kernel()

                return f'[ERROR] Execution timed out after {effective_timeout} seconds'

            try:
                msg = client.get_iopub_msg(timeout=1.0)

            except queue.Empty:
                continue

            if msg.get('parent_header', {}).get('msg_id') != msg_id:
                continue

            msg_type = msg.get('msg_type')
            content = msg.get('content', {})

            if msg_type == 'stream':
                text = content.get('text', '')

                if content.get('name') == 'stdout':
                    stdout_parts.append(text)

                else:
                    stderr_parts.append(text)

            elif msg_type == 'error':
                traceback_list = content.get('traceback', [])

                stderr_parts.append(self._format_error(traceback_list))

            elif msg_type in {'execute_result', 'display_data'}:
                data = content.get('data', {})
                text = data.get('text/plain')

                if text:
                    stdout_parts.append(text if text.endswith('\n') else f'{text}\n')

            elif msg_type == 'status':
                if content.get('execution_state') == 'idle':
                    break

        stdout = ''.join(stdout_parts)
        stderr = ''.join(stderr_parts)

        if stderr:
            return f'{stdout.rstrip()}\n{stderr}' if stdout else stderr

        return stdout if stdout.strip() else '[WARN] No output. Use print() to see results.'

    def close(self):

        with contextlib.suppress(Exception):
            if self._client:
                self._client.stop_channels()

        if self._owns_kernel and self._km is not None:
            with contextlib.suppress(Exception):
                self._km.shutdown_kernel(now=True)

            with contextlib.suppress(Exception):
                self._km.cleanup_resources()

    def reset(self):

        self.execute('%reset -f')
        self.execute('import gc; gc.collect()')

        self.execute(
            'import math\n'
            'import sympy\n'
            'import itertools\n'
            'import collections\n'
            'import numpy as np\n'
            'import mpmath\n'
            'mpmath.mp.dps = 64\n'

        )

    def __del__(self):

        self.close()

In [None]:
class AIMO3Tool:

    def __init__(self, local_jupyter_timeout: float, tool_prompt: str, sandbox=None):

        self._local_jupyter_timeout = local_jupyter_timeout
        self._tool_prompt = tool_prompt
        self._jupyter_session = sandbox

        self._owns_session = sandbox is None

        self._execution_lock = threading.Lock()
        self._init_lock = threading.Lock()

    def _ensure_session(self):

        if self._jupyter_session is None:
            with self._init_lock:
                if self._jupyter_session is None:
                    self._jupyter_session = AIMO3Sandbox(timeout=self._local_jupyter_timeout)

    def _ensure_last_print(self, code: str) -> str:

        lines = code.strip().split('\n')

        if not lines:
            return code

        last_line = lines[-1].strip()

        if 'print' in last_line or 'import' in last_line:
            return code

        if not last_line:
            return code

        if last_line.startswith('#'):
            return code

        lines[-1] = 'print(' + last_line + ')'

        return '\n'.join(lines)

    @property
    def instruction(self) -> str:

        return self._tool_prompt

    @property
    def tool_config(self) -> ToolNamespaceConfig:

        return ToolNamespaceConfig(
            name='python',
            description=self.instruction,
            tools=[]
        )

    def _make_response(self, output: str, channel: str | None = None) -> Message:

        content = TextContent(text=output)
        author = Author(role=Role.TOOL, name='python')
        message = Message(author=author, content=[content]).with_recipient('assistant')

        if channel:
            message = message.with_channel(channel)

        return message

    def process_sync_plus(self, message: Message) -> list[Message]:

        self._ensure_session()
        raw_script = message.content[0].text
        final_script = self._ensure_last_print(raw_script)
        MAX_OUTPUT_LEN = 680

        with self._execution_lock:
            try:
                output = self._jupyter_session.execute(final_script)
                # 2. OUTPUT TRUNCATI10N
                if len(output) > MAX_OUTPUT_LEN:
                    output = output[:MAX_OUTPUT_LEN] + f"\n... [Output truncated. Total length: {len(output)} chars]"
            except TimeoutError as exc:
                output = f'[ERROR] {exc}'

        return [self._make_response(output, channel=message.channel)]

    def close(self):

        if self._jupyter_session is not None:
            if self._owns_session:
                self._jupyter_session.close()

            self._jupyter_session = None

    def __del__(self):

        self.close()

In [None]:
def get_error_guidance(error_text: str) -> str:
    """Provide helpful guidance for common Python errors."""
    error_lower = error_text.lower()

    guidance_map = {
        'timeout': "Computation timed out. Try a more efficient algorithm or reduce search space.",
        'memory': "Memory limit exceeded. Use generators, process in batches, or reduce data size.",
        'killed': "Process killed (likely memory). Reduce memory usage.",
        'overflow': "Numerical overflow. Use modular arithmetic or sympy.Integer for big numbers.",
        'zerodivision': "Division by zero. Add a check before dividing.",
        'index': "Index out of range. Check array bounds before accessing.",
        'keyerror': "Key not found. Verify dictionary keys exist before access.",
        'typeerror': "Type mismatch. Check operation validity for data types used.",
        'syntax': "Syntax error. Check for missing colons, parentheses, or indentation.",
        'nameerror': "Undefined variable. Ensure all variables are defined before use.",
        'recursion': "Recursion limit. Use iteration instead or sys.setrecursionlimit().",
    }

    for keyword, guidance in guidance_map.items():
        if keyword in error_lower:
            return guidance

    return ""

In [None]:
class AIMO3Solver:

    def __init__(self, cfg, port: int = 8000):

        self.cfg = cfg
        self.port = port
        self.base_url = f'http://0.0.0.0:{port}/v1'
        self.api_key = 'sk-local'
        self.template = AIMO3Template()
        self.encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
        self.stop_token_ids = self.encoding.stop_tokens_for_assistant_actions()

        self._preload_model_weights()

        self.server_process = self._start_server()

        self.client = OpenAI(
            base_url=self.base_url,
            api_key=self.api_key,
            timeout=self.cfg.session_timeout
        )

        self._wait_for_server()
        self._initialize_kernels()

        self.notebook_start_time = time.time()
        self.problems_remaining = 50

    def _get_attempt_config(self, attempt_index: int, problem: str) -> dict:
        """Get configuration for a specific attempt."""

        # Rotate through prompt ensembles
        num_system = len(self.cfg.system_prompts)
        num_dev = len(self.cfg.developer_prompts)
        num_pref = len(self.cfg.preference_prompts)

        system_prompt = self.cfg.system_prompts[attempt_index % num_system]
        pref_prompt = self.cfg.preference_prompts[attempt_index % num_pref]
        developer_prompt = self.cfg.developer_prompts[attempt_index % num_dev]

        # Temperature variation: first half lower, second half higher
        base_temp = self.cfg.base_temperature
        half_attempts = self.cfg.attempts // 2

        if attempt_index < half_attempts:
            # Lower temperature for more deterministic reasoning
            temp_offset = -0.25 * (attempt_index % 3)
            temperature = max(self.cfg.temp_low, base_temp + temp_offset)
        else:
            # Higher temperature for more creative approaches
            temp_offset = 0.25 * ((attempt_index - half_attempts) % 3)
            temperature = min(self.cfg.temp_high, base_temp + temp_offset)

        # Better seed variation (avoid hash collisions)
        # seed = self.cfg.seed + attempt_index * 1000 + hash(problem[:50]) % 1000
        seed = int(math.pow(self.cfg.seed + attempt_index, 2))

        return {
            'system_prompt': system_prompt,
            'developer_prompt': developer_prompt,
            'preference_prompt': pref_prompt,
            'temperature': temperature,
            'seed': seed,
        }


    def _preload_model_weights(self) -> None:

        print(f'Loading model weights from {self.cfg.model_path} into OS Page Cache...')
        start_time = time.time()

        files_to_load = []
        total_size = 0

        for root, _, files in os.walk(self.cfg.model_path):
            for file_name in files:
                file_path = os.path.join(root, file_name)

                if os.path.isfile(file_path):
                    files_to_load.append(file_path)
                    total_size += os.path.getsize(file_path)

        def _read_file(path: str) -> None:

            with open(path, 'rb') as file_object:
                while file_object.read(1024 * 1024 * 1024):
                    pass

        with ThreadPoolExecutor(max_workers=self.cfg.workers) as executor:
            list(executor.map(_read_file, files_to_load))

        elapsed = time.time() - start_time
        print(f'Processed {len(files_to_load)} files ({total_size / 1e9:.2f} GB) in {elapsed:.2f} seconds.\n')

    def _start_server(self) -> subprocess.Popen:

        cmd = [
            sys.executable,
            '-m',
            'vllm.entrypoints.openai.api_server',
            '--seed',
            str(self.cfg.seed),
            '--model',
            self.cfg.model_path,
            '--served-model-name',
            self.cfg.served_model_name,
            '--tensor-parallel-size',
            '1',
            '--max-num-seqs',
            str(self.cfg.batch_size),
            '--gpu-memory-utilization',
            str(self.cfg.gpu_memory_utilization),
            '--host',
            '0.0.0.0',
            '--port',
            str(self.port),
            '--dtype',
            self.cfg.dtype,
            '--kv-cache-dtype',
            self.cfg.kv_cache_dtype,
            '--max-model-len',
            str(self.cfg.context_tokens),
            '--stream-interval',
            str(self.cfg.stream_interval),
            '--enable-prefix-caching'
        ]

        self.log_file = open('vllm_server.log', 'w')

        return subprocess.Popen(
            cmd,
            stdout=self.log_file,
            stderr=subprocess.STDOUT,
            start_new_session=True
        )

    def _wait_for_server(self):

        print('Waiting for vLLM server...')
        start_time = time.time()

        for _ in range(self.cfg.server_timeout):
            return_code = self.server_process.poll()

            if return_code is not None:
                self.log_file.flush()

                with open('vllm_server.log', 'r') as log_file:
                    logs = log_file.read()

                raise RuntimeError(f'Server died with code {return_code}. Full logs:\n{logs}\n')

            try:
                self.client.models.list()
                elapsed = time.time() - start_time
                print(f'Server is ready (took {elapsed:.2f} seconds).\n')

                return

            except Exception:
                time.sleep(1)

        raise RuntimeError('Server failed to start (timeout).\n')

    def _initialize_kernels(self) -> None:

        print(f'Initializing {self.cfg.workers} persistent Jupyter kernels...')
        start_time = time.time()

        self.sandbox_pool = queue.Queue()

        def _create_sandbox():

            return AIMO3Sandbox(timeout=self.cfg.jupyter_timeout)

        with ThreadPoolExecutor(max_workers=self.cfg.workers) as executor:
            futures = [executor.submit(_create_sandbox) for _ in range(self.cfg.workers)]

            for future in as_completed(futures):
                self.sandbox_pool.put(future.result())

        elapsed = time.time() - start_time
        print(f'Kernels initialized in {elapsed:.2f} seconds.\n')

    def _scan_for_answer(self, text: str) -> int | None:

        pattern = r'\\boxed\s*\{\s*([0-9,]+)\s*\}'
        matches = re.findall(pattern, text)

        if matches:
            try:
                clean_value = matches[-1].replace(',', '')
                value = int(clean_value)

                if 0 <= value <= 99999:
                    return value

            except ValueError:
                pass


        pattern = r'final\s+answer\s+is\s*([0-9,]+)'
        matches = re.findall(pattern, text, re.IGNORECASE)

        if matches:
            try:
                clean_value = matches[-1].replace(',', '')
                value = int(clean_value)

                if 0 <= value <= 99999:
                    return value

            except ValueError:
                pass

        return None


    def _scan_for_probability(self, text: str) -> float:
        """
        Scans for the probability format \probability{0.95}
        Returns 0.0 if not found or invalid.
        """
        if not text:
            return 0.0

        # Pattern looks for \probability{...}
        # It captures content inside the curly braces
        pattern = r"\\probability\{(.*?)\}"
        matches = re.findall(pattern, text)

        if not matches:
            return 0.0

        # We take the LAST probability output, just like we do for boxed answers
        last_match = matches[-1].strip()

        try:
            prob = float(last_match)
            # Clamp between 0 and 1 just in case
            return max(0.0, min(1.0, prob))
        except ValueError:
            return 0.0

    def _compute_mean_entropy(self, logprobs_buffer: list) -> float:

        if not logprobs_buffer:
            return float('inf')

        total_entropy = 0.0
        token_count = 0

        for top_logprobs_dict in logprobs_buffer:

            if not isinstance(top_logprobs_dict, dict):
                continue

            if not top_logprobs_dict:
                continue

            token_entropy = 0.0

            for token_str, log_prob in top_logprobs_dict.items():
                prob = math.exp(log_prob)

                if prob > 0:
                    token_entropy -= prob * math.log2(prob)

            total_entropy += token_entropy
            token_count += 1

        if token_count == 0:
            return float('inf')

        return total_entropy / token_count


    def _process_attempt_v2(
        self,
        problem: str,
        system_prompt: str,
        attempt_index: int,
        stop_event: threading.Event,
        deadline: float
    ) -> dict:
        """Process a single attempt with detailed logging for analysis."""

        attempt_start = time.time()

        if stop_event.is_set() or time.time() > deadline:
            return {
                'Attempt': attempt_index + 1,
                'Answer': None,
                'Python Calls': 0,
                'Python Errors': 0,
                'Response Length': 0,
                'Turns': 0,
                'Time': 0,
                'Exit Reason': 'skipped',
            }

        # Get attempt-specific configuration
        config = self._get_attempt_config(attempt_index, problem)

        print(f"\n{'='*60}")
        print(f"üöÄ ATTEMPT {attempt_index + 1} STARTED")
        print(f"{'='*60}")
        print(f"  Temperature: {config['temperature']:.2f}")
        print(f"  Seed: {config['seed']}")
        print(f"  System prompt: {config['system_prompt'][:80]}...")
        print(f"  Preference: {config['preference_prompt'][:60]}...")
        print(f"{'='*60}\n")

        local_tool = None
        sandbox = None
        python_calls = 0
        python_errors = 0
        total_tokens = 0
        final_answer = None
        final_prob = 0.0
        consecutive_errors = 0
        max_consecutive_errors = 3
        turn_count = 0
        exit_reason = 'unknown'
        search_text = ''

        # Track per-turn data for analysis
        turn_log = []

        try:
            sandbox = self.sandbox_pool.get(timeout=self.cfg.sandbox_timeout)

            local_tool = AIMO3Tool(
                local_jupyter_timeout=self.cfg.jupyter_timeout,
                tool_prompt=self.cfg.tool_prompt,
                sandbox=sandbox
            )

            # Build prompt with preference
            full_problem = f"{problem} {config['preference_prompt']}"

            encoding = self.encoding
            messages = self.template.apply_chat_template(
                config['system_prompt'],
                config['developer_prompt'],
                full_problem,
                local_tool.tool_config
            )

            conversation = Conversation.from_messages(messages)
            waiting_for_summary = False

            for turn in range(self.cfg.turns):
                turn_start = time.time()
                turn_count = turn + 1

                if stop_event.is_set():
                    exit_reason = 'early_stop_signal'
                    print(f"  [Turn {turn_count}] ‚ö° Early stop signal received")
                    break

                if time.time() > deadline:
                    exit_reason = 'timeout'
                    print(f"  [Turn {turn_count}] ‚è∞ Deadline exceeded")
                    break

                prompt_ids = encoding.render_conversation_for_completion(
                    conversation, Role.ASSISTANT
                )
                max_tokens = self.cfg.context_tokens - len(prompt_ids)


                # ------------------------------------------------------------------
                # üß† SMART CONTEXT MANAGEMENT (Summarize -> Prune)
                # ------------------------------------------------------------------

                current_tokens = len(prompt_ids)
                TOKEN_LIMIT = self.cfg.context_tokens  # Adjust to your model (e.g. 32000 for deepseek/qwen)

                # 2. Check if we just received the summary we asked for
                if waiting_for_summary:
                    print(f"  [Turn {turn}] üìâ Summary received. Compressing history...")

                    # The last message is the Model's Summary.
                    # The message before that was our "Please Summarize" request.
                    # We want to keep: System(0), Problem(1), and the Summary(-1).

                    summary_content = conversation.messages[-1].content[0].text
                    print(f" Summary content : {summary_content}")

                    # FIXED: Use TextContent + List for the bridge message
                    # state_msg = Message(
                    #     role=Role.USER,
                    #     content=[
                    #         TextContent(text=(
                    #             f"--- PREVIOUS WORK SUMMARY ---\n"
                    #             f"{summary_content}\n"
                    #             f"-----------------------------\n"
                    #             f"[SYSTEM: Memory cleared. The summary above contains all derived facts. "
                    #             f"All Python variables are still active. Proceed with the next step.]"
                    #         ))
                    #     ]
                    # )

                    state_msg = Message.from_role_and_content(
                        Role.USER,
                        f"--- PREVIOUS WORK SUMMARY ---\n"
                        f"{summary_content}\n"
                        f"-----------------------------\n"
                        f"[SYSTEM: Memory cleared. The summary above contains all derived facts. "
                        f"All Python variables are still active. Proceed with the next step.]"
                    )

                    # Rebuild History: System -> Problem -> Summary State
                    conversation.messages = [
                        conversation.messages[0],
                        conversation.messages[1],
                        state_msg
                    ]

                    waiting_for_summary = False

                    # Re-render tokens since history changed
                    prompt_ids = encoding.render_conversation_for_completion(
                        conversation, Role.ASSISTANT
                    )

                # 3. Trigger Summarization if Context is Full (and we aren't already waiting)
                elif current_tokens > (TOKEN_LIMIT * 1.85):
                    print(f"  [Turn {turn}] üßπ Context full ({current_tokens} tok). Requesting summary...")

                    # # FIXED: Use TextContent + List for the request message
                    # summary_request = Message(
                    #     role=Role.USER,
                    #     content=[
                    #         TextContent(text=(
                    #             "‚ö†Ô∏è SYSTEM ALERT: MEMORY LIMIT REACHED.\n"
                    #             "We must clear the context window.\n\n"
                    #             "Please provide a CONCISE SUMMARY of the current state.\n"
                    #             "1. List all verified variable values (e.g., n=5, k=12).\n"
                    #             "2. State the last derived mathematical fact.\n"
                    #             "3. State the immediate next step.\n"
                    #             "DO NOT output Python code. Just text."
                    #         ))
                    #     ]
                    # )

                    # FIXED: Request summary when context is full
                    summary_request = Message.from_role_and_content(
                        Role.USER,
                        "‚ö†Ô∏è SYSTEM ALERT: MEMORY LIMIT REACHED.\n"
                        "We must clear the context window.\n\n"
                        "Please provide a CONCISE SUMMARY of the current state.\n"
                        "1. List all verified variable values (e.g., n=5, k=12).\n"
                        "2. State the last derived mathematical fact.\n"
                        "3. State the immediate next step.\n"
                        "DO NOT output Python code. Just text."
                    )

                    conversation.messages.append(summary_request)
                    waiting_for_summary = True

                    # We continue the loop so the model can generate the summary in this turn
                    # The NEXT iteration will catch it in step #2 above and prune.

                # ------------------------------------------------------------------


                if max_tokens < self.cfg.buffer_tokens:
                    exit_reason = 'context_full'
                    print(f"  [Turn {turn_count}] üì¶ Context window full ({len(prompt_ids)} tokens used)")
                    break

                print(f"  [Turn {turn_count}] üîÑ Generating... (context: {len(prompt_ids)} tokens, max_new: {max_tokens})")

                stream = self.client.completions.create(
                    model=self.cfg.served_model_name,
                    temperature=config['temperature'],
                    max_tokens=max_tokens,
                    prompt=prompt_ids,
                    seed=config['seed'],
                    stream=True,
                    extra_body={
                        'min_p': self.cfg.min_p,
                        'stop_token_ids': self.stop_token_ids,
                        'return_token_ids': True
                    }
                )

                turn_tokens = 0
                try:
                    token_buffer = []
                    text_chunks = []
                    num_answers_found = 0
                    answers_found = []


                    for chunk in stream:
                        if stop_event.is_set() or time.time() > deadline:
                            break

                        new_tokens = chunk.choices[0].token_ids
                        new_text = chunk.choices[0].text

                        if new_tokens:
                            token_buffer.extend(new_tokens)
                            total_tokens += len(new_tokens)
                            turn_tokens += len(new_tokens)
                            text_chunks.append(new_text)

                        # Early answer detection during streaming
                        if '}' in new_text:
                            search_text = ''.join(text_chunks[-self.cfg.search_tokens:])
                            answer = self._scan_for_answer(search_text)

                            if answer is not None:
                                final_answer = answer
                                final_prob = self._scan_for_probability(search_text)
                                num_answers_found += 1
                                answers_found.append(final_answer)

                                print(f"  [Turn {turn_count}] üéØ Answer found in stream: {answer}, num_answers_found: {num_answers_found}")

                                if num_answers_found > 2:
                                    exit_reason = 'answer_found_streaming'
                                    print(answers_found)
                                    print(f"  [Turn {turn_count}] üéØ Answer found in stream: {answer}")
                                    break

                finally:
                    stream.close()

                turn_time = time.time() - turn_start
                full_response = ''.join(text_chunks)

                # Log turn summary
                turn_info = {
                    'turn': turn_count,
                    'tokens': turn_tokens,
                    'time': turn_time,
                    'response_preview': full_response[:200] if full_response else '',
                }
                turn_log.append(turn_info)

                if final_answer is not None:
                    break

                if not token_buffer:
                    exit_reason = 'empty_response'
                    print(f"  [Turn {turn_count}] ‚ö†Ô∏è Empty response from model")
                    break

                try:
                    new_messages = encoding.parse_messages_from_completion_tokens(
                        token_buffer, Role.ASSISTANT
                    )
                except Exception:
                    # If parsing fails (e.g. due to interruption/EOS),
                    # we just ignore this incomplete turn and stop.
                    print(f"  [Turn {turn}] ‚ö†Ô∏è Parsing failed (incomplete output ignored)")
                    break

                conversation.messages.extend(new_messages)
                # Safety check: ensure we actually got messages back
                if not new_messages:
                    break

                last_message = new_messages[-1]

                # Log what the model is doing
                if last_message.channel == 'final':
                    answer_text = last_message.content[0].text
                    final_answer = self._scan_for_answer(answer_text)
                    final_prob = self._scan_for_probability(search_text)
                    exit_reason = 'final_channel'

                    print(f"  [Turn {turn_count}] üìù Final response ({turn_tokens} tokens, {turn_time:.1f}s)")
                    print(f"  [Turn {turn_count}] üí¨ Content: {answer_text[:300]}{'...' if len(answer_text) > 300 else ''}")

                    if final_answer is not None:
                        print(f"  [Turn {turn_count}] üéØ Answer extracted: {final_answer}")
                    else:
                        print(f"  [Turn {turn_count}] ‚ùå No valid answer found in response")
                    break

                # Handle Python tool calls
                if last_message.recipient == 'python':
                    python_calls += 1
                    code_content = last_message.content[0].text if last_message.content else "N/A"

                    print(f"  [Turn {turn_count}] üêç Python call #{python_calls} ({turn_tokens} tokens, {turn_time:.1f}s)")
                    print(f"  [Turn {turn_count}] üìÑ Code:\n{'‚îÄ'*40}")
                    # Indent code for readability
                    for line in code_content.split('\n')[:15]:  # First 15 lines
                        print(f"    {line}")
                    if code_content.count('\n') > 15:
                        print(f"    ... ({code_content.count(chr(10)) - 15} more lines)")
                    print(f"{'‚îÄ'*40}")

                    tool_responses = local_tool.process_sync_plus(last_message)
                    response_text = tool_responses[0].content[0].text

                    is_error = (
                        response_text.startswith('[ERROR]') or
                        'Traceback' in response_text or
                        'Error:' in response_text or
                        'Exception' in response_text
                    )

                    if is_error:
                        python_errors += 1
                        consecutive_errors += 1

                        print(f"  [Turn {turn_count}] ‚ùå Error #{python_errors} (consecutive: {consecutive_errors})")
                        print(f"  [Turn {turn_count}] üìÑ Error output:\n{'‚îÄ'*40}")
                        for line in response_text.split('\n')[:10]:
                            print(f"    {line}")
                        print(f"{'‚îÄ'*40}")

                        # Add error guidance
                        if consecutive_errors <= max_consecutive_errors:
                            guidance = get_error_guidance(response_text)
                            if guidance:
                                print(f"  [Turn {turn_count}] üí° Guidance: {guidance}")
                                enhanced_response = f"{response_text}\n\nHint: {guidance}"
                                tool_responses[0].content[0].text = enhanced_response

                        if (consecutive_errors >= max_consecutive_errors) and (consecutive_errors < max_consecutive_errors + 2):
                            enhanced_response = f"{response_text}\n\nHint: Take it easy. Reset everything and approach the problem from a fresh perspective."
                            tool_responses[0].content[0].text = enhanced_response

                            # force_msg = Message(
                            #     role=Role.USER,  # Use the Enum if possible, or try 'user'
                            #     content=[
                            #         TextContent(text=(
                            #             "STOP. You have triggered 3 consecutive Python errors. "
                            #             "Take it easy. Reset everything and approach the problem from a fresh perspective."
                            #         ))
                            #     ]
                            # )

                            # FIXED: Suggest fresh approach after consecutive errors
                            force_msg = Message.from_role_and_content(
                                Role.USER,
                                "STOP. You have triggered 3 consecutive Python errors. "
                                "Take it easy. Reset everything and approach the problem from a fresh perspective."
                            )

                            conversation.messages.append(force_msg)
                            print(f"  [Turn {turn_count}] üîÑ Suggesting fresh approach")

                        if consecutive_errors >= max_consecutive_errors + 2:
                            exit_reason = 'too_many_errors'
                            print(f"  [Turn {turn_count}] üõë Too many consecutive errors, stopping")
                            break
                    else:
                        consecutive_errors = 0

                        # Show successful output (truncated)
                        print(f"  [Turn {turn_count}] ‚úÖ Success! Output:")
                        print(f"{'‚îÄ'*40}")
                        output_lines = response_text.split('\n')
                        for line in output_lines[:8]:
                            print(f"    {line[:100]}{'...' if len(line) > 100 else ''}")
                        if len(output_lines) > 8:
                            print(f"    ... ({len(output_lines) - 8} more lines)")
                        print(f"{'‚îÄ'*40}")

                        # Check if output contains a potential answer
                        potential_answer = self._scan_for_answer(response_text)
                        if potential_answer is not None:
                            print(f"  [Turn {turn_count}] üëÄ Potential answer in output: {potential_answer}")

                    conversation.messages.extend(tool_responses)
                else:
                    # Model did something else (reasoning, etc.)
                    print(f"  [Turn {turn_count}] üí≠ Model response ({turn_tokens} tokens, {turn_time:.1f}s)")
                    print(f"  [Turn {turn_count}] Channel: {last_message.channel}, Recipient: {last_message.recipient}")
                    if last_message.content:
                        preview = last_message.content[0].text[:200] if hasattr(last_message.content[0], 'text') else str(last_message.content[0])[:200]
                        print(f"  [Turn {turn_count}] Preview: {preview}...")

            else:
                # Loop completed without break
                exit_reason = 'max_turns'
                print(f"  ‚ö†Ô∏è Reached max turns ({self.cfg.turns})")

        except Exception as exc:
            import traceback
            python_errors += 1
            exit_reason = f'exception:{type(exc).__name__}'

            print(f"\n{'!'*60}")
            print(f"‚ùå EXCEPTION in attempt {attempt_index + 1}")
            print(f"{'!'*60}")
            print(f"Type: {type(exc).__name__}")
            print(f"Message: {exc}")
            print(f"Traceback:")
            traceback.print_exc()
            print(f"{'!'*60}\n")

        finally:
            if local_tool is not None:
                local_tool.close()

            if sandbox is not None:
                sandbox.reset()
                self.sandbox_pool.put(sandbox)

        attempt_time = time.time() - attempt_start

        # Print attempt summary
        print(f"\n{'‚îÄ'*60}")
        print(f"üìä ATTEMPT {attempt_index + 1} SUMMARY")
        print(f"{'‚îÄ'*60}")
        print(f"  Answer: {final_answer if final_answer is not None else 'None'}")
        print(f"  Prob: {final_prob if final_prob is not None else 'None'}")
        print(f"  Exit Reason: {exit_reason}")
        print(f"  Turns: {turn_count}")
        print(f"  Total Tokens: {total_tokens}")
        print(f"  Python Calls: {python_calls}")
        print(f"  Python Errors: {python_errors}")
        print(f"  Time: {attempt_time:.1f}s")
        print(f"  Tokens/sec: {total_tokens/attempt_time:.1f}" if attempt_time > 0 else "  Tokens/sec: N/A")
        print(f"{'‚îÄ'*60}\n")

        return {
            'Attempt': attempt_index + 1,
            'Answer': final_answer,
            'Prob' : final_prob,
            'Python Calls': python_calls,
            'Python Errors': python_errors,
            'Response Length': total_tokens,
            'Turns': turn_count,
            'Time': round(attempt_time, 1),
            'Exit Reason': exit_reason,
            'Temperature': config['temperature'],
            'System prompt': config["system_prompt"],
            'Dev prompt': config["developer_prompt"],
            'Pref prompt': config["preference_prompt"],
            'Final Text': search_text,

        }



    def _process_attempt(
        self,
        problem: str,
        system_prompt: str,
        attempt_index: int,
        stop_event: threading.Event,
        deadline: float
    ) -> dict:
        """Process a single attempt with improved error handling."""

        if stop_event.is_set() or time.time() > deadline:
            return {
                'Attempt': attempt_index + 1,
                'Answer': None,
                'Python Calls': 0,
                'Python Errors': 0,
                'Response Length': 0,
                'Confidence': 0.0,
            }

        # Get attempt-specific configuration
        config = self._get_attempt_config(attempt_index, problem)

        local_tool = None
        sandbox = None
        python_calls = 0
        python_errors = 0
        total_tokens = 0
        final_answer = None
        consecutive_errors = 0
        max_consecutive_errors = 3
        search_text = ''

        try:
            sandbox = self.sandbox_pool.get(timeout=self.cfg.sandbox_timeout)

            local_tool = AIMO3Tool(
                local_jupyter_timeout=self.cfg.jupyter_timeout,
                tool_prompt=self.cfg.tool_prompt,
                sandbox=sandbox
            )

            # Build prompt with preference
            full_problem = f"{problem} {config['preference_prompt']}"

            encoding = self.encoding
            messages = self.template.apply_chat_template(
                config['system_prompt'],
                config['developer_prompt'],
                full_problem,
                local_tool.tool_config
            )

            conversation = Conversation.from_messages(messages)

            for turn in range(self.cfg.turns):
                if stop_event.is_set() or time.time() > deadline:
                    break

                prompt_ids = encoding.render_conversation_for_completion(
                    conversation, Role.ASSISTANT
                )
                max_tokens = self.cfg.context_tokens - len(prompt_ids)

                if max_tokens < self.cfg.buffer_tokens:
                    break

                stream = self.client.completions.create(
                    model=self.cfg.served_model_name,
                    temperature=config['temperature'],  # Use varied temperature
                    max_tokens=max_tokens,
                    prompt=prompt_ids,
                    seed=config['seed'],
                    stream=True,
                    extra_body={
                        'min_p': self.cfg.min_p,
                        'stop_token_ids': self.stop_token_ids,
                        'return_token_ids': True
                    }
                )

                try:
                    token_buffer = []
                    text_chunks = []

                    for chunk in stream:
                        if stop_event.is_set() or time.time() > deadline:
                            break

                        new_tokens = chunk.choices[0].token_ids
                        new_text = chunk.choices[0].text

                        if new_tokens:
                            token_buffer.extend(new_tokens)
                            total_tokens += len(new_tokens)
                            text_chunks.append(new_text)

                        # Early answer detection during streaming
                        if '}' in new_text:
                            search_text = ''.join(text_chunks[-self.cfg.search_tokens:])
                            answer = self._scan_for_answer(search_text)

                            if answer is not None:
                                final_answer = answer
                                break

                finally:
                    stream.close()

                if final_answer is not None:
                    break

                if not token_buffer:
                    break

                new_messages = encoding.parse_messages_from_completion_tokens(
                    token_buffer, Role.ASSISTANT
                )
                conversation.messages.extend(new_messages)
                last_message = new_messages[-1]
                # print(last_message)

                if last_message.channel == 'final':
                    answer_text = last_message.content[0].text
                    # print(f"Answer text {answer_text}")
                    final_answer = self._scan_for_answer(answer_text)
                    break

                # Handle Python tool calls with error recovery
                if last_message.recipient == 'python':
                    python_calls += 1
                    tool_responses = local_tool.process_sync_plus(last_message)
                    response_text = tool_responses[0].content[0].text

                    is_error = (
                        response_text.startswith('[ERROR]') or
                        'Traceback' in response_text or
                        'Error:' in response_text or
                        'Exception' in response_text
                    )

                    if is_error:
                        python_errors += 1
                        consecutive_errors += 1

                        # Add error guidance if we haven't hit max consecutive errors
                        if consecutive_errors <= max_consecutive_errors:
                            guidance = get_error_guidance(response_text)
                            if guidance:
                                enhanced_response = f"{response_text}\n\nHint: {guidance}"
                                tool_responses[0].content[0].text = enhanced_response

                        # If too many consecutive errors, might be stuck
                        if (consecutive_errors >= max_consecutive_errors) and (consecutive_errors < max_consecutive_errors + 2):
                            enhanced_response = f"{response_text}\n\nHint: Take it easy. Reset everything and approach the problem from a fresh perspective."
                            tool_responses[0].content[0].text = enhanced_response

                        # If too many consecutive errors, might be stuck
                        if consecutive_errors >= max_consecutive_errors + 2:
                            break

                    else:
                        consecutive_errors = 0  # Reset on success

                    conversation.messages.extend(tool_responses)

        except Exception as exc:
            import traceback
            python_errors += 1

            print(f"\n{'='*50}")
            print(f"ERROR in attempt {attempt_index + 1}")
            print(f"{'='*50}")
            print(f"Exception type: {type(exc).__name__}")
            print(f"Exception message: {exc}")
            print(f"Traceback:")
            traceback.print_exc()
            print(f"{'='*50}")
            print(f"State at failure:")
            print(f"  - Python calls: {python_calls}")
            print(f"  - Python errors: {python_errors}")
            print(f"  - Total tokens: {total_tokens}")
            print(f"  - Final answer: {final_answer}")
            print(f"  - Time remaining: {deadline - time.time():.1f}s")
            print(f"{'='*50}\n")

        finally:
            if local_tool is not None:
                local_tool.close()

            if sandbox is not None:
                sandbox.reset()
                self.sandbox_pool.put(sandbox)

        ## Config info
        return {
            'Attempt': attempt_index + 1,
            'Response Length': total_tokens,
            'Python Calls': python_calls,
            'Python Errors': python_errors,
            'Answer': final_answer,
            'Temperature': config["temperature"],
            'System prompt': config["system_prompt"],
            'Dev prompt': config["developer_prompt"],
            'Pref prompt': config["preference_prompt"],
            'Final Text': search_text,

        }

    def _select_answer(self, detailed_results: list) -> int:

        stats = defaultdict(lambda: {'votes': 0, 'calls': 0})

        for result in detailed_results:
            answer = result['Answer']

            if answer is not None:
                stats[answer]['votes'] += 1
                stats[answer]['calls'] += result['Python Calls']

        sorted_stats = sorted(
            stats.items(),
            key=lambda item: (item[1]['votes'], item[1]['calls']),
            reverse=True
        )

        vote_data = []

        for answer, data in sorted_stats:
            vote_data.append((answer, data['votes'], data['calls']))

        vote_dataframe = pd.DataFrame(vote_data, columns=['Answer', 'Votes', 'Calls'])
        display(vote_dataframe)

        final_answer = sorted_stats[0][0]
        final_votes = sorted_stats[0][1]['votes']
        final_calls = sorted_stats[0][1]['calls']

        print(f'\nFinal Result: {final_answer} | Votes: {final_votes} | Calls: {final_calls}\n')

        return final_answer

    def solve_problem_v2(self, problem: str) -> int:
        """Solve a problem with ensemble prompts and adaptive early stopping."""

        print(f"\n{'#'*70}")
        print(f"# NEW PROBLEM")
        print(f"{'#'*70}")
        print(f"\n{problem}\n")
        print(f"{'#'*70}\n")

        # Calculate time budget
        elapsed_global = time.time() - self.notebook_start_time
        time_left = self.cfg.notebook_limit - elapsed_global
        problems_left_others = max(0, self.problems_remaining - 1)
        reserved_time = problems_left_others * self.cfg.base_problem_timeout

        budget = time_left - reserved_time
        budget = min(budget, self.cfg.high_problem_timeout)
        budget = max(budget, self.cfg.base_problem_timeout)

        deadline = time.time() + budget

        print(f"‚è±Ô∏è  Time Budget: {budget:.0f}s | Problems Remaining: {self.problems_remaining}")
        print(f"‚è±Ô∏è  Global Elapsed: {elapsed_global:.0f}s | Reserved for others: {reserved_time:.0f}s\n")

        # Prepare tasks
        tasks = []
        for attempt_index in range(self.cfg.attempts):
            config = self._get_attempt_config(attempt_index, problem)
            tasks.append((config['system_prompt'], attempt_index))

        detailed_results = []
        valid_answers = []

        stop_event = threading.Event()
        executor = ThreadPoolExecutor(max_workers=self.cfg.workers)

        solve_start = time.time()

        try:
            futures = []

            for (system_prompt, attempt_index) in tasks:
                config = self._get_attempt_config(attempt_index, problem)
                full_problem = f"{problem} {config['preference_prompt']}"

                future = executor.submit(
                    self._process_attempt_v2,
                    full_problem,
                    system_prompt,
                    attempt_index,
                    stop_event,
                    deadline
                )
                futures.append(future)

            for future in as_completed(futures):
                try:
                    result = future.result()
                    detailed_results.append(result)

                    if result['Answer'] is not None:
                        valid_answers.append(result['Answer'])

                    # Adaptive early stopping
                    completed = len(detailed_results)

                    if completed >= self.cfg.min_samples_before_stop:
                        counts = Counter(valid_answers).most_common()

                        if counts:
                            top_count = counts[0][1]
                            threshold = max(
                                self.cfg.early_stop,
                                (completed // 2) + 1
                            )

                            if top_count >= threshold:
                                has_clear_winner = (
                                    len(counts) == 1 or
                                    counts[0][1] > counts[1][1] + 1
                                )

                                if has_clear_winner:
                                    print(f"\n‚ö° EARLY STOP: Answer {counts[0][0]} has {top_count}/{completed} votes")
                                    stop_event.set()

                                    for f in futures:
                                        f.cancel()
                                    break

                except Exception as exc:
                    print(f'‚ùå Future failed: {exc}')
                    continue

        finally:
            executor.shutdown(wait=False, cancel_futures=True)
            self.problems_remaining = max(0, self.problems_remaining - 1)

        solve_time = time.time() - solve_start

        # ========== ANALYSIS SUMMARY ==========
        print(f"\n{'='*70}")
        print(f"üìä PROBLEM ANALYSIS SUMMARY")
        print(f"{'='*70}")

        if detailed_results:
            results_df = pd.DataFrame(detailed_results)
            results_df['Answer'] = results_df['Answer'].astype('Int64')


            # --- NEW CODE START ---
            # 1. Add a column to identify which problem these attempts belong to
            # (Using the first 100 characters of the problem text as a unique tag)
            results_df['Problem_Preview'] = problem[:100]

            # 2. Define filename
            output_filename = "all_detailed_results.csv"

            # 3. Check if file exists so we only write the header once
            write_header = not os.path.exists(output_filename)

            # 4. Append to CSV (mode='a')
            try:
                results_df.to_csv(output_filename, mode='a', header=write_header, index=False)
                print(f"  üíæ Saved attempt details to {output_filename}")
            except Exception as e:
                print(f"  ‚ö†Ô∏è Could not save to file: {e}")
            # --- NEW CODE END ---

            # Basic stats
            total_attempts = len(detailed_results)
            successful = sum(1 for r in detailed_results if r['Answer'] is not None)
            total_python_calls = sum(r['Python Calls'] for r in detailed_results)
            total_python_errors = sum(r['Python Errors'] for r in detailed_results)
            total_tokens = sum(r['Response Length'] for r in detailed_results)

            print(f"\nüìà ATTEMPT STATISTICS:")
            print(f"  Attempts completed: {total_attempts}/{self.cfg.attempts}")
            print(f"  Successful (found answer): {successful}/{total_attempts} ({100*successful/total_attempts:.0f}%)")
            print(f"  Total time: {solve_time:.1f}s")
            print(f"  Total tokens: {total_tokens:,}")
            print(f"  Tokens/second: {total_tokens/solve_time:.0f}")

            print(f"\nüêç PYTHON TOOL USAGE:")
            print(f"  Total calls: {total_python_calls}")
            print(f"  Total errors: {total_python_errors}")
            if total_python_calls > 0:
                print(f"  Error rate: {100*total_python_errors/total_python_calls:.1f}%")
            print(f"  Avg calls/attempt: {total_python_calls/total_attempts:.1f}")

            # Exit reasons analysis
            exit_reasons = Counter(r.get('Exit Reason', 'unknown') for r in detailed_results)
            print(f"\nüö™ EXIT REASONS:")
            for reason, count in exit_reasons.most_common():
                print(f"  {reason}: {count}")

            # Temperature analysis
            if 'Temperature' in results_df.columns:
                temp_success = {}
                for r in detailed_results:
                    temp = r.get('Temperature', 0)
                    temp_bucket = f"{temp:.1f}"
                    if temp_bucket not in temp_success:
                        temp_success[temp_bucket] = {'total': 0, 'success': 0}
                    temp_success[temp_bucket]['total'] += 1
                    if r['Answer'] is not None:
                        temp_success[temp_bucket]['success'] += 1

                print(f"\nüå°Ô∏è  TEMPERATURE ANALYSIS:")
                for temp, stats in sorted(temp_success.items()):
                    rate = 100 * stats['success'] / stats['total'] if stats['total'] > 0 else 0
                    print(f"  Temp {temp}: {stats['success']}/{stats['total']} success ({rate:.0f}%)")

            # Answer distribution
            print(f"\nüéØ ANSWER DISTRIBUTION:")
            answer_counts = Counter(valid_answers)
            for answer, count in answer_counts.most_common():
                pct = 100 * count / len(valid_answers) if valid_answers else 0
                print(f"  {answer}: {count} votes ({pct:.0f}%)")

            # Display full results table
            print(f"\nüìã DETAILED RESULTS: TABLES:")
            display(results_df)

        print(f"{'='*70}\n")

        if not valid_answers:
            print('‚ùå No valid answers found. Returning 0.\n')
            return 0

        return self._select_answer(detailed_results)


    def solve_problem(self, problem: str) -> int:
        """Solve a problem with ensemble prompts and adaptive early stopping."""

        print(f'\nProblem: {problem}\n')

        # Calculate time budget
        elapsed_global = time.time() - self.notebook_start_time
        time_left = self.cfg.notebook_limit - elapsed_global
        problems_left_others = max(0, self.problems_remaining - 1)
        reserved_time = problems_left_others * self.cfg.base_problem_timeout

        budget = time_left - reserved_time
        budget = min(budget, self.cfg.high_problem_timeout)
        budget = max(budget, self.cfg.base_problem_timeout)

        deadline = time.time() + budget

        print(f'Budget: {budget:.2f} seconds | Deadline: {deadline:.2f}\n')

        # Prepare tasks with ensemble configurations
        tasks = []
        for attempt_index in range(self.cfg.attempts):
            config = self._get_attempt_config(attempt_index, problem)
            tasks.append((config['system_prompt'], config["developer_prompt"], attempt_index))

        detailed_results = []
        valid_answers = []

        stop_event = threading.Event()
        executor = ThreadPoolExecutor(max_workers=self.cfg.workers)

        try:
            futures = []

            for (system_prompt, dev_prompt, attempt_index) in tasks:
                # Add preference prompt to problem
                config = self._get_attempt_config(attempt_index, problem)
                full_problem = f"{problem} {config['preference_prompt']}"

                future = executor.submit(
                    self._process_attempt_v2,
                    full_problem,
                    system_prompt,
                    attempt_index,
                    stop_event,
                    deadline
                )
                futures.append(future)

            for future in as_completed(futures):
                try:
                    result = future.result()
                    detailed_results.append(result)

                    if result['Answer'] is not None:
                        valid_answers.append(result['Answer'])

                    # Adaptive early stopping
                    completed = len(detailed_results)

                    if completed >= self.cfg.min_samples_before_stop:
                        counts = Counter(valid_answers).most_common()

                        if counts:
                            top_count = counts[0][1]

                            # Dynamic threshold: need > 50% of completed
                            threshold = max(
                                self.cfg.early_stop,
                                (completed // 2) + 1
                            )

                            # Also check for no close second
                            if top_count >= threshold:
                                has_clear_winner = (
                                    len(counts) == 1 or
                                    counts[0][1] > counts[1][1] + 1
                                )

                                if has_clear_winner:
                                    print(f'Early stop: {counts[0][0]} has '
                                          f'{top_count}/{completed} votes')
                                    stop_event.set()

                                    for f in futures:
                                        f.cancel()
                                    break

                except Exception as exc:
                    print(f'Future failed: {exc}')
                    continue

        finally:
            executor.shutdown(wait=False, cancel_futures=True)
            self.problems_remaining = max(0, self.problems_remaining - 1)

        if detailed_results:
            results_df = pd.DataFrame(detailed_results)
            results_df['Answer'] = results_df['Answer'].astype('Int64')
            display(results_df)

        if not valid_answers:
            print('\nResult: 0\n')
            return 0

        return self._select_answer(detailed_results)



    def __del__(self):

        if hasattr(self, 'server_process'):
            self.server_process.terminate()
            self.server_process.wait()

        if hasattr(self, 'log_file'):
            self.log_file.close()

        if hasattr(self, 'sandbox_pool'):
            while not self.sandbox_pool.empty():
                try:
                    sb = self.sandbox_pool.get_nowait()
                    sb.close()

                except Exception:
                    pass

In [None]:
solver = AIMO3Solver(CFG)

In [None]:
def predict(id_: pl.DataFrame, question: pl.DataFrame, answer: Optional[pl.DataFrame] = None) -> pl.DataFrame:
    global correct_count, total_count, predictions

    question_id = id_.item(0)
    question_text = question.item(0)

    print("------")
    print(f"ID: {question_id}")
    print(f"Question: {question_text[:200]}...")

    final_answer = solver.solve_problem_v2(question_text)
    predictions[question_id] = final_answer

    # Check accuracy if ground truth available
    total_count += 1
    if question_id in ground_truth:
        gt = ground_truth[question_id]
        is_correct = (final_answer == gt)
        if is_correct:
            correct_count += 1
        status = "‚úÖ" if is_correct else "‚ùå"
        print(f"Answer: {final_answer} | Ground Truth: {gt} | {status}")
        print(f"üìä Running Accuracy: {correct_count}/{total_count} ({100*correct_count/total_count:.1f}%)")
    else:
        print(f"Answer: {final_answer}")

    print("------\n")

    return pl.DataFrame({'id': question_id, 'answer': final_answer})


In [None]:
import pandas as pd
import uuid

# Load the existing reference data
# Make sure to point to the correct path where your current reference file resides
df = pd.read_csv(
    "/kaggle/input/ai-mathematical-olympiad-progress-prize-3/reference.csv"
)

# Define the 10 new hard problems
new_problems = [
    {
        "problem": r"Let $B$ be the set of rectangular boxes with surface area $54$ and volume $23$. Let $r$ be the radius of the smallest sphere that can contain each of the rectangular boxes that are elements of $B$. The value of $r^2$ can be written as $\frac{p}{q}$, where $p$ and $q$ are relatively prime positive integers. Find $p + q$.",
        "answer": 721
    },
    {
        "problem": r"Find the largest prime number $p < 1000$ for which there exists a complex number $z$ satisfying: the real and imaginary parts of $z$ are both integers; $|z| = \sqrt{p}$; and there exists a triangle whose three side lengths are $p$, the real part of $z^3$, and the imaginary part of $z^3$.",
        "answer": 349
    },
    {
        "problem": r"For each positive integer $n$ let $a_n$ be the least positive integer multiple of $23$ such that $a_n \equiv 1 \pmod{2^n}$. Find the number of positive integers $n$ less than or equal to $1000$ that satisfy $a_n = a_{n+1}$.",
        "answer": 363
    },
    {
        "problem": r"Let $x, y, z$ be positive real numbers satisfying $\sqrt{2x - xy} + \sqrt{2y - xy} = 1$, $\sqrt{2y - yz} + \sqrt{2z - yz} = \sqrt{2}$, $\sqrt{2z - zx} + \sqrt{2x - zx} = \sqrt{3}$. Then $[(1-x)(1-y)(1-z)]^2$ can be written as $\frac{n}{m}$ where $m,n$ are coprime positive integers. Find $m+n$.",
        "answer": 33
    },
    {
        "problem": r"Let $S$ be the set of positive integers $k$ such that the two parabolas $y = x^2 - k$ and $x = 2(y-20)^2 - k$ intersect in four distinct points, and these four points lie on a circle with radius at most $21$. Find the sum of the least element of $S$ and the greatest element of $S$.",
        "answer": 285
    },
    {
        "problem": r"Let $ABC$ be an acute triangle with circumcircle $\omega$, and let $H$ be the intersection of the altitudes. Suppose the tangent to the circumcircle of $\triangle HBC$ at $H$ intersects $\omega$ at $X$ and $Y$ with $HA=3, HX=2, HY=6$. The area of $\triangle ABC$ is $m\sqrt{n}$, where $n$ is square-free. Find $m+n$.",
        "answer": 58
    },
    {
        "problem": r"Let $\triangle ABC$ be an acute scalene triangle with circumcircle $\omega$. The tangents to $\omega$ at $B$ and $C$ intersect at $T$. Let $X$ and $Y$ be the projections of $T$ onto lines $AB$ and $AC$, respectively. Suppose $BT=CT=16$, $BC=22$, and $TX^2+TY^2+XY^2=1143$. Find $XY^2$.",
        "answer": 717
    },
    {
        "problem": r"The area of the smallest equilateral triangle with one vertex on each of the sides of the right triangle with side lengths $2\sqrt{3}, 5, \sqrt{37}$ is $\frac{m\sqrt{p}}{n}$, where $m,n$ are coprime and $p$ is square-free. Find $m+n+p$.",
        "answer": 145
    },
    {
        "problem": r"Tetrahedron $ABCD$ has $AD=BC=28, AC=BD=44, AB=CD=52$. For any point $X$ in space, define $f(X) = AX+BX+CX+DX$. The least possible value of $f(X)$ can be expressed as $m\sqrt{n}$, where $m,n$ are integers and $n$ is square-free. Find $m+n$.",
        "answer": 682
    },
    {
        "problem": r"For each integer $n \geq 2$, let $A(n)$ be the area of the region in the coordinate plane defined by the inequalities $1 \le x < n$ and $0 \le y \le x \lfloor \sqrt{x} \rfloor$. Find the number of values of $n$ with $2 \le n \le 1000$ for which $A(n)$ is an integer.",
        "answer": 483
    }
]

# Create a DataFrame for the new problems
new_rows = []
for p in new_problems:
    # Generate a random 6-character hex ID
    new_rows.append([str(uuid.uuid4().hex)[:6], p['problem'], p['answer']])

df_new = pd.DataFrame(new_rows, columns=['id', 'problem', 'answer'])

# Concatenate the original and new DataFrames
df = pd.concat([df, df_new], ignore_index=True)

# Save the combined DataFrame back to CSV
df.to_csv("reference.csv", index=False)
# df_new.to_csv("reference2.csv", index=False)

print("Successfully appended 10 problems. New shape:", df.shape)

In [None]:
# # Load reference data and keep ground truth for accuracy calculation

# Store ground truth answers for accuracy calculation (only in local mode)
ground_truth = dict(zip(df["id"], df["answer"])) if "answer" in df.columns else {}

# Create input file without answers
df.drop("answer", axis=1, errors="ignore").to_csv("reference.csv", index=False)

# Track predictions for accuracy calculation
predictions = {}
correct_count = 0
total_count = 0




In [None]:
df

In [None]:
import kaggle_evaluation.aimo_3_inference_server

inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)

if os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
    inference_server.serve()
else:
    inference_server.run_local_gateway(("reference.csv",))
    # inference_server.run_local_gateway(("/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv",))

    # Print final accuracy summary
    if ground_truth and total_count > 0:
        print("\n" + "=" * 50)
        print("üìä FINAL ACCURACY SUMMARY")
        print("=" * 50)
        print(f"Correct: {correct_count}/{total_count}")
        print(f"Accuracy: {100*correct_count/total_count:.1f}%")
        print("=" * 50)

        # Show details
        print("\nDetails:")
        for qid, pred in predictions.items():
            if qid in ground_truth:
                gt = ground_truth[qid]
                status = "‚úÖ" if pred == gt else "‚ùå"
                print(f"  {qid}: pred={pred}, gt={gt} {status}")