In [None]:
import os
import json
from itertools import zip_longest
import numpy as np
import matplotlib.pyplot as plt
from experiments.base.iqm import get_iqm_and_conf_per_epoch

game = "BattleZone"
experiment_folders = [
    f"lr_1e-5_5e-5_1e-4_{game}/adadqnstatic",
    f"olr_working_lr_6.25e-5_{game}/dqn",
]

base_path = os.path.join(os.path.abspath(''), "exp_output")



experiment_data = {experiment: {} for experiment in experiment_folders}

for experiment in experiment_folders:
	experiment_path = os.path.join(base_path, experiment, "episode_returns_and_lenghts")
		
	returns_experiment_ = []

	for experiment_file in os.listdir(experiment_path):
		list_episode_returns = json.load(open(os.path.join(experiment_path, experiment_file), "r"))["episode_returns"]

		returns_experiment_.append([np.mean(episode_returns) for episode_returns in list_episode_returns])

	returns_experiment = np.array(list(zip_longest(*returns_experiment_, fillvalue=np.nan))).T

	print(f"Plot {experiment} with {returns_experiment.shape[0]} seeds.")
	if np.isnan(returns_experiment).any():
		seeds = np.array(list(map(lambda path: int(path.strip(".json")), os.listdir(experiment_path))))
		print(f"!!! Seeds {seeds[np.isnan(returns_experiment).any(axis=1)]} are not complete !!!")

	experiment_data[experiment]["iqm"], experiment_data[experiment]["confidence"]  = get_iqm_and_conf_per_epoch(returns_experiment)
	experiment_data[experiment]["x_values"] = np.arange(1, returns_experiment.shape[1] + 1)

In [None]:
plt.rc("font", family="serif", serif="Times New Roman", size=15)
plt.rc("lines", linewidth=4)

fig = plt.figure()
ax = fig.add_subplot(111)

for experiment in experiment_folders:
	ax.plot(
		experiment_data[experiment]["x_values"],
		experiment_data[experiment]["iqm"],
		label=experiment.split("/")[1],
	)
	ax.fill_between(
		experiment_data[experiment]["x_values"],
		experiment_data[experiment]["confidence"][0],
		experiment_data[experiment]["confidence"][1],
		alpha=0.3,
	)

ax.set_xlabel("Env Steps")
ax.set_ylabel("IQM Human Norm Score", fontsize=15)

ax.grid()
ax.legend()
_ = ax.set_title(game)