# Plotting code for CS267 IRL project

In [None]:
%matplotlib inline

import os.path as osp

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import yaml

# change context="paper" for final report
sns.set(context="poster", style="darkgrid")

## Plots for small-scale neural network training

In [None]:
small_data = pd.read_csv("./small-expts/all.csv")
small_data["group"] = small_data["type"]
has_gpu = small_data["gpus"] > 0
small_data.loc[has_gpu, "group"] = "gpu"

grid = sns.FacetGrid(small_data, col="batch_size", row="hiddens",
                    sharex=False, sharey=False, height=4,
                     legend_out=True, aspect=1.8) \
  .map_dataframe(sns.barplot, x="group", hue="cores", y="samples_per_s") \
  .set_axis_labels("Strategy", "Samples/s") \
  .set_titles("$b={col_name}$, $h=({row_name})$") \
  .add_legend(title="Cores")
plt.savefig(osp.expanduser('~/small-expts-plot.png'))
plt.show()

## Plots for training sample throughput

In [None]:
def plot_expt(expt_id):
    config = yaml.safe_load(open("./data/sacred-runs/%d/config.json" % expt_id, "r"))
    print("Run %d, %d workers, %d gpus" % (expt_id, config["td3_conf"]["num_workers"], config["tf_configs"]["discrim"]["gpu_num"] is not None))
    data = pd.read_csv("./data/sacred-runs/%d/progress.csv" % expt_id)
    def add_nested_column(data, column, key):
        def add_col(cell):
            d = yaml.safe_load(cell)
            return d[key]
        data[column + "." + key] = data[column].map(add_col)
    add_nested_column(data, "info", "num_steps_trained")
    add_nested_column(data, "info", "num_steps_sampled")
    add_nested_column(data, "evaluation", "episode_reward_mean")
    data['throughput_train_disc'] = data['disc_samples_seen'] / data['time_since_restore']
    data['throughput_train_td3'] = data['info.num_steps_trained'] / data['time_since_restore']
    data['throughput_rollouts'] = data['info.num_steps_sampled'] / data['time_since_restore']
    sns.lineplot(x=data["time_since_restore"], y=data["throughput_train_disc"], label='disc_throughput')
    sns.lineplot(x=data["time_since_restore"], y=data["throughput_train_td3"], label='td3_throughput')
    sns.lineplot(x=data["time_since_restore"], y=data["throughput_rollouts"], label='rollout_throughput')
    plt.title("throughput for id %d" % expt_id)
    plt.yscale("log")
    plt.show()
    plt.title("reward for id %d" % expt_id)
    sns.lineplot(data=data, x="time_since_restore", y="evaluation.episode_reward_mean")
    plt.show()
plot_expt(112)
plot_expt(113)
plot_expt(114)

## Loss curves