In [1]:
import os, psutil, gc
import time 
import json
import pprint

from collections import defaultdict
import random

import numpy as np
from scipy.stats import ttest_rel

In [2]:
from sal.config import Config

from core import best_of_n
from utils.load_data import load_data_prm800k
from utils import grader 

In [3]:
# base_dir
base_dir = '/groups/kjun/tnn/datasets/'

# dataset path
data_dir = base_dir + "/prm800k/math_splits"

# llm and prm path
llm_dir = base_dir + "/Llama-3.2-1B-Instruct-GGUF/Llama-3.2-1B-Instruct.Q4_K_M.gguf"
prm_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data-GGUF/Llama3.1-8B-PRM-Deepseek-Data.Q4_K_M.gguf"

llm_tokenizer_dir = base_dir + "/Llama-3.2-1B-Instruct"
prm_tokenizer_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data"

In [4]:
#  load data 
data_by_levels = load_data_prm800k(data_dir)


# ds_completions = load_completions(completions_dir)

# load random_seeds     
# random_seeds = np.loadtxt("random_seeds.txt").astype("int64")
# random_seeds = [int(seed) for seed in random_seeds]

1: 43
2: 90
3: 105
4: 128
5: 134


In [5]:
level = '4'
num_questions = len(data_by_levels[level])
# num_questions = 1
num_trials = 1
num_budgets = 8
print(f"num_questions = {num_questions}")

num_questions = 128


In [57]:
def evaluate_correctness(data_dir, data_by_levels, num_trials, num_budgets=None):
    with open(data_dir, 'r', encoding='utf-8') as fin:
        all_correctness = []
        best_correctness = []
        best_completions = []
        best_prm_scores = []
        pred_answers = []
        gt_answers = []
        trial_idx = 0
        for line in fin:
            if trial_idx >= num_trials:
                break
                
            trial_data = json.loads(line)
            for q_idx in range(len(data_by_levels)):
                if num_budgets is not None:
                    completions = trial_data['completions'][q_idx][:num_budgets]
                    best_idx = np.argmax(trial_data['agg_scores'][q_idx][:num_budgets])
                else:
                    completions = trial_data['completions'][q_idx]
                    best_idx = np.argmax(trial_data['agg_scores'][q_idx])
                best_completion = trial_data['completions'][q_idx][best_idx]
                pred_answer = grader.extract_last_boxed_answer(best_completion)
                gt_answer = data_by_levels[q_idx]['answer']
                is_correct = False
                for cidx, completion in enumerate(completions):
                    c_answer = grader.extract_last_boxed_answer(completion)
                    if grader.grade_answer(c_answer, gt_answer):
                        is_correct = True
                        break
                best_is_correct = grader.grade_answer(pred_answer, gt_answer)
                # if best_is_correct != is_correct and q_idx > 22:
                #     print(f"\n-> question {q_idx}")
                #     print(f"question : {data_by_levels[q_idx]['problem']}")
                #     print(f"pred answer: {pred_answer}")
                #     print(f"gt answer: {gt_answer}")
                #     print(f"is correct: {is_correct}")
                #     print(f"is correct (best): {best_is_correct}")
                #     # print(best_completion)
                #     print(trial_data['agg_scores'][q_idx][:num_budgets])
                #     print(best_idx)
                #     for cidx, completion in enumerate(trial_data['completions'][q_idx][:num_budgets]):
                #         print(f"cidx = {cidx}")
                #         print(completion)
                # is_correct = grader.grade_answer(pred_answer, gt_answer)

                # print(len(trial_data['completions'][q_idx]))
                # if pred_answer is None and q_idx > 22:
                #     print(f"\n-> question {q_idx}")
                #     print(f"question : {data_by_levels[q_idx]['problem']}")
                #     print(f"pred answer: {pred_answer}")
                #     print(f"gt answer: {gt_answer}")
                #     print(f"is correct: {is_correct}")
                #     # print(best_completion)
                #     print(trial_data['agg_scores'][q_idx][:num_budgets])
                #     print(best_idx)
                #     for cidx, completion in enumerate(trial_data['completions'][q_idx][:num_budgets]):
                #         print(f"cidx = {cidx}")
                #         print(completion)
                # print(f"all scores = {trial_data['agg_scores'][q_idx]}")
                # print(f"best score = {trial_data['agg_scores'][q_idx][best_idx]}")
                all_correctness.append(is_correct)
                best_correctness.append(best_is_correct)
                # results.append(trial_data['agg_scores'][q_idx][best_idx])
                best_completions.append(best_completion)
                pred_answers.append(pred_answer)
                gt_answers.append(gt_answer)
                best_prm_scores.append(trial_data['agg_scores'][q_idx][best_idx])

        trial_idx += 1
        
    return all_correctness, best_correctness, best_completions, best_prm_scores, pred_answers, gt_answers


# general params
config = Config()
config.agg_strategy = 'last'
config.n = 8
config.beam_width = 4
config.lookahead = 0
config.num_iterations = 1
config.sort_completed = False

# diverse_select params
config.lam = 10
config.normalize_embeds = True

bon_dir = "results/scores_bon_prm800k_level4_n16_v11.jsonl" 
sd_dir = f"results/scores_sd_prm800k_level{level}_n{config.n}_bw{config.beam_width}_depth{config.num_iterations}_lam{config.lam}_v11.jsonl"

bon_correctness, bon_best_correctness, bon_best_completions, bon_best_prm_scores, bon_pred_answers, bon_gt_answer = \
    evaluate_correctness(bon_dir, data_by_levels[level], num_trials, config.n)
sd_correctness, sd_best_correctness, sd_best_completions, sd_best_prm_scores, sd_pred_answers, sd_gt_answer = \
    evaluate_correctness(sd_dir, data_by_levels[level], num_trials)
# print(sd_results)
# print(sd_best_completions)
# print(sd_pred_answers)
# print(sd_gt_answer)
# print(f"bon_score = {np.mean(bon_correctness)}")
# print(f"sd_score = {np.mean(sd_correctness)}")
print(f"bon_score = {np.mean(bon_best_correctness)}")
print(f"sd_score = {np.mean(sd_best_correctness)}")

# print(bon_correctness)
# print(bon_best_correctness)
num_differences = np.sum(np.array(bon_correctness) != np.array(bon_best_correctness))
print(num_differences)
# print(np.sum(sd_correctness != sd_best_correctness))

t_stat, p_value = ttest_rel(np.array(bon_correctness).astype(int), np.array(sd_correctness).astype(int))
t_stat, p_value = ttest_rel(np.array(bon_best_correctness).astype(int), np.array(sd_best_correctness).astype(int))
print(t_stat)
print(p_value)

bon_score = 0.328125
sd_score = 0.3359375
32
-0.21740424873680092
0.8282422121612354
