<a href="https://colab.research.google.com/github/quantexolution/aimo/blob/main/%5B41_50%5D_AIMO_3_Weighted_entropy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

ai_mathematical_olympiad_progress_prize_3_path = kagglehub.competition_download('ai-mathematical-olympiad-progress-prize-3')
andreasbis_aimo_3_utils_path = kagglehub.notebook_output_download('andreasbis/aimo-3-utils')
danielhanchen_gpt_oss_120b_transformers_default_1_path = kagglehub.model_download('danielhanchen/gpt-oss-120b/Transformers/default/1')

print('Data source import complete.')


In [None]:
%pip uninstall --yes 'keras' 'matplotlib' 'scikit-learn' 'tensorflow'

In [None]:
import warnings; warnings.simplefilter('ignore')
import os, sys, subprocess, gc, re, math, time, queue, threading, contextlib

In [None]:
def set_env(archive, tmp):
    if not os.path.exists(tmp):
        os.makedirs(tmp, exist_ok=True)
        subprocess.run(['tar', '-xzf', archive, '-C', tmp], check=True)
    subprocess.run([sys.executable, '-m', 'pip', 'install', '--no-index', '--find-links', f'{tmp}/wheels',
                    'unsloth', 'trl', 'vllm', 'openai_harmony'], check=True)

set_env('/kaggle/input/aimo-3-utils/wheels.tar.gz', '/kaggle/tmp/setup')

In [None]:
subprocess.run(['ls', '/kaggle/tmp/setup/tiktoken_encodings'])

In [None]:
for k, v in [('TRANSFORMERS_NO_TF', '1'), ('TRANSFORMERS_NO_FLAX', '1'), ('CUDA_VISIBLE_DEVICES', '0'),
             ('TOKENIZERS_PARALLELISM', 'false'), ('TRITON_PTXAS_PATH', '/usr/local/cuda/bin/ptxas'),
             ('TIKTOKEN_ENCODINGS_BASE', '/kaggle/tmp/setup/tiktoken_encodings')]:
    os.environ[k] = v

In [None]:
from typing import Optional
from jupyter_client import KernelManager
from collections import Counter, defaultdict
from concurrent.futures import as_completed, ThreadPoolExecutor
import pandas as pd, polars as pl
from openai import OpenAI
from openai_harmony import (HarmonyEncodingName, load_harmony_encoding, SystemContent, ReasoningEffort,
                             ToolNamespaceConfig, Author, Message, Role, TextContent, Conversation)
from transformers import set_seed
import kaggle_evaluation.aimo_3_inference_server

In [None]:
class CFG:
    system_prompt = ('You are a world-class International Mathematical Olympiad (IMO) competitor. '
                    'The final answer must be a non-negative integer between 0 and 99999. '
                    'You must place the final integer answer inside \\boxed{}.')
    tool_prompt = ('Use this tool to execute Python code. The environment is a stateful Jupyter notebook. '
                  'You must use print() to output results.')
    preference_prompt = 'You have access to `math`, `numpy` and `sympy` to solve the problem.'
    served_model_name, model_path = 'gpt-oss', '/kaggle/input/gpt-oss-120b/transformers/default/1'
    kv_cache_dtype, dtype = 'fp8_e4m3', 'auto'
    high_problem_timeout, base_problem_timeout = 900, 270
    notebook_limit, server_timeout = 17400, 180
    session_timeout, jupyter_timeout, sandbox_timeout = 960, 6, 3
    stream_interval, context_tokens, buffer_tokens, search_tokens = 200, 65536, 512, 32
    top_logprobs, batch_size, early_stop, attempts, workers, turns = 5, 256, 4, 8, 16, 128
    gpu_memory_utilization, temperature, min_p, seed = 0.96, 1.0, 0.02, 42

In [None]:
set_seed(CFG.seed)

In [None]:
class AIMO3Template:
    def get_system_content(self, prompt, tool_cfg):
        return SystemContent.new().with_model_identity(prompt).with_reasoning_effort(
            reasoning_effort=ReasoningEffort.HIGH).with_tools(tool_cfg)

    def apply_chat_template(self, sys_prompt, usr_prompt, tool_cfg):
        return [Message.from_role_and_content(Role.SYSTEM, self.get_system_content(sys_prompt, tool_cfg)),
                Message.from_role_and_content(Role.USER, usr_prompt)]

In [None]:
class AIMO3Sandbox:
    _port_lock, _next_port = threading.Lock(), 50000

    @classmethod
    def _get_next_ports(cls, count=5):
        with cls._port_lock:
            ports = list(range(cls._next_port, cls._next_port + count))
            cls._next_port += count
            return ports

    def __init__(self, timeout):
        self._default_timeout, self._owns_kernel, self._client, self._km = timeout, False, None, None
        ports = self._get_next_ports(5)
        env = os.environ.copy()
        env.update({'PYDEVD_DISABLE_FILE_VALIDATION': '1', 'PYDEVD_WARN_EVALUATION_TIMEOUT': '0',
                   'JUPYTER_PLATFORM_DIRS': '1', 'PYTHONWARNINGS': 'ignore', 'MPLBACKEND': 'Agg'})
        self._km = KernelManager()
        self._km.shell_port, self._km.iopub_port, self._km.stdin_port, self._km.hb_port, self._km.control_port = ports
        self._km.start_kernel(env=env, extra_arguments=['--Application.log_level=CRITICAL'])
        self._client = self._km.blocking_client()
        self._client.start_channels()
        self._client.wait_for_ready(timeout=self._default_timeout)
        self._owns_kernel = True
        self.execute('import math, numpy, sympy, mpmath, itertools, collections\nmpmath.mp.dps = 64\n')

    def _format_error(self, tb):
        return ''.join(re.sub(r'\x1b\[[0-9;]*m', '', f) for f in tb
                      if 'File "' not in f or 'ipython-input' in f)

    def execute(self, code, timeout=None):
        effective_timeout = timeout or self._default_timeout
        msg_id = self._client.execute(code, store_history=True, allow_stdin=False, stop_on_error=False)
        stdout, stderr, start = [], [], time.time()
        while True:
            if time.time() - start > effective_timeout:
                self._km.interrupt_kernel()
                return f'[ERROR] Execution timed out after {effective_timeout} seconds'
            try:
                msg = self._client.get_iopub_msg(timeout=1.0)
            except queue.Empty:
                continue
            if msg.get('parent_header', {}).get('msg_id') != msg_id: continue
            mt, c = msg.get('msg_type'), msg.get('content', {})
            if mt == 'stream':
                (stdout if c.get('name') == 'stdout' else stderr).append(c.get('text', ''))
            elif mt == 'error':
                stderr.append(self._format_error(c.get('traceback', [])))
            elif mt in {'execute_result', 'display_data'}:
                if txt := c.get('data', {}).get('text/plain'):
                    stdout.append(txt if txt.endswith('\n') else f'{txt}\n')
            elif mt == 'status' and c.get('execution_state') == 'idle':
                break
        out, err = ''.join(stdout), ''.join(stderr)
        return f'{out.rstrip()}\n{err}' if err and out else (err or out or '[WARN] No output. Use print() to see results.')

    def close(self):
        with contextlib.suppress(Exception):
            if self._client: self._client.stop_channels()
        if self._owns_kernel and self._km:
            with contextlib.suppress(Exception): self._km.shutdown_kernel(now=True)
            with contextlib.suppress(Exception): self._km.cleanup_resources()

    def reset(self):
        self.execute('%reset -f\nimport math, numpy, sympy, mpmath, itertools, collections\nmpmath.mp.dps = 64\n')

    def __del__(self):
        self.close()

In [None]:
class AIMO3Tool:
    def __init__(self, timeout, prompt, sandbox=None):
        self._local_jupyter_timeout, self._tool_prompt, self._jupyter_session = timeout, prompt, sandbox
        self._owns_session, self._execution_lock, self._init_lock = sandbox is None, threading.Lock(), threading.Lock()

    def _ensure_session(self):
        if self._jupyter_session is None:
            with self._init_lock:
                if self._jupyter_session is None:
                    self._jupyter_session = AIMO3Sandbox(timeout=self._local_jupyter_timeout)

    def _ensure_last_print(self, code):
        lines = code.strip().split('\n')
        if not lines: return code
        last = lines[-1].strip()
        if any(x in last for x in ['print', 'import']) or not last or last.startswith('#'): return code
        lines[-1] = 'print(' + last + ')'
        return '\n'.join(lines)

    @property
    def instruction(self): return self._tool_prompt

    @property
    def tool_config(self): return ToolNamespaceConfig(name='python', description=self.instruction, tools=[])

    def _make_response(self, output, channel=None):
        msg = Message(author=Author(role=Role.TOOL, name='python'),
                     content=[TextContent(text=output)]).with_recipient('assistant')
        return msg.with_channel(channel) if channel else msg

    def process_sync_plus(self, message):
        self._ensure_session()
        final_script = self._ensure_last_print(message.content[0].text)
        with self._execution_lock:
            try:
                output = self._jupyter_session.execute(final_script)
            except TimeoutError as exc:
                output = f'[ERROR] {exc}'
        return [self._make_response(output, channel=message.channel)]

In [None]:
class AIMO3Solver:
    def __init__(self, cfg, port=8000):
        self.cfg, self.port = cfg, port
        self.base_url, self.api_key = f'http://0.0.0.0:{port}/v1', 'sk-local'
        self.template, self.encoding = AIMO3Template(), load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
        self.stop_token_ids = self.encoding.stop_tokens_for_assistant_actions()
        self._preload_model_weights()
        self.server_process = self._start_server()
        self.client = OpenAI(base_url=self.base_url, api_key=self.api_key, timeout=self.cfg.session_timeout)
        self._wait_for_server()
        self._initialize_kernels()
        self.notebook_start_time, self.problems_remaining = time.time(), 50

    def _preload_model_weights(self):
        print(f'Loading model weights from {self.cfg.model_path} into OS Page Cache...')
        start, files, total = time.time(), [], 0
        for root, _, fnames in os.walk(self.cfg.model_path):
            for fn in fnames:
                fp = os.path.join(root, fn)
                if os.path.isfile(fp):
                    files.append(fp)
                    total += os.path.getsize(fp)
        with ThreadPoolExecutor(max_workers=self.cfg.workers) as ex:
            list(ex.map(lambda p: open(p, 'rb').read(), files))
        print(f'Processed {len(files)} files ({total/1e9:.2f} GB) in {time.time()-start:.2f} seconds.\n')

    def _start_server(self):
        cmd = [sys.executable, '-m', 'vllm.entrypoints.openai.api_server', '--seed', str(self.cfg.seed),
               '--model', self.cfg.model_path, '--served-model-name', self.cfg.served_model_name,
               '--tensor-parallel-size', '1', '--max-num-seqs', str(self.cfg.batch_size),
               '--gpu-memory-utilization', str(self.cfg.gpu_memory_utilization), '--host', '0.0.0.0',
               '--port', str(self.port), '--dtype', self.cfg.dtype, '--kv-cache-dtype', self.cfg.kv_cache_dtype,
               '--max-model-len', str(self.cfg.context_tokens), '--stream-interval', str(self.cfg.stream_interval),
               '--async-scheduling', '--disable-log-stats', '--enable-prefix-caching']
        self.log_file = open('vllm_server.log', 'w')
        return subprocess.Popen(cmd, stdout=self.log_file, stderr=subprocess.STDOUT, start_new_session=True)

    def _wait_for_server(self):
        print('Waiting for vLLM server...')
        start = time.time()
        for _ in range(self.cfg.server_timeout):
            if (rc := self.server_process.poll()) is not None:
                self.log_file.flush()
                raise RuntimeError(f'Server died with code {rc}. Full logs:\n{open("vllm_server.log").read()}\n')
            try:
                self.client.models.list()
                print(f'Server is ready (took {time.time()-start:.2f} seconds).\n')
                return
            except Exception:
                time.sleep(1)
        raise RuntimeError('Server failed to start (timeout).\n')

    def _initialize_kernels(self):
        print(f'Initializing {self.cfg.workers} persistent Jupyter kernels...')
        start = time.time()
        self.sandbox_pool = queue.Queue()
        with ThreadPoolExecutor(max_workers=self.cfg.workers) as ex:
            for future in as_completed([ex.submit(lambda: AIMO3Sandbox(timeout=self.cfg.jupyter_timeout))
                                       for _ in range(self.cfg.workers)]):
                self.sandbox_pool.put(future.result())
        print(f'Kernels initialized in {time.time()-start:.2f} seconds.\n')

    def _scan_for_answer(self, text):
        for pattern in [r'\\boxed\s*\{\s*([0-9,]+)\s*\}', r'final\s+answer\s+is\s*([0-9,]+)']:
            if matches := re.findall(pattern, text, re.IGNORECASE):
                try:
                    val = int(matches[-1].replace(',', ''))
                    if 0 <= val <= 99999: return val
                except ValueError: pass
        return None

    def _compute_mean_entropy(self, logprobs):
        if not logprobs: return float('inf')
        total, count = 0.0, 0
        for top_lp in logprobs:
            if isinstance(top_lp, dict) and top_lp:
                ent = sum(-math.exp(lp)*math.log2(math.exp(lp)) for lp in top_lp.values() if math.exp(lp) > 0)
                total += ent
                count += 1
        return total/count if count else float('inf')

    def _process_attempt(self, problem, sys_prompt, idx, stop_evt, deadline):
        if stop_evt.is_set() or time.time() > deadline:
            return {'Attempt': idx+1, 'Answer': None, 'Python Calls': 0, 'Python Errors': 0,
                   'Response Length': 0, 'Entropy': float('inf')}
        local_tool, sandbox, py_calls, py_errs, total_toks, ans, logprobs = None, None, 0, 0, 0, None, []
        seed = int(math.pow(self.cfg.seed + idx, 2))
        try:
            sandbox = self.sandbox_pool.get(timeout=self.cfg.sandbox_timeout)
            local_tool = AIMO3Tool(self.cfg.jupyter_timeout, self.cfg.tool_prompt, sandbox)
            conv = Conversation.from_messages(self.template.apply_chat_template(
                sys_prompt, problem, local_tool.tool_config))
            for _ in range(self.cfg.turns):
                if stop_evt.is_set() or time.time() > deadline: break
                prompt_ids = self.encoding.render_conversation_for_completion(conv, Role.ASSISTANT)
                if (max_toks := self.cfg.context_tokens - len(prompt_ids)) < self.cfg.buffer_tokens: break
                stream = self.client.completions.create(model=self.cfg.served_model_name,
                    temperature=self.cfg.temperature, logprobs=self.cfg.top_logprobs, max_tokens=max_toks,
                    prompt=prompt_ids, seed=seed, stream=True, extra_body={
                        'min_p': self.cfg.min_p, 'stop_token_ids': self.stop_token_ids, 'return_token_ids': True})
                try:
                    tok_buf, txt_chunks = [], []
                    for chunk in stream:
                        if stop_evt.is_set() or time.time() > deadline: break
                        if new_toks := chunk.choices[0].token_ids:
                            tok_buf.extend(new_toks)
                            total_toks += len(new_toks)
                            txt_chunks.append(chunk.choices[0].text)
                            if (clp := chunk.choices[0].logprobs) and clp.top_logprobs:
                                logprobs.extend(clp.top_logprobs)
                        if '}' in chunk.choices[0].text and (ans := self._scan_for_answer(
                            ''.join(txt_chunks[-self.cfg.search_tokens:]))):
                            break
                finally:
                    stream.close()
                if ans or not tok_buf: break
                new_msgs = self.encoding.parse_messages_from_completion_tokens(tok_buf, Role.ASSISTANT)
                conv.messages.extend(new_msgs)
                last = new_msgs[-1]
                if last.channel == 'final':
                    ans = self._scan_for_answer(last.content[0].text)
                    break
                if last.recipient == 'python':
                    py_calls += 1
                    resp = local_tool.process_sync_plus(last)
                    if any(x in (txt := resp[0].content[0].text) for x in ['[ERROR]', 'Traceback', 'Error:']):
                        py_errs += 1
                    conv.messages.extend(resp)
        except Exception: py_errs += 1
        finally:
            if sandbox:
                sandbox.reset()
                self.sandbox_pool.put(sandbox)
        return {'Attempt': idx+1, 'Response Length': total_toks, 'Python Calls': py_calls,
               'Python Errors': py_errs, 'Entropy': self._compute_mean_entropy(logprobs), 'Answer': ans}

    def _select_answer(self, results):
        ans_weights, ans_votes = defaultdict(float), defaultdict(int)
        for r in results:
            if (a := r['Answer']) is not None:
                ans_weights[a] += 1.0/max(r['Entropy'], 1e-9)
                ans_votes[a] += 1
        scored = sorted([{'answer': a, 'votes': ans_votes[a], 'score': w}
                        for a, w in ans_weights.items()], key=lambda x: x['score'], reverse=True)
        display(pd.DataFrame([(s['answer'], s['votes'], s['score']) for s in scored],
                            columns=['Answer', 'Votes', 'Score']).round({'Score': 3}))
        final = scored[0]['answer'] if scored else 0
        print(f'\nFinal Answer: {final}\n')
        return final

    def solve_problem(self, problem):
        print(f'\nProblem: {problem}\n')
        user_input = f'{problem} {self.cfg.preference_prompt}'
        time_left = self.cfg.notebook_limit - (time.time() - self.notebook_start_time)
        budget = max(self.cfg.base_problem_timeout,
                    min(time_left - max(0, self.problems_remaining-1)*self.cfg.base_problem_timeout,
                        self.cfg.high_problem_timeout))
        deadline = time.time() + budget
        print(f'Budget: {budget:.2f} seconds | Deadline: {deadline:.2f}\n')
        results, valid, stop_evt = [], [], threading.Event()
        with ThreadPoolExecutor(max_workers=self.cfg.workers) as ex:
            futures = [ex.submit(self._process_attempt, user_input, self.cfg.system_prompt, i, stop_evt, deadline)
                      for i in range(self.cfg.attempts)]
            for future in as_completed(futures):
                try:
                    if (r := future.result())['Answer'] is not None:
                        valid.append(r['Answer'])
                    results.append(r)
                    if (cnts := Counter(valid).most_common(1)) and cnts[0][1] >= self.cfg.early_stop:
                        stop_evt.set()
                        for f in futures: f.cancel()
                        break
                except Exception as exc:
                    print(f'Future failed: {exc}')
        self.problems_remaining = max(0, self.problems_remaining - 1)
        if results:
            df = pd.DataFrame(results)
            df['Entropy'] = df['Entropy'].round(3)
            df['Answer'] = df['Answer'].astype('Int64')
            display(df)
        return self._select_answer(results) if valid else 0

    def __del__(self):
        if hasattr(self, 'server_process'):
            self.server_process.terminate()
            self.server_process.wait()
        if hasattr(self, 'log_file'): self.log_file.close()
        if hasattr(self, 'sandbox_pool'):
            while not self.sandbox_pool.empty():
                with contextlib.suppress(Exception): self.sandbox_pool.get_nowait().close()

In [None]:
solver = AIMO3Solver(CFG)

In [None]:
def predict(id_: pl.DataFrame, question: pl.DataFrame, answer: Optional[pl.DataFrame] = None) -> pl.DataFrame:
    gc.disable()
    final_answer = solver.solve_problem(question.item(0))
    gc.enable()
    gc.collect()
    return pl.DataFrame({'id': id_.item(0), 'answer': final_answer})

In [None]:
inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)
if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(('/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv',))