In [1]:
from __future__ import annotations
import os
import json
import torch
import argparse
from tqdm import tqdm
from datetime import datetime
from omegaconf import OmegaConf
from rstar_deepthink.agents import BS, MCTS
from rstar_deepthink.solver import Solver
from rstar_deepthink.config import BaseConfig


In [2]:
torch.set_num_threads(12)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def load_qaf(filename: str):
    if filename.endswith(".json"):
        with open(filename, "r") as f:
            data = json.load(f)
        if "example" in data:
            data = data["example"]
    elif filename.endswith(".jsonl"):
        data = []
        with open(filename, "r") as f:
            lines = f.readlines()
        for line in lines:
            data.append(json.loads(line))
    else:
        raise ValueError(f"Unrecognized file format: {filename}")
    return data

def batch(iterable, n=-1):
    l = len(iterable)
    if n <= 0:
        n = l
    for ndx in range(0, l, n):
        yield iterable[ndx: min(ndx + n, l)]

def parse_args():
    base_dir = '/groups/kjun/tnn/datasets/'
    
    # dataset path
    data_dir = base_dir + "/prm800k/math_splits"

    llm_tokenizer_dir = base_dir + "/Llama-3.2-1B-Instruct"
    prm_tokenizer_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data"
    
    args = argparse.ArgumentParser()
    args.add_argument('--custom_cfg', type=str, default="config/sft_eval_mcts.yaml")
    args.add_argument("--qaf", type=str, default="eval_data/math500_test.json", help="quesuion and answer file")
    args.add_argument('--model_dir', type=str, default=f"{llm_tokenizer_dir}") 
    args.add_argument('--reward_model_dir', type=str, default=f"{prm_tokenizer_dir}") 
    args.add_argument('--save_in_model', type=str, default="results/")
    args = args.parse_args()
    return args

In [3]:
# args = parse_args()

base_dir = '/groups/kjun/tnn/datasets/'
    
# dataset path
data_dir = base_dir + "/prm800k/math_splits"

llm_tokenizer_dir = base_dir + "/Llama-3.2-1B-Instruct"
prm_tokenizer_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data"
prm_tokenizer_dir = "results/Llama3.1-8B-PRM-Deepseek-Data"
# prm_tokenizer_dir = base_dir + "/Llama-3.2-1B-RM-GSM8k"

args = argparse.Namespace()
args.custom_cfg = f"config/sft_eval_mcts.yaml"
args.qaf = f"eval_data/math500_test.json"
args.model_dir = f"{llm_tokenizer_dir}"
args.reward_model_dir = f"{prm_tokenizer_dir}"
args.save_in_model = f"results/"

config = OmegaConf.structured(BaseConfig)
if args.custom_cfg:
    custom_config = OmegaConf.load(args.custom_cfg)
    config = OmegaConf.merge(config, custom_config)
config = OmegaConf.create(OmegaConf.to_yaml(config, resolve=True))
if args.model_dir:
    config.model_dir = args.model_dir
if args.reward_model_dir:
    config.reward_model_dir = args.reward_model_dir
print(config)

llm_version = os.path.basename(config.model_dir.rstrip("/"))

data = load_qaf(args.qaf)
data = data[100:101]
solver = Solver(config=config)

# init agent
if config.mode == "mcts":
    agent = MCTS
elif config.mode == "bs":
    agent = BS
else:
    raise NotImplementedError
if args.reward_model_dir:
    llm_version += "." + args.reward_model_dir.split("/")[-1]

print(llm_version)

saved_jsonl_file = f"{args.qaf}.{config.mode}.{llm_version}.{datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" 
    
if args.save_in_model:
    saved_jsonl_file = args.save_in_model + '.jsonl'
    saved_jsonl_file_dir = os.path.dirname(saved_jsonl_file)
    os.makedirs(saved_jsonl_file_dir, exist_ok=True)
    


{'mode': 'mcts', 'model_dir': '/groups/kjun/tnn/datasets//Llama-3.2-1B-Instruct', 'reward_model_dir': 'results/Llama3.1-8B-PRM-Deepseek-Data', 'few_shot_path': './rstar_deepthink/few_shots/few_shots.json', 'prompt_path': './rstar_deepthink/few_shots/sft_prompt.json', 'num_few_shot': 0, 'prompt_wrap': 'rstar', 'result_unwrap': 'rstar', 'step_delim': '\n', 'temperature': 1.0, 'top_p': 1.0, 'top_k': -1, 'use_beam_search': False, 'best_of': 32, 'max_tokens': 2048, 'seed': None, 'swap_space': 12, 'n_generate_sample': 16, 'stop': ['<end_of_step>', '<end_of_code>', '<end_of_answer>'], 'step_beam_width': 1, 'max_depth': 16, 'iterations': 2, 'positive_reward': 1.0, 'negative_reward': -1.0, 'errors_threshold': 1, 'need_value_func': True, 'update_leaf_value': True, 'c_puct': 2.0, 'is_sampling': False, 'prune': False, 'batch_size': 8000, 'max_model_len': 4096, 'terminal_sample': False, 'llm_gpu_memory_utilization': 0.5, 'tp': 1, 'save_intermediate_rollouts': True}
INFO 06-04 13:09:33 config.py:227

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 06-04 13:09:44 model_runner.py:1099] Loading model weights took 2.3185 GB
INFO 06-04 13:09:44 config.py:510] This model supports multiple tasks: {'score', 'classify', 'embed', 'reward', 'generate'}. Defaulting to 'generate'.
INFO 06-04 13:09:44 llm_engine.py:234] Initializing an LLM engine (v0.6.6.post1) with config: model='/groups/kjun/tnn/datasets//Llama-3.2-1B-Instruct', speculative_config=None, tokenizer='/groups/kjun/tnn/datasets//Llama-3.2-1B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(o

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 06-04 13:09:46 model_runner.py:1099] Loading model weights took 2.3029 GB
INFO 06-04 13:09:46 worker.py:241] Memory profiling takes 0.45 seconds
INFO 06-04 13:09:46 worker.py:241] the current vLLM instance can use total_gpu_memory (31.73GiB) x gpu_memory_utilization (0.50) = 15.87GiB
INFO 06-04 13:09:46 worker.py:241] model weights take 2.30GiB; non_torch_memory takes 0.02GiB; PyTorch activation peak memory takes 1.19GiB; the rest of the memory reserved for KV Cache is 12.35GiB.
INFO 06-04 13:09:46 gpu_executor.py:76] # GPU blocks: 25300, # CPU blocks: 24576
INFO 06-04 13:09:46 gpu_executor.py:80] Maximum concurrency for 4096 tokens per request: 98.83x
INFO 06-04 13:09:52 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 5.89 seconds
Llama-3.2-1B-Instruct.Llama3.1-8B-PRM-Deepseek-Data


In [4]:
stop

NameError: name 'stop' is not defined

In [5]:
print(data)

[{'index': 100, 'question': 'A hexagon is inscribed in a circle: [asy]\npair pA, pB, pC, pD, pE, pF, pO;\npO = (0, 0);\npA = pO + dir(-10);\npB = pO + dir(60);\npC = pO + dir(130);\npD = pO + dir(170);\npE = pO + dir(-160);\npF = pO + dir(-80);\ndraw(pA--pB--pC--pD--pE--pF--pA);\nlabel("$105^\\circ$", pF, N * 2);\nlabel("$110^\\circ$", pB, SW * 1.5);\nlabel("$\\alpha$", pD, E);\ndraw(circle(pO, 1));\n[/asy] What is the measure of $\\alpha$, in degrees?', 'answer': '145^\\circ'}]


In [7]:
print(solver.max_agent_steps)
# data = data[:1]
data = load_qaf(args.qaf)
data = data[100:101]
data[0]['question'] = "Simplify $\sqrt{242}$."
data[0]['answer'] = "11\sqrt2"

2


In [8]:
print(data)

[{'index': 100, 'question': 'Simplify $\\sqrt{242}$.', 'answer': '11\\sqrt2'}]


In [9]:
with open(saved_jsonl_file, "a+", encoding='utf-8') as writer:
    for cur_data in tqdm(batch(data, config.batch_size), desc="Main Processing"):
        agents = [agent(config=config, question=d["question"], ground_truth=str(d["answer"])) 
                  for d in cur_data]
        jsonlines = solver.solve(agents, saved_jsonl_file, cur_data)
        for d in cur_data:
            question = d["question"]
            d["rstar"] = jsonlines[question]
            writer.write(json.dumps(d, ensure_ascii=False) + '\n')
            writer.flush()

Main Processing: 0it [00:00, ?it/s]

-----------------Current Rollout:  0 -----------------
-----------------Current Step:  0 -----------------

-> current_agents
[MCTSNode(state={'text': '', 'extra_info': 'question: Simplify $\\sqrt{242}$.', 'action': '', 'action_input': '', 'final_answer': ''}, additional_state_keys=['action', 'action_input', 'final_answer'], parent=None, children=[], depth=0, is_terminal=False, reward=None, value=0, tag='0', consecutive_errors=0, c_puct=2.0, inited=False)]
current_nodes
[MCTSNode(state={'text': '', 'extra_info': 'question: Simplify $\\sqrt{242}$.', 'action': '', 'action_input': '', 'final_answer': ''}, additional_state_keys=['action', 'action_input', 'final_answer'], parent=None, children=[], depth=0, is_terminal=False, reward=None, value=0, tag='0', consecutive_errors=0, c_puct=2.0, inited=False)]
candidate_nodes
[MCTSNode(state={'text': '', 'extra_info': 'question: Simplify $\\sqrt{242}$.', 'action': '', 'action_input': '', 'final_answer': ''}, additional_state_keys=['action', 'actio


Processed prompts:   0%|          | 0/32 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][A
Processed prompts:   3%|▎         | 1/32 [00:19<10:05, 19.53s/it, est. speed input: 1.69 toks/s, output: 478.60 toks/s][A


current_nodes
[MCTSNode(state={'text': '', 'extra_info': 'question: Simplify $\\sqrt{242}$.', 'action': '', 'action_input': '', 'final_answer': ''}, additional_state_keys=['action', 'action_input', 'final_answer'], parent=None, children=[MCTSNode(state={'text': " Here's the step-by-step solution for the given expression: \\begin{align*}\n\\sqrt{242} &amp;= \\sqrt{2\\cdot121} \\\\\n&amp;= \\sqrt{2}\\cdot\\sqrt{121} \\\\\n&amp;= \\sqrt{2} \\cdot \\sqrt{11^2} \\\\\n&amp;= \\sqrt{2}\\cdot121\n\\end{align*}As we can see, we have repeated the square root value, which we know from earlier. This is the base for some easy-to-simplify expressions. The expression can be demonstrated as \\begin{align*}\n\\sqrt{2} \\cdot 121 &amp;= \\sqrt{242} \\\\\n0.5 \\z230 &amp;= \\sqrt{242} \n\\end{align*}This means that $\\sqrt{242} = 11\\sqrt{2}$. Therefore, we conclude that $242 = 121 \\cdot 2$, and since $\\sqrt{121} = 11$, it follows that $\\sqrt{242} = 11\\sqrt{2}$. This is our final solution! \\end{alig


Processed prompts:   0%|          | 0/11 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][A
Processed prompts: 100%|██████████| 11/11 [00:00<00:00, 61.50it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][A


TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equality'
TypeError: unsupported operand type(s) for /: 'int' and 'Equal

Main Processing: 1it [00:37, 37.07s/it]


-> invalid_agents

-> expanded_agents
-----------------Current Rollout:  0 -----------------
-----------------Current Step:  1 -----------------

-> current_agents
[]

-> prompts
[]

-> valid_agents

-> invalid_agents
[]

-> expanded_agents
-----------------Current Rollout:  1 -----------------
-----------------Current Step:  0 -----------------

-> current_agents
[]

-> prompts
[]

-> valid_agents

-> invalid_agents
[]

-> expanded_agents





In [None]:
print(agents[0].config.terminal_sample)