In [62]:
import csv
import json
import os
import pdb
import torch

from tqdm import tqdm

from judges.pair_judge import PairJudge
from utils import (
    DATA_ROOT,
    load_model_tokenizer,
)

NUM_DATA = 100
NUM_RESPONSES = 10
BATCH_SIZE = 16

NUM_EPOCHS = 10

NUM_TOP = 2

PREF_MODEL = "vectorzhou/gemma-2-2b-it-preference_dataset_mixture2_and_safe_pku-Preference"

BASE_NAMES = {
    "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT-PKU-SafeRLHF-OnlineIPO1-lora-0227213453": "oipoone",
    "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT-PKU-SafeRLHF-OnlineIPO2-lora-0227214805": "oipotwo",
    "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT-PKU-SafeRLHF-NashMD-lora-0227215018": "nmd",
    "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT-PKU-SafeRLHF-NashMDPG-lora-0301154042": "nmdpg",
    "dummy/new_MPO": "mpo",
}
REF_MODEL = "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT"

MAIN_MODEL = "vectorzhou/gemma-2-2b-it-alpaca-cleaned-SFT-PKU-SafeRLHF-Extragradient-lora-0224142549"
MAIN_SHORT = "eg"

ALL_MODEL_DICT = BASE_NAMES.copy()
ALL_MODEL_DICT.update({REF_MODEL: "pitref"})

def get_pref(compare_results, model1, model2, judge_type):
    model1 = model1.split("/")[-1]
    model2 = model2.split("/")[-1]

    try:
        pair_results = compare_results[model1][model2]
        flip = False
    except:
        pair_results = compare_results[model2][model1]
        flip = True

    avg = sum([sum([x[judge_type] for x in pair_results[k]["results"]]) for k in range(NUM_DATA)]) / (NUM_DATA * NUM_RESPONSES)
    if flip:
        avg = 1 - avg
    
    return avg

def get_all_prefs(compare_results, model1, model2, judge_type):
    model1 = model1.split("/")[-1]
    model2 = model2.split("/")[-1]

    try:
        pair_results = compare_results[model1][model2].copy()
        flip = False
    except:
        pair_results = compare_results[model2][model1].copy()
        flip = True
    
    pair_results = [[y[judge_type] for y in x["results"]] for x in pair_results]
    if flip:
        pair_results = [1 - x for x in pair_results]
    
    return pair_results

In [None]:
fn = os.path.join("eval_results", "compare.json")
with open(fn) as f:
    compare_results = json.load(f)

all_models = {}

for base_name in list(BASE_NAMES.keys()) + [MAIN_MODEL]:
    win_rates = {}
    for epoch in range(NUM_EPOCHS):
        model1 = f"{base_name}-epoch-{epoch + 1}"
        win_rates[epoch] = get_pref(compare_results, model1, REF_MODEL, "pref_model")
    sorted_win_rates = sorted(win_rates.items(), key=lambda x: x[1], reverse=True)
    picked_ckpts = [sorted_win_rates[k][0] for k in range(NUM_TOP)]
    picked_ckpts.sort()

    all_models[base_name] = picked_ckpts

for k, v in all_models.items():
    print(k, v)


In [64]:
model2_list = [REF_MODEL] + list(BASE_NAMES.keys())

all_prefs = {}
for epoch in all_models[MAIN_MODEL]:
    all_prefs[epoch] = [[{m:{"pref": -1, "epoch": -1} for m in model2_list} for j in range(NUM_RESPONSES)] for i in range(NUM_DATA)]
    for model2 in model2_list:
        if model2 == REF_MODEL:
            epochs = [None]
        else:
            epochs = all_models[model2]

        for epoch2 in epochs:
            model2_name = f"{model2}-epoch-{epoch2 + 1}" if epoch2 is not None else model2

            prefs = get_all_prefs(compare_results, f"{MAIN_MODEL}-epoch-{epoch + 1}", model2_name, "pref_model")

            for i in range(NUM_DATA):
                for j in range(NUM_RESPONSES):
                    if prefs[i][j] > all_prefs[epoch][i][j][model2]["pref"]:
                        all_prefs[epoch][i][j][model2] = {"pref": prefs[i][j], "epoch": epoch2}


In [65]:
responses = {}

for model in list(all_models.keys()) + [REF_MODEL]:
    if model == REF_MODEL:
        epochs = [None]
    else:
        epochs = all_models[model]

    for epoch in epochs:
        model_name = f"{model}-epoch-{epoch+1}" if epoch is not None else model

        _model_name = model_name.split("/")[-1]
        response_path = os.path.join("eval_results", "generation", f"{_model_name}.json")
        with open(response_path) as f:
            responses[model_name] = json.load(f)


In [66]:
def cut(text, maxline=10, maxchar=500):
    lines = text.split("\n")
    if len(lines) > maxline:
        lines = lines[:maxline]
    text = "\n".join(lines)

    if len(text) > maxchar:
        text = text[:maxchar] + "\n..."
    
    return text


In [None]:
used_data = []

for _ in range(5):
    for epoch in all_models[MAIN_MODEL]:
        max_pref, max_i, max_j = -1, -1, -1
        for i in range(NUM_DATA):
            for j in range(NUM_RESPONSES):
                avg_pref = sum([all_prefs[epoch][i][j][m]["pref"] for m in model2_list]) / len(model2_list)
                if avg_pref > max_pref and i not in used_data:
                    max_pref = avg_pref
                    max_i = i
                    max_j = j
        
        # print(f"{MAIN_MODEL}-epoch-{epoch+1}: Data #{max_i}, Response #{max_j}", max_pref)

        
        main_model = f"{MAIN_MODEL}-epoch-{epoch+1}"

        print(f"\\begin{{filecontents*}}{{prompt{epoch}}}")
        print(responses[main_model][max_i]['prompt'])
        print("\end{filecontents*}")
        print("")

        for model2 in model2_list:
            epoch2 = all_prefs[epoch][max_i][max_j][model2]["epoch"]
            model_name = f"{model2}-epoch-{epoch2+1}" if epoch2 is not None else model2
            print(f"\\begin{{filecontents*}}{{{ALL_MODEL_DICT[model2]}{epoch}}}")
            print(cut(responses[model_name][max_i]['responses'][max_j]))
            print("\end{filecontents*}")
            print("")
        
        
        print(f"\\begin{{filecontents*}}{{{MAIN_SHORT}{epoch}}}")
        print(cut(responses[main_model][max_i]['responses'][max_j]))
        print("\end{filecontents*}")
        print("")

        used_data.append(max_i)
    