In [1]:
from utils_eval import compute_pairwise_metrics, extract_score
import json, os, numpy as np, pandas as pd
from IPython.display import display

eval_fn = "data/lamp_PRGS_test.json"

with open(eval_fn) as f:
    data = json.load(f)

id2data = {d["id"]: d for d in data}

models = set([])
for fn in os.listdir("data/preds"):
    with open(f"data/preds/{fn}") as f:
        model_name = fn.replace(".jsonl", "").replace("preds_", "")
        models.add(model_name)
        for line in f:
            d = json.loads(line)
            id2data[d["id"]]["pred_" + model_name] = d["output"]

sample_types = {}
for d in data:
    if d["sample_type"] not in sample_types:
        sample_types[d["sample_type"]] = []
    sample_types[d["sample_type"]].append(d)

results, N_samples = [], []
total_N_samples = {"model": "Total"}
for sample_type in sorted(sample_types):
    total_N_samples[sample_type] = len(sample_types[sample_type])
N_samples.append(total_N_samples)

for model in models:
    N_samples_row = {"model": model}
    result_row = {"model": model}
    for sample_type in sorted(sample_types):
        model_samples = [d for d in sample_types[sample_type] if "pred_" + model in d]
        N_samples_row[sample_type] = len(model_samples)
        if sample_type.startswith("pairwise"):
            pref1, acc, err = compute_pairwise_metrics(model_samples, model)
            result_row[sample_type] = acc
        else:
            y_true = [d["zscore"] for d in model_samples]
            y_pred = []
            for d in model_samples:
                pred, err = extract_score(d, "pred_" + model)
                y_pred.append(pred)
            abs_err = np.abs(np.array(y_true) - np.array(y_pred))
            corr = np.corrcoef(y_true, y_pred)[0, 1]
            avg_R = np.mean(y_pred)
            result_row[sample_type+"_MAE_R"] = abs_err.mean()
            result_row[sample_type+"_Corr_R"] = corr
            result_row[sample_type+"_Avg_R"] = avg_R
    N_samples.append(N_samples_row)
    results.append(result_row)

display(pd.DataFrame(results).sort_values(by="pairwise-gold", ascending=False).set_index("model").round(2))
display(pd.DataFrame(N_samples).set_index("model").round(2))

Unnamed: 0_level_0,pairwise,pairwise-P1,pairwise-P2,pairwise-P3,pairwise-P4,pairwise-P5,pairwise-P6,pairwise-P7,pairwise-gold,pairwise-silver,reward_MAE_R,reward_Corr_R,reward_Avg_R
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
lamp-4o-p,100.0,92.09,96.74,98.56,100.0,100.0,100.0,99.28,73.55,99.82,1.5,0.38,5.74
lamp-4o-mini-p,99.01,88.84,93.49,95.69,97.99,98.91,100.0,99.28,72.72,99.91,3.27,0.28,2.51
lamp-gem-1p5-flash-p-b,47.03,50.23,54.88,49.28,51.76,41.53,52.2,47.83,50.58,51.96,2.47,-0.07,7.22
lamp-gem-1p5-flash-p-a,50.25,52.09,49.77,52.15,51.76,43.72,47.8,53.62,50.41,49.11,2.55,-0.02,7.36
lamp-gem-1p5-flash-p-c,55.2,52.56,50.23,62.68,51.76,50.27,42.14,47.1,49.17,52.5,2.57,-0.06,7.41
gemini-1.5-flash,22.77,26.05,28.37,20.57,18.09,20.77,18.87,18.84,38.56,16.88,3.88,-0.19,8.83
gpt-4o-2024-08-06,17.82,23.72,21.86,15.31,14.57,13.66,15.72,10.14,37.98,6.88,3.5,-0.05,8.45
gpt-4o-mini,13.12,31.63,19.53,12.92,10.55,9.84,5.03,7.25,34.25,4.29,3.49,-0.11,8.44


Unnamed: 0_level_0,pairwise,pairwise-P1,pairwise-P2,pairwise-P3,pairwise-P4,pairwise-P5,pairwise-P6,pairwise-P7,pairwise-gold,pairwise-silver,reward
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Total,404,215,215,209,199,183,159,138,1206,1120,430
gpt-4o-mini,404,215,215,209,199,183,159,138,1206,1120,430
lamp-gem-1p5-flash-p-b,404,215,215,209,199,183,159,138,1206,1120,430
lamp-4o-p,404,215,215,209,199,183,159,138,1206,1120,430
lamp-gem-1p5-flash-p-c,404,215,215,209,199,183,159,138,1206,1120,430
gemini-1.5-flash,404,215,215,209,199,183,159,138,1206,1120,430
lamp-gem-1p5-flash-p-a,404,215,215,209,199,183,159,138,1206,1120,430
gpt-4o-2024-08-06,404,215,215,209,199,183,159,138,1206,1120,430
lamp-4o-mini-p,404,215,215,209,199,183,159,138,1206,1120,430


# Subedits Evaluation

In [3]:
from utils_eval import compute_pairwise_metrics
import json, os, pandas as pd

sub_datasets = {}

N_keeps = [1, 2, 3, 4, 5, 6, 7, "all"]

models = set([])
for N_keep in N_keeps:
    eval_fn = f"data/subedits_P{N_keep}_test.json"
    with open(eval_fn) as f:
        sub_datasets[N_keep] = json.load(f)
    id2data = {d["id"]: d for d in sub_datasets[N_keep]}

    for fn in os.listdir(f"data/preds"):
        with open(f"data/preds/{fn}") as f:
            model_name = fn.replace(".jsonl", "").replace("preds_", "")
            lines = list(f.readlines())
            for line in lines:
                d = json.loads(line)
                if d.get("input_fn", "") == eval_fn:
                    original_sample = id2data[d["id"]]
                    original_sample["pred_" + model_name] = d["output"]
                    models.add(model_name)

results = []
for model in models:
    result_row = {"model": model}
    for N_keep in N_keeps:
        pref1, acc, err = compute_pairwise_metrics(sub_datasets[N_keep], model)
        result_row[f"Acc_P{N_keep}"] = acc
    results.append(result_row)

display(pd.DataFrame(results).sort_values(by="Acc_P2", ascending=False).set_index("model").round(2))

Unnamed: 0_level_0,Acc_P1,Acc_P2,Acc_P3,Acc_P4,Acc_P5,Acc_P6,Acc_P7,Acc_Pall
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
lamp-gpt-4o-mini-P,86.05,97.18,98.53,98.97,100.0,100.0,100.0,99.53
lamp-gem-1p5-flash-p-c,82.79,89.2,94.61,94.33,95.38,97.37,100.0,99.53
baseline,48.37,50.7,51.96,55.67,52.02,48.03,52.63,53.02
