# Randomized Image Sampling for Explanations (RISE)

In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt
from skimage.transform import resize as resize1
from tqdm import tqdm
import tensorflow as tf
import tensorflow.keras as keras
import gym
import tensorforce
from tensorforce import Agent, Environment

from PIL import Image
import torchvision.transforms as T
import torch

import math
import random
import numpy as np

In [None]:
import retro
import time
from tensorforce import Agent, Environment

## Cargar agente y definición de entorno

In [None]:
agent = Agent.load(directory='DQN-SPACEINVADERS-NES-VISION-HALF-GRAYSCALE-DISCRETIZADO')

In [None]:
class Discretizer(gym.ActionWrapper):
    """
    Wrap a gym environment and make it use discrete actions.
    Args:
        combos: ordered list of lists of valid button combinations
    """

    def __init__(self, env, combos):
        super().__init__(env)
        assert isinstance(env.action_space, gym.spaces.MultiBinary)
        buttons = env.unwrapped.buttons
        self._decode_discrete_action = []
        for combo in combos:
            arr = np.array([False] * env.action_space.n)
            for button in combo:
                arr[buttons.index(button)] = True
            self._decode_discrete_action.append(arr)

        self.action_space = gym.spaces.Discrete(len(self._decode_discrete_action))

    def action(self, act):
        return self._decode_discrete_action[act].copy()


class SpaceInvadersNesDiscretizer(Discretizer):
    def __init__(self, env):
      # We allow the character to stay still, move either way, or shoot standing still.
      super().__init__(env=env, combos=[['LEFT'], ['RIGHT'], ['A']])

In [None]:
env = retro.make(game='SpaceInvaders-Nes')
env = SpaceInvadersNesDiscretizer(env)
environment = Environment.create(environment=env)

In [None]:
print(env._decode_discrete_action)

In [None]:
buttons = env.unwrapped.buttons
print(buttons)

In [None]:
agent.tracked_tensors()
states = environment.reset()
print(states.shape[0])
img_ = Image.fromarray(states, 'RGB')
plt.imshow(img_)

## Métodos específicos RISE

In [None]:
def generate_masks(N, s, p1, dimx = 40, dimy = 90):
    cell_size = np.ceil(np.array((dimx,dimy)) / s) # si dividimos en s "cachos" dim entera
    up_size = (s + 1) * cell_size # si cupiese uno mas dim
    
    grid = np.random.rand(N, s, s) < p1 #generar cuadrado s * s con 0 o 1
    grid = grid.astype('float32')
    masks = np.empty((N, dimx,dimy)) #lo rellenaremos

    for i in tqdm(range(N), desc='Generating masks'):
        # Random shifts
        x = np.random.randint(0, cell_size[0]) #cuando shift en x
        y = np.random.randint(0, cell_size[1]) #cuanto shift en y
        # Linear upsampling and cropping
        masks[i, :, :] = resize1(grid[i], up_size, order=1, mode='reflect',
                                anti_aliasing=False)[x:x + dimx, y:y + dimy]
        
    masks = masks.reshape(-1, dimx,dimy , 1)
    return masks

In [None]:
def explain(inp, masks , agent, dimx = 40, dimy = 90):
    preds = []
    # Make sure multiplication is being done for correct axes
    masked = inp * masks
    
    #print(masked.shape) # = (2000, 40, 90, 3) => N alteraciones de la imagen x 
    for i in tqdm(range (0,N), desc = 'Explaining'):
        decision = agent.act(states=masked[i], independent = True)
        elem = agent.tracked_tensors()['agent/policy/action-values']
        softmax = tf.nn.softmax(elem).numpy()
        preds.append(softmax)
    
    preds = np.array(preds)
    

    sal = preds.T.dot(masks.reshape(N, -1)).reshape(-1, dimx,dimy)
    sal = sal / N / p1
    return sal


In [None]:
combos=[['LEFT'], ['RIGHT'], ['A']]
def saliency(class_idx,img, sal):
    plt.title('Explanation for `{}`'.format(combos[class_idx]))
    plt.axis('off')
    plt.imshow(img)
    plt.imshow(sal[class_idx], cmap='jet', alpha=0.5)
    plt.show()

## Creación de la explicación

In [None]:
states = environment.reset()
terminal = False
#constantes generacion mascaras
N = 2000
s = 8
p1 = 0.9
masks = generate_masks(N, s, p1,states.shape[0],states.shape[1])

In [None]:
i=0
while(not terminal):
    img = Image.fromarray(states, 'RGB')
    actions = agent.act(states=states)
    if i % 20 == 0 or actions == 2:
        sal = explain(states, masks , agent,states.shape[0],states.shape[1])
        saliency(actions,img,sal)
    states, terminal, reward = environment.execute(actions=actions)
    agent.observe(terminal=terminal, reward=reward)
    i = i+1

In [None]:
environment.close()