## Imports

In [1]:
# Super Mario Bros env dependencies
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, RIGHT_ONLY
import gym

# Torch
import torch
import torch.nn as nn

# Stable Baselines
from stable_baselines3.common.vec_env import DummyVecEnv

# Networks to Evaluate
from ..DQN.Agent import MarioNet
from ..a2c.a2c.model import ACNetwork

# CNN Visualisation (Lucent)
from lucent.optvis import render, param, transform, objectives
from lucent.misc.io import show

# Utilities
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
import os

def create_env(random = False, record = None):
    env_name = "SuperMarioBros"
    if random:
        env_name += "RandomStage"
    env_name += '-v3'
    env = gym_super_mario_bros.make(env_name)
    env = JoypadSpace(env, SIMPLE_MOVEMENT)
    if record:
        print("Setting up recorder")
        out_dir, max_length, model_name = record
        assert os.path.isdir(out_dir)
        env = DummyVecEnv([lambda: env])
        from stable_baselines3.common.vec_env import VecVideoRecorder
        env = VecVideoRecorder(env, video_folder=out_dir, record_video_trigger=lambda _: 0, video_length=max_length, name_prefix=f"{model_name}")
    return env, env_name

  from .autonotebook import tqdm as notebook_tqdm


Can torch see a GPU via cuda? Yes


## Model and Env Prep

In [None]:
def create_env(random = False, record = None):
    env_name = "SuperMarioBros"
    if random:
        env_name += "RandomStage"
    env_name += '-v3'
    env = gym_super_mario_bros.make(env_name)
    env = JoypadSpace(env, SIMPLE_MOVEMENT)
    if record:
        print("Setting up recorder")
        out_dir, max_length, model_name = record
        assert os.path.isdir(out_dir)
        env = DummyVecEnv([lambda: env])
        from stable_baselines3.common.vec_env import VecVideoRecorder
        env = VecVideoRecorder(env, video_folder=out_dir, record_video_trigger=lambda _: 0, video_length=max_length, name_prefix=f"{model_name}")
    return env, env_name



## Setup Feature Visualisation Functions

In [None]:
# Code derived from: https://colab.research.google.com/github/greentfrapp/lucent-notebooks/blob/master/notebooks/feature_inversion.ipynb#scrollTo=d47pkOPKvNjs
@objectives.wrap_objective()
def dot_compare(layer, batch=1, cossim_pow=0):
    def inner(T):
        dot = (T(layer)[batch] * T(layer)[0]).sum()
        mag = torch.sqrt(torch.sum(T(layer)[0]**2))
        cossim = dot/(1e-6 + mag)
        return -dot * cossim ** cossim_pow
    return inner

transforms = [
    transform.pad(8, mode='constant', constant_value=.5),
    transform.jitter(8),
    transform.random_scale([0.9, 0.95, 1.05, 1.1] + [1]*4),
    transform.random_rotate(list(range(-5, 5)) + [0]*5),
    transform.jitter(2),
]

def get_param_f(img, device, param):
    img = torch.tensor(np.transpose(img, [2, 0, 1])).to(device)
    # Initialize parameterized input and stack with target image
    # to be accessed in the objective function
    params, image_f = param.image(img.shape[1], channels=img.shape[0])
    def stacked_param_f():
        return params, lambda: torch.stack([image_f()[0], img])

    return stacked_param_f

def feature_inversion(img, layer, model, n_steps=512, cossim_pow=0.0):  
    obj = objectives.Objective.sum([
    1.0 * dot_compare(layer, cossim_pow=cossim_pow),
    objectives.blur_input_each_step(),
    ])

    param_f = get_param_f(img)
    images = render.render_vis(model, obj, param_f, transforms=transforms, preprocess=False, thresholds=(n_steps,), show_image=False)
    return images

def visualise_cnn_layers(src_img_path, model, convergence_steps, out_img_path):
    image = np.array(Image.open(src_img_path), np.float32)
    if len(image.shape) == 2:
        image = image.reshape((image.shape[0], image.shape[1], 1))

    # Extract Conv Layers
    layers = ['0', '2', '4']
    images = []
    for layer in layers:
        print(layer)
        images = images + feature_inversion(image, layer, model, n_steps=convergence_steps)
        print()
        print([len(img) for img in images])
    images = [images[0][1]] + [cnn_act_img[0] for cnn_act_img in images]
    _, axs = plt.subplots(1, len(images))
    for ax, image in zip(axs, images):
        ax.imshow(image)
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
    axs[0].set_title("Input")
    plt.savefig(f"{out_img_path}.png")

def visualise_cnn_layer_neurons(src_img_path, model, convergence_steps, out_img_path):
    image = np.array(Image.open(src_img_path), np.float32)
    if len(image.shape) == 2:
        image = image.reshape((image.shape[0], image.shape[1], 1))
    param_f = get_param_f(image)
    
    # Extract Conv Layers
    # Extract Neurons (Feature depth)
    layers = {}
    images = []
    for layer_id, neurons in layers.items():
        for neuron in neurons:
            obj = f"{layer_id}:{neuron}"
            print(obj)
            images.append(render.render_vis(model, obj, param_f, transforms=transforms, preprocess=False, thresholds=(convergence_steps,), show_image=False))
            print()
            print([len(img) for img in images])
    images = [images[0][0][1]] + [[cnn_act_img[0] for cnn_act_img in layer_cnn_act_img] for layer_cnn_act_img in images]
    for layer in range(len(images)):
        layer_images = images[layer]
        _, axs = plt.subplots(1, len(layer_images))
        for ax, image in zip(axs, layer_images):
            ax.imshow(image)
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
        axs[0].set_title("Input")
        plt.savefig(f"{out_img_path}_layer-{layer}.png")

## A2C


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(args.model_path)
model = model.policy
model = model.features_extractor
model = model.cnn
print(model)
model = model.to(device).eval()

## DQN

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(args.model_path)
model = model.policy
model = model.features_extractor
model = model.cnn
print(model)
model = model.to(device).eval()