In [None]:
import argparse
import json
import logging
import pickle
import wandb

import matplotlib.pyplot as plt
import seaborn as sns
import glob
import pandas as pd
import os
import numpy as np
import copy

In [None]:
import matplotlib
matplotlib.rcParams.update(
    {
        "figure.dpi": 150,
        "font.size": 20,
    }
)
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

In [None]:
api = wandb.Api()

In [None]:
def load_groups(group_and_keys, relabel_dict, x_range, extra_filter):
    all_interp_data = []
    for group, x_key, y_key in group_and_keys:
        runs = api.runs(
            path="resl-mixppo/stabilized-rl",
            filters={
                "$and": [
                    {"group": group},
                    {"$not": {"tags": "exclude-from-paper"}},
                    extra_filter,
                ]
            },
        )
        print(f"Got {len(runs)} runs for group {group}")
        x_vals = np.linspace(x_range[0], x_range[1], 1000)
        for r in runs:
            # h = r.history(samples=2000, keys=[x_key, y_key])
            h = pd.DataFrame(r.scan_history(keys=[x_key, y_key]))
            try:
                if np.max(h[x_key]) < 0.99 * x_range[1]:
                    print("Maximum x value of run is", np.max(h[x_key]))
                interp_y = np.interp(x_vals, h[x_key], h[y_key])
            except KeyError:
                print("Could not get keys in run", r)
                print(h)
            else:
                all_interp_data.append(
                    pd.DataFrame.from_dict(
                        {
                            relabel_dict.get(x_key, x_key): x_vals,
                            relabel_dict.get(y_key, y_key): interp_y,
                            relabel_dict.get("group", "group"): relabel_dict.get(
                                group, group
                            ),
                            "run": str(r),
                        }
                    )
                )
    return pd.concat(all_interp_data, ignore_index=True)

In [None]:
env = "HalfCheetah-v2"
group_and_keys = [
    ("xppo_single_step", "global_step", "rollout/ep_rew_mean"),
    ("no-reset_single_step", "global_step", "rollout/ep_rew_mean"),
    ("one-phase_single_step", "global_step", "rollout/ep_rew_mean"),
    ("mean-kl-target_single_step", "global_step", "rollout/ep_rew_mean"),
#     ("xppo-512-5", "global_step", "rollout/ep_rew_mean"),
#     ("no-reset-512-5", "global_step", "rollout/ep_rew_mean"),
#     ("one-phase-512-5", "global_step", "rollout/ep_rew_mean"),
#     ("mean-kl-target-512-5", "global_step", "rollout/ep_rew_mean"),
#     ("no-historic-512-5", "global_step", "rollout/ep_rew_mean"),
#     ("second-loop-vf-512-5", "global_step", "rollout/ep_rew_mean"),
]
relabels = {
    "xppo_single_step": "xPPO",
    "no-reset_single_step": "no-reset",
    "one-phase_single_step": "one-phase",
    "mean-kl-target_single_step": "mean-kl-target",
    "xppo-512-5": "xPPO",
    "no-reset-512-5": "no-reset",
    "one-phase-512-5": "one-phase",
    "mean-kl-target-512-5": "mean-kl-target",
    "no-historic-512-5": "no-historic",
    "second-loop-vf-512-5": "second-loop-vf",
    "baseline_ppo": "PPO-clip",
    "global_step": "Total Environment Steps",
    "rollout/ep_rew_mean": "Average Episode Reward",
    "group": "Algorithm",
}
all_data = load_groups(
    group_and_keys,
    relabels,
    (0, 3e6),
    {
        "$and": [
            {"config.env": env},
            {"$or": [{"tags": {"$in": ['paper']}}, {"state": "finished"}, {"state": "running"}]},
        ]
    },
)

In [None]:
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Average Episode Reward",
    hue="Algorithm",
    ci=95,
    style="Algorithm",
#     palette="viridis",
)
plt.legend(loc="lower right", fontsize=12)
plt.tight_layout()
plt.savefig(f"xppo_ablations_{env}.pdf")

In [None]:
env = "Hopper-v2"
all_data = load_groups(
    group_and_keys,
    relabels,
    (0, 3e6),
    {
        "$and": [
            {"config.env": env},
            {"$or": [{"tags": {"$in": ['paper']}}, {"state": "finished"}, {"state": "running"}]},
        ]
    },
)
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Average Episode Reward",
    hue="Algorithm",
    ci=95,
    style="Algorithm",
#     palette="viridis",
)
plt.legend(loc="lower right", fontsize=12)
plt.tight_layout()
plt.savefig(f"xppo_ablations_{env}.pdf")

In [None]:
env = "Walker2d-v2"
all_data = load_groups(
    group_and_keys,
    relabels,
    (0, 3e6),
    {
        "$and": [
            {"config.env": env},
            {"$or": [{"tags": {"$in": ['paper']}}, {"state": "finished"}, {"state": "running"}]},
        ]
    },
)
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Average Episode Reward",
    hue="Algorithm",
    ci=95,
    style="Algorithm",
#     palette="viridis",
)
plt.legend(loc="lower right", fontsize=12)
plt.tight_layout()
plt.savefig(f"xppo_ablations_{env}.pdf")

In [None]:
plt.legend?

In [None]:
env = "Walker2d-v2"
group_and_keys = [
    ("xppo_single_step", "global_step", "train/final_kl_div"),
    ("no-reset_single_step", "global_step", "train/final_kl_div"),
    ("one-phase_single_step", "global_step", "train/final_kl_div"),
    ("mean-kl-target_single_step", "global_step", "train/final_kl_div"),
]
all_data = load_groups(
    group_and_keys,
    {
        "xppo_single_step": "xPPO",
        "no-reset_single_step": "no-reset",
        "one-phase_single_step": "one-phase",
        "mean-kl-target_single_step": "mean-kl-target",
        "train/final_kl_div": "Mean KL Divergence",
        "group": "Ablation",
        "global_step": "Total Environment Steps",
    },
    (0, 3e6),
    {
        "$and": [
            {"config.env": env},
            {"$or": [{"tags": {"$in": ['paper']}}, {"state": "finished"}, {"state": "running"}]},
        ]
    },
)

In [None]:
matplotlib.rcParams["font.size"] = 14
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Mean KL Divergence",
    hue="Ablation",
    ci=95,
    style="Ablation",
#     palette="viridis",
)
plt.legend(loc="upper right")
plt.tight_layout()
plt.savefig(f"ablations_kl_divergence_{env}.pdf")