In [None]:
# load cached scripts?
load_aunle = True
load_sunle = True
save_res = False

In [None]:
from experiments_utils.fetching import ResultsManager
from sbibm import get_results

import pandas as pd
import jax.numpy as jnp
import pickle

In [None]:
num_simulations_str_map = {"10⁵": 100000, "10⁴": 10000, "10³": 1000}
sbibm_results = get_results()
sbibm_results['num_rounds'] = [10 if m in ("SNRE", "SNPE", "SNLE") else 1 for m in sbibm_results.algorithm.values]
sbr = sbibm_results[[ "task", "num_simulations", "num_observation", "algorithm", "RT",]].copy()
sbr['num_simulations']  =  [num_simulations_str_map[v] for v in sbr.num_simulations.values]
sbr = sbr.set_index([ "task", "num_simulations", "num_observation", "algorithm",])

### SUNLE

In [None]:
if load_sunle:
    with open("df_sunle_cached.pkl", "rb") as f:
        df_sunle_eval = pickle.load(f)
    with open("df_time_sunle_cached.pkl", "rb") as f:
        df_sunle_time = pickle.load(f)
else:
    r = ResultsManager("icml")
    
    # PERFORMANCE
    rs = []
    for t in ("lotka_volterra", "slcp", "gaussian_linear_uniform", "two_moons"):
        print(t)
        eval_result = r.fetch_evaluation_results(
            experience_name="sunle",
            task=t
        )
        rs.append(eval_result)
    df_sunle_eval = pd.concat(rs)
    if save_res:
        with open("df_sunle_cached.pkl", "wb") as f:
            pickle.dump(df_sunle_eval, f)
            
    # TIME
    rs = {}
    for t in ("lotka_volterra", "slcp", "gaussian_linear_uniform", "two_moons"):
        for ns in (
            (100,)*10, (1000,)*10, (10000,)*10,
        ):
            for no in range(1,10):
                for random_seed in (1,2,3):
                    print(t, ns, no, random_seed)
                    try:
                        result = r.fetch_one_result(
                            experience_name="sunle",
                            task=t,
                            random_seed=random_seed,
                            num_observation=no,
                            num_samples=ns,

                        )
                    except Exception as e:
                        print(e)
                        continue

                    rs[(t, ns, no)] = {}
                    rs[(t, ns, no)] = {
                        'training':   sum(r.train_results.time for r in result.result.train_results.single_round_results),
                        'simulation': sum(r.simulation_time for r in result.result.train_results.single_round_results),
                        'inference':   sum(r.inference_time for r in result.result.train_results.single_round_results)
                    }
    df = pd.DataFrame(rs).T
    df.index.names = ["task", "num_simulations", "num_observation"]
    df = df.reset_index()
    df['algorithm'] = ['AUNLE' if len(ns) == 1 else 'SUNLE' for ns in df.num_simulations.values]
    df['num_simulations'] = [sum(ns) for ns in df.num_simulations.values]
    df = df.set_index([c for c in df.columns.values if c not in ("simulation", "inference", "training")])
    df[['training', 'inference', 'simulation']] = df[['training', 'inference', 'simulation']] / 60
    
    df_sunle_time = df
    if save_res:
        with open("df_time_sunle_cached.pkl", "wb") as f:
            pickle.dump(df_sunle_time, f)

### AUNLE

In [None]:
if load_aunle:
    with open("df_aunle_cached.pkl", "rb") as f:
        df_aunle_eval = pickle.load(f)
    with open("df_time_aunle_cached.pkl", "rb") as f:
        df_aunle_time = pickle.load(f)
else:
    r = ResultsManager("icml")
    
    # PERFORMANCE
    rs = []
    for t in ("lotka_volterra", "slcp", "gaussian_linear_uniform", "two_moons"):
        print(t)
        eval_result = r.fetch_evaluation_results(
            experience_name="aunle",
            task=t
        )
        rs.append(eval_result)
    df_aunle_eval = pd.concat(rs)
    if save_res:
        with open("df_aunle_cached.pkl", "wb") as f:
            pickle.dump(df_aunle_eval, f)
            
    # TIME
    rs = {}
    for t in ("lotka_volterra", "slcp", "gaussian_linear_uniform", "two_moons"):
        for ns in (
            (1000,), (10000,), (100000,),
        ):
            for no in range(1,10):
                for random_seed in (1,2,3):
                    print(t, ns, no, random_seed)
                    try:
                        result = r.fetch_one_result(
                            experience_name="aunle",
                            task=t,
                            random_seed=random_seed,
                            num_observation=no,
                            num_samples=ns,

                        )
                    except Exception as e:
                        print(e)
                        continue

                    rs[(t, ns, no)] = {}
                    rs[(t, ns, no)] = {
                        'training':   sum(r.train_results.time for r in result.result.train_results.single_round_results),
                        'simulation': sum(r.simulation_time for r in result.result.train_results.single_round_results),
                        'inference':   sum(r.inference_time for r in result.result.train_results.single_round_results)
                    }
    df = pd.DataFrame(rs).T
    df.index.names = ["task", "num_simulations", "num_observation"]
    df = df.reset_index()
    df['algorithm'] = ['AUNLE' if len(ns) == 1 else 'SUNLE' for ns in df.num_simulations.values]
    df['num_simulations'] = [sum(ns) for ns in df.num_simulations.values]
    df = df.set_index([c for c in df.columns.values if c not in ("simulation", "inference", "training")])
    df[['training', 'inference', 'simulation']] = df[['training', 'inference', 'simulation']] / 60
    
    df_aunle_time = df
    if save_res:
        with open("df_time_aunle_cached.pkl", "wb") as f:
            pickle.dump(df_aunle_time, f)

### SMNLE

In [None]:
with open('results_sm.pkl', 'rb') as f:
    smnle_results_ssm= pickle.load(f)

In [None]:
smnle_results_ssm = dict(smnle_results_ssm)
df_ssm_eval = pd.concat(
    {k: pd.DataFrame(v[0]) for k, v in smnle_results_ssm.items()},
    names=("algorithm", "task", "num_observation", "lr", "num_samples")
)
df_ssm_time = pd.concat(
    {k: pd.DataFrame(v[1], index=[0]) for k, v in smnle_results_ssm.items()},
    names=("algorithm", "task", "num_observation", "lr", "num_samples")
)

df_ssm_time = df_ssm_time.unstack(level=-1).sum(axis=1).to_frame('RT')

In [None]:
df = df_ssm_eval.unstack(level='algorithm').stack(level=0).rename(columns={'SM': 'SMNLE(SM)', 'SSM': 'SMNLE(SSM)'}).stack().unstack(level=-2).swaplevel(-2, -1).swaplevel(-3, -2).swaplevel(-4, -3).swaplevel(-5, -4).sort_index()

In [None]:
df.index = df.index.set_names("num_simulations",level="num_samples")
df.index = df.index.set_levels(["10³", "10⁴", "10⁵"] ,level="num_simulations")

In [None]:
avg_results = df.groupby(["task", "num_simulations", "algorithm"])

In [None]:
new_gs = []
for n, g in avg_results:
    mean_per_lr = g.mmd.groupby(level="lr").mean()
    best_lr = mean_per_lr.index[mean_per_lr.argmin()]
    new_g = g.xs(best_lr, level="lr", drop_level=False)
    new_gs.append(new_g)
    
best_results = pd.concat(new_gs)
best_results = best_results.rename(columns={"c2st": "C2ST", "mmd": "MMD"})

In [None]:
best_results_mean = best_results.groupby(["task", "num_simulations", "algorithm"]).mean()
best_results_std = best_results.groupby(["task", "num_simulations", "algorithm"]).mean()

In [None]:
best_smnle_results = best_results.copy()
best_smnle_results['num_rounds'] = 1

### Plots

In [None]:
def get_metadata(eres):
    num_unique_fields = eres.reset_index().apply(lambda x: len(x.unique()))
    nonunique_fields = list(num_unique_fields[num_unique_fields > 1].index.values)
    if set(nonunique_fields).issubset(set(['num_samples', 'num_observation', 'mmd', 'c2st', 'random_seed', 'learning_rate', 'task', 'max_iter'])):
        return eres.reset_index()[nonunique_fields]
    else:
        nonunique_fields = nonunique_fields.remove('max_iter')
        return eres.reset_index().drop(columns='max_iter')[nonunique_fields]

In [None]:
def format_unique_fielded_df(df, ebm_model_type, task=None):
    df = df.copy()
    
    if ebm_model_type == "likelihood":
        df['algorithm'] = "SUNLE"
    else:
        assert ebm_model_type == "joint_tilted"
        df['algorithm'] = "AUNLE"
        
    if task is not None:
        df['task'] = task
        
    num_simulations_str_map = {100000: "10⁵", 10000: "10⁴", 1000: "10³"}
    
    
    df['num_simulations'] = df.num_samples.apply(lambda x: num_simulations_str_map[sum(x)])
    df['num_rounds'] = df.num_samples.apply(len)
    
    df = df.drop("num_samples", axis=1)
    
    df = df.rename(columns={"c2st": "C2ST", "mmd": "MMD"})
    return df

In [None]:
all_res = pd.concat(
    [format_unique_fielded_df(get_metadata(df_aunle_eval), "joint_tilted"), format_unique_fielded_df(get_metadata(df_sunle_eval), "likelihood")]
)

In [None]:
sbibm_results = get_results()
sbibm_results['num_rounds'] = [10 if m in ("SNRE", "SNPE", "SNLE") else 1 for m in sbibm_results.algorithm.values]
sbibm_results_mini = sbibm_results[['task', 'num_simulations', "num_rounds", 'algorithm', 'num_observation', 'MMD', 'C2ST']]

In [None]:
all_results = pd.concat([sbibm_results_mini, all_res], axis=0)

In [None]:
all_results = pd.concat(
    [all_results, best_smnle_results.reset_index().drop('lr', axis=1)], axis=0)

In [None]:
all_results.algorithm.unique()

In [None]:
ar = all_results.set_index([c for c in all_results.columns if c not in ("MMD", "C2ST")])

In [None]:
import seaborn as sns
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

In [None]:
pretty_names = {
    "two_moons": "Two Moons",
    "slcp": "SLCP",
    "gaussian_linear_uniform": "Gaussian Linear Uniform",
    "lotka_volterra": "Lotka Volterra"
}
metric = "C2ST"
limits_metric = {
    "MMD": (0, 1),
    "C2ST": (0.5, 1.1)
}

errorbar_kws = dict(
    linewidth=4,
    elinewidth=3,
    markersize=12,
    capsize=7, 
    marker="o"
)


fontsize=26
tk_fontsize=20
time_fontsize=20

with mpl.rc_context(fname='.matplotlibrc'):
    mpl.rc('font',family='DejaVu Sans')
    mpl.rc("text", usetex=False)

    tasks = ("two_moons", "slcp", "lotka_volterra", "gaussian_linear_uniform")
    nrows = 2
    f, axss = plt.subplots(ncols=len(tasks), nrows=nrows, figsize=(6 * len(tasks), 3.5 * nrows))
    
    axs = axss[0]
        
    for t_no, task in enumerate(tasks):
        ax = axs[t_no]
        ax.grid(axis="y")
        for axis in ['bottom','left']:
            ax.spines[axis].set_linewidth(4)
            
        ax.errorbar(
            ar.xs(task, level="task").xs("AUNLE", level="algorithm").groupby("num_simulations")[metric].mean().index,
            ar.xs(task, level="task").xs("AUNLE", level="algorithm").groupby("num_simulations")[metric].mean(),
            ar.xs(task, level="task").xs("AUNLE", level="algorithm").groupby("num_simulations")[metric].std(),
            label="A-UNLE \n(Ours)",
            color="firebrick",
            **errorbar_kws
        )


        ax.errorbar(
            ar.xs(task, level="task").xs("SMNLE(SSM)", level="algorithm").groupby("num_simulations")[metric].mean().index,
            ar.xs(task, level="task").xs("SMNLE(SSM)", level="algorithm").groupby("num_simulations")[metric].mean(),
            ar.xs(task, level="task").xs("SMNLE(SSM)", level="algorithm").groupby("num_simulations")[metric].std(),
            label="SMNLE",
            color="goldenrod",
            **errorbar_kws
        )

        ax.set_ylim(*limits_metric[metric])
        ax.errorbar(
            ar.xs(task, level="task").xs("NLE", level="algorithm").groupby("num_simulations")[metric].mean().index,
            ar.xs(task, level="task").xs("NLE", level="algorithm").groupby("num_simulations")[metric].mean(),
            ar.xs(task, level="task").xs("NLE", level="algorithm").groupby("num_simulations")[metric].std(),
            label="NLE",
            color="royalblue",
            **errorbar_kws
        )


        ax.set_ylim(*limits_metric[metric])


        ax.set_title(pretty_names[task], fontsize=fontsize, pad=20)
        ax.tick_params(axis='both', labelsize=tk_fontsize)
        ax.get_xaxis().set_ticks([])
        
        import seaborn as sns
        with sns.axes_style("whitegrid"):
            snle_rt = sbr['RT'].xs(task, level="task").xs("NLE", level="algorithm").mean()
            aunle_rt = df_aunle_time.sum(axis=1).to_frame("RT").xs(task, level="task").xs("AUNLE", level="algorithm").mean()[0]
            sunle_rt = df_sunle_time.sum(axis=1).to_frame("RT").xs(task, level="task").xs("SUNLE", level="algorithm").mean()[0]
            smnle_rt = df_ssm_time['RT'].xs(task, level="task").xs("SSM", level="algorithm").mean()
            max_rt = max(snle_rt, sunle_rt)
            total_max_rt = 50

            divider = make_axes_locatable(ax)
            ax1 = divider.append_axes('right', size='10%', pad=0.3)
            ax1.set_xlim(0,2)

            ax1.set_ylim(0, total_max_rt + 5)

            ticks = np.arange(0, total_max_rt + 5, 5)
            ax1.set_yticks(ticks)
            ax1.set_yticklabels(["0"] + [""]*len(ticks[1:-1]) + [str(int(ticks[-1]))], fontsize=time_fontsize)
            ax1.set_xticks([])

            ax1.eventplot(positions = np.array([snle_rt, sunle_rt, smnle_rt]).reshape(-1,1), linewidths=8, lineoffsets=[1,1,1,1,1][:3], colors=["royalblue", "firebrick", 'goldenrod', "#984ea3", "#4daf4a","#e41a1c"][:3], linelengths=2, orientation='vertical')

            ax1.set_ylabel("time [m] ",  fontsize=time_fontsize, labelpad=-20)

            ax1.yaxis.tick_right()
            ax1.yaxis.set_label_position("right")
            ax1.yaxis.set_tick_params(length=0,labelbottom=False)
            plt.subplots_adjust(wspace=0.5, hspace=0.8)

        
    l = axs[0].legend(fontsize=fontsize, bbox_to_anchor=(-0.2, 0.5), bbox_transform=axs[0].transAxes, loc="center right")
    axs[0].set_ylabel("C2ST", size=20)
    l.texts[0].set_weight("bold")
    
    axs = axss[1]
        
    for t_no, task in enumerate(tasks):
        ax = axs[t_no]
        ax.grid(axis="y")
        for axis in ['bottom','left']:
            ax.spines[axis].set_linewidth(4)
            
        ax.errorbar(
            ar.xs(task, level="task").xs("SUNLE", level="algorithm").groupby("num_simulations")[metric].mean().index,
            ar.xs(task, level="task").xs("SUNLE", level="algorithm").groupby("num_simulations")[metric].mean(),
            ar.xs(task, level="task").xs("SUNLE", level="algorithm").groupby("num_simulations")[metric].std(),
            label="S-UNLE \n(Ours)",
            color="firebrick",
            **errorbar_kws
        )



        ax.set_ylim(*limits_metric[metric])
        ax.errorbar(
            ar.xs(task, level="task").xs("SNLE", level="algorithm").groupby("num_simulations")[metric].mean().index,
            ar.xs(task, level="task").xs("SNLE", level="algorithm").groupby("num_simulations")[metric].mean(),
            ar.xs(task, level="task").xs("SNLE", level="algorithm").groupby("num_simulations")[metric].std(),
            label="SNLE",
            color="royalblue",
            **errorbar_kws
        )


        ax.set_ylim(*limits_metric[metric])

        ax.tick_params(axis='both', labelsize=tk_fontsize)
        
        ax.get_xaxis().set_ticklabels(
            ar.xs(task, level="task").xs("AUNLE", level="algorithm").groupby("num_simulations")[metric].mean().index,
        )
        ax.set_xlabel("Num. Simulations", size=20)
        
        ax.get_xaxis().set_ticklabels(
            ar.xs(task, level="task").xs("AUNLE", level="algorithm").groupby("num_simulations")[metric].mean().index,
        )
        
        import seaborn as sns
        with sns.axes_style("whitegrid"):
            snle_rt = sbr['RT'].xs(task, level="task").xs("SNLE", level="algorithm").mean()
            aunle_rt = df_aunle_time.sum(axis=1).to_frame("RT").xs(task, level="task").xs("AUNLE", level="algorithm").mean()[0]
            sunle_rt = df_sunle_time.sum(axis=1).to_frame("RT").xs(task, level="task").xs("SUNLE", level="algorithm").mean()[0]
            max_rt = max(snle_rt, sunle_rt)
            total_max_rt = 50

            divider = make_axes_locatable(ax)
            ax1 = divider.append_axes('right', size='10%', pad=0.3)
            ax1.set_xlim(0,2)

            ax1.set_ylim(0, total_max_rt + 5)

            ticks = np.arange(0, total_max_rt + 5, 5)
            ax1.set_yticks(ticks)
            ax1.set_yticklabels(["0"] + [""]*len(ticks[1:-1]) + [str(int(ticks[-1]))], fontsize=time_fontsize)

            ax1.set_xticks([])

            ax1.eventplot(positions = np.array([snle_rt, sunle_rt]).reshape(-1,1), linewidths=8, lineoffsets=[1,1,1,1,1][:2], colors=["royalblue", "firebrick", "#984ea3", "#4daf4a","#e41a1c"][:2], linelengths=2, orientation='vertical')

            ax1.set_ylabel("time [m] ",  fontsize=time_fontsize, labelpad=-20)

            ax1.yaxis.tick_right()
            ax1.yaxis.set_label_position("right")
            ax1.yaxis.set_tick_params(length=0,labelbottom=False)
            plt.subplots_adjust(wspace=0.5, hspace=0.8)

        
    l = axs[1].legend(fontsize=fontsize, bbox_to_anchor=(-0.2, 0.5), bbox_transform=axs[0].transAxes, loc="center right")
    axs[0].set_ylabel("C2ST", size=20)
    l.texts[0].set_weight("bold")

f
f.savefig("figure-2.pdf", dpi=300, bbox_inches='tight')