In [None]:
KEYS = {
    "eval/mean_log_prob/ftsqts": "$f_{\\text{ts}} q_{\\text{ts}}$",
    "eval/mean_log_prob/fwikiqwiki": "$f_{\\text{wiki}} q_{\\text{wiki}}$",
    "eval/mean_log_prob/ftsqwiki": "$f_{\\text{ts}} q_{\\text{wiki}}$",
    "eval/mean_log_prob/fwikiqts": "$f_{\\text{wiki}} q_{\\text{ts}}$",
}
TR_KEYS = {
    "train/ts": "$f_{\\text{ts}}$",
    "train/wiki": "$f_{\\text{wiki}}$",
}

import pandas as pd 
import wandb
api = wandb.Api()

runs = api.runs("[REDACTED]")

history_list, tr_history_list, name_list = [], [], []
for run in runs:
    history_list.append(run.history(keys=list(KEYS.keys())))
    tr_history_list.append(run.history(keys=list(TR_KEYS.keys())))
    name_list.append(run.name)

runs_df = pd.DataFrame({
    "history": history_list,
    "tr_history": tr_history_list,
    "name": name_list,
})

In [40]:
def parse_args(run_name):
    _ = run_name[run_name.find("occur")+len("occur"):]
    _ = _.split("_")
    in_occur = int(_[0])
    cross_occur = int(_[1])

    _ = run_name[run_name.find("dataseed")+len("dataseed"):]
    _ = _[_.find("seed")+len("seed"):]
    seed = int(_.split("_")[0])

    _ = run_name[run_name.find("dataseed")+len("dataseed"):]
    dataseed = int(_.split("_")[0])

    return {
        "in_occur": in_occur,
        "cross_occur": cross_occur,
        "seed": seed,
        "dataseed": dataseed,
    }

In [None]:
import os
import yaml
import json
import matplotlib.pyplot as plt
import numpy as np
import scipy.optimize as opt

def gather_data_for_plots(runs_df):
    plot_data = {k: {} for k in list(KEYS.keys())+list(TR_KEYS.keys())}
    
    for i in range(len(runs_df.index)):
        for df_key in ["history", "tr_history"]:
            run_data = runs_df.iloc[i][df_key]

            try:
                args = parse_args(runs_df.iloc[i]["name"])
            except:
                continue
            
            if args["cross_occur"] != 0:
                ratio = args["in_occur"] / args["cross_occur"]
            else:
                ratio = np.inf
            seed = args["seed"]
            
            keys = list(KEYS.keys()) if df_key == "history" else list(TR_KEYS.keys())

            for key in keys:
                if seed not in plot_data[key]:
                    plot_data[key][seed] = {}
                val = run_data[key].tail(10).mean()
                if df_key == "tr_history":
                    val = -val
                plot_data[key][seed][ratio] = val

    return plot_data

def plot_ratio_points(plot_data):
    plt.figure(figsize=(16, 12))

    table = {}

    for i, key in enumerate(list(TR_KEYS.keys())+list(KEYS.keys()), 1):
        data = plot_data[key]
        plt.subplot(3, 2, i)

        all_data = {}
        inf_y_values = []  # Store y-values for inf ratio cases

        for seed, seed_data in data.items():
            x = list(seed_data.keys())
            y = [seed_data[k] for k in x]
            
            # Identify cases where ratio == inf
            if np.inf in x:
                inf_y_values.extend([seed_data[k] for k in x if k == np.inf])

            # Plot normal scatter points
            plt.scatter(x, y, s=4**2, color="orange", alpha=0.7, edgecolors="none")

            for k, v in seed_data.items():
                if k == np.inf:
                    continue
                if k not in all_data:
                    all_data[k] = []
                all_data[k].append(v)

        x = list(all_data.keys())
        y = [sum(all_data[k]) / len(all_data[k]) for k in x]
        std = [np.std(all_data[k]) for k in x]

        for j in range(len(x)):
            if x[j] <= 4:
                if x[j] not in table:
                    table[x[j]] = {}
                table[x[j]][key] = y[j]
        
        plt.scatter(x, y, color='red', marker="x", linewidth=2)

        y_val = sum(inf_y_values) / len(inf_y_values)
        plt.axhline(y=y_val, linestyle='dashed', color='black', alpha=0.7, label="$r = \infty$")

        if i in [5, 6]:
            # do regression
            x_data = np.array(x).reshape(-1)
            y_data = np.array(y).reshape(-1)
            def sigmoid_model(x, a, b, c):
                return a / (1 + np.exp(-b * (np.log(x) - c)))

            initial_guess = [-0.2, 1/64, np.log(64)]

            # Fit the curve
            params, covariance = opt.curve_fit(sigmoid_model, x_data, y_data, p0=initial_guess)

            # Extract fitted parameters
            a_fit, b_fit, c_fit = params

            # Plot the fitted curve
            x = np.exp(np.linspace(min(np.log(x_data)), max(np.log(x_data)), 100))
            y = sigmoid_model(x, a_fit, b_fit, c_fit)
            label = f'$f(r) = {a_fit:.2f} \sigma({b_fit:.2f} (\\log(r) - {c_fit:.2f}))$'
            plt.plot(x, y, color='blue', alpha=0.5, label=label)

        plt.xscale('log')
        plt.xlabel('Ratio $r$', fontsize=14)
        plt.ylabel('Normalized Log Prob', fontsize=14)
        if key in KEYS:
            plt.title(KEYS[key], fontsize=16)
        else:
            plt.title(TR_KEYS[key], fontsize=16)
        plt.grid(True)

        plt.tick_params(axis='both', which='major', labelsize=14)
        
        if inf_y_values != []:
            plt.legend(fontsize=12)

    plt.tight_layout()
    plt.savefig("ratio.pdf")

    rs = sorted(table.keys())
    for r in rs:
        print(f"$r = {r}$", end = " ")
        for k in list(TR_KEYS.keys())+list(KEYS.keys()):
            print(f"& ${table[r][k]:.3g}$", end = " ")

        print("\\\\ \hline")

plot_data = gather_data_for_plots(runs_df)
plot_ratio_points(plot_data)
