In [1]:
'''
Evaluate the performance of search algorithms 
Collect scores across all prompts and trials, and compute the overall statistics.
'''

import os, psutil, gc
import time 
import json
import pprint

from collections import defaultdict
import random

import numpy as np
np.set_printoptions(precision=4)
from scipy.stats import ttest_rel

In [2]:
from sal.config import Config

from datasets import load_dataset

# from core import best_of_n
from utils.load_data import load_data_prm800k
from utils import grader 
from utils import grader2
from utils import parser

In [3]:
# base_dir
base_dir = '/groups/kjun/tnn/datasets/'

# dataset path
data_dir = base_dir + "/prm800k/math_splits"

# llm and prm path
llm_dir = base_dir + "/Llama-3.2-1B-Instruct-GGUF/Llama-3.2-1B-Instruct.Q4_K_M.gguf"
prm_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data-GGUF/Llama3.1-8B-PRM-Deepseek-Data.Q4_K_M.gguf"

llm_tokenizer_dir = base_dir + "/Llama-3.2-1B-Instruct"
prm_tokenizer_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data"

In [4]:
#  load data 
data_by_levels = load_data_prm800k(data_dir)


# ds_completions = load_completions(completions_dir)

# load random_seeds     
# random_seeds = np.loadtxt("random_seeds.txt").astype("int64")
# random_seeds = [int(seed) for seed in random_seeds]

1: 43
2: 90
3: 105
4: 128
5: 134


In [5]:
import signal

In [12]:
class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException()

def run_with_timeout(fn_extract_answer, fn_grade, completion, gt_answer, timeout=2):
    # Set the signal handler for SIGALRM
    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(timeout)  # Schedule an alarm after `timeout` seconds
    try:
        c_answer = fn_extract_answer(completion, 'math')
        result = fn_grade(c_answer, gt_answer)
    except TimeoutException:
        print(f"Timeout: {completion}")
        c_answer = None
        result = None
    finally:
        signal.alarm(0)  # Cancel alarm if function returns early
    return c_answer, result

def evaluate_correctness_hf(data_dir, level, n=8, limit_budget=False):

    # dataset = load_dataset(dataset_name, name=config_name, split=dataset_split, cache_dir=data_dir)
    dataset = load_dataset("json", data_files = data_dir, split='train')
    dataset_by_level = dataset.filter(lambda example: example['level'] == int(level))

    passn_correctness = np.zeros((len(dataset_by_level)))
    pass1b_correctness = np.zeros((len(dataset_by_level)))
    naive1b_correctness = np.zeros((len(dataset_by_level)))
    weighted1b_correctness = np.zeros((len(dataset_by_level)))
    maj1b_correctness = np.zeros((len(dataset_by_level)))
    for q_idx, data in enumerate(dataset_by_level):
        passn_completions = data["completions"][:n]
        if limit_budget == True:
            completions = data["completions"][:n]
        else:
            completions = data["completions"]

        # gt_answer = data['answer']
        gt_cot, gt_answer  = parser.parse_ground_truth(data, 'math')
        naive1b_answer = parser.extract_answer(data[f"pred_naive@{8}"], 'math')
        weighted1b_answer = parser.extract_answer(data[f"pred_weighted@{8}"], 'math')
        maj1b_answer = parser.extract_answer(data[f"pred_maj@{8}"], 'math')
        
        naive1b_correct = grader2.math_equal(naive1b_answer, gt_answer)
        weighted1b_correct = grader2.math_equal(weighted1b_answer, gt_answer)
        maj1b_correct = grader2.math_equal(maj1b_answer, gt_answer)

        # pass1b_correct = run_with_timeout(_evaluate_pass1b, q_idx, gt_answer, completions, timeout=1)
        # pass1b_correct = False
        pass1b_correct = False
        for cidx, completion in enumerate(completions):
            c_answer, is_correct = run_with_timeout(parser.extract_answer, grader2.math_equal, completion, gt_answer)
            if is_correct is True: 
                pass1b_correct = True
                break
            # c_answer = parser.extract_answer(completion, 'math')
            # pass1b_correct = grader2.math_equal(c_answer, gt_answer)
            # if pass1b_correct:
            #     # print(f"\n-> q_idx = {q_idx}")
            #     # print(f"gt_answer = {gt_answer}")
            #     # print(f"c_answer = {c_answer}")
            #     break

        passn_correct = False
        for cidx, completion in enumerate(passn_completions):
            c_answer, is_correct = run_with_timeout(parser.extract_answer, grader2.math_equal, completion, gt_answer)
            if is_correct is True: 
                passn_correct = True
                break

        passn_correctness[q_idx] = passn_correct
        pass1b_correctness[q_idx] = pass1b_correct
        naive1b_correctness[q_idx] = naive1b_correct
        weighted1b_correctness[q_idx] = weighted1b_correct
        maj1b_correctness[q_idx] = maj1b_correct

    # stop
    # passn_correctness = np.mean(passn_correctness)
    # pass1b_correctness = np.mean(pass1b_correctness)
    # naive1b_correctness = np.mean(naive1b_correctness)
    # weighted1b_correctness = np.mean(weighted1b_correctness)
    # maj1b_correctness = np.mean(maj1b_correctness)
        
    return passn_correctness, pass1b_correctness, naive1b_correctness, weighted1b_correctness, maj1b_correctness 


# general params
config = Config()
config.agg_strategy = 'last'
config.n = 8
config.beam_width = 2
config.lookahead = 0
config.num_iterations = 10
config.sort_completed = False

# diverse_select params
config.lam = 10
config.normalize_embeds = True

level = '4'
# num_questions = len(data_by_levels[level])
# num_questions = 1
num_trials = 4
# print(f"num_questions = {num_questions}")

sd_config = "bob--n-8--d-40--level-4--v11"
sd_config = "mcts--n-8--d-40--lam-10--dalpha-10.0--dbeta-1.0--cpuct-0-2--ppl-True--normalize-True--level-4--v51"

sd_passn_correctness = []
sd_pass1b_correctness = []
sd_naive1b_correctness = []
sd_weighted1b_correctness = []
sd_maj1b_correctness = []
for trial_idx in range(num_trials):
    passn_correctness, pass1b_correctness, naive1b_correctness, weighted1b_correctness, maj1b_correctness = \
        evaluate_correctness_hf(f"results/{sd_config}--trial-{trial_idx}.jsonl", level, config.n, limit_budget=False)

    sd_passn_correctness.append(passn_correctness)
    sd_pass1b_correctness.append(pass1b_correctness)
    sd_naive1b_correctness.append(naive1b_correctness)
    sd_weighted1b_correctness.append(weighted1b_correctness)
    sd_maj1b_correctness.append(maj1b_correctness)

sd_passn_correctness = np.concatenate(sd_passn_correctness)
sd_pass1b_correctness = np.concatenate(sd_pass1b_correctness)
sd_naive1b_correctness = np.concatenate(sd_naive1b_correctness)
sd_weighted1b_correctness = np.concatenate(sd_weighted1b_correctness)
sd_maj1b_correctness = np.concatenate(sd_maj1b_correctness)

sd_passn_correctness_mean = np.mean(sd_passn_correctness)
sd_pass1b_correctness_mean = np.mean(sd_pass1b_correctness)
sd_naive1b_correctness_mean = np.mean(sd_naive1b_correctness)
sd_weighted1b_correctness_mean = np.mean(sd_weighted1b_correctness)
sd_maj1b_correctness_mean = np.mean(sd_maj1b_correctness)

sd_passn_correctness_std = np.std(sd_passn_correctness, ddof=1)/np.sqrt(num_trials*128) # 128 is number of prompts for level 4 
sd_pass1b_correctness_std = np.std(sd_pass1b_correctness, ddof=1)/np.sqrt(num_trials*128)
sd_naive1b_correctness_std = np.std(sd_naive1b_correctness, ddof=1)/np.sqrt(num_trials*128)
sd_weighted1b_correctness_std = np.std(sd_weighted1b_correctness, ddof=1)/np.sqrt(num_trials*128)
sd_maj1b_correctness_std = np.std(sd_maj1b_correctness, ddof=1)/np.sqrt(num_trials*128)

# print(sd_passn_correctness)
# print(sd_pass1b_correctness)
# print(sd_weighted1b_correctness)
# print(sd_pass1b_correctness_mean)
# print(sd_pass1b_correctness_std)

print(f"passn_correctness: {sd_passn_correctness_mean:0.4f} (\u00B1{sd_passn_correctness_std:0.4f})")
print(f"pass1b_correctness: {sd_pass1b_correctness_mean:0.4f} (\u00B1{sd_pass1b_correctness_std:0.4f})")
print(f"naive1b_correctness: {sd_naive1b_correctness_mean:0.4f} (\u00B1{sd_naive1b_correctness_std:0.4f})")
print(f"weighted1b_correctness: {sd_weighted1b_correctness_mean:0.4f} (\u00B1{sd_weighted1b_correctness_std:0.4f})")
print(f"maj1b_correctness: {sd_maj1b_correctness_mean:0.4f} (\u00B1{sd_maj1b_correctness_std:0.4f})")

Generating train split: 0 examples [00:00, ? examples/s]

Filter:   0%|          | 0/128 [00:00<?, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Filter:   0%|          | 0/128 [00:00<?, ? examples/s]



Generating train split: 0 examples [00:00, ? examples/s]

Filter:   0%|          | 0/128 [00:00<?, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Filter:   0%|          | 0/128 [00:00<?, ? examples/s]

passn_correctness: 0.5332 (±0.0221)
pass1b_correctness: 0.5586 (±0.0220)
naive1b_correctness: 0.4336 (±0.0219)
weighted1b_correctness: 0.4414 (±0.0220)
maj1b_correctness: 0.4355 (±0.0219)
