In [None]:
# import gc
# del llm
# gc.collect()
# torch.cuda.empty_cache()

In [10]:
try:
    from vllm import LLM, SamplingParams
    LOCAL = True
    MODEL_PATH = "deepseek-ai/deepseek-math-7b-rl"
    from functions import *
    dtype = 'auto'
    gpu_memory_utilization = 0.95

except:
    %pip uninstall -y torch -q
    %pip install --no-index --find-links=/kaggle/input/vllm-whl -U vllm -q
    from vllm import LLM, SamplingParams
    LOCAL = False
    MODEL_PATH = "/kaggle/input/deepseek-math"
    dtype = 'half'
    gpu_memory_utilization = 0.99
    from functions_math import *


import torch
import pandas as pd
import subprocess
import sys
import gc
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
llm = LLM(model=MODEL_PATH,
          dtype=dtype,
          enforce_eager=True,
          gpu_memory_utilization=gpu_memory_utilization,
          swap_space=8,
          max_model_len=2048,
          kv_cache_dtype="fp8_e5m2",
          tensor_parallel_size=1)
tokenizer = llm.get_tokenizer()
stop_words = [tokenizer.eos_token if tokenizer is not None and tokenizer.eos_token is not None else '</s>']
stop_words.append("\n")

sampling_params = SamplingParams(temperature=1,
                                 max_tokens=256,
                                 min_tokens=32,
                                 stop=stop_words)

cot_instruction = "\nYou are an expert at mathematical reasoning. Please reason step by step, and put your final answer within \\boxed{}. The answer should be an interger between 0 and 999."


n = 1 # beams
n_sol = 4
samples = 16
max_depth = 24
max_pct = 0.66

In [11]:
import json
with open('../Data/AMC/aime_normal.json', 'r') as file:
    data = json.load(file)
# to have consistent format as in Kaggle
data = pd.DataFrame(data)
data.rename(columns={'question': 'problem'}, inplace=True)

In [14]:
def process_inputs(inputs):
    # inputs is a list of str
    outs = []
    for problem in inputs:
        query_prompt = problem + cot_instruction
        messages = [{"role": "user","content": query_prompt}]
        input = tokenizer.apply_chat_template(messages, tokenize=False)
        outs.append(input)
    return outs

In [5]:
logit2prob = lambda x: 1/(1+np.exp(-x))
def eval_prm(candidates):
    all_log_probs = []
    for i in range(len(candidates)):
        input_ids = prm_tokenizer.encode(candidates[i], return_tensors="pt").to("cuda:1")
        with torch.no_grad():
            hidden_states = base_model(input_ids)[0][:,-1] # 1,l,d -> 1,d
            logits = prm_model.score(hidden_states)[0]
        all_log_probs.append(logit2prob(logits.item()))
    return all_log_probs

In [3]:
def repeat_elements(lst, k):
    return [i for i in lst for _ in range(k)]

def flatten(nested_list):
    """Flatten a nested list."""
    out = []
    lengths = []
    for sublist in nested_list:
        lengths.append(len(sublist))
        for item in sublist:
            out.append(item)
    return out,lengths

def unflatten(flat_list, lengths):
    """Unflatten a flat list into a nested list based on lengths."""
    nested_list = []
    index = 0
    for length in lengths:
        nested_list.append(flat_list[index:index + length])
        index += length
    return nested_list

def filter_input(batch_response,current_level_node):
    prm_inputs = []
    for candidate,parent in zip(batch_response,current_level_node):
        if candidate.outputs[0].text not in parent:
            prm_input = parent + candidate.outputs[0].text
            prm_inputs.append(prm_input)
    # Get the indices of unique elements in prm_inputs
    unique_indices = [i for i, x in enumerate(prm_inputs) if prm_inputs.index(x) == i]
    prm_inputs = [prm_inputs[i] for i in unique_indices]
    return prm_inputs,len(prm_inputs)

def filter_inputs(batch_responses,current_level_nodes,lengths):
    batch_responses,current_level_nodes = unflatten(batch_responses,lengths),unflatten(current_level_nodes,lengths)
    prm_inputs = []
    lengths = []
    for batch_response,current_level_node in zip(batch_responses,current_level_nodes):
        prm_input,length = filter_input(batch_response,current_level_node)
        prm_inputs.extend(prm_input)
        lengths.append(length)
    return prm_inputs,lengths


def get_next_nodes(prm_inputs,prm_scores,lengths,completed_paths):
    # for completed_paths, next_level_nodes would be removed
    pass

current_level_nodes = process_inputs(data.problem.tolist())
lengths = [1] * len(current_level_nodes)
current_level = 1
completed_paths = [[] for _ in current_level_nodes]

while (current_level < max_depth) and (current_level_nodes):
    # everything at this level is flattened
    current_level_nodes = repeat_elements(current_level_nodes,samples)
    lengths = [l*samples for l in lengths]

    batch_responses = llm.generate(current_level_nodes, sampling_params)
    prm_inputs,lengths = filter_inputs(batch_responses,current_level_nodes,lengths)

    prm_scores = eval_prm(prm_inputs)

    current_level_nodes,completed_paths,lengths = get_next_nodes(prm_inputs,prm_scores,lengths,completed_paths)
    current_level += 1


True