In [12]:
from sklearn.metrics import accuracy_score
import json, os, numpy as np, pandas as pd

eval_fn = "data/finetune_PRGS_test.json"

with open(eval_fn) as f:
    data = json.load(f)

id2data = {d["id"]: d for d in data}

models = set([])
for fn in os.listdir("data/preds"):
    with open(f"data/preds/{fn}") as f:
        model_name = fn.replace(".jsonl", "").replace("preds_", "")
        models.add(model_name)
        for line in f:
            d = json.loads(line)
            id2data[d["id"]]["pred_" + model_name] = d["output"]

data_pairwise = [d for d in data if d["sample_type"] == "pairwise"]
data_silver = [d for d in data if d["sample_type"] == "pairwise-silver"]
data_gold = [d for d in data if d["sample_type"] == "pairwise-gold"]
data_reward = [d for d in data if d["sample_type"] == "reward"]

print(len(data), len(data_pairwise), len(data_reward))

def extract_preference(d, pred_key):
    try:
        return int(d[pred_key]["preference"]), 0
    except:
        return 0, 1

def extract_score(d, pred_key):
    try:
        return d[pred_key]["score"], 0
    except:
        return 0, 1

def compute_pairwise_metrics(data):
    err = 0
    y_true = [int(d["reference_preference"]) for d in data]

    y_pred = []
    for d in data:
        pred, err = extract_preference(d, "pred_" + model)
        y_pred.append(pred)
        err += err

    pref1 = 100.0 * len([p for p in y_pred if p == 1]) / len(y_pred)
    acc = 100.0 * accuracy_score(y_true, y_pred)
    return pref1, acc, err

results = []
for model in models:
    N_pairwise = len([d for d in data_pairwise if "pred_" + model in d])
    N_silver = len([d for d in data_silver if "pred_" + model in d])
    N_gold = len([d for d in data_gold if "pred_" + model in d])
    N_reward = len([d for d in data_reward if "pred_" + model in d])

    N_errors = 0

    pref1, acc, err = compute_pairwise_metrics(data_pairwise)
    pref1_silver, acc_silver, err_silver = compute_pairwise_metrics(data_silver)
    pref1_gold, acc_gold, err_gold = compute_pairwise_metrics(data_gold)

    N_errors += err + err_silver + err_gold

    y_true = [d["zscore"] for d in data_reward]
    y_pred = []
    for d in data_reward:
        pred, err = extract_score(d, "pred_" + model)
        y_pred.append(pred)
        N_errors += err

    # print(y_true[:5], y_pred[:5])

    abs_err = np.abs(np.array(y_true) - np.array(y_pred))
    corr = np.corrcoef(y_true, y_pred)[0, 1]
    results.append({"model": model, "N_pref": N_pairwise, "Pref_basic_acc": acc, "Pref_basic_choice1": pref1, "N_pref_silver": N_silver, "Pref_silver_acc": acc_silver, "Pref_silver_choice1": pref1_silver, "N_pref_gold": N_gold, "Pref_gold_acc": acc_gold, "Pref_gold_choice1": pref1_gold, "Rew_abs_err": abs_err.mean(), "Rew_corr": corr})

pd.DataFrame(results).sort_values(by="Pref_basic_acc", ascending=False).set_index("model").round(2)

2600 404 430


  c /= stddev[:, None]
  c /= stddev[None, :]


Unnamed: 0_level_0,N_pref,Pref_basic_acc,Pref_basic_choice1,N_pref_silver,Pref_silver_acc,Pref_silver_choice1,N_pref_gold,Pref_gold_acc,Pref_gold_choice1,Rew_abs_err,Rew_corr
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
lamp-gem-1p5-flash-p,404,99.5,49.5,560,100.0,50.0,1206,73.71,48.59,2.29,-0.26
lamp-gem-1p5-flash-s,404,79.7,44.06,560,100.0,50.0,1206,69.15,47.51,2.71,0.2
baseline,404,50.0,100.0,560,50.0,100.0,1206,50.0,100.0,1.46,
lamp-gem-1p5-flash-pr,404,49.75,48.27,560,54.64,45.18,1206,50.33,48.26,1.93,0.17
lamp-gem-1p5-flash-ps,184,45.3,24.75,258,46.07,23.04,511,31.18,23.63,3.55,0.16
lamp-gem-1p5-flash-prs,171,42.33,23.02,237,42.32,21.07,476,29.02,20.73,3.41,0.11
lamp-gem-1p5-flash-rs,173,37.62,18.81,239,42.68,21.25,477,26.37,16.09,3.44,0.08
lamp-gem-1p5-flash-r,404,33.17,56.44,560,32.68,46.25,1206,42.12,53.15,1.88,-0.02


In [3]:
from collections import Counter

Counter([k for d in data for k in d.keys()])

Counter({'original_id': 2600,
         'split': 1394,
         'source': 1394,
         'type': 834,
         'sample_type': 1997,
         'paragraph': 430,
         'zscore': 430,
         'text_input': 2600,
         'output': 2600,
         'id': 2600,
         'pred_lamp-gem-1p5-flash-p': 2600,
         'pred_lamp-gem-1p5-flash-r': 2507,
         'pred_lamp-gem-1p5-flash-prs': 599,
         'pred_lamp-gem-1p5-flash-ps': 667,
         'pred_lamp-gem-1p5-flash-pr': 2600,
         'pred_lamp-gem-1p5-flash-s': 2600,
         'pred_lamp-gem-1p5-flash-rs': 626,
         'paragraph1': 2170,
         'paragraph2': 2170,
         'reference_preference': 2170})

In [4]:
Counter([d.get("sample_type", -1) for d in data])

Counter({'reward': 430,
         'pairwise-gold': 603,
         'pairwise': 404,
         'pairwise-silver': 560,
         -1: 603})