In [1]:
'''
This method evaluates the formatted jsonl file for a single trial. 
'''

import os, psutil, gc
import time 
import json
import pprint

from collections import defaultdict
import random

import numpy as np
np.set_printoptions(precision=4)
from scipy.stats import ttest_rel

In [2]:
from sal.config import Config

from datasets import load_dataset

from core import best_of_n
from utils.load_data import load_data_prm800k
from utils import grader 

In [3]:
# base_dir
base_dir = '/groups/kjun/tnn/datasets/'

# dataset path
data_dir = base_dir + "/prm800k/math_splits"

# llm and prm path
llm_dir = base_dir + "/Llama-3.2-1B-Instruct-GGUF/Llama-3.2-1B-Instruct.Q4_K_M.gguf"
prm_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data-GGUF/Llama3.1-8B-PRM-Deepseek-Data.Q4_K_M.gguf"

llm_tokenizer_dir = base_dir + "/Llama-3.2-1B-Instruct"
prm_tokenizer_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data"

In [4]:
#  load data 
data_by_levels = load_data_prm800k(data_dir)


# ds_completions = load_completions(completions_dir)

# load random_seeds     
# random_seeds = np.loadtxt("random_seeds.txt").astype("int64")
# random_seeds = [int(seed) for seed in random_seeds]

1: 43
2: 90
3: 105
4: 128
5: 134


In [18]:
def evaluate_correctness_hf(data_dir, level, num_budgets=None):

    # dataset = load_dataset(dataset_name, name=config_name, split=dataset_split, cache_dir=data_dir)
    dataset = load_dataset("json", data_files = data_dir, split='train')
    dataset_by_level = dataset.filter(lambda example: example['level'] == int(level))

    all_correctness = np.zeros((1, len(dataset_by_level)))
    for q_idx, data in enumerate(dataset_by_level):
        if num_budgets is not None:
            completions = data["completions"][:num_budgets]
        else:
            completions = data["completions"]

        gt_answer = data['answer']
        
        for cidx, completion in enumerate(completions):
            c_answer = grader.extract_last_boxed_answer(completion)
            is_correct = grader.grade_answer(c_answer, gt_answer)
            if is_correct:
                # print(f"\n-> q_idx = {q_idx}")
                # print(f"gt_answer = {gt_answer}")
                # print(f"c_answer = {c_answer}")
                break

        if is_correct == False:
            print(f"\n-> q_idx = {q_idx}")
            print(f"gt_answer = {gt_answer}")
            for completion in completions:
                print(completion[-20:])
        all_correctness[0,q_idx] = is_correct

    return all_correctness

# general params
config = Config()
config.agg_strategy = 'last'
config.n = 8
config.beam_width = 2
config.lookahead = 0
config.num_iterations = 10
config.sort_completed = False

# diverse_select params
config.lam = 10
config.normalize_embeds = True

level = '4'
num_questions = len(data_by_levels[level])
# num_questions = 1
num_trials = 1
print(f"num_questions = {num_questions}")

bon_dir = "results/generate_bon--n-16--level-4--v11.jsonl" 
sd_dir = "results/generate_sd--n-16--bw-2--depth-2--lam-10--True--seed-0--level-4--v21.jsonl"
sd_dir = "results/generate_sd--n-16--bw-2--depth-40--lam-10--True--seed-0--level-4--v11.jsonl"
sd_dir = "results/beam--n-16--d-40--bw-4--lh-1--level-4--v11--seed-0.jsonl"
sd_dir = "results/beam--n-8--d-40--bw-4--lh-1--dup-True--level-4--v01--trial-1.jsonl"
# sd_dir = f"results/generate_sd_prm800k_level{level}_n{config.n}_bw{config.beam_width}_depth{config.num_iterations}_lam{config.lam}_{config.normalize_embeds}_v11.jsonl"

# sd_dir = f"results/generate_beam_prm800k_level4_n8_bw2_depth40_v11.jsonl"
print(sd_dir)

# bon_correctness = \
#     evaluate_correctness(bon_dir, data_by_levels[level], num_trials, config.n)
sd_correctness = \
    evaluate_correctness_hf(sd_dir, level, None)

# print(bon_correctness)
# print(sd_correctness)
# trial_bon_correctness = np.mean(bon_correctness, axis=1)
trial_sd_correctness = np.mean(sd_correctness, axis=1)
# print(trial_bon_correctness)
print(trial_sd_correctness)

# mean_bon_correctness = np.mean(trial_bon_correctness)
# std_bon_correctness = np.std(trial_bon_correctness, ddof=1)
# error_bon_correctness = std_bon_correctness/np.sqrt(num_trials)

mean_sd_correctness = np.mean(trial_sd_correctness)
std_sd_correctness = np.std(trial_sd_correctness, ddof=1)
error_sd_correctness = std_sd_correctness/np.sqrt(num_trials)

# print(f"bon_score = {mean_bon_correctness:0.4f} (\u00B1{error_bon_correctness:0.2f})")
print(f"sd_score = {mean_sd_correctness:0.4f} (\u00B1{error_sd_correctness:0.2f})")
# print(f"bon_score_per_question = {np.mean(bon_correctness_per_question):0.4f}")
# print(f"sd_score_per_question = {np.mean(sd_correctness_per_question):0.4f}")
# print(sd_correctness_per_question[:10])
# print(bon_correctness_per_question[:10])


# print(bon_correctness)
# print(bon_best_correctness)
# num_differences = np.sum(np.array(bon_correctness) != np.array(bon_best_correctness))
# print(num_differences)
# print(np.sum(sd_correctness != sd_best_correctness))

# t_stat, p_value = ttest_rel(bon_correctness.flatten().astype(int), sd_correctness.flatten().astype(int))
# print(f"t_stat = {t_stat:0.4f}")
# print(f"p_value = {p_value:0.4f}")

# t_stat, p_value = ttest_rel(trial_bon_correctness, trial_sd_correctness)
# # t_stat, p_value = ttest_rel(np.array(bon_best_correctness).astype(int), np.array(sd_best_correctness).astype(int))
# print(f"t_stat = {t_stat:0.4f}")
# print(f"p_value = {p_value:0.4f}")


num_questions = 0
results/beam--n-8--d-40--bw-4--lh-1--dup-True--level-4--v01--trial-1.jsonl

-> q_idx = 0
gt_answer = 90^\circ
is: $\boxed{119.47}$
is: $\boxed{119.47}$
is: $\boxed{119.47}$
is: $\boxed{119.47}$
er is: $\boxed{111}$
er is: $\boxed{111}$
er is: $\boxed{111}$
is: $\boxed{111.81}$

-> q_idx = 1
gt_answer = \pi
oxed{\frac{\pi}{2}}$
oxed{\frac{\pi}{2}}$
oxed{\frac{\pi}{2}}$
oxed{\frac{\pi}{2}}$
oxed{\frac{\pi}{2}}$
final answer is π/2.
oxed{\frac{\pi}{2}}$
oxed{\frac{\pi}{2}}$

-> q_idx = 4
gt_answer = \frac{243}{625}
= \frac{2205}{567}$.
{\frac{2205}{5625}}$
xed{\frac{176}{45}}$
xed{\frac{176}{45}}$
xed{\frac{176}{45}}$
xed{\frac{176}{45}}$
ed{\frac{146}{375}}$
ed{\frac{146}{375}}$

-> q_idx = 6
gt_answer = 17
swer is: $\boxed{5}$
swer is: $\boxed{5}$
swer is: $\boxed{5}$
swer is: $\boxed{5}$
swer is: $\boxed{5}$
swer is: $\boxed{5}$
swer is: $\boxed{5}$
swer is: $\boxed{5}$

-> q_idx = 8
gt_answer = 3
swer is: $\boxed{4}$
swer is: $\boxed{4}$
swer is: $\boxed{4}$
swer is: