## Imports

In [38]:
# Super Mario Bros env dependencies
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, RIGHT_ONLY
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack, GrayScaleObservation, TransformObservation

# Torch
import torch
import torch.nn as nn

# Stable Baselines
from stable_baselines3.common.vec_env import DummyVecEnv

# Networks to Evaluate
import sys, os
def add_to_path(model_dir):
    notebook_file = os.path.dirname("CNN_Feature_Visualisation.ipynb")
    path2add = os.path.normpath(os.path.abspath(os.path.join(notebook_file, os.path.pardir, model_dir)))
    if (not (path2add in sys.path)):
        sys.path.append(path2add)
add_to_path('DQN')
from Agent import MarioNet, Mario
add_to_path('a2c/a2c')
from model import ACNetwork

# CNN Visualisation (Lucent)
from lucent.optvis import render, param, transform, objectives
from lucent.misc.io import show

# Utilities
add_to_path('a2c/utils')
from wrappers import ResizeObservation, SkipFrame
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from pathlib import Path

## Model and Env Prep

In [8]:
def create_env(random = False, record = None):
    env_name = "SuperMarioBros"
    if random:
        env_name += "RandomStage"
    env_name += '-v3'
    env = gym_super_mario_bros.make(env_name)
    env = JoypadSpace(env, SIMPLE_MOVEMENT)
    if record:
        print("Setting up recorder")
        out_dir, max_length, model_name = record
        assert os.path.isdir(out_dir)
        env = DummyVecEnv([lambda: env])
        from stable_baselines3.common.vec_env import VecVideoRecorder
        env = VecVideoRecorder(env, video_folder=out_dir, record_video_trigger=lambda _: 0, video_length=max_length, name_prefix=f"{model_name}")
    env.reset()
    return env, env_name



## Setup Feature Visualisation Functions

In [9]:
# Code derived from: https://colab.research.google.com/github/greentfrapp/lucent-notebooks/blob/master/notebooks/feature_inversion.ipynb#scrollTo=d47pkOPKvNjs
@objectives.wrap_objective()
def dot_compare(layer, batch=1, cossim_pow=0):
    def inner(T):
        dot = (T(layer)[batch] * T(layer)[0]).sum()
        mag = torch.sqrt(torch.sum(T(layer)[0]**2))
        cossim = dot/(1e-6 + mag)
        return -dot * cossim ** cossim_pow
    return inner

transforms = [
    transform.pad(8, mode='constant', constant_value=.5),
    transform.jitter(8),
    transform.random_scale([0.9, 0.95, 1.05, 1.1] + [1]*4),
    transform.random_rotate(list(range(-5, 5)) + [0]*5),
    transform.jitter(2),
]

def get_param_f(img, device, param):
    img = torch.tensor(np.transpose(img, [2, 0, 1])).to(device)
    # Initialize parameterized input and stack with target image
    # to be accessed in the objective function
    params, image_f = param.image(img.shape[1], channels=img.shape[0])
    def stacked_param_f():
        return params, lambda: torch.stack([image_f()[0], img])

    return stacked_param_f

def feature_inversion(img, layer, model, n_steps=512, cossim_pow=0.0):  
    obj = objectives.Objective.sum([
    1.0 * dot_compare(layer, cossim_pow=cossim_pow),
    objectives.blur_input_each_step(),
    ])

    param_f = get_param_f(img)
    images = render.render_vis(model, obj, param_f, transforms=transforms, preprocess=False, thresholds=(n_steps,), show_image=False)
    return images

def visualise_cnn_layers(src_img_path, model, convergence_steps, out_img_path):
    image = np.array(Image.open(src_img_path), np.float32)
    if len(image.shape) == 2:
        image = image.reshape((image.shape[0], image.shape[1], 1))

    # Extract Conv Layers
    layers = ['0', '2', '4']
    images = []
    for layer in layers:
        print(layer)
        images = images + feature_inversion(image, layer, model, n_steps=convergence_steps)
        print()
        print([len(img) for img in images])
    images = [images[0][1]] + [cnn_act_img[0] for cnn_act_img in images]
    _, axs = plt.subplots(1, len(images))
    for ax, image in zip(axs, images):
        ax.imshow(image)
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
    axs[0].set_title("Input")
    plt.savefig(f"{out_img_path}.png")

def visualise_cnn_layer_neurons(src_img_path, model, convergence_steps, out_img_path):
    image = np.array(Image.open(src_img_path), np.float32)
    if len(image.shape) == 2:
        image = image.reshape((image.shape[0], image.shape[1], 1))
    param_f = get_param_f(image)
    
    # Extract Conv Layers
    # Extract Neurons (Feature depth)
    layers = {}
    images = []
    for layer_id, neurons in layers.items():
        for neuron in neurons:
            obj = f"{layer_id}:{neuron}"
            print(obj)
            images.append(render.render_vis(model, obj, param_f, transforms=transforms, preprocess=False, thresholds=(convergence_steps,), show_image=False))
            print()
            print([len(img) for img in images])
    images = [images[0][0][1]] + [[cnn_act_img[0] for cnn_act_img in layer_cnn_act_img] for layer_cnn_act_img in images]
    for layer in range(len(images)):
        layer_images = images[layer]
        _, axs = plt.subplots(1, len(layer_images))
        for ax, image in zip(axs, layer_images):
            ax.imshow(image)
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
        axs[0].set_title("Input")
        plt.savefig(f"{out_img_path}_layer-{layer}.png")

## A2C


In [36]:
env, env_name = create_env(record=('out/video', 1000, 'A2C'))
env = SkipFrame(env, skip=4)
env = ResizeObservation(env, shape=84) # image dim: [84, 84]
env = GrayScaleObservation(env, keep_dim=False) # Grayscale images
env = FrameStack(env, num_stack=4) # 4 frames at a time
obs = (4, 84, 84)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ACNetwork(obs, env.action_space.n)
checkpoint = torch.load('checkpoints/a2c/a2c_rollout10_ep100k.pt', map_location=device)
model.load_state_dict(checkpoint['model'])
print(model)
model = model.eval()

Setting up recorder
ACNetwork(
  (conv): Sequential(
    (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (6): ReLU()
    (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (actor): Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=7, bias=True)
  )
  (critic): Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=1, bias=True)
  )
)


## DQN

In [40]:
env, env_name = create_env(record=('out/video', 1000, 'DQN'))
env = JoypadSpace(
    env,
    [['right'],
    ['right', 'A']]
)
env = SkipFrame(env, skip=4)
env = ResizeObservation(env, shape=84) # image dim: [84, 84]
env = GrayScaleObservation(env, keep_dim=False) # Grayscale images
env = FrameStack(env, num_stack=4) # 4 frames at a time
obs = (4, 84, 84)

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Mario(state_dim=obs, action_dim=env.action_space.n, save_dir=".")
path = Path('checkpoints') / 'dqn/mario_net_12.chkpt'
model.load(path)
model = model.net
print(model)
model.eval()
# model.online_features.load_state_dict(checkpoint[])
# model.load_state_dict(torch.load('checkpoints/dqn/mario_net_4.chkpt'))
# # print(model)
# print(model)
# model = model.eval()
# print(model)

# visualise_cnn_layers

Setting up recorder
Loading model at checkpoints/dqn/mario_net_12.chkpt with exploration rate 0.11943293650685695


RuntimeError: Error(s) in loading state_dict for MarioNet:
	Missing key(s) in state_dict: "online_features.0.weight", "online_features.0.bias", "online_features.2.weight", "online_features.2.bias", "online_features.4.weight", "online_features.4.bias", "online_features.6.weight", "online_features.6.bias", "online_td_est.0.weight", "online_td_est.0.bias", "online_td_est.2.weight", "online_td_est.2.bias", "online_aux.0.weight", "online_aux.0.bias", "online_aux.2.weight", "online_aux.2.bias", "target_features.0.weight", "target_features.0.bias", "target_features.2.weight", "target_features.2.bias", "target_features.4.weight", "target_features.4.bias", "target_features.6.weight", "target_features.6.bias", "target_td_est.0.weight", "target_td_est.0.bias", "target_td_est.2.weight", "target_td_est.2.bias". 
	Unexpected key(s) in state_dict: "online.0.weight", "online.0.bias", "online.2.weight", "online.2.bias", "online.5.weight", "online.5.bias", "online.7.weight", "online.7.bias", "online.10.weight", "online.10.bias", "online.12.weight", "online.12.bias", "target.0.weight", "target.0.bias", "target.2.weight", "target.2.bias", "target.5.weight", "target.5.bias", "target.7.weight", "target.7.bias", "target.10.weight", "target.10.bias", "target.12.weight", "target.12.bias". 