In [1]:
import argparse
import json
import logging
import pickle
import wandb

import matplotlib.pyplot as plt
import seaborn as sns
import glob
import pandas as pd
import os
import numpy as np
import copy
from pprint import pprint

In [7]:
import matplotlib
matplotlib.rcParams.update(
    {
        "figure.dpi": 150,
        "font.size": 14,
    }
)
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

In [3]:
api = wandb.Api()

In [4]:
def load_groups(group_and_keys, relabel_dict, x_range, extra_filter):
    all_interp_data = []
    for group, x_key, y_key in group_and_keys:
        total_filters = {
            "$and": [
                {"group": group},
                {"$not": {"tags": "exclude-from-paper"}},
                extra_filter,
            ]
        }
        pprint(total_filters)
        runs = api.runs(
            path="resl-mixppo/stabilized-rl",
            filters=total_filters,
        )
        print(f"Got {len(runs)} runs for group {group}")
        x_vals = np.linspace(x_range[0], x_range[1], 1000)
        for r in runs:
            # h = r.history(samples=2000, keys=[x_key, y_key])
            h = pd.DataFrame(r.scan_history(keys=[x_key, y_key]))
            try:
                if np.max(h[x_key]) < 0.99 * x_range[1]:
                    print("Maximum x value of run", str(r), "is", np.max(h[x_key]))
                    continue
                interp_y = np.interp(x_vals, h[x_key], h[y_key])
            except KeyError:
                print("Could not get keys in run", r)
                print(h)
            else:
                all_interp_data.append(
                    pd.DataFrame.from_dict(
                        {
                            relabel_dict.get(x_key, x_key): x_vals,
                            relabel_dict.get(y_key, y_key): interp_y,
                            relabel_dict.get("group", "group"): relabel_dict.get(
                                group, group
                            ),
                            "run": str(r),
                        }
                    )
                )
    return pd.concat(all_interp_data, ignore_index=True)

In [5]:
env = "HalfCheetah-v2"
group_and_keys = [
    ("xppo_single_step", "global_step", "rollout/ep_rew_mean"),
    ("baseline_ppo", "global_step", "rollout/ep_rew_mean"),
]
relabels = {
    "xppo-512-5": "xPPO",
    "baseline_ppo": "PPO-clip",
    "xppo10m-512-5": "xPPO",
    "xppo_single_step": "xPPO",
    "baseline_ppo_10m": "PPO-clip",
    "global_step": "Total Environment Steps",
    "rollout/ep_rew_mean": "Average Episode Reward",
    "group": "Algorithm",
}
all_data = load_groups(
    group_and_keys,
    relabels,
    (0, 3e6),
    {
        "$and": [
            {"config.env": env},
            {"$or": [{"tags": {"$in": ['paper']}}, {"state": "finished"}, {"state": "running"}]},
        ]
    },
)

{'$and': [{'group': 'xppo_single_step'},
          {'$not': {'tags': 'exclude-from-paper'}},
          {'$and': [{'config.env': 'HalfCheetah-v2'},
                    {'$or': [{'tags': {'$in': ['paper']}},
                             {'state': 'finished'},
                             {'state': 'running'}]}]}]}
Got 19 runs for group xppo_single_step
Maximum x value of run <Run resl-mixppo/stabilized-rl/3skaxz59 (crashed)> is 1937408.0
Maximum x value of run <Run resl-mixppo/stabilized-rl/upbbwt40 (crashed)> is 1945600.0
Maximum x value of run <Run resl-mixppo/stabilized-rl/ofnwff0r (crashed)> is 1970176.0
{'$and': [{'group': 'baseline_ppo'},
          {'$not': {'tags': 'exclude-from-paper'}},
          {'$and': [{'config.env': 'HalfCheetah-v2'},
                    {'$or': [{'tags': {'$in': ['paper']}},
                             {'state': 'finished'},
                             {'state': 'running'}]}]}]}
Got 8 runs for group baseline_ppo


In [None]:
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Average Episode Reward",
    hue="Algorithm",
    ci=95,
    style="Algorithm",
    palette="viridis",
)
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig(f"xppo_vs_ppo_{env}.pdf")

In [None]:
env = "Hopper-v2"
group_and_keys = [
    ("xppo_single_step", "global_step", "rollout/ep_rew_mean"),
    ("baseline_ppo_10m", "global_step", "rollout/ep_rew_mean"),
]
all_data = load_groups(
    group_and_keys,
    relabels,
    (0, 10e6),
    {
        "$and": [
            {"config.env": env},
            {"$or": [{"tags": {"$in": ['paper']}}, {"state": "finished"}, {"state": "running"}]},
        ]
    },
)
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Average Episode Reward",
    hue="Algorithm",
    ci=95,
    style="Algorithm",
    palette="viridis",
)
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig(f"xppo_vs_ppo_{env}.pdf")

In [None]:
env = "Walker2d-v2"
group_and_keys = [
    ("xppo_single_step", "global_step", "rollout/ep_rew_mean"),
    ("baseline_ppo_10m", "global_step", "rollout/ep_rew_mean"),
]
all_data = load_groups(
    group_and_keys,
    relabels,
    (0, 10e6),
    {
        "$and": [
            {"config.env": env},
            {"$or": [{"tags": {"$in": ['paper']}}, {"state": "finished"}, {"state": "running"}]},
        ]
    },
)
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Average Episode Reward",
    hue="Algorithm",
    ci=95,
    style="Algorithm",
    palette="viridis",
)
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig(f"xppo_vs_ppo_{env}.pdf")

In [None]:
all_data[all_data["Total Environment Steps"] == 0]

In [None]:
list(
    api.runs(
        path="resl-mixppo/stabilized-rl",
        filters={
            "$and": [
                {"group": "baseline_ppo_10m"},
                {"config.env": "Walker2d-v2"},
                {"tags": "paper"},
            ]
        },
    )
)

In [None]:
list(
    api.runs(
        path="resl-mixppo/stabilized-rl",
        filters={
            "$and": [
                {"group": "baseline_ppo_10m"},
                {"$not": {"tags": "exclude-from-paper"}},
                {
                    "$and": [
                        {"config.env": "Walker2d-v2"},
                        {
                            "$or": [
                                {"tags": "paper"},
                                {"state": "finished"},
                                {"state": "running"},
                            ]
                        },
                    ]
                },
            ]
        },
    )
)

In [None]:
env = "Walker2d-v2"
group_and_keys = [
    ("xppo_single_step", "global_step", "train/mean_second_penalty_loops"),
    ("xppo_single_step", "global_step", "train/max_second_penalty_loops"),
]
all_data = load_groups(
    group_and_keys,
    relabels,
    (0, 10e6),
    {
        "$and": [
            {"config.env": env},
            {"$or": [{"tags": {"$in": ['paper']}}, {"state": "finished"}, {"state": "running"}]},
        ]
    },
)

In [None]:
data_fixup = pd.DataFrame.from_dict({
    "Total Environment Steps": all_data["Total Environment Steps"],
    "Second Phase Loop Iterations": np.where(np.isnan(all_data["train/mean_second_penalty_loops"]),
                                             all_data["train/max_second_penalty_loops"] - 1,
                                             all_data["train/mean_second_penalty_loops"] - 1), # We log each attempt to loop, including the last one
#     "Second Phase Loop Iterations": all_data["train/max_second_penalty_loops"] - 1,
    "Iterations": np.where(np.isnan(all_data["train/mean_second_penalty_loops"]), "Max", "Mean"),
#     "Mean Iterations": ~np.isnan(all_data["train/mean_second_penalty_loops"]),
#     "Max Iterations": ~np.isnan(all_data["train/max_second_penalty_loops"]),
})

In [None]:
matplotlib.rcParams["font.size"] = 14
sns.lineplot(
    data=data_fixup,
    x="Total Environment Steps",
    y="Second Phase Loop Iterations",
    hue="Iterations",
    ci=95,
    style="Iterations",
    palette="viridis",
)
plt.legend(loc="upper right")
plt.tight_layout()
plt.savefig(f"mean_and_max_second_loops_{env}.pdf")

In [10]:
env = "Walker2d-v2"
group_and_keys = [
    ("xppo_single_step", "global_step", "train/std"),
    ("baseline_ppo_10m", "global_step", "train/std"),
]
relabels = {
    "xppo-512-5": "FixPO",
    "baseline_ppo": "PPO-clip",
    "xppo10m-512-5": "FixPO",
    "xppo_single_step": "FixPO",
    "baseline_ppo_10m": "PPO-clip",
    "global_step": "Total Environment Steps",
    "train/std": "Action Distribution Standard Deviation",
    "group": "Algorithm",
}
all_data = load_groups(
    group_and_keys,
    relabels,
    (0, 10e6),
    {
        "$and": [
            {"config.env": env},
            {"$or": [{"tags": {"$in": ['paper']}}, {"state": "finished"}, {"state": "running"}]},
        ]
    },
)
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Action Distribution Std Dev.",
    hue="Algorithm",
    ci=95,
    style="Algorithm",
    palette="viridis",
)
plt.legend(loc="upper right")
plt.tight_layout()
plt.savefig(f"xppo_vs_ppo_std_dev_{env}.pdf")

{'$and': [{'group': 'xppo_single_step'},
          {'$not': {'tags': 'exclude-from-paper'}},
          {'$and': [{'config.env': 'Walker2d-v2'},
                    {'$or': [{'tags': {'$in': ['paper']}},
                             {'state': 'finished'},
                             {'state': 'running'}]}]}]}
Got 20 runs for group xppo_single_step
Maximum x value of run <Run resl-mixppo/stabilized-rl/hfsts37x (crashed)> is 8052736.0
Maximum x value of run <Run resl-mixppo/stabilized-rl/yaw4m35y (crashed)> is 9744384.0
Maximum x value of run <Run resl-mixppo/stabilized-rl/ghkheu6s (crashed)> is 1904640.0
Maximum x value of run <Run resl-mixppo/stabilized-rl/zc7n2hta (crashed)> is 1929216.0
Maximum x value of run <Run resl-mixppo/stabilized-rl/cvuu9tur (crashed)> is 1937408.0
Maximum x value of run <Run resl-mixppo/stabilized-rl/4lr0eij5 (crashed)> is 5066752.0
Maximum x value of run <Run resl-mixppo/stabilized-rl/xe208b9l (crashed)> is 1961984.0
{'$and': [{'group': 'baseline_ppo_10m'},



The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  sns.lineplot(


ValueError: Could not interpret value `Action Distribution Std Dev.` for parameter `y`

In [None]:
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Action Distribution Std. Dev.",
    hue="Algorithm",
    ci=95,
    style="Algorithm",
    palette="viridis",
)
plt.legend(loc="upper right")
plt.tight_layout()
plt.savefig(f"xppo_vs_ppo_std_dev_{env}.pdf")