In [1]:
import os, psutil, gc
import json
import time
import copy
import pprint
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

In [2]:

import torch 
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams


INFO 03-22 21:57:23 [__init__.py:256] Automatically detected platform cuda.


In [3]:
# from sal.models.reward_models import RLHFFlow

from sal.search.utils import Beam, build_conv, generate_k_steps, last
from sal.config import Config
from sal.models.reward_models import PRM
from sal.utils.score import aggregate_scores

from reward_models import RLHFFlow

In [4]:
# base_path
base_path = '/groups/kjun/tnn/datasets/'

# dataset path
dataset_path = base_path + "/prm800k/math_splits"

# llm and prm path
llm_path = base_path + "/Llama-3.2-1B-Instruct-GGUF/Llama-3.2-1B-Instruct.Q4_K_M.gguf"
prm_path = base_path + "/Llama3.1-8B-PRM-Deepseek-Data-GGUF/Llama3.1-8B-PRM-Deepseek-Data.Q4_K_M.gguf"

llm_tokenizer_path = base_path + "/Llama-3.2-1B-Instruct"
prm_tokenizer_path = base_path + "/Llama3.1-8B-PRM-Deepseek-Data"

In [5]:
# gpu_memory_utilization=0.2
llm_vllm = LLM(
    model = llm_tokenizer_path,
    gpu_memory_utilization = 0.5,  # Utilize 50% of GPU memory
    max_model_len = 10000,
    # enable_prefix_caching=True,  # Optimize repeated prefix computations
    dtype = "float16",
    seed = 123)

# use the gguf quantized model 
# llm_vllm = LLM(
#     model = llm_path,
#     tokenizer = llm_tokenizer_path,
#     tensor_parallel_size=1,
#     gpu_memory_utilization = 0.2,  # Utilize 50% of GPU memory
#     max_model_len = 10000,
#     dtype = "float16",
#     seed = 123)

gc.collect();torch.cuda.empty_cache();
print('#--- memory:', torch.cuda.memory_allocated(0)/(1024**3))

INFO 03-22 21:57:43 [config.py:583] This model supports multiple tasks: {'classify', 'generate', 'score', 'reward', 'embed'}. Defaulting to 'generate'.
INFO 03-22 21:57:43 [llm_engine.py:241] Initializing a V0 LLM engine (v0.8.1) with config: model='/groups/kjun/tnn/datasets//Llama-3.2-1B-Instruct', speculative_config=None, tokenizer='/groups/kjun/tnn/datasets//Llama-3.2-1B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=10000, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=Fal

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 03-22 21:57:48 [loader.py:429] Loading weights took 1.32 seconds
INFO 03-22 21:57:48 [model_runner.py:1146] Model loading took 2.3185 GB and 1.605521 seconds
INFO 03-22 21:57:49 [worker.py:267] Memory profiling takes 0.76 seconds
INFO 03-22 21:57:49 [worker.py:267] the current vLLM instance can use total_gpu_memory (31.73GiB) x gpu_memory_utilization (0.50) = 15.87GiB
INFO 03-22 21:57:49 [worker.py:267] model weights take 2.32GiB; non_torch_memory takes 0.09GiB; PyTorch activation peak memory takes 1.21GiB; the rest of the memory reserved for KV Cache is 12.25GiB.
INFO 03-22 21:57:49 [executor_base.py:111] # cuda blocks: 25088, # CPU blocks: 8192
INFO 03-22 21:57:49 [executor_base.py:116] Maximum concurrency for 10000 tokens per request: 40.14x
INFO 03-22 21:57:51 [model_runner.py:1442] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.

Capturing CUDA graph shapes: 100%|██████████| 35/35 [00:16<00:00,  2.16it/s]

INFO 03-22 21:58:07 [model_runner.py:1570] Graph capturing finished in 16 secs, took 0.13 GiB
INFO 03-22 21:58:07 [llm_engine.py:447] init engine (profile, create kv cache, warmup model) took 18.96 seconds





#--- memory: 14.584884643554688


In [6]:
# del(llm_vllm)
gc.collect();torch.cuda.empty_cache();
print('#--- memory:', torch.cuda.memory_allocated(0)/(1024**3))

#--- memory: 14.584884643554688


In [7]:
prm = RLHFFlow(model_path=prm_tokenizer_path, device_map='cuda:1')

gc.collect();torch.cuda.empty_cache();
print('#--- memory:', torch.cuda.memory_allocated(1)/(1024**3))

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

#--- memory: 14.95752763748169


In [8]:
data_by_levels = defaultdict(list)
with open(f"{dataset_path}/test.jsonl", 'r', encoding='utf-8') as filein:
    for line in filein:
        if line.strip():
            data = json.loads(line)
            # print(data['level'])
            data_by_levels[f"{data['level']}"].append(data)

    # data =  [json.loads(line) for line in filein if line.strip()]
    # pprint.pprint(data, compact=True)

for key in range(1,6):
    key = str(key)
    print(f"{key}: {len(data_by_levels[key])}")
    # pprint.pprint(data_by_levels[key][:2], compact=True)
# print(data_by_levels.keys())
# pprint.pprint(data_by_levels['2'], compact=True)

# random_seeds = np.loadtxt("random_seeds.txt").astype("int64")
# random_seeds = [int(seed) for seed in random_seeds]

1: 43
2: 90
3: 105
4: 128
5: 134


In [9]:
def best_of_n(batch_of_questions, config: Config, llm_vllm: LLM, prm: PRM, random_seed):

    convs = [
        [
            {"role": "system", "content": config.system_prompt},
            {"role": "user", "content": prompt},
        ]
        for prompt in batch_of_questions
    ]
    
    tokenizer = llm_vllm.get_tokenizer()
    
    # TODO: set the augmented template from a file
    if config.custom_chat_template is not None:
        tokenizer.chat_template = config.custom_chat_template
        
    templated_convs = tokenizer.apply_chat_template(
        convs, add_generation_prompt=True, tokenize=False,
    )

    # Duplicate convs to generate config.n completions per prompt so we can do continous batching
    # This makes [p1, p2, p3, p4] become [p1, p1, p2, p2, p3, p3, p4, p4] for e.g. config.n=2
    # templated_convs = [c for conv in templated_convs for c in [conv] * config.n]

    sampling_params = SamplingParams(
        temperature=config.temperature,
        # temperature=0,
        max_tokens=config.max_tokens,
        top_p=config.top_p,
        n=config.n,  # generate n outputs
        best_of=config.n,
        # stop=[
        #     "\n\n"
        # ],  # we consider that a step in the problem is indicated by a double newline
        # include_stop_str_in_output=True,
        seed=random_seed,
    )        

    # Generate responses 
    responses = llm_vllm.generate(
        templated_convs,
        sampling_params=sampling_params,
        use_tqdm=False,
    )

    # Re-generate responses if we get more responses than expected
    if len(responses) != len(batch_of_questions):
        responses = llm_vllm.generate(
            templated_convs,
            sampling_params=sampling_params,
            use_tqdm=False,
        )
        assert len(responses) == len(batch_of_questions), \
            f"Generated {len(responses)} responses instead of {len(batch_of_questions)}"
    
    # Collect the completions from responses
    completions = [[] for _ in range(len(batch_of_questions))]
    completion_ntokens = [[] for _ in range(len(batch_of_questions))]

    for r_idx, r in enumerate(responses):
        # print(r.request_id)
        if len(r.outputs) != config.n:
            raise ValueError(f"Generated {len(r.outputs)} completions instead of {config.n}")
            
        for output in r.outputs[:config.n]:
            # print(output.text)
            # print(output.stop_reason)
            completions[r_idx].append(output.text)
            completion_ntokens[r_idx].append(len(output.token_ids))

    # Compute the scores of completions
    scores = prm.score(batch_of_questions, completions)
    agg_scores = [
        [aggregate_scores(s, config.agg_strategy) for s in score] for score in scores
    ]
    # print(agg_scores)
    # print(len(completions))

    # results = {"completions": [], "best_completions": [], "completion_tokens": [], "agg_scores": [], "best_agg_scores": []}
    # results = {"questions": [], "gt_answers": [], "completions": [], "completion_ntokens": [], "agg_scores": []}
    results = {"completions": [], "completion_ntokens": [], "agg_scores": []}
    # results["questions"] = batch_of_questions
    results["completions"] = completions
    results["completion_ntokens"] = completion_ntokens
    results["agg_scores"] = agg_scores
    
    # for pidx in range(len(batch_of_questions)):
    #     best_idx = np.argmax(agg_scores[pidx])
    #     results["best_scores"].append(agg_scores[pidx][best_idx])
    #     results["best_completions"].append(completions[pidx][best_idx])

    return results

In [10]:
# general params
config = Config()
config.agg_strategy = 'last'
config.n = 4                  # num of generations in BoN 

config.lookahead = 0
config.num_iterations = 10
config.sort_completed = False

# diverse_select params
config.lam = 10
config.normalize_embeddings = True

In [None]:
level = '1'
num_questions = len(data_by_levels[level])
# num_questions = 10
num_trials = 50
config.n = 128
print(f"num_trials = {num_trials}")
print(f"num_questions = {num_questions}")

batch_of_questions = [data_by_levels[level][q_idx]['problem'] for q_idx in range(num_questions)]
# batch_of_answers = [data_by_levels[level][q_idx]['answer'] for q_idx in range(num_questions)]

all_results = []
start_time = time.time()
for trial_idx in range(num_trials):
    results = best_of_n(batch_of_questions, config, llm_vllm, prm, 10000+trial_idx)
    # results["gt_answers"] = batch_of_answers
    all_results.append(results)
    
    # compute the time
    total_time = time.time() - start_time
    time_per_trial = total_time/(trial_idx+1)
    time_per_question = time_per_trial/num_questions
    if trial_idx % 5 == 0:
        print(f"trial {trial_idx}")
        print(f"it takes {time_per_question:0.4f}s per question")
        print(f"it takes {time_per_trial:0.4f}s for this trial")

print(f"it takes {total_time:0.4f}s in total")

result_filename = f"results/run_best_of_n_prm800k_level{level}_v21.json"
with open(result_filename, 'w+', encoding = 'utf-8') as fout:
    json.dump(all_results, fout, ensure_ascii=True, indent=4)

num_trials = 50
num_questions = 43


In [None]:
# total_score = 0
# correct_idxes = []
# for q_idx in range(num_questions):
#     print(f"question {q_idx}")
#     # print(f"question: {data_by_levels['4'][q_idx]['problem']}")
#     best_completion = results['best_completions'][q_idx]
#     print(f"best completion: {best_completion}")
#     pred_answer = extract_last_boxed_answer(best_completion)
#     gt_answer = data_by_levels['4'][q_idx]['answer']
#     is_correct = grader.grade_answer(pred_answer, gt_answer)
#     print(f"pred answer: {pred_answer}")
#     print(f"gt answer: {gt_answer}")
#     print(f"is correct: {is_correct}")
#     print(f"all scores = {results['all_scores'][q_idx]}")
#     print(f"best score = {results['best_scores'][q_idx]}")
#     if is_correct:
#         correct_idxes.append(q_idx)

# num_corrects = len(correct_idxes)
# acc = num_corrects/num_questions
# print(f"num correct answers = {num_corrects}")
# print(f"acc = {acc:0.4f}")

In [None]:
# def extract_last_boxed_answer(text):
#     """
#     Extracts the content inside the last \\boxed{...} in the given text, 
#     handling nested braces properly.
#     """
#     # Find the starting index of the last '\\boxed{'
#     boxed_start = text.rfind('\\boxed{')
#     if boxed_start == -1:
#         return None  # No \\boxed{ found
    
#     # Start after the opening '{'
#     start_index = boxed_start + len('\\boxed{')
#     brace_count = 1  # We've seen the opening '{'
#     content = ''
    
#     # Iterate through the text to find the matching closing brace
#     for i in range(start_index, len(text)):
#         char = text[i]
#         if char == '{':
#             brace_count += 1
#         elif char == '}':
#             brace_count -= 1
#             if brace_count == 0:
#                 return content.strip()  # Return content when braces balance
#         content += char
    
#     return None  # No matching closing brace found