In [None]:
!pip uninstall -y torch -q
!pip install --no-index --find-links=/kaggle/input/vllm-whl -U vllm -q
# keep data in float16 to avoid OOM
file_path = '/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py'
with open(file_path, 'r') as file:
    file_contents = file.readlines()
file_contents = [line for line in file_contents if "logits = logits.float()" not in line]
with open(file_path, 'w') as file:
    file.writelines(file_contents)

In [None]:
from vllm import LLM, SamplingParams
import numpy as np
from transformers import LlamaForSequenceClassification
import torch
torch.backends.cuda.enable_mem_efficient_sdp(False)
llm = LLM(model="/kaggle/input/policy-model",
          dtype='half',
          enforce_eager=True,
          gpu_memory_utilization=0.99,
          swap_space=4,
          max_model_len=2048,
          kv_cache_dtype="fp8_e5m2",
          tensor_parallel_size=1)

tokenizer = llm.get_tokenizer()

prm_tokenizer = tokenizer
prm_model = LlamaForSequenceClassification.from_pretrained('/kaggle/input/prm-code',\
                                                    num_labels=1,\
                                                    device_map="cuda:1",
                                                    torch_dtype="auto",
                                                    ignore_mismatched_sizes=True,
                                                    ).eval()

base_model = prm_model.model
prm_model.score.load_state_dict(torch.load('/kaggle/input/prm-code/model_score8_code.pth'))

In [None]:
import aimo
env = aimo.make_env()
iter_test = env.iter_test()

In [None]:
logit2prob = lambda x: 1/(1+np.exp(-x))
def eval_prm(candidates):
    all_log_probs = []
    for i in range(len(candidates)):
        input_ids = prm_tokenizer.encode(candidates[i], return_tensors="pt").to("cuda:1")
        with torch.no_grad():
            hidden_states = base_model(input_ids)[0][:,-1] # 1,l,d -> 1,d
            logits = prm_model.score(hidden_states)[0]
        all_log_probs.append(logit2prob(logits.item()))
    return all_log_probs

In [None]:
stop_words = [tokenizer.eos_token,"```output","```Output","```output\n","```Output\n","```\nOutput" , ")\n```" , "``````output","``````Output"]
# stop_words.append("\n")
sampling_params = SamplingParams(temperature=1,
                                 max_tokens=180,
                                #  min_tokens=32,
                                 stop=stop_words,
                                 include_stop_str_in_output=True
                                )

def gen_prompt_codeIn1(problem):
    return f"""Problem: {problem}\n
To accomplish this, first determine a sympy-based approach for solving the problem by listing each step to take and what functions need to be called in each step. Be clear so even an idiot can follow your instructions, and your final answer should be integer, not expression, list, tuple or dictionary!
Write the entire script covering all the steps (use comments and document it well) and print the final result.
Approach:"""

def gen_prompt_codeIn2(problem):
    return f"""Problem: {problem}\n
You are an expert at solving math problem. Analyze this problem and think step by step to develop a python solution. Your solution should include reasoning steps in Python comments, explaining your thought process and the mathematical principles you applied. print the final output, as an integer not other python object such as list or tuple."""

n = 1 # beams
n_sol = 1
samples = 21
max_depth = 24
max_pct = 0.88
timeout = 7
len_limit = 49


all_prompts = []
total_paths = []
total_answers = []

def is_integer(num):
    if isinstance(num, float):
        return num.is_integer()
    elif isinstance(num, int):
        return True
    else:
        return False
    
def is_between_0_and_999(num):
    return 0 <= num <= 999

import re
def extract_number(text):
    patterns = [
        r'[Tt]he answer is.*\\boxed\{(.*?)\}',
        r"[Tt]he answer is[:\s]*\$([0-9]+)\$",
        r"[Tt]he answer is[:\s]*([0-9]+)"
    ]
    for pattern in patterns:
        match = re.search(pattern, text)
        if match:
            return match.group(1)
    return 'parse err'

def repeat_elements(lst, k):
    return [i for i in lst for _ in range(k)]

def filter_input(batch_response,current_level_node):
    # one question filter
    prm_inputs = []
    for candidate,parent in zip(batch_response,current_level_node):
        if candidate.outputs[0].text not in parent:
            prm_input = parent + candidate.outputs[0].text
            prm_inputs.append(prm_input)
    # Get the indices of unique elements in prm_inputs
    unique_indices = [i for i, x in enumerate(prm_inputs) if prm_inputs.index(x) == i]
    prm_inputs = [prm_inputs[i] for i in unique_indices]
    return prm_inputs

def IsFinished(node):
    matches = re.findall(r'print\(([^)]*)\)', node)
    return len(matches)>0

def get_next_node(prm_inputs,prm_scores):
    # need to update completed_paths in-place
    if len(prm_inputs) == 0: return []
    next_level_nodes = []
    combined = list(zip(prm_inputs,prm_scores))
    combined.sort(key=lambda x: x[1], reverse=True)  # Sort nodes by their scores
    max_score = combined[0][1]
    for node,score in combined:
        finish = IsFinished(node)
        if finish: # finished
            if score > max_score * max_pct:
                completed_paths.append((score,node))
        else: # not inished
            if (len(next_level_nodes) < n) and (score > max_score * max_pct):
                next_level_nodes.append(node)
    return next_level_nodes

def repl(match):
    if "real" not in match.group():
        return "{}{}".format(match.group()[:-1], ', real=True)')
    else:
        return "{}{}".format(match.group()[:-1], ')')
    
single_line_comment_pattern = re.compile(r'(?<!\\)#.*')
multi_line_comment_pattern = re.compile(r'(\'\'\'|\"\"\")(.*?)(\'\'\'|\"\"\")', flags=re.DOTALL)
trailing_whitespace_pattern = re.compile(r'[ \t]+$', flags=re.MULTILINE)
multiple_blank_lines_pattern = re.compile(r'\n\s*\n')

def remove_python_comments(code):
    # Remove single-line comments
    code = single_line_comment_pattern.sub('', code)
    # Remove multi-line comments (docstrings)
    code = multi_line_comment_pattern.sub('', code)
    # Remove leading and trailing whitespace from each line
    code = trailing_whitespace_pattern.sub('', code)
    # Reduce multiple blank lines to a single blank line
    code = multiple_blank_lines_pattern.sub('\n', code)
    return code

import subprocess
import sys
def agg_code(paths):
    paths = [p for p in paths if p]
    paths.sort(key=lambda x: x[0], reverse=True)
    code_set = set()
    for path in paths:# path (score,node)
        input = path[1]
        if input[-12:]=="print(result": # stop token was not included. print(result) might miss a ")"
            input += ")"
        splits = input.split('```')
        if len(splits) < 2:
            continue
        code = "from sympy import *\n" + input.split('```')[1][7:] 
        if len(code) < len_limit: continue # ignore very short answer
        clean_code = remove_python_comments(code)
        if clean_code in code_set:
            continue
        else:
            code_set.add(clean_code)
        code = re.sub(r"symbols\([^)]+\)", repl, code)
        # execute code
        with open('code.py', 'w') as fout:
            fout.write(code)
        # timeout err
        try:
            process = subprocess.run([sys.executable, 'code.py'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout)
        except subprocess.TimeoutExpired:
            continue
        if process.stderr:# code.py err
            continue
        else:
            stdout = process.stdout.decode('utf8')
            try:
                answer = eval(stdout)
                if is_integer(answer) and is_between_0_and_999(answer):
                    return int(answer)
                else:
                    continue
            except:
                continue
    return 37

for test, sample_submission in iter_test:
    problem = test['problem'].values[0]
    base_prompt1 = tokenizer.apply_chat_template([{"role": "user","content": gen_prompt_codeIn1(problem)}],tokenize=False)
    base_prompt2 = tokenizer.apply_chat_template([{"role": "user","content": gen_prompt_codeIn2(problem)}],tokenize=False)
    current_level = 1
    current_level_nodes = [base_prompt1,base_prompt2]
    completed_paths = []
    completed_path_splits = []
    try:
        while (len(completed_paths) < n_sol) and (current_level < max_depth) and (current_level_nodes):
            current_level_nodes = repeat_elements(current_level_nodes,samples)
            batch_responses = llm.generate(current_level_nodes, sampling_params)
            prm_inputs = filter_input(batch_responses,current_level_nodes)
            prm_scores = eval_prm(prm_inputs)
            current_level_nodes = get_next_node(prm_inputs,prm_scores)
        sample_submission['answer'] = agg_code(completed_paths)
    except:
        sample_submission['answer'] = 37
    env.predict(sample_submission)

In [None]:
# total_paths
# len(set(current_level_nodes)),len(current_level_nodes),len(set(prm_inputs)),len(prm_inputs)