In [None]:
import argparse
import json
import logging
import pickle
import wandb

import matplotlib.pyplot as plt
import seaborn as sns
import glob
import pandas as pd
import os
import numpy as np
import copy
from pprint import pprint

In [None]:
import matplotlib
matplotlib.rcParams.update(
    {
        "figure.dpi": 150,
        "font.size": 10,
    }
)
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

In [None]:
api = wandb.Api()

In [None]:
def load_groups(group_and_keys, relabel_dict, x_range, extra_filter, include_configs = None):
    all_interp_data = []
    if include_configs is None:
        include_configs = []
    for group, x_key, y_key, extra_cond in group_and_keys:
        total_filters = {
            "$and": [
                *extra_cond,
                {"group": group},
                {"$not": {"tags": "exclude-from-paper"}},
                extra_filter,
            ]
        }
        pprint(total_filters)
        runs = api.runs(
            path="resl-mixppo/stabilized-rl",
            filters=total_filters,
        )
        print(f"Got {len(runs)} runs for group {group}")
        x_vals = np.linspace(x_range[0], x_range[1], 1000)
        for r in runs:
            # h = r.history(samples=2000, keys=[x_key, y_key])
            h = pd.DataFrame(r.scan_history(keys=[x_key, y_key]))
            try:
                if np.max(h[x_key]) < 0.99 * x_range[1]:
                    print("Maximum x value of run", str(r), "is", np.max(h[x_key]))
                    continue
                interp_y = np.interp(x_vals, h[x_key], h[y_key])
            except KeyError:
                print("Could not get keys in run", r)
                print(h)
            else:
                df = pd.DataFrame.from_dict(
                        {
                            relabel_dict.get(x_key, x_key): x_vals,
                            relabel_dict.get(y_key, y_key): interp_y,
                            relabel_dict.get("group", "group"): relabel_dict.get(
                                group, group
                            ),
                            "run": str(r),
                        }
                    )
                for inc_cfg in include_configs:
                    df[relabel_dict.get(inc_cfg, inc_cfg)] = r.config.get(inc_cfg, None)
                all_interp_data.append(df)
    return pd.concat(all_interp_data, ignore_index=True)

In [None]:

relabels = {
    "fixpo-512-5": "xPPO",
    "baseline_ppo": "PPO-clip",
    "xppo10m-512-5": "xPPO",
    "xppo_single_step": "xPPO",
    "baseline_ppo_10m": "PPO-clip",
    "global_step": "Total Environment Steps",
    "rollout/ep_rew_mean": "Average Episode Reward",
    "group": "Algorithm",
    "fixpo-tianshou-mujoco": "FixPo",
    "ppo-tianshou-mujoco": "PPO-clip",
    "test/reward": "Average Episode Reward",
    "evaluation_reward/mean": "Average Episode Reward",
    "trust_region_layers" : "KL Proj.",
    "trust-region-layers-papi": "KL Proj.",
    "trust-region-layers-kl-metaworld-logged": "KL Proj.",
    "test/success_rate": "Success Rate",
    "evaluation_reward/success_rate": "Success Rate",
    "fixpo-tianshou-metaworld": "FixPo",
    "ppo-tianshou-metaworld": "PPO-clip",
    "trust-region-layers-kl-metaworld-logged": "KL Proj.",



}


fixpo_configs = {'fixup_loop': 1, 'fixup_every_repeat': 1, 'eps_kl': 0.5, 'target_coeff': 3, "kl_target_stat": 'max', 'init_beta': None}
ppo_configs = {}
trl= {'proj_type': 'kl'}



group_and_keys = [
    # ("fixpo-tianshou-metaworld", "global_step", "test/success_rate", [*({f"config.{k}": v} for k, v in fixpo_configs.items()),]),
    # ("ppo-tianshou-metaworld", "global_step", "test/success_rate", [*({f"config.{k}": v} for k, v in ppo_configs.items()),]),
    ("trust-region-layers-kl-metaworld-logged", "global_step", "evaluation_reward/success_rate", [*({f"config.{k}": v} for k, v in trl.items()),]),

]


all_data = load_groups(
    group_and_keys,
    relabels,
    (0, 10e6),
    {
        "$and": [
            {"$or": [{"tags": {"$in": ['paper']}}, {"state": "finished"}, {"state": "running"}]},
        ]
    },
)

In [None]:
all_data

In [None]:
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Success Rate",
    hue="Algorithm",
    ci=95,
    style="Algorithm",
    palette="viridis",
)
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig(f"fixPo_vs_ppo_MT50.pdf")