In [None]:
import csv
import json
import os
import pdb
import torch

from tqdm import tqdm

from judges.pair_judge import PairJudge
from utils import (
    DATA_ROOT,
    load_model_tokenizer,
)

NUM_DATA = 100
NUM_RESPONSES = 10
BATCH_SIZE = 16

NUM_EPOCHS = 10

NUM_TOP = 2

PREF_MODEL = "vectorzhou/gemma-2-2b-it-preference_dataset_mixture2_and_safe_pku-Preference"

BASE_NAMES = {
    "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT-PKU-SafeRLHF-OnlineIPO1-lora-0227213453": "\\oipoone",
    "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT-PKU-SafeRLHF-OnlineIPO2-lora-0227214805": "\\oipotwo",
    "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT-PKU-SafeRLHF-NashMD-lora-0227215018": "\\nmd",
    "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT-PKU-SafeRLHF-NashMDPG-lora-0301154042": "\\nmdpg",
    "dummy/new_MPO": "\\mpo",
    "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT-PKU-SafeRLHF-Extragradient-lora-0224142549": "\\eg",
}
REF_MODEL = "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT"

ALL_MODEL_DICT = BASE_NAMES.copy()
ALL_MODEL_DICT.update({REF_MODEL: "$\\pi_\\tref$"})

def get_pref(compare_results, model1, model2, judge_type):
    model1 = model1.split("/")[-1]
    model2 = model2.split("/")[-1]

    try:
        pair_results = compare_results[model1][model2]
        flip = False
    except:
        pair_results = compare_results[model2][model1]
        flip = True

    avg = sum([sum([x[judge_type] for x in pair_results[k]["results"]]) for k in range(NUM_DATA)]) / (NUM_DATA * NUM_RESPONSES)
    if flip:
        avg = 1 - avg
    
    return avg

In [None]:
fn = os.path.join("eval_results", "compare.json")
with open(fn) as f:
    compare_results = json.load(f)

all_models = {}

for base_name in BASE_NAMES:
    win_rates = {}
    for epoch in range(NUM_EPOCHS):
        model1 = f"{base_name}-epoch-{epoch + 1}"
        win_rates[epoch] = get_pref(compare_results, model1, REF_MODEL, "pref_model")
    sorted_win_rates = sorted(win_rates.items(), key=lambda x: x[1], reverse=True)
    picked_ckpts = [sorted_win_rates[k][0] for k in range(NUM_TOP)]
    picked_ckpts.sort()

    all_models[base_name] = picked_ckpts

for k, v in all_models.items():
    print(k, v)


In [None]:
model2_list = [REF_MODEL] + list(all_models.keys())

# write first line
print("\\begin{tabular}{" + "cc|c" + ("|" + "c" * NUM_TOP) * (len(model2_list) - 1) + "}")
print("\\toprule")
print("ALG &  & \multirow{2}{*}{$\\pi_\\tref$} ", end="")
for i in range(1, len(model2_list)):
    model = model2_list[i]
    sep = "|" if i != len(model2_list) - 1 else ""
    print(f"& \multicolumn{{{NUM_TOP}}}{{c{sep}}}{{{ALL_MODEL_DICT[model]}}} ", end="")
print("\\\\")
print(" & Ep & ", end="")
for model in model2_list[1:]:
    for epoch in all_models[model]:
        print(f"& {epoch + 1} ", end="")
print("\\\\")
print("\\hhline{" + "=" * (2 + 1 + NUM_TOP * (len(model2_list) - 1)) + "}")

for i, model1 in enumerate(all_models.keys()):
    if i > 0:
        print("\\hline")
    print(f"\multirow{{{NUM_TOP}}}{{*}}{{{BASE_NAMES[model1]}}}", end="")
    for epoch in all_models[model1]:
        print(f" & {epoch + 1}", end="")
        for model2 in model2_list:
            if model2 == REF_MODEL:
                epochs = [None]
            else:
                epochs = all_models[model2]

            for epoch2 in epochs:
                if model1 == model2:
                    print(" &", end="")
                    continue

                model2_name = f"{model2}-epoch-{epoch2 + 1}" if epoch2 is not None else model2

                avg = get_pref(compare_results, f"{model1}-epoch-{epoch + 1}", model2_name, "pref_model")

                if avg > 0.5:
                    bfl, bfr = "\\red{\\boldsymbol{", "}}"
                else:
                    bfl, bfr = "", ""
                
                print(" & $" + f"{bfl}{avg:.1%}{bfr}$".replace("%", "\\%"), end="")
        print("\\\\")

print("\\bottomrule")
print("\\end{tabular}")

In [None]:
# write first line
print("\\begin{tabular}{" + "cc|c||" * (len(BASE_NAMES) - 1) + "cc|c}")
print("\\toprule")
for i in range(len(BASE_NAMES)):
    print(("& " if i > 0 else "") + "ALG & Ep & $\\pi_\\tref$ ", end="")
print("\\\\")
print("\hhline{" + "=" * (3 * len(BASE_NAMES)) + "}")


for epoch in range(NUM_EPOCHS):
    for i, model1 in enumerate(BASE_NAMES):
        if i > 0:
            print(" & ", end="")
        if epoch == 0:
            print(f"\multirow{{{NUM_EPOCHS}}}{{*}}{{{BASE_NAMES[model1]}}}", end="")
        print(f" & {epoch + 1}", end="")

        avg = get_pref(compare_results, f"{model1}-epoch-{epoch + 1}", REF_MODEL, "pref_model")

        if epoch in all_models[model1]:
            bfl, bfr = "\\red{\\boldsymbol{", "}}"
        else:
            bfl, bfr = "", ""
        
        print(" & $" + f"{bfl}{avg:.1%}{bfr}$".replace("%", "\\%"), end="")
    print("\\\\")

print("\\bottomrule")
print("\\end{tabular}")