In [None]:
# credits:
# https://www.kaggle.com/code/bsmit1659/aimo-vllm-accelerated-tot-sc-deepseekmath

In [None]:
!pip uninstall -y torch -q
!pip install --no-index --find-links=/kaggle/input/vllm-whl -U vllm -q

# keep data in float16 to avoid OOM
file_path = '/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py'
with open(file_path, 'r') as file:
    file_contents = file.readlines()
file_contents = [line for line in file_contents if "logits = logits.float()" not in line]
with open(file_path, 'w') as file:
    file.writelines(file_contents)

In [None]:
from vllm import LLM, SamplingParams
import pandas as pd
from tqdm import tqdm
import gc
import re
import sys
import subprocess
from collections import defaultdict, Counter
import numpy as np
from transformers import (AutoModelForCausalLM,
    AutoTokenizer,
    set_seed)
import torch
import math

llm = LLM(model="/kaggle/input/deepseek-math",
          dtype='half',
          enforce_eager=True,
          gpu_memory_utilization=0.99,
          swap_space=4,
          max_model_len=2048,
          kv_cache_dtype="fp8_e5m2",
          tensor_parallel_size=1)

tokenizer = llm.get_tokenizer()

good_token = '+'
bad_token = '-'
step_tag = 'ки'

prm_tokenizer = AutoTokenizer.from_pretrained('/kaggle/input/math-shepherd-mistral-7b-prm')
prm_candidate_tokens = prm_tokenizer.encode(f"{good_token} {bad_token}")[1:] # [648, 387]
step_tag_id = prm_tokenizer.encode(f"{step_tag}")[-1] # 12902
prm_model = AutoModelForCausalLM.from_pretrained('/kaggle/input/math-shepherd-mistral-7b-prm',
                                                 torch_dtype=torch.float16,
                                                 device_map="balanced_low_0").eval()

In [None]:
import aimo
env = aimo.make_env()
iter_test = env.iter_test()

In [None]:
def eval_prm(candidates):
    # Initialize a list to store all the log probabilities
    all_log_probs = []
    # Process the candidates in batches
    for i in range(len(candidates)):
        # Select a batch of candidates
        input_ids = prm_tokenizer.encode(candidates[i], return_tensors="pt").to("cuda:1")  # Concatenate the padded inputs into a tensor

        with torch.no_grad():
            logits = prm_model(input_ids).logits[:, :, prm_candidate_tokens] # b,l,C
            scores = logits.softmax(dim=-1)[:, :, 0].squeeze() # l
            # Collect the log probabilities from this batch
            all_log_probs.append(scores[-1].item())
    return all_log_probs

In [None]:
stop_words = [tokenizer.eos_token if tokenizer is not None and tokenizer.eos_token is not None else '</s>']
stop_words.append("\n")

sampling_params = SamplingParams(temperature=1,
                                 max_tokens=256,
                                 min_tokens=32,
                                 stop=stop_words)

cot_instruction = "\nYou are an expert at mathematical reasoning. Please reason step by step, and put your final answer within \\boxed{}. The answer should be an interger between 0 and 999."


n = 5 # beams
samples = 7
max_depth = 24
overlap_threshold = 0.6
all_prompts = []
total_paths = []
total_answers = []

def is_integer(num):
    if isinstance(num, float):
        return num.is_integer()
    elif isinstance(num, int):
        return True
    else:
        return False
    
def is_between_0_and_999(num):
    return 0 <= num <= 999

def prm_prompt(text, current_level):
    return f"Step {str(current_level)}:" + text + ' ки'

def remove_prm_prompt(text):
    pattern = r"Step \d+:"
    text = re.sub(pattern, "", text)
    return text.replace(" ки","")

import re
def extract_number(text):
    patterns = [
        r'The answer is.*\\boxed\{(.*?)\}',
        r"The answer is[:\s]*\$([0-9]+)\$",
        r"The answer is[:\s]*([0-9]+)"
    ]
    for pattern in patterns:
        match = re.search(pattern, text)
        if match:
            return match.group(1)
    return 'parse err'

def tot_agg(completed_paths):
    # [(answer,score,len_,current_level),...]
    if completed_paths:
        return max(completed_paths,key=lambda x:x[1]+x[3]**2*0.00108)[0]
    else:
        return 37 # empty completed_paths

def get_overlap(nodes_split,node_list):
    max_overlap = float("-inf")
    node_len = len(node_list)
    for previous_split in nodes_split:
        count = 0
        len_ = max(len(previous_split),node_len)
        for i,j in zip(previous_split,node_list):
            count += (i==j)
        count /= len_
        max_overlap = max(max_overlap,count)
    return max_overlap

for test, sample_submission in iter_test:
    problem = test['problem']

    messages = [
        {
            "role": "user",
            "content": problem + cot_instruction
        }
    ]

    base_prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False
    )
    current_level = 1

    current_level_nodes = [base_prompt]
    current_scores = [float('inf')] # for min score
    completed_paths = []
    completed_path_splits = []
    try:
        while (len(completed_paths) < n) and (current_level < max_depth) and (current_level_nodes):
            # for generation, remove special tokens for PRM
            batch_responses = llm.generate([remove_prm_prompt(t) for t in current_level_nodes]*samples, sampling_params)
            prm_inputs = []
            cumulative_lens = []

            # Collect candidates for reward model evaluation
            for candidate,parent in zip(batch_responses,current_level_nodes*samples):
                prm_input = parent + prm_prompt(candidate.outputs[0].text,current_level)
                cumulative_tokens = len(candidate.prompt_token_ids) + len(candidate.outputs[0].token_ids)
                prm_inputs.append(prm_input)
                cumulative_lens.append(cumulative_tokens)
            # Get the indices of unique elements in prm_inputs
            unique_indices = [i for i, x in enumerate(prm_inputs) if prm_inputs.index(x) == i]
            prm_inputs = [prm_inputs[i] for i in unique_indices]
            current_scores = [(current_scores*samples)[i] for i in unique_indices]
            cumulative_lens = [cumulative_lens[i] for i in unique_indices]

            # Batch reward model evaluation
            prm_scores = eval_prm(prm_inputs)
            prm_scores = [min(old,new) for old,new in zip(current_scores,prm_scores)]
            next_level_nodes = []
            next_scores = []
            nodes_split = []

            # Prune to keep only the top 'n' candidates based on scores
            combined = list(zip(prm_inputs,prm_scores,cumulative_lens))
            combined.sort(key=lambda x: x[1], reverse=True)  # Sort nodes by their scores
            for node,score,len_ in combined:
                answer = extract_number(remove_prm_prompt(node))
                if answer == 'parse err': # not finished
                    if len_ > 2048: continue # max out len_
                    node_list = node.split(" ки")
                    print(get_overlap(nodes_split,node_list))
                    if get_overlap(nodes_split,node_list) < overlap_threshold:
                        next_level_nodes.append(node)
                        next_scores.append(score)
                        nodes_split.append(node_list)
                    else: continue
                    if len(next_level_nodes) == n: break
                else: # finished
                    node_list = node.split(" ки")
                    if get_overlap(completed_path_splits,node_list) < overlap_threshold:
                        try:
                            answer = eval(answer)
                            if is_integer(answer) and is_between_0_and_999(answer):# correct format
                                completed_paths.append((answer,score,len_,current_level))
                                completed_path_splits.append(node_list)
                        except: # bad eval
                            continue
            # if current_level_nodes is empty, all max out or err out. exit loop
            current_scores, current_level_nodes = next_scores, next_level_nodes
            current_level += 1

    #     print(f'problem {i}, sol {completed_paths}')
    #     total_paths.append(completed_paths)

        sample_submission['answer'] = tot_agg(completed_paths)
    except:
        sample_submission['answer'] = 37
    env.predict(sample_submission)

In [None]:
# total_paths
# len(set(current_level_nodes)),len(current_level_nodes),len(set(prm_inputs)),len(prm_inputs)