In [None]:
import argparse
import json
import logging
import pickle
import wandb
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import pandas as pd
import os
import numpy as np
import copy

In [None]:
import matplotlib
matplotlib.rcParams.update(
    {
        "figure.dpi": 150,
        "font.size": 14,
        "figure.figsize": (7.5, 4.8)
    }
)
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

In [None]:
api = wandb.Api(timeout=200)

In [None]:
SMOOTH_WINDOW = 30


def smooth(to_smooth, window_size=SMOOTH_WINDOW):
    smoothed = []
    for idx, val in enumerate(to_smooth):
        if idx < window_size:
            smoothed.append(np.mean(to_smooth[: idx + 1]))
        else:
            smoothed.append(np.mean(to_smooth[idx - (SMOOTH_WINDOW - 1) : idx + 1]))

    return np.array(smoothed)

In [None]:
SMOOTH_WINDOW = 10

def smooth(to_smooth, window_size=SMOOTH_WINDOW):
    smoothed = []
    for idx, val in enumerate(to_smooth):
        if idx < window_size:
            smoothed.append(np.mean(to_smooth[: idx + 1]))
        else:
            smoothed.append(np.mean(to_smooth[idx - (SMOOTH_WINDOW - 1) : idx + 1]))

    return np.array(smoothed)

def load_groups(group_and_keys, relabel_dict, x_range, extra_filter):
    all_interp_data = []
    for group, x_key, y_key in group_and_keys:
        runs = api.runs(
            path="resl-mixppo/stabilized-rl",
            filters={
                "$and": [
                    {"group": group},
                    {"$not": {"tags": "exclude-from-paper"}},
                    extra_filter,
                ]
            },
        )
        print(f"Got {len(runs)} runs for group {group}")
        x_vals = np.linspace(x_range[0], x_range[1], 1000)
        for r in tqdm(runs):
            # h = r.history(samples=2000, keys=[x_key, y_key])
           
            h = pd.DataFrame(r.scan_history(keys=[x_key, y_key]))
            try:
                if np.max(h[x_key]) < 0.66 * x_range[1]:
                    print("Maximum x value of run", str(r), "is", np.max(h[x_key]))
#                     continue
                # interp_y = np.interp(x_vals, h[x_key], smooth(h[y_key], 5))
                if group == "basline_stbl_ppo":
                    interp_y = np.interp(x_vals, h[x_key], smooth(h[y_key], 30))
                else:
                    interp_y = np.interp(x_vals, h[x_key],h[y_key])
            except KeyError:
                print("Could not get keys in run", r)
            else:
                env = r.config['env']
                all_interp_data.append(
                    pd.DataFrame.from_dict(
                        {
                            relabel_dict.get(x_key, x_key): x_vals,
                            relabel_dict.get(y_key, y_key): interp_y,
                            relabel_dict.get("group", "group"): relabel_dict.get(
                                group, group
                            ),
                            "run": str(r),
                            relabel_dict.get("env", "env"): relabel_dict.get(
                                env, env
                            ),
                        }
                    )
                )
    return pd.concat(all_interp_data, ignore_index=True)

In [None]:
relabels = {
    "xppo-512-5": "xPPO",
    "baseline_ppo": "PPO-clip",
    "xppo10m-512-5": "xPPO",
    "basline_ppo_10m": "PPO-clip",
    "xppo_single_step_4096": "xPPO $|D_{h}|=32000$",
    "basline_stbl_ppo": "PPO-clip 4096",
    'xppo_single_step_large_historic': "xPPO $|D_{h}|=128000$",
    "xppo_single_step_no_historic": "xPPO $|D_{h}|=4096$",
    "global_step": "Total Environment Steps",
    "rollout/ep_rew_mean": "Average Episode Reward",
    "group": "Algorithm",
    'env': "Environment",
    'rollout/SuccessRate':'Average Success Rate',
    'train/std': "Action Distribution STD",
    "<SawyerPushEnvV2 instance>":"push",
    "<SawyerWindowCloseEnvV2 instance>":"window-close",
    "<SawyerDoorEnvV2 instance>":"door-open",
    "<SawyerReachEnvV2 instance>":"reach",
    "<SawyerButtonPressTopdownEnvV2 instance>":"button-press-topdown",
    "<SawyerWindowOpenEnvV2 instance>":"window-open",
    "<SawyerDrawerOpenEnvV2 instance>":"drawer-open",
    "<SawyerPegInsertionSideEnvV2 instance>":"peg-insert-side",
    "<SawyerDrawerCloseEnvV2 instance>":"drawer-close",
}
envs = ["<SawyerPushEnvV2 instance>",
"<SawyerWindowCloseEnvV2 instance>",
"<SawyerDoorEnvV2 instance>",
"<SawyerReachEnvV2 instance>",
"<SawyerButtonPressTopdownEnvV2 instance>",
"<SawyerPickPlaceEnvV2 instance>",
"<SawyerWindowOpenEnvV2 instance>",
"<SawyerDrawerOpenEnvV2 instance>",
"<SawyerPegInsertionSideEnvV2 instance>",
"<SawyerDrawerCloseEnvV2 instance>",]

## Pick-Place

In [None]:

env = "PickPlace-V2"
config_env = "<SawyerPickPlaceEnvV2 instance>"
total_steps=7e6
group_and_keys = [
    ("xppo_single_step_4096", "global_step",  "rollout/SuccessRate"),
#     ("xppo_single_step_large_historic", "global_step", "rollout/SuccessRate"),
    ('xppo_single_step_no_historic', "global_step", "rollout/SuccessRate"),
    ("basline_stbl_ppo", "global_step", "rollout/SuccessRate"),
]
tag_filter = {"tags": {"$in": ["stbl_ppo_pick-place_baseline"]}}

# Don't add baseline_stbl_ppo as a group, that mixes the batch_size 50k v/s default size
group_filter = {"group": {"$in": ["xppo_single_step_4096", "xppo_single_step_large_historic", 'xppo_single_step_no_historic']}} # to not exclude not yet tagged runs
env_filter = {"config.env":config_env}
or_filter = {"$or": [tag_filter, {"$and": [group_filter,env_filter]}]}
# state_filter = {"$or": [{"state": "finished"}, {"state": "running"}]}
state_filter = {"$or": [{"state": "finished"}, {"state": "running"}]}



extra_filter = {'$and': [ or_filter]}
all_data = load_groups(
    group_and_keys,
    relabels,
    (0, total_steps),
    extra_filter
)

In [None]:
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Average Success Rate",
    hue="Algorithm",
    ci=95,
    # style="Environment",
    # palette="viridis",
)
plt.legend(loc="upper left")
# plt.legend(loc='upper left', bbox_to_anchor=(0.90, 1),
#           ncol=1, fancybox=True, shadow=True, fontsize=13)
plt.tight_layout()
plt.savefig(f"historical_buffer_sweep_{env}.pdf")


In [None]:
all_data[all_data["Algorithm"] == "PPO-clip 4096"].loc["Average Success Rate"] = smooth(np.array(all_data[all_data["Algorithm"] == "PPO-clip 4096"]["Average Success Rate"]))

## Window Open

In [None]:

env = 'WindowOpen'
config_env = "<SawyerWindowOpenEnvV2 instance>"
total_steps = 2e6
group_and_keys = [
    ("xppo_single_step_4096", "global_step",  "rollout/SuccessRate"),
    ("xppo_single_step_large_historic", "global_step", "rollout/SuccessRate"),
    ('xppo_single_step_no_historic', "global_step", "rollout/SuccessRate"),
    ("basline_stbl_ppo", "global_step", "rollout/SuccessRate"),
]

tag_filter = {"tags": {"$in": ["stbl_ppo_pick-place_baseline"]}}
group_filter = {"group": {"$in": ["xppo_single_step_4096", "xppo_single_step_large_historic", 'xppo_single_step_no_historic']}} # to not exclude not yet tagged runs
env_filter = {"config.env":config_env}
or_filter = {"$or": [tag_filter, {"$and": [group_filter,env_filter]}]}

state_filter = {"$or": [{"state": "finished"}, {"state": "running"}]}

extra_filter = {'$and': [ or_filter]}

all_data = load_groups(
    group_and_keys,
    relabels,
    (0, total_steps),
    extra_filter
)

In [None]:
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Average Success Rate",
    hue="Algorithm",
    ci=95,
    # style="Environment",
    # palette="viridis",
)
plt.legend(loc="lower right")
# plt.legend(loc='upper left', bbox_to_anchor=(0.90, 1),
#           ncol=1, fancybox=True, shadow=True, fontsize=13)
plt.tight_layout()
plt.savefig(f"historical_buffer_sweep_{env}.pdf")


## Reach

In [None]:

env = 'Reach'
config_env = "<SawyerReachEnvV2 instance>"
total_steps = 3e6
group_and_keys = [
    ("xppo_single_step_4096", "global_step",  "rollout/SuccessRate"),
    ("xppo_single_step_large_historic", "global_step", "rollout/SuccessRate"),
    ('xppo_single_step_no_historic', "global_step", "rollout/SuccessRate"),
    ("basline_stbl_ppo", "global_step", "rollout/SuccessRate"),
]

tag_filter = {"tags": {"$in": ["stbl_ppo_pick-place_baseline"]}}
group_filter = {"group": {"$in": ["xppo_single_step_4096", "xppo_single_step_large_historic", 'xppo_single_step_no_historic']}} # to not exclude not yet tagged runs
env_filter = {"config.env":config_env}
or_filter = {"$or": [tag_filter, {"$and": [group_filter,env_filter]}]}

state_filter = {"$or": [{"state": "finished"}, {"state": "running"}]}

extra_filter = {'$and': [ state_filter , or_filter]}

all_data = load_groups(
    group_and_keys,
    relabels,
    (0, total_steps),
    extra_filter
)

In [None]:
sns.lineplot(
    data=all_data,
    x="Total Environment Steps",
    y="Average Success Rate",
    hue="Algorithm",
    ci=95,
    # style="Environment",
    # palette="viridis",
)
plt.legend(loc="lower right")
# plt.legend(loc='upper left', bbox_to_anchor=(0.90, 1),
#           ncol=1, fancybox=True, shadow=True, fontsize=13)
plt.tight_layout()
plt.savefig(f"historical_buffer_sweep_{env}.pdf")
