In [1]:
import os, psutil, gc
import time 
import json
import pprint

from collections import defaultdict
import random
import numpy as np

In [2]:
import torch 
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams, PoolingParams

from sal.config import Config

INFO 03-23 10:58:11 [__init__.py:256] Automatically detected platform cuda.


In [3]:
!nvidia-smi topo -m

	[4mGPU0	GPU1	GPU2	GPU3	CPU Affinity	NUMA Affinity	GPU NUMA ID[0m
GPU0	 X 	NODE	SYS	SYS	1-24	0		N/A
GPU1	NODE	 X 	SYS	SYS	1-24	0		N/A
GPU2	SYS	SYS	 X 	NODE	49-72	1		N/A
GPU3	SYS	SYS	NODE	 X 	49-72	1		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks


In [3]:
# base_path
base_path = '/groups/kjun/tnn/datasets/'

# dataset path
dataset_path = base_path + "/prm800k/math_splits"

# llm and prm path
llm_path = base_path + "/Llama-3.2-1B-Instruct-GGUF/Llama-3.2-1B-Instruct.Q4_K_M.gguf"
prm_path = base_path + "/Llama3.1-8B-PRM-Deepseek-Data-GGUF/Llama3.1-8B-PRM-Deepseek-Data.Q4_K_M.gguf"

llm_tokenizer_path = base_path + "/Llama-3.2-1B-Instruct"
prm_tokenizer_path = base_path + "/Llama3.1-8B-PRM-Deepseek-Data"

In [7]:
data_by_levels = defaultdict(list)
with open(f"{dataset_path}/test.jsonl", 'r', encoding='utf-8') as filein:
    for line in filein:
        if line.strip():
            data = json.loads(line)
            # print(data['level'])
            data_by_levels[f"{data['level']}"].append(data)

    # data =  [json.loads(line) for line in filein if line.strip()]
    # pprint.pprint(data, compact=True)

for key in range(1,6):
    key = str(key)
    print(f"{key}: {len(data_by_levels[key])}")
    # pprint.pprint(data_by_levels[key][:2], compact=True)
# print(data_by_levels.keys())
# pprint.pprint(data_by_levels['2'], compact=True)

# random_seeds = np.loadtxt("random_seeds.txt").astype("int64")
# random_seeds = [int(seed) for seed in random_seeds]

1: 43
2: 90
3: 105
4: 128
5: 134


In [3]:
level = '1'
json_path = f"results/generate_bon_prm800k_level{level}_v11.json"
jsonl_path = f"results/generate_bon_prm800k_level{level}_v11.jsonl"
with open(json_path, 'r', encoding='utf-8') as filein:
    data = json.load(filein)

with open(jsonl_path, 'w', encoding='utf-8') as fileout:
    for item in data:
        json.dump(item, fileout)
        fileout.write('\n')

In [10]:
# general params
config = Config()
config.n = 128

level = '1'
num_questions = len(data_by_levels[level])
# num_questions = 2
num_trials = 10
print(f"num_trials = {num_trials}")
print(f"num_questions = {num_questions}")

method_number = 1
if method_number == 1:
    test_method = test_best_of_n_v11
else:
    test_method = test_best_of_n_v12
print(test_method)

batch_of_prompts = [data_by_levels[level][q_idx]['problem'] for q_idx in range(num_questions)]

result_filename = f"results/generate_bon_prm800k_level{level}_v11.json"
all_results = []
start_time = time.time()
for trial_idx in range(num_trials):
    # test_method(batch_of_prompts, config, llm_vllm, random_seeds[trial_idx])
    results = test_method(batch_of_prompts, config, llm_vllm, 10000+trial_idx)
    all_results.append(results)
    with open(result_filename, 'w', encoding = 'utf-8') as fout:
        json.dump(all_results, fout, ensure_ascii=True, indent=4)

    # compute the time
    if trial_idx % 10 == 0:
        total_time = time.time() - start_time
        time_per_trial = total_time/(trial_idx+1)
        time_per_question = time_per_trial/num_questions
        print(f"trial {trial_idx}")
        print(f"it takes {time_per_question:0.4f}s per question")
        print(f"it takes {time_per_trial:0.4f}s per trial")

total_time = time.time() - start_time
print(f"it takes {total_time:0.4f}s in total")

num_trials = 10
num_questions = 43
<function test_best_of_n_v11 at 0x7f8cce59f880>
trial 0
it takes 8.7686s per question
it takes 377.0496s per trial
it takes 3768.7243s in total
