In [20]:
from sklearn.metrics import accuracy_score
import json, os, numpy as np, pandas as pd
from IPython.display import display

eval_fn = "data/finetune_PRGS_test.json"

with open(eval_fn) as f:
    data = json.load(f)

id2data = {d["id"]: d for d in data}

models = set([])
for fn in os.listdir("data/preds"):
    with open(f"data/preds/{fn}") as f:
        model_name = fn.replace(".jsonl", "").replace("preds_", "")
        models.add(model_name)
        for line in f:
            d = json.loads(line)
            id2data[d["id"]]["pred_" + model_name] = d["output"]

data_pairwise = [d for d in data if d["sample_type"] == "pairwise"]
data_reward = [d for d in data if d["sample_type"] == "reward"]
data_gold = [d for d in data if d["sample_type"] == "pairwise-gold"]
data_silver = [d for d in data if d["sample_type"] == "pairwise-silver"]

# print(len(data), len(data_pairwise), len(data_reward), len(data_gold), len(data_silver))
print(f"All: {len(data)}, Pairwise: {len(data_pairwise)}, Reward: {len(data_reward)}, Gold: {len(data_gold)}, Silver: {len(data_silver)}")

def extract_preference(d, pred_key):
    try:
        return int(d[pred_key]["preference"]), 0
    except:
        return 0, 1

def extract_score(d, pred_key):
    try:
        return d[pred_key]["score"], 0
    except:
        return 0, 1

def compute_pairwise_metrics(data):
    err = 0
    y_true = [int(d["reference_preference"]) for d in data]

    y_pred = []
    for d in data:
        pred, err = extract_preference(d, "pred_" + model)
        y_pred.append(pred)
        err += err

    pref1 = 100.0 * len([p for p in y_pred if p == 1]) / len(y_pred)
    acc = 100.0 * accuracy_score(y_true, y_pred)
    return pref1, acc, err

results, N_samples = [], []
for model in models:
    N_pairwise = len([d for d in data_pairwise if "pred_" + model in d])
    N_silver = len([d for d in data_silver if "pred_" + model in d])
    N_gold = len([d for d in data_gold if "pred_" + model in d])
    N_reward = len([d for d in data_reward if "pred_" + model in d])

    N_samples.append({"model": model, "N_pairwise": N_pairwise, "N_silver": N_silver, "N_gold": N_gold, "N_reward": N_reward})

    N_errors = 0

    pref1, acc, err = compute_pairwise_metrics(data_pairwise)
    pref1_silver, acc_silver, err_silver = compute_pairwise_metrics(data_silver)
    pref1_gold, acc_gold, err_gold = compute_pairwise_metrics(data_gold)

    N_errors += err + err_silver + err_gold

    y_true = [d["zscore"] for d in data_reward]
    y_pred = []
    for d in data_reward:
        pred, err = extract_score(d, "pred_" + model)
        y_pred.append(pred)
        N_errors += err

    abs_err = np.abs(np.array(y_true) - np.array(y_pred))
    corr = np.corrcoef(y_true, y_pred)[0, 1]
    avg_R = np.mean(y_pred)
    results.append({"model": model, "Acc_P": acc, "Acc_S": acc_silver, "Acc_G": acc_gold, "MAE_R": abs_err.mean(), "Corr_R": corr, "Avg_R": avg_R})

display(pd.DataFrame(results).sort_values(by="Acc_P", ascending=False).set_index("model").round(2))
display(pd.DataFrame(N_samples).sort_values(by="N_pairwise", ascending=False).set_index("model").round(2))

All: 3160, Pairwise: 404, Reward: 430, Gold: 1206, Silver: 1120


  c /= stddev[:, None]
  c /= stddev[None, :]


Unnamed: 0_level_0,Acc_P,Acc_S,Acc_G,MAE_R,Corr_R,Avg_R
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
lamp-gem-1p5-flash-p-c,99.26,100.0,73.13,2.63,-0.46,3.54
lamp-gem-1p5-flash-p,99.26,100.0,73.88,2.14,-0.18,3.75
lamp-gem-1p5-flash-p-b,99.26,100.0,73.96,2.32,-0.43,3.94
lamp-gem-1p5-flash-pr-c,75.99,85.36,64.34,3.71,0.09,2.12
lamp-gem-1p5-flash-pr-b,50.5,51.52,51.99,3.44,0.1,2.59
baseline,50.0,50.0,50.0,1.46,,5.0
lamp-gem-1p5-flash-pr,49.75,56.96,50.75,1.73,0.11,4.88
lamp-gem-1p5-flash-r-b,45.54,58.48,50.83,1.72,0.07,4.22
lamp-gem-1p5-flash-r,43.07,55.98,49.17,1.83,0.07,4.14
lamp-gem-1p5-flash-r-c,35.15,46.88,49.0,1.92,-0.07,4.53


Unnamed: 0_level_0,N_pairwise,N_silver,N_gold,N_reward
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
lamp-gem-1p5-flash-p-c,404,1120,1206,430
gpt-4o-mini,404,1120,1206,430
lamp-gem-1p5-flash-r,404,1120,1206,430
lamp-gem-1p5-flash-p,404,1120,1206,430
lamp-gem-1p5-flash-pr-b,404,1120,1206,430
lamp-gem-1p5-flash-r-c,404,1120,1206,430
lamp-gem-1p5-flash-p-b,404,1120,1206,430
baseline,404,1120,1206,430
gemini-1.5-flash,404,1120,1206,430
lamp-gem-1p5-flash-pr-c,404,1120,1206,430
