In [12]:
from utils_eval import compute_pairwise_metrics, extract_score
import json, os, numpy as np, pandas as pd
from IPython.display import display

eval_fn = "data/finetune_PRGS_test.json"

with open(eval_fn) as f:
    data = json.load(f)

id2data = {d["id"]: d for d in data}

models = set([])
for fn in os.listdir("data/preds"):
    with open(f"data/preds/{fn}") as f:
        model_name = fn.replace(".jsonl", "").replace("preds_", "")
        models.add(model_name)
        for line in f:
            d = json.loads(line)
            id2data[d["id"]]["pred_" + model_name] = d["output"]

data_pairwise = [d for d in data if d["sample_type"] == "pairwise"]
data_reward = [d for d in data if d["sample_type"] == "reward"]
data_gold = [d for d in data if d["sample_type"] == "pairwise-gold"]
data_silver = [d for d in data if d["sample_type"] == "pairwise-silver"]

# print(len(data), len(data_pairwise), len(data_reward), len(data_gold), len(data_silver))
print(f"All: {len(data)}, Pairwise: {len(data_pairwise)}, Reward: {len(data_reward)}, Gold: {len(data_gold)}, Silver: {len(data_silver)}")

results, N_samples = [], []
for model in models:
    if model.endswith("-b") or model.endswith("-c"):
        continue
    N_pairwise = len([d for d in data_pairwise if "pred_" + model in d])
    N_silver = len([d for d in data_silver if "pred_" + model in d])
    N_gold = len([d for d in data_gold if "pred_" + model in d])
    N_reward = len([d for d in data_reward if "pred_" + model in d])

    N_samples.append({"model": model, "N_pairwise": N_pairwise, "N_silver": N_silver, "N_gold": N_gold, "N_reward": N_reward})

    N_errors = 0

    pref1, acc, err = compute_pairwise_metrics(data_pairwise, model)
    pref1_silver, acc_silver, err_silver = compute_pairwise_metrics(data_silver, model)
    pref1_gold, acc_gold, err_gold = compute_pairwise_metrics(data_gold, model)

    N_errors += err + err_silver + err_gold

    y_true = [d["zscore"] for d in data_reward]
    y_pred = []
    for d in data_reward:
        pred, err = extract_score(d, "pred_" + model)
        y_pred.append(pred)
        N_errors += err

    abs_err = np.abs(np.array(y_true) - np.array(y_pred))
    corr = np.corrcoef(y_true, y_pred)[0, 1]
    avg_R = np.mean(y_pred)
    results.append({"model": model, "Acc_P": acc, "Acc_S": acc_silver, "Acc_G": acc_gold, "MAE_R": abs_err.mean(), "Corr_R": corr, "Avg_R": avg_R})

display(pd.DataFrame(results).sort_values(by="Acc_P", ascending=False).set_index("model").round(2))
display(pd.DataFrame(N_samples).sort_values(by="N_pairwise", ascending=False).set_index("model").round(2))

KeyError: 'pairwise-P2-0'

# Subedits Evaluation

In [3]:
from utils_eval import compute_pairwise_metrics
import json, os, pandas as pd

sub_datasets = {}

N_keeps = [1, 2, 3, 4, 5, 6, 7, "all"]

models = set([])
for N_keep in N_keeps:
    eval_fn = f"data/subedits_P{N_keep}_test.json"
    with open(eval_fn) as f:
        sub_datasets[N_keep] = json.load(f)
    id2data = {d["id"]: d for d in sub_datasets[N_keep]}

    for fn in os.listdir(f"data/preds"):
        with open(f"data/preds/{fn}") as f:
            model_name = fn.replace(".jsonl", "").replace("preds_", "")
            lines = list(f.readlines())
            for line in lines:
                d = json.loads(line)
                if d.get("input_fn", "") == eval_fn:
                    original_sample = id2data[d["id"]]
                    original_sample["pred_" + model_name] = d["output"]
                    models.add(model_name)

results = []
for model in models:
    result_row = {"model": model}
    for N_keep in N_keeps:
        pref1, acc, err = compute_pairwise_metrics(sub_datasets[N_keep], model)
        result_row[f"Acc_P{N_keep}"] = acc
    results.append(result_row)

display(pd.DataFrame(results).sort_values(by="Acc_P2", ascending=False).set_index("model").round(2))

Unnamed: 0_level_0,Acc_P1,Acc_P2,Acc_P3,Acc_P4,Acc_P5,Acc_P6,Acc_P7,Acc_Pall
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
lamp-gpt-4o-mini-P,86.05,97.18,98.53,98.97,100.0,100.0,100.0,99.53
lamp-gem-1p5-flash-p-c,82.79,89.2,94.61,94.33,95.38,97.37,100.0,99.53
baseline,48.37,50.7,51.96,55.67,52.02,48.03,52.63,53.02
