In [14]:
from utils_eval import compute_pairwise_metrics, extract_score
import json, os, numpy as np, pandas as pd
from IPython.display import display

def evaluate_models(eval_fn):

    with open(eval_fn) as f:
        data = json.load(f)

    id2data = {d["id"]: d for d in data}

    models = set([])
    for fn in os.listdir("data/preds"):
        with open(f"data/preds/{fn}") as f:
            model_name = fn.replace(".jsonl", "").replace("preds_", "")
            if "gem-1p5" in model_name:
                continue
            models.add(model_name)
            for line in f:
                d = json.loads(line)
                if d["input_fn"] != eval_fn:
                    continue
                id2data[d["id"]]["pred_" + model_name] = d["output"]

    sample_types = {}
    for d in data:
        if d["sample_type"] not in sample_types:
            sample_types[d["sample_type"]] = []
        sample_types[d["sample_type"]].append(d)

    results, N_samples = [], []
    total_N_samples = {"model": "Total"}
    for sample_type in sorted(sample_types):
        total_N_samples[sample_type] = len(sample_types[sample_type])
    N_samples.append(total_N_samples)

    # remove models that don't have any samples annotated
    models = [model for model in models if any("pred_" + model in d for d in data)]

    for model in models:
        N_samples_row = {"model": model}
        result_row = {"model": model}
        for sample_type in sorted(sample_types):
            model_samples = [d for d in sample_types[sample_type] if "pred_" + model in d]
            if len(model_samples) == 0:
                continue
            N_samples_row[sample_type] = len(model_samples)
            if sample_type.startswith("pairwise"):
                pref1, acc, err = compute_pairwise_metrics(model_samples, model)
                result_row[sample_type] = acc
            else:
                y_true = [d["zscore"] for d in model_samples]
                y_pred = []
                for d in model_samples:
                    pred, err = extract_score(d, "pred_" + model)
                    y_pred.append(pred)
                abs_err = np.abs(np.array(y_true) - np.array(y_pred))
                corr = np.corrcoef(y_true, y_pred)[0, 1]
                avg_R = np.mean(y_pred)
                result_row[sample_type+"_MAE_R"] = abs_err.mean()
                result_row[sample_type+"_Corr_R"] = corr
                result_row[sample_type+"_Avg_R"] = avg_R
        N_samples.append(N_samples_row)

        results.append(result_row)

    # add the model eval_fn as header to the results
    print(eval_fn.center(80, "-"))
    display(pd.DataFrame(results).sort_values(by="pairwise", ascending=False).set_index("model").round(2))
    display(pd.DataFrame(N_samples).set_index("model").round(2))

evaluate_models(eval_fn="data/lamp_PRGS_test.json")
evaluate_models(eval_fn="data/lamp_PR_editor_test.json")

----------------------------data/lamp_PRGS_test.json----------------------------


Unnamed: 0_level_0,pairwise,pairwise-P1,pairwise-P2,pairwise-P3,pairwise-P4,pairwise-P5,pairwise-P6,pairwise-P7,pairwise-gold,pairwise-silver,reward_MAE_R,reward_Corr_R,reward_Avg_R
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
lamp-4o-p,100.0,92.09,96.74,98.56,100.0,100.0,100.0,99.28,73.55,99.82,1.5,0.38,5.74
lamp-4o-mini-p,99.01,88.84,93.49,95.69,97.99,98.91,100.0,99.28,72.72,99.91,3.27,0.28,2.51
llama_3.2_1b_results_parallel,92.57,77.21,76.28,80.86,80.9,90.16,91.19,93.48,69.15,98.39,1.34,0.43,5.05
llama_3.2_1b_results_cosine,89.6,71.63,76.28,77.99,78.89,84.7,93.08,86.96,69.98,98.75,1.4,0.42,5.27
llama_3.2_1b_results_constant,86.14,70.7,77.21,75.12,76.38,81.42,84.28,82.61,65.34,95.71,1.38,0.37,4.89
gemini-1.5-flash,22.77,26.05,28.37,20.57,18.09,20.77,18.87,18.84,38.56,16.88,3.88,-0.19,8.83
gpt-4o-2024-08-06,17.82,23.72,21.86,15.31,14.57,13.66,15.72,10.14,37.98,6.88,3.5,-0.05,8.45
gpt-4o-mini,13.12,31.63,19.53,12.92,10.55,9.84,5.03,7.25,34.25,4.29,3.49,-0.11,8.44


Unnamed: 0_level_0,pairwise,pairwise-P1,pairwise-P2,pairwise-P3,pairwise-P4,pairwise-P5,pairwise-P6,pairwise-P7,pairwise-gold,pairwise-silver,reward
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Total,404,215,215,209,199,183,159,138,1206,1120,430
llama_3.2_1b_results_parallel,404,215,215,209,199,183,159,138,1206,1120,430
lamp-4o-p,404,215,215,209,199,183,159,138,1206,1120,430
lamp-4o-mini-p,404,215,215,209,199,183,159,138,1206,1120,430
gpt-4o-2024-08-06,404,215,215,209,199,183,159,138,1206,1120,430
gpt-4o-mini,404,215,215,209,199,183,159,138,1206,1120,430
llama_3.2_1b_results_cosine,404,215,215,209,199,183,159,138,1206,1120,430
gemini-1.5-flash,404,215,215,209,199,183,159,138,1206,1120,430
llama_3.2_1b_results_constant,404,215,215,209,199,183,159,138,1206,1120,430


-------------------------data/lamp_PR_editor_test.json--------------------------


Unnamed: 0_level_0,pairwise,pairwise-P1,pairwise-P2,pairwise-P3,pairwise-P4,pairwise-P5,pairwise-P6,pairwise-P7,reward_MAE_R,reward_Corr_R,reward_Avg_R
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
lamp-4o-mini-p-editor,98.64,84.47,92.83,97.16,98.8,98.83,99.6,99.46,3.61,-0.42,3.9
gpt-4o-mini-2024-07-18,9.43,32.62,18.44,11.6,6.01,4.69,4.38,3.76,3.26,-0.07,8.48


Unnamed: 0_level_0,pairwise,pairwise-P1,pairwise-P2,pairwise-P3,pairwise-P4,pairwise-P5,pairwise-P6,pairwise-P7,reward
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
Total,880,515,488,457,416,341,251,186,1030
lamp-4o-mini-p-editor,880,515,488,457,416,341,251,186,1030
gpt-4o-mini-2024-07-18,880,515,488,457,416,341,251,186,1030
