In [147]:
import sys

In [148]:
import wandb
import json
import pandas as pd
import glob

sys.path.append("../.")
import direct.openai_ranking

In [149]:
dataset = "imdb"
project = "wm-debug-imdb"
tag = "xmas-sweep2"
m = 256
n = 1024

In [150]:
def get_eval_data(project, tag, filename):
    dfs = []
    for r in wandb.Api().runs(project, filters=dict(tags=tag)):
        if tag in r.tags:
            cfg = json.loads(r.json_config)
            for f in r.files():
                if f.name == filename:
                    print("loading", f.name, "from", r.name)
                    root = "/tmp"
                    f.download(root, replace=True)
                    path = f"{root}/{f.name}"
                    eval_data = json.loads(open(path).read())

                    df = pd.DataFrame(eval_data)
                    df["acquire_pairs_function"] = cfg["exp5"]["value"]["acquire_pairs_function"]
                    df["seed"] = int(cfg["seed"]["value"])
                    df["run_name"] = r.name
                    dfs.append(df)

    return pd.concat(dfs)

In [151]:
def submit_to_oracle(m):
    df = get_eval_data(project, tag, f"evaluation_m{m}_post_training_T0.25.json")
    df = df[df.acquire_pairs_function == "HIGH_ENTROPY_AND_CERTAINTY"]
    sample_df = df.groupby(["acquire_pairs_function", "seed"], group_keys=False).apply(lambda x: x.sample(n))
    batch = sample_df[["prompts", "completions", "vs_completions"]].rename(
    columns={"completions": "completion_a", "vs_completions": "completion_b", "prompts": "prompt"}).to_dict("records")
    
    oracle_response = direct.openai_ranking.get_preference_batch(batch, "gpt-4-1106-preview", None, 10, dataset, provider="openai")
    cost = sum([r['cost'] for r in oracle_response])
    print(f"That cost ~ {cost} USD")
    sample_df["win"] = [r['preferred'] == 0 for r in oracle_response]  
    return sample_df


In [1]:
# for m in [128, 256, 512, 768]:
#     s = submit_to_oracle(m)
#     s.to_csv(f"../results/post-eval-winrate-{dataset}-m{m}-{tag}.csv", index=False)

In [None]:
for m in [128, 256, 512]:
    s = submit_to_oracle(m)
    s.to_csv(f"../results/post-eval-winrate-{dataset}-m{m}-{tag}.csv", index=False)