In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from experiments.atari import EXPERIMENTED_GAME


experiments = ["ut30_uh6000"]
baselines = ['REM'] # ['DQN (Nature)', 'Quantile (JAX)_dopamine', 'DQN (Adam)', 'C51', 'REM', 'Rainbow', 'IQN', 'M-IQN']
baselines_performance_profile = []
games = ["Breakout"]
ks_idqn = []
ks_iiqn = []
seeds = [11]
selected_epochs = np.arange(200) # np.array([1, 10, 25, 50, 75, 100, 125, 150, 175, 200]) - 1
taus = np.linspace(0.0, 8.0, 81)
show = {"dqn": False, "iqn": False, "rem": False, "head_std": False, "approximation_error": False, "std": True}

### Extract data

In [None]:
from idqn.utils.baselines_scores import get_baselines_scores

def collect_data(scores, algorithm, idqn_key_k="", idqn_key_path=""):
    for experiment in experiments:
        experiment_key = f"{algorithm}_{experiment}{idqn_key_k}"
        scores[experiment_key] = {}
        for game in games:
            scores[experiment_key][game] = np.zeros((200, len(seeds))) * np.nan
            for idx_seed, seed in enumerate(seeds):
                scores[experiment_key][game][:, idx_seed] = np.load(f"figures/{experiment}/{game}/{algorithm}/{idqn_key_path}J_{seed}.npy")

if show["dqn"]:
    dqn_scores = {}
    collect_data(dqn_scores, "DQN")

if show["iqn"]:
    iqn_scores = {}
    collect_data(iqn_scores, "IQN")

if show["rem"]:
    rem_scores = {}
    collect_data(rem_scores, "REM")

if len(ks_idqn) > 0:
    idqn_scores = {}
    for k in ks_idqn:
        collect_data(idqn_scores, "iDQN", f"_{k}", f"{k}_")

if len(ks_iiqn) > 0:
    iiqn_scores = {}
    for k in ks_iiqn:
        collect_data(iiqn_scores, "iIQN", f"_{k}", f"{k}_")

baselines_scores = get_baselines_scores(baselines, games)
baselines_performance_profile_scores = get_baselines_scores(baselines_performance_profile, games)

### IQM vs iterations & performance profile

In [None]:
import matplotlib.pyplot as plt
from idqn.utils.process_scores import compute_iqm_and_confidence_interval
from experiments.atari import COLORS, LABEL, ORDER


plt.rc("font", size=18)  # 21 for main paper, 18 for the table of figures and 15 big figures.
plt.rc("lines", linewidth=3)
fig = plt.figure("Main figure")
ax = fig.add_subplot(111)
fig_legend = plt.figure("Legend figure")
lines = []


def plot_iqm(scores, normalize=True):
    for experiment in scores.keys():
        iqms, iqms_confidence_interval = compute_iqm_and_confidence_interval(scores[experiment], selected_epochs, normalize)
        lines.append(ax.plot(selected_epochs + 1, iqms, label=LABEL[experiment], color=COLORS[experiment], zorder=ORDER[experiment])[0])
        if show["std"]:
            ax.fill_between(selected_epochs + 1, iqms_confidence_interval[0, :], iqms_confidence_interval[1, :], color=COLORS[experiment], zorder=ORDER[experiment], alpha=0.3)


if len(ks_iiqn) > 0:
    plot_iqm(iiqn_scores)

if len(ks_idqn) > 0:
    plot_iqm(idqn_scores)

if show["dqn"]:
    plot_iqm(dqn_scores)

if show["iqn"]:
    plot_iqm(iqn_scores)

if show["rem"]:
    plot_iqm(rem_scores)

plot_iqm(baselines_scores)

ax.grid(zorder=0)
# ax.set_xticklabels([])
ax.set_xlabel("Number of Frames (in millions)")
# ax.set_ylabel("IQM Human Normalized Score")


if len(lines) < 6:
    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))
else:
    import itertools
    ncols = int(np.ceil(len(lines) / 2))
    def flip(items):
        return itertools.chain(*[items[i::ncols] for i in range(ncols)])
    fig_legend.legend(flip(lines), flip([line.get_label() for line in lines]), ncols=ncols)

if len(games) == 1 and len(experiments) > 0:
    ax.set_title(games[0])
    fig.savefig(f"figures/{experiments[0]}/{games[0]}/J.pdf", bbox_inches='tight')
    _ = fig_legend.savefig(f"figures/{experiments[0]}/{games[0]}/J_legend.pdf", bbox_inches='tight')
elif len(experiments) > 0:
    fig.savefig(f"figures/{experiments[0]}/J.pdf", bbox_inches='tight')
    _ = fig_legend.savefig(f"figures/{experiments[0]}/J_legend.pdf", bbox_inches='tight')

In [None]:
if len(games) > 1:
    from idqn.utils.process_scores import compute_performance_profile_and_confidence_interval

    plt.rc("font", size=15)
    plt.rc("lines", linewidth=3)
    fig = plt.figure("Main figure")
    ax = fig.add_subplot(111)
    fig_legend = plt.figure("Legend figure")
    lines = []


    def plot_performance_profile(scores):
        for experiment in scores.keys():
            performance_profile, performance_profile_confidence_interval = compute_performance_profile_and_confidence_interval(scores[experiment], taus)
            lines.append(ax.plot(taus, performance_profile, label=LABEL[experiment], color=COLORS[experiment], zorder=ORDER[experiment])[0])
            if show["std"]:
                ax.fill_between(taus, performance_profile_confidence_interval[0, :], performance_profile_confidence_interval[1, :], color=COLORS[experiment], zorder=ORDER[experiment], alpha=0.3)


    if show["dqn"]:
        plot_performance_profile(dqn_scores)

    if show["iqn"]:
        plot_performance_profile(iqn_scores)

    if show["rem"]:
        plot_performance_profile(rem_scores)

    if len(ks_idqn) > 0:
        plot_performance_profile(idqn_scores)

    if len(ks_iiqn) > 0:
        plot_performance_profile(iiqn_scores)

    plot_performance_profile(baselines_performance_profile_scores)

    ax.grid(zorder=0)
    ax.set_xlabel(r"Human Normalized Score $(\tau)$")
    ax.set_ylabel(r"Fraction of runs with score $> \tau$")
    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))
    if len(experiments) > 0:
        fig.savefig(f"figures/{experiments[0]}/P.pdf", bbox_inches='tight')
        _ = fig_legend.savefig(f"figures/{experiments[0]}/P_legend.pdf", bbox_inches='tight')

### Head std

In [None]:
if show["head_std"]:
    plt.rc("font", size=15)
    plt.rc("lines", linewidth=3)
    fig = plt.figure("Main figure")
    ax = fig.add_subplot(111)
    fig_legend = plt.figure("Legend figure")
    lines = []

    head_stds = {}
    for experiment in experiments:
        for k in ks_idqn:
            head_stds[f"iDQN_{experiment}_{k}"] = {}
            for game in games:
                head_stds[f"iDQN_{experiment}_{k}"][game] = np.zeros((200, len(seeds))) * np.nan
                for idx_seed, seed in enumerate(seeds):
                    head_stds[f"iDQN_{experiment}_{k}"][game][:, idx_seed] = np.load(f"figures/{experiment}/{game}/iDQN/{k}_S_{seed}.npy")


    plot_iqm(head_stds, normalize=False)

    ax.grid(zorder=0)
    ax.set_xlabel("Number of Frames (in millions)")
    ax.set_ylabel("IQM inter-head standard deviation")
    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))
    if len(games) == 1 and len(experiments) > 0:
        ax.set_title(games[0])
        fig.savefig(f"figures/{experiments[0]}/{games[0]}/S.pdf", bbox_inches='tight')
        _ = fig_legend.savefig(f"figures/{experiments[0]}/{games[0]}/S_legend.pdf", bbox_inches='tight')

### Approximation error

In [None]:
if show["approximation_error"]:
    plt.rc("font", size=15)
    plt.rc("lines", linewidth=3)
    fig = plt.figure("Main figure")
    ax = fig.add_subplot(111)
    fig_legend = plt.figure("Legend figure")
    lines = []

    approximation_errors = {}
    for experiment in experiments:
        for k in ks_idqn:
            approximation_errors[f"{experiment}_{k}"] = {}
            for game in games:
                approximation_errors[f"{experiment}_{k}"][game] = np.zeros((200, len(seeds))) * np.nan
                for idx_seed, seed in enumerate(seeds):
                    approximation_errors[f"{experiment}_{k}"][game][:, idx_seed] = np.load(f"figures/{experiment}/{game}/iDQN/{k}_A_{seed}.npy")


    plot_iqm(approximation_errors)

    ax.grid(zorder=0)
    ax.set_xlabel("Number of Frames (in millions)")
    ax.set_ylabel("IQM approximation error")
    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))
    if len(games) == 1 and len(experiments) > 0:
        ax.set_title(games[0])
        fig.savefig(f"figures/{experiments[0]}/{games[0]}/A.pdf", bbox_inches='tight')
        _ = fig_legend.savefig(f"figures/{experiments[0]}/{games[0]}/A_legend.pdf", bbox_inches='tight')