In [1]:
'''
This method evaluates the formatted jsonl file for a single trial. 
'''

import os, psutil, gc
import time 
import json
import pprint

from collections import defaultdict
import random

import numpy as np
np.set_printoptions(precision=4)
from scipy.stats import ttest_rel

In [2]:
from sal.config import Config

from datasets import load_dataset

from core import best_of_n
from utils.load_data import load_data_prm800k
from utils import grader 
from utils import grader2
from utils import parser

In [3]:
# base_dir
base_dir = '/groups/kjun/tnn/datasets/'

# dataset path
data_dir = base_dir + "/prm800k/math_splits"

# llm and prm path
llm_dir = base_dir + "/Llama-3.2-1B-Instruct-GGUF/Llama-3.2-1B-Instruct.Q4_K_M.gguf"
prm_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data-GGUF/Llama3.1-8B-PRM-Deepseek-Data.Q4_K_M.gguf"

llm_tokenizer_dir = base_dir + "/Llama-3.2-1B-Instruct"
prm_tokenizer_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data"

In [4]:
#  load data 
data_by_levels = load_data_prm800k(data_dir)


# ds_completions = load_completions(completions_dir)

# load random_seeds     
# random_seeds = np.loadtxt("random_seeds.txt").astype("int64")
# random_seeds = [int(seed) for seed in random_seeds]

1: 43
2: 90
3: 105
4: 128
5: 134


In [5]:
import signal

In [29]:
class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException()
    
def run_with_timeout(func, q_idx, gt_answer, completions, timeout=2):
    # Set the signal handler for SIGALRM
    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(timeout)  # Schedule an alarm after `timeout` seconds

    try:
        result = func(gt_answer, completions)
    except TimeoutException:
        print(f"Timeout: {q_idx}")
        result = False
    finally:
        signal.alarm(0)  # Cancel alarm if function returns early
    return result
    
def _evaluate_one_pred(gt_answer, completions):
    one_pred_correct = False
    for cidx, completion in enumerate(completions):
        c_answer = parser.extract_answer(completion, 'math')
        one_pred_correct = grader2.math_equal(c_answer, gt_answer)
        if one_pred_correct:
            # print(f"\n-> q_idx = {q_idx}")
            # print(f"gt_answer = {gt_answer}")
            # print(f"c_answer = {c_answer}")
            break

    return one_pred_correct

def evaluate_correctness_hf(data_dir, level, n=8, limit_budget=False):

    # dataset = load_dataset(dataset_name, name=config_name, split=dataset_split, cache_dir=data_dir)
    dataset = load_dataset("json", data_files = data_dir, split='train')
    dataset_by_level = dataset.filter(lambda example: example['level'] == int(level))

    one_pred_correctness = np.zeros((len(dataset_by_level)))
    pred_naive_correctness = np.zeros((len(dataset_by_level)))
    pred_weighted_correctness = np.zeros((len(dataset_by_level)))
    pred_maj_correctness = np.zeros((len(dataset_by_level)))
    for q_idx, data in enumerate(dataset_by_level):
        if limit_budget == True:
            completions = data["completions"][:n]
        else:
            completions = data["completions"]

        gt_answer = data['answer']
        gt_cot, gt_answer  = parser.parse_ground_truth(data, 'math')
        pred_naive_answer = parser.extract_answer(data[f"pred_naive@{n}"], 'math')
        pred_weighted_answer = parser.extract_answer(data[f"pred_weighted@{n}"], 'math')
        pred_maj_answer = parser.extract_answer(data[f"pred_maj@{n}"], 'math')
        
        pred_naive_correct = grader2.math_equal(pred_naive_answer, gt_answer)
        pred_weighted_correct = grader2.math_equal(pred_weighted_answer, gt_answer)
        pred_maj_correct = grader2.math_equal(pred_maj_answer, gt_answer)

        one_pred_correct = run_with_timeout(_evaluate_one_pred, q_idx, gt_answer, completions, timeout=1)
        # one_pred_correct = False
        # for cidx, completion in enumerate(completions):
        #     c_answer = parser.extract_answer(completion, 'math')
        #     one_pred_correct = grader2.math_equal(c_answer, gt_answer)
        #     if one_pred_correct:
        #         # print(f"\n-> q_idx = {q_idx}")
        #         # print(f"gt_answer = {gt_answer}")
        #         # print(f"c_answer = {c_answer}")
        #         break

        # if is_correct != pred_correct2:
        #     print(f"\n-> q_idx = {q_idx}")
        #     print(f"gt_answer = {gt_answer}")
        #     print(f"c_answer = {c_answer}")
        #     print(f"pred_answer = {pred_answer}")
        #     print(f"is_correct = {is_correct}")
        #     print(f"pred_correct = {pred_correct}")

        # print(f"\n-> q_idx = {q_idx}")
        # print(f"gt_answer = {gt_answer}")
        # print(f"pred_answer = {pred_answer}")
        # if pred_correct2 == False:
        #     print(f"\n-> q_idx = {q_idx}")
        #     print(f"gt_answer = {gt_answer}")
        #     # print(f"c_answer = {c_answer}")
        #     print(f"pred_answer = {pred_answer}")
        #     print(f"pred_correct = {pred_correct}")
        #     print(f"pred_correct2 = {pred_correct2}")
    
        
        # if is_correct == False:
        #     print(f"\n-> q_idx = {q_idx}")
        #     print(f"gt_answer = {gt_answer}")
        #     for completion in completions:
        #         print(completion[-20:])
                
        one_pred_correctness[q_idx] = one_pred_correct
        pred_naive_correctness[q_idx] = pred_naive_correct
        pred_weighted_correctness[q_idx] = pred_weighted_correct
        pred_maj_correctness[q_idx] = pred_maj_correct

    one_pred_correctness = np.mean(one_pred_correctness)
    pred_naive_correctness = np.mean(pred_naive_correctness)
    pred_weighted_correctness = np.mean(pred_weighted_correctness)
    pred_maj_correctness = np.mean(pred_maj_correctness)
        
    return one_pred_correctness, pred_naive_correctness, pred_weighted_correctness, pred_maj_correctness 


# general params
config = Config()
config.agg_strategy = 'last'
config.n = 8
config.beam_width = 2
config.lookahead = 0
config.num_iterations = 10
config.sort_completed = False

# diverse_select params
config.lam = 10
config.normalize_embeds = True

level = '4'
# num_questions = len(data_by_levels[level])
# num_questions = 1
num_trials = 1
# print(f"num_questions = {num_questions}")

sd_config = "bon--n-16--level-4--v01"
# sd_config = "beam--n-8--d-40--bw-4--lh-1--dup-True--limit--False--level-4--v01"
# sd_config = "sd--n-8--bw-2--d-2--lam-10--True--level-4--v21"
sd_config = "sd--n-8--bw-2--d-10--lam-10--True--level-4--v11"

sd_one_pred_correctness = []
sd_pred_naive_correctness = []
sd_pred_weighted_correctness = []
sd_pred_maj_correctness = []
for trial_idx in range(num_trials):
    one_pred_correctness, pred_naive_correctness, pred_weighted_correctness, pred_maj_correctness = \
        evaluate_correctness_hf(f"results/{sd_config}--trial-{trial_idx}.jsonl", level, config.n, limit_budget=False)
    
    sd_one_pred_correctness.append(one_pred_correctness)
    sd_pred_naive_correctness.append(pred_naive_correctness)
    sd_pred_weighted_correctness.append(pred_weighted_correctness)
    sd_pred_maj_correctness.append(pred_maj_correctness)
    
sd_one_pred_correctness_mean = np.mean(sd_one_pred_correctness)
sd_pred_naive_correctness_mean = np.mean(sd_pred_naive_correctness)
sd_pred_weighted_correctness_mean = np.mean(sd_pred_weighted_correctness)
sd_pred_maj_correctness_mean = np.mean(sd_pred_maj_correctness)

sd_one_pred_correctness_std = np.std(sd_one_pred_correctness, ddof=1)/np.sqrt(num_trials)
sd_pred_naive_correctness_std = np.std(sd_pred_naive_correctness, ddof=1)/np.sqrt(num_trials)
sd_pred_weighted_correctness_std = np.std(sd_pred_weighted_correctness, ddof=1)/np.sqrt(num_trials)
sd_pred_maj_correctness_std = np.std(sd_pred_maj_correctness, ddof=1)/np.sqrt(num_trials)

print(sd_one_pred_correctness)
print(sd_pred_weighted_correctness)
print(sd_one_pred_correctness_mean)
print(sd_one_pred_correctness_std)

print(f"one_pred_correctness: {sd_one_pred_correctness_mean:0.4f} (\u00B1{sd_one_pred_correctness_std:0.4f})")
print(f"pred_naive_correctness: {sd_pred_naive_correctness_mean:0.4f} (\u00B1{sd_pred_naive_correctness_std:0.4f})")
print(f"pred_weighted_correctness: {sd_pred_weighted_correctness_mean:0.4f} (\u00B1{sd_pred_weighted_correctness_std:0.4f})")
print(f"pred_maj_correctness: {sd_pred_maj_correctness_mean:0.4f} (\u00B1{sd_pred_maj_correctness_std:0.4f})")

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f048819b4d0>>
Traceback (most recent call last):
  File "/home/u20/tnguyen9210/micromamba/envs/py311/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 790, in _clean_thread_parent_frames
    active_threads = {thread.ident for thread in threading.enumerate()}
                                                 ^^^^^^^^^^^^^^^^^^^^^
  File "/home/u20/tnguyen9210/micromamba/envs/py311/lib/python3.11/threading.py", line 1510, in enumerate
    return list(_active.values()) + list(_limbo.values())
                ^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_1128413/3443053402.py", line 5, in timeout_handler
TimeoutException: 


Generating train split: 0 examples [00:00, ? examples/s]

Filter:   0%|          | 0/128 [00:00<?, ? examples/s]

[0.5078125, 0.4921875]
[0.296875, 0.2890625]
0.5
0.0078125
one_pred_correctness: 0.5000 (±0.0078)
pred_naive_correctness: 0.3281 (±0.0078)
pred_weighted_correctness: 0.2930 (±0.0039)
pred_maj_correctness: 0.2578 (±0.0078)
