# Train / Test errors for the autoencoder

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as onp
import os
import math

import matplotlib.pyplot as plt
from temporal_explanations_4_rl.autoencoder import AtariAutoencoder

from temporal_explanations_4_rl.plot import plot_embedding
from temporal_explanations_4_rl.utils import load_embedding
from temporal_explanations_4_rl.dataset import (
    load_atari_obs,
    load_state_values,
    load_trajectories,
)

from ipywidgets import HTML
from matplotlib.animation import FuncAnimation

## hyperparameters

In [None]:
root_dataset_folder = "../datasets/"

num_cols = 3
fig_width = 12
sub_axes_height = 3

## Plot the training with observation error

Plot the training error with mean squared error

In [None]:
directories = sorted(
    [
        directory
        for directory in os.listdir(f"{root_dataset_folder}/training")
        if os.path.exists(
            f"{root_dataset_folder}/training/{directory}/autoencoder-obs-model-embedding.npz"
        )
    ]
)
assert len(directories) > 0
num_rows = math.ceil(len(directories) / num_cols)
fig, axs = plt.subplots(
    nrows=num_rows, ncols=num_cols, figsize=(fig_width, num_rows * sub_axes_height)
)
for ax, directory in zip(axs.flatten(), directories):
    agent_name, env_name = directory.split("-")

    with onp.load(
        f"{root_dataset_folder}/training/{directory}/autoencoder-obs-model-embedding.npz",
        allow_pickle=True,
    ) as file:
        metadata = file["metadata"].item()

        ax.set_title(f"{agent_name} - {env_name}")
        clip_value = onp.max(metadata["testing loss"]) * 1.1
        ax.plot(
            onp.arange(10_000),
            onp.clip(metadata["training loss"], 0, clip_value),
            label="Training loss",
        )
        ax.plot(
            onp.arange(500, 10_001, 500),
            metadata["testing loss"],
            label="Testing loss",
            linestyle="dashed",
        )

axs[-1, -1].legend()
plt.tight_layout()
plt.savefig(f"figs/autoencoder-obs-training.png")

## Compare Atari Observation error to Policy error

Load a testing trajectory of a network with figure 1 having the correct obs, fig 2 having the obs error reconstruction, fig 3 having the policy error reconstruction

In [None]:
agent_name = "dqn_adam_mse"  # change here
env_name = "Pong"  # change here

trajectory_slice = slice(1000, 1200)

In [None]:
assert os.path.exists(
    f"{root_dataset_folder}/testing/{agent_name}-{env_name}/autoencoder-obs-embedding.npz"
)
# assert os.path.exists(
#     f"{root_dataset_folder}/testing/{agent_name}-{env_name}/autoencoder-policy-embedding.npz"
# )

trajectory_obs = load_atari_obs(
    f"{root_dataset_folder}/testing/{agent_name}-{env_name}/trajectories"
)[trajectory_slice]

autoencoder_obs_model = AtariAutoencoder()
autoencoder_obs_model.load_weights(
    f"{root_dataset_folder}/training/{agent_name}-{env_name}/autoencoder-obs-model"
)
# autoencoder_policy_model = AtariAutoencoder()
# autoencoder_policy_model.load_weights(
#     f"{root_dataset_folder}/training/{agent_name}-{env_name}/autoencoder-policy-model"
# )

obs_reconstructed_obs = autoencoder_obs_model(trajectory_obs)
# policy_reconstructed_obs = autoencoder_policy_model(trajectory_obs)
assert (
    trajectory_obs.shape
    == obs_reconstructed_obs.shape
    # == policy_reconstructed_obs.shape
)

# fig, axs = plt.subplots(ncols=3, figsize=(4, 8))
fig, axs = plt.subplots(ncols=2, figsize=(4, 8))
axs[0].set_title("True Obs")
axs[1].set_title("Reconstructed Obs Error")
# axs[2].set_title("Reconstructed Policy Error")

true_plot = axs[0].imshow(trajectory_obs[0, :, :, -1], cmap="gray")
obs_plot = axs[1].imshow(obs_reconstructed_obs[0, :, :, -1], cmap="gray")
# policy_plot = axs[2].imshow(policy_reconstructed_obs[0, :, :, -1], cmap='gray')


def make_frame(t):
    """Make a frame for time step `t`."""
    true_plot.set_data(trajectory_obs[t, :, :, -1])
    obs_plot.set_data(obs_reconstructed_obs[t, :, :, -1])
    # policy_plot.set_data(policy_reconstructed_obs[t, :, :, -1])

    return [true_plot, obs_plot]  # , policy_plot


animation_fn = FuncAnimation(
    fig, make_frame, frames=trajectory_obs.shape[0], interval=100
)

HTML(animation_fn.to_html5_video())

## Visualise embedding

In [None]:
dataset_path = f"{root_dataset_folder}/training/{agent_name}-{env_name}"

autoencoder_obs_tsne = load_embedding(f"{dataset_path}/autoencoder-obs-tsne.npz")
state_values = load_state_values(f"{dataset_path}/trajectories")
trajectories = load_trajectories(f"{dataset_path}/trajectories")

plot_embedding(
    autoencoder_obs_tsne,
    state_values,
    trajectories,
    title="t-SNE visualisation of the autoencoder obs-error embedding",
)
plt.savefig(f"{dataset_path}/autoencoder-obs-visualise.png")

In [None]:
# if not os.path.exists(f"{dataset_path}/autoencoder-policy-tsne.npz"):
#     autoencoder_policy_embedding = load_embedding(
#         f"{dataset_path}/autoencoder-policy-embedding.npz"
#     )
#     run_tsne(
#         autoencoder_policy_embedding,
#         save_path=f"{dataset_path}/autoencoder-policy-tsne.npz",
#     )
#
# autoencoder_policy_tsne = load_embedding(f"{dataset_path}/autoencoder-policy-tsne.npz")

In [None]:
# plot_embedding(
#     autoencoder_policy_tsne,
#     state_values,
#     trajectories,
#     title="t-SNE visualisation of the autoencoder policy-error embedding",
# )
# plt.savefig(f"{dataset_path}/autoencoder-policy-visualise.png")