In [1]:
'''
This method evaluates the formatted jsonl file for a single trial. 
Add multiprocessing to the evaluation process to handle each question in parallel.
'''

import os, psutil, gc
from collections import defaultdict

import json
import time 
import pprint
import signal

import random
import numpy as np
np.set_printoptions(precision=4)

from scipy.stats import ttest_rel

import multiprocessing as mp


In [2]:
from sal.config import Config

from datasets import load_dataset

from core import best_of_n
from utils.load_data import load_data_prm800k
from utils import grader 
from utils import grader2
from utils import parser

In [3]:
# base_dir
base_dir = '/groups/kjun/tnn/datasets/'

# dataset path
data_dir = base_dir + "/prm800k/math_splits"

# llm and prm path
llm_dir = base_dir + "/Llama-3.2-1B-Instruct-GGUF/Llama-3.2-1B-Instruct.Q4_K_M.gguf"
prm_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data-GGUF/Llama3.1-8B-PRM-Deepseek-Data.Q4_K_M.gguf"

llm_tokenizer_dir = base_dir + "/Llama-3.2-1B-Instruct"
prm_tokenizer_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data"

In [4]:
#  load data 
# data_by_levels = load_data_prm800k(data_dir)

# ds_completions = load_completions(completions_dir)

# load random_seeds     
# random_seeds = np.loadtxt("random_seeds.txt").astype("int64")
# random_seeds = [int(seed) for seed in random_seeds]

In [6]:
class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException()

# def run_with_timeout(fn_extract_answer, fn_grade, completion, gt_answer, timeout=2):
#     # Set the signal handler for SIGALRM
#     signal.signal(signal.SIGALRM, timeout_handler)
#     signal.alarm(timeout)  # Schedule an alarm after `timeout` seconds
#     try:
#         c_answer = fn_extract_answer(completion, 'math')
#         result = fn_grade(c_answer, gt_answer)
#     except TimeoutException:
#         print(f"Timeout: {completion}")
#         c_answer = None
#         result = None
#     finally:
#         signal.alarm(0)  # Cancel alarm if function returns early
#     return c_answer, result

def run_with_timeout(fn_extract_answer, fn_grade, completion, gt_answer, timeout=2):
    c_answer = fn_extract_answer(completion, 'math')
    result = fn_grade(c_answer, gt_answer)
    return c_answer, result

def _evaluate(data, n):
    passn_completions = data["completions"][:n]
    passp_completions = data["completions"]
    gt_cot, gt_answer  = parser.parse_ground_truth(data, 'math')
    naive_answer = parser.extract_answer(data[f"pred_naive@{n}"], 'math')
    weighted_answer = parser.extract_answer(data[f"pred_weighted@{n}"], 'math')
    maj_answer = parser.extract_answer(data[f"pred_maj@{n}"], 'math')
    
    naive_correct = grader2.math_equal(naive_answer, gt_answer)
    weighted_correct = grader2.math_equal(weighted_answer, gt_answer)
    maj_correct = grader2.math_equal(maj_answer, gt_answer)

    passn_correct = False
    for cidx, completion in enumerate(passn_completions):
        c_answer, is_correct = run_with_timeout(parser.extract_answer, grader2.math_equal, completion, gt_answer)
        if is_correct is True: 
            passn_correct = True
            break
            
    passp_correct = False
    for cidx, completion in enumerate(passp_completions):
        c_answer, is_correct = run_with_timeout(parser.extract_answer, grader2.math_equal, completion, gt_answer)
        if is_correct is True: 
            passp_correct = True
            break
    

    return (passn_correct, passp_correct, naive_correct, weighted_correct, maj_correct)

def evaluate_data(data_dir, level, n=8):

    # dataset = load_dataset(dataset_name, name=config_name, split=dataset_split, cache_dir=data_dir)
    dataset = load_dataset("json", data_files = data_dir, split='train')
    dataset_by_level = dataset.filter(lambda example: example['level'] == int(level))

    tasks = [(data, n) for data in dataset_by_level]

    with mp.Pool() as pool:
        pool_results = pool.starmap(_evaluate, tasks)

    passn_correctness, passp_correctness, naive_correctness, weighted_correctness, maj_correctness = \
        zip(*pool_results)
    
    # stop
    passn_correctness = np.mean(passn_correctness)
    passp_correctness = np.mean(passp_correctness)
    naive_correctness = np.mean(naive_correctness)
    weighted_correctness = np.mean(weighted_correctness)
    maj_correctness = np.mean(maj_correctness)
        
    return passn_correctness, passp_correctness, naive_correctness, weighted_correctness, maj_correctness 


# general params
config = Config()
config.agg_strategy = 'last'
config.n = 8
config.beam_width = 2
config.lookahead = 0
config.num_iterations = 10
config.sort_completed = False

# diverse_select params
config.lam = 10
config.normalize_embeds = True

level = '4'
# num_questions = len(data_by_levels[level])
# num_questions = 1
num_trials = 5
# print(f"num_questions = {num_questions}")

sd_config = "bob--n-8--d-40--level-4--v11"
# sd_config = "beam--n-8--d-40--bw-4--lh-1--dup-True--limit--False--level-4--v01"
# sd_config = "sda--n-8--bw-2--d-2--lam-10--True--level-4--v21"
# sd_config = "sdp--n-8--bw-4--d-40--lam-10--True--dalpha-0--dbeta-1.0--ppl-True--level-4--v11"
sd_config = "sdp--n-8--bw-4--d-40--lam-10--True--dalpha-1.0--dbeta-0--ppl-False--level-4--v11"

sd_passn_correctness = []
sd_passp_correctness = []
sd_naive_correctness = []
sd_weighted_correctness = []
sd_maj_correctness = []
for trial_idx in range(num_trials):
    passn_correctness, passp_correctness, naive_correctness, weighted_correctness, maj_correctness = \
        evaluate_data(f"results/{sd_config}--trial-{trial_idx}.jsonl", level, config.n)

    sd_passn_correctness.append(passn_correctness)
    sd_passp_correctness.append(passp_correctness)
    sd_naive_correctness.append(naive_correctness)
    sd_weighted_correctness.append(weighted_correctness)
    sd_maj_correctness.append(maj_correctness)

sd_passn_correctness_mean = np.mean(sd_passn_correctness)
sd_passp_correctness_mean = np.mean(sd_passp_correctness)
sd_naive_correctness_mean = np.mean(sd_naive_correctness)
sd_weighted_correctness_mean = np.mean(sd_weighted_correctness)
sd_maj_correctness_mean = np.mean(sd_maj_correctness)

sd_passn_correctness_std = np.std(sd_passn_correctness, ddof=1)/np.sqrt(num_trials)
sd_passp_correctness_std = np.std(sd_passp_correctness, ddof=1)/np.sqrt(num_trials)
sd_naive_correctness_std = np.std(sd_naive_correctness, ddof=1)/np.sqrt(num_trials)
sd_weighted_correctness_std = np.std(sd_weighted_correctness, ddof=1)/np.sqrt(num_trials)
sd_maj_correctness_std = np.std(sd_maj_correctness, ddof=1)/np.sqrt(num_trials)

print(sd_passn_correctness)
print(sd_passp_correctness)
print(sd_weighted_correctness)
print(sd_passp_correctness_mean)
print(sd_passp_correctness_std)

print(f"passn_correctness: {sd_passn_correctness_mean:0.4f} (\u00B1{sd_passn_correctness_std:0.4f})")
print(f"passp_correctness: {sd_passp_correctness_mean:0.4f} (\u00B1{sd_passp_correctness_std:0.4f})")
print(f"naive_correctness: {sd_naive_correctness_mean:0.4f} (\u00B1{sd_naive_correctness_std:0.4f})")
print(f"weighted_correctness: {sd_weighted_correctness_mean:0.4f} (\u00B1{sd_weighted_correctness_std:0.4f})")
print(f"maj_correctness: {sd_maj_correctness_mean:0.4f} (\u00B1{sd_maj_correctness_std:0.4f})")

Process ForkPoolWorker-410:
Process ForkPoolWorker-398:
Process ForkPoolWorker-412:
Process ForkPoolWorker-422:
Process ForkPoolWorker-434:
Process ForkPoolWorker-414:
Process ForkPoolWorker-402:
Process ForkPoolWorker-438:
Process ForkPoolWorker-400:
Process ForkPoolWorker-409:
Process ForkPoolWorker-474:
Process ForkPoolWorker-470:
Process ForkPoolWorker-406:
Process ForkPoolWorker-424:
Process ForkPoolWorker-426:
Process ForkPoolWorker-416:
Process ForkPoolWorker-449:
Process ForkPoolWorker-437:
Process ForkPoolWorker-441:
Process ForkPoolWorker-448:
Process ForkPoolWorker-447:
Process ForkPoolWorker-407:
Process ForkPoolWorker-432:
Process ForkPoolWorker-392:
Process ForkPoolWorker-479:
Process ForkPoolWorker-417:
Process ForkPoolWorker-443:
Process ForkPoolWorker-430:
Process ForkPoolWorker-420:
Process ForkPoolWorker-387:
Process ForkPoolWorker-413:
Process ForkPoolWorker-386:
Process ForkPoolWorker-460:
Process ForkPoolWorker-429:
Process ForkPoolWorker-472:
Process ForkPoolWork

KeyboardInterrupt: 