# Explanations of RL agent with Baseline Temporal Explanations, Graying the Black Box, Perturbation Saliency maps

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from temporal_explanations_4_rl.agent_networks import load_dopamine_dqn_flax_model

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import tensorflow as tf

from temporal_explanations_4_rl.plan import Plan

from temporal_explanations_4_rl.dataset import (
    load_atari_obs,
    load_state_values,
    load_trajectories,
    load_discrete_actions,
    load_q_values,
)
from temporal_explanations_4_rl.explain import (
    generate_dataset_explanation,
    generate_skill_explanation,
)
from temporal_explanations_4_rl.graying_the_black_box import (
    load_network_features,
    pca_reduce_features,
    run_tsne,
    SpatioTemporalKMeans,
)
from temporal_explanations_4_rl.plot import animate_observations
from temporal_explanations_4_rl.explain import (
    generate_atari_perturbation_saliency_explanation,
    generate_atari_grad_cam_explanation,
    atari_greyscale_saliency_map,
)
from temporal_explanations_4_rl.skill import (
    skill_labels_to_trajectory_skills,
)
from temporal_explanations_4_rl.utils import (
    load_embedding,
)

## Hyperparameters

In [None]:
agent_name = "dqn_adam_mse"
env_name = "Breakout"

network_root_folder = "../models/dopamine/jax"
dataset_folder = f"../datasets/{agent_name}-{env_name}"

In [None]:
agent_model_def, agent_model_params = load_dopamine_dqn_flax_model(
    env_name, network_root_folder
)

In [None]:
training_dataset_obs = load_atari_obs(f"{dataset_folder}/trajectories")
training_dataset_state_values = load_state_values(f"{dataset_folder}/trajectories")
training_dataset_trajectories = load_trajectories(f"{dataset_folder}/trajectories")

## Generate observation to explain

In [None]:
testing_dataset_folder = f"../../datasets/explanation-testing/{agent_name}-{env_name}"

testing_obs_dataset = load_atari_obs(f"{testing_dataset_folder}/trajectories")
testing_action_dataset = load_discrete_actions(f"{testing_dataset_folder}/trajectories")
testing_q_values_dataset = load_q_values(f"{testing_dataset_folder}/trajectories")

f"{testing_obs_dataset.shape=}, {testing_action_dataset.shape=}, {testing_q_values_dataset.shape=}"

In [None]:
if os.path.exists(f"{testing_dataset_folder}/explanation-obs.npz"):
    with np.load(
        f"{testing_dataset_folder}/explanation-obs.npz", allow_pickle=True
    ) as file:
        num_explanations = int(file["num_explanations"])
        obs_length = int(file["obs_length"])
        explain_time_steps = np.array(
            [
                file[f"explain-{pos}"].item()["time_step"]
                for pos in range(num_explanations)
            ],
            dtype=int,
        )
        explain_obs = np.array(
            [file[f"explain-{pos}"].item()["obs"] for pos in range(num_explanations)],
            dtype=np.uint8,
        )
        explain_actions = np.array(
            [
                file[f"explain-{pos}"].item()["action"]
                for pos in range(num_explanations)
            ],
            dtype=int,
        )
else:
    num_explanations = 5
    padding = 20
    obs_length = 20

    explain_time_steps = np.zeros(num_explanations, dtype=int)
    mask = np.ones_like(testing_action_dataset)
    for pos in range(num_explanations):
        time_step = np.argmax(
            np.where(mask, np.std(testing_q_values_dataset, axis=1), mask)
        )
        explain_time_steps[pos] = time_step
        mask[
            np.arange(max(0, time_step - padding), min(time_step + padding, len(mask)))
        ] = 0

    explain_obs = testing_obs_dataset[explain_time_steps]
    np.savez_compressed(
        f"{testing_dataset_folder}/explanation-obs.npz",
        num_explanations=num_explanations,
        padding=padding,
        obs_length=obs_length,
        **{
            f"explain-{pos}": {
                "time_step": time_step,
                "obs": testing_obs_dataset[time_step],
                "action": testing_action_dataset[time_step],
            }
            for pos, time_step in enumerate(explain_time_steps)
        },
    )

f"{explain_obs.shape=}, {num_explanations=}, {obs_length=}"

In [None]:
if not os.path.exists(f"{testing_dataset_folder}/explanations"):
    os.mkdir(f"{testing_dataset_folder}/explanations")

for pos, time_step in enumerate(explain_time_steps):
    if not os.path.exists(f"{testing_dataset_folder}/explanations/explain-{pos}.mp4"):
        print(f"Animating explanation {pos=} at {time_step=}")
        _, _, animation = animate_observations(
            testing_obs_dataset[max(0, time_step - obs_length) : time_step],
            return_html_animation=False,
        )
        animation.save(f"{testing_dataset_folder}/explanations/explain-{pos}.mp4")
        plt.close()

## Agents have a cunning plan

In [None]:
env_autoencoder = tf.keras.models.load_models(
    f"{dataset_folder}/models/autoencoder.params"
)
policy_similarity_model = tf.keras.models.load_models(
    f"{dataset_folder}/models/policy-similarity.params"
)

pse_embedding = load_embedding(f"{dataset_folder}/embedding/pse-dense.npz")

### Dataset similarity explanation

In [None]:
if not os.path.exists(f"{testing_dataset_folder}/cunning-dataset-similarity"):
    os.mkdir(f"{testing_dataset_folder}/cunning-dataset-similarity")

for pos, obs in enumerate(explain_obs):
    print(f"{pos=}")
    obs_embedding = env_autoencoder.encode(obs)
    dataset_explanation_obs = generate_dataset_explanation(
        obs_embedding, pse_embedding, training_dataset_obs[..., -1]
    )

    for num, explanation in enumerate(tqdm(dataset_explanation_obs)):
        _, _, animation = animate_observations(explanation, return_html_animation=False)
        animation.save(
            f"{testing_dataset_folder}/cunning-dataset-similarity/explain-{pos}-version-{num}.mp4"
        )
        plt.close()

### Skill explanations

In [None]:
if not os.path.exists(f"{testing_dataset_folder}/cunning-skill-explain"):
    os.mkdir(f"{testing_dataset_folder}/cunning-skill-explain")

pse_clustering = None
pse_skill_transitions = None

for pos, obs in enumerate(tqdm(explain_obs)):
    obs_embedding = env_autoencoder.encode(obs)
    obs_skill = pse_clustering.label(obs_embedding)
    skill_explanation_obs = generate_skill_explanation(
        obs_embedding,
        obs_skill,
        pse_embedding,
        pse_skill_transitions,
        training_dataset_obs[..., -1],
    )

    for num, explanation in enumerate(skill_explanation_obs):
        _, _, animation = animate_observations(explanation, return_html_animation=False)
        animation.save(
            f"{testing_dataset_folder}/cunning-skill-explain/explain-{pos}-version-{num}.mp4"
        )
        plt.close()

## Graying the black box explanation

### Retrain graying the black box with explanation obs

In [None]:
if not os.path.exists(f"{testing_dataset_folder}/graying-embedding.npz"):
    print(f"Generating graying the black box embedding")
    training_dataset_dense_features = load_network_features(
        f"{dataset_folder}/dense-features.npz"
    )
    for time_step in explain_time_steps:
        _, network_state = agent_model_def.apply(
            agent_model_params, testing_obs_dataset[time_step], mutable="intermediates"
        )
        np.append(
            training_dataset_dense_features, network_state["intermediates"]["dense"][0]
        )

    graying_reduced_features, _, _ = pca_reduce_features(
        training_dataset_dense_features
    )
    print(f"Graying reduced features: {graying_reduced_features.shape}")
    run_tsne(
        graying_reduced_features,
        f"{testing_dataset_folder}/graying-embedding.npz",
        n_iter=1000,
    )

graying_embedding = load_embedding(f"{testing_dataset_folder}/graying-embedding.npz")

In [None]:
num_clusters, window_size = 15, 10
graying_clustering_model = SpatioTemporalKMeans(
    n_clusters=num_clusters, window_size=window_size
)
# We ignore the explanation obs as the trajectory information is unknown
clustering_labels = graying_clustering_model.fit(
    graying_embedding[:-num_explanations],
    training_dataset_state_values,
    training_dataset_trajectories,
)

skill_trajectories = skill_labels_to_trajectory_skills(
    clustering_labels, num_clusters, training_dataset_trajectories
)

In [None]:
if not os.path.exists(f"{testing_dataset_folder}/graying-skill-explain"):
    os.mkdir(f"{testing_dataset_folder}/graying-skill-explain")

for pos in tqdm(range(num_explanations)):
    explain_obs_embedding = graying_embedding[-num_explanations + pos]
    explain_obs_skill = np.argmax(
        np.linalg.norm(
            graying_clustering_model.cluster_centers - explain_obs_embedding, axis=-1
        )
    )
    plan = Plan(skill_trajectories)

    skill_explanation_obs = generate_skill_explanation(
        explain_obs_embedding,
        explain_obs_skill,
        graying_embedding[:-num_explanations],
        plan,
        training_dataset_obs[..., -1],
    )

    for num, explanation in enumerate(skill_explanation_obs):
        _, _, animation = animate_observations(explanation, return_html_animation=False)
        animation.save(
            f"{testing_dataset_folder}/cunning-skill-explain/explain-{pos}-version-{num}.mp4"
        )
        plt.close()

## Perturbation-based saliency map

In [None]:
if not os.path.exists(f"{testing_dataset_folder}/perturbation-saliency"):
    os.mkdir(f"{testing_dataset_folder}/perturbation-saliency")

for pos, agent_obs in enumerate(tqdm(explain_obs)):
    saliency_map, saliency_values = generate_atari_perturbation_saliency_explanation(
        np.expand_dims(agent_obs, axis=0), agent_model_def, agent_model_params
    )

    fig, ax = plt.subplots()
    ax.imshow(
        atari_greyscale_saliency_map(np.expand_dims(agent_obs, axis=0), saliency_map),
        cmap="gray",
    )
    plt.savefig(f"{testing_dataset_folder}/perturbation-saliency/explain-{pos}.png")
    plt.close()

## Grad-cam saliency maps

In [None]:
if not os.path.exists(f"{testing_dataset_folder}/grad-cam-saliency"):
    os.mkdir(f"{testing_dataset_folder}/grad-cam-saliency")

for pos, obs in enumerate(tqdm(explain_obs)):
    saliency_map, _ = generate_atari_grad_cam_explanation(
        agent_model_def, agent_model_params, obs, explain_actions[pos]
    )

    fig, ax = plt.subplots()
    ax.imshow(atari_greyscale_saliency_map(obs, saliency_map))
    plt.savefig(f"{testing_dataset_folder}/grad-cam-saliency/explain-{pos}.png")
    plt.close()