In [None]:
%load_ext autoreload
%autoreload 2

import json
import numpy as np


experiment_name = "r20000_d50000_b512_t1000"
show_dqn = True; show_prodqn = False
bellman_iterations_scope = 5
seeds = [1]
p = json.load(open(f"figures/{experiment_name}/parameters.json"))

### Extract data

In [None]:
if show_dqn:
    dqn_j = np.zeros((len(seeds), p["n_epochs"])) * np.nan

    for idx_seed, seed in enumerate(seeds):
        dqn_j[idx_seed] = np.load(f"figures/{experiment_name}/DQN/J_{seed}.npy")

if show_prodqn:
    prodqn_j = np.zeros((len(seeds), p["n_epochs"])) * np.nan

    for idx_seed, seed in enumerate(seeds):
        prodqn_j[idx_seed] = np.load(f"figures/{experiment_name}/ProDQN/{bellman_iterations_scope}_J_{seed}.npy")

### Plot performance

In [None]:
import matplotlib.pyplot as plt 
from experiments import colors

from pbo.utils.confidence_interval import means_and_confidence_interval

plt.rc("font", size=15)
plt.rc("lines", linewidth=3)

iterations = range(p["n_epochs"])


if show_dqn:
    dqn_j_mean, dqn_j_confidence_interval = means_and_confidence_interval(dqn_j)

    plt.plot(iterations, dqn_j_mean, label="DQN", color=colors["DQN"], zorder=2)
    plt.fill_between(iterations, dqn_j_confidence_interval[0], dqn_j_confidence_interval[1], color=colors["DQN"], alpha=0.3)

if show_prodqn:
    prodqn_j_mean, prodqn_j_confidence_interval = means_and_confidence_interval(prodqn_j)

    plt.plot(iterations, prodqn_j_mean, label="ProDQN (ours)", color=colors["ProDQN"], zorder=3)
    plt.fill_between(iterations, prodqn_j_confidence_interval[0], prodqn_j_confidence_interval[1], color=colors["ProDQN"], alpha=0.3)

plt.xticks(range(0, p["n_epochs"], p["n_epochs"] // 5))
plt.xlabel("#Epochs")
plt.title(r"$ J_i $")
plt.legend().set_zorder(1)
plt.grid(zorder=0)

_ = plt.savefig(f"figures/{experiment_name}/J_{bellman_iterations_scope}.pdf", bbox_inches='tight')