# Randomized Image Sampling for Explanations (RISE)

In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt
from skimage.transform import resize as resize1
from tqdm import tqdm
import tensorflow as tf
import tensorflow.keras as keras
import gym
import tensorforce
from tensorforce import Agent, Environment

from PIL import Image
import torchvision.transforms as T
import torch

import math
import random
import numpy as np

In [None]:
import retro
import time
from tensorforce import Agent, Environment

## Cargar agente y definición de entorno

In [None]:
agent = Agent.load(directory='DQN-ATARI-VISION-GRAYSCALE-HALF-COMPLETO')

In [None]:
env = gym.make('SpaceInvaders-v0')
environment = Environment.create(environment=env)

In [None]:
meanings = env.unwrapped.get_action_meanings()

In [None]:
agent.tracked_tensors()
states = environment.reset()
print(states.shape[0])
img_ = Image.fromarray(states, 'RGB')
plt.imshow(img_)

## Métodos específicos RISE

In [None]:
def generate_masks(N, s, p1, dimx, dimy):
    cell_size = np.ceil(np.array((dimx,dimy)) / s) # si dividimos en s "cachos" dim entera
    up_size = (s + 1) * cell_size # si cupiese uno mas dim
    
    grid = np.random.rand(N, s, s) < p1 #generar cuadrado s * s con 0 o 1
    grid = grid.astype('float32')
    masks = np.empty((N, dimx,dimy)) #lo rellenaremos

    for i in tqdm(range(N), desc='Generating masks'):
        # Random shifts
        x = np.random.randint(0, cell_size[0]) #cuando shift en x
        y = np.random.randint(0, cell_size[1]) #cuanto shift en y
        # Linear upsampling and cropping
        masks[i, :, :] = resize1(grid[i], up_size, order=1, mode='reflect',
                                anti_aliasing=False)[x:x + dimx, y:y + dimy]
        
    masks = masks.reshape(-1, dimx,dimy , 1)
    return masks

In [None]:
def explain(inp, masks , agent, dimx, dimy):
    preds = []
    # Make sure multiplication is being done for correct axes
    masked = inp * masks
    
    for i in tqdm(range (0,N), desc = 'Explaining'):
        decision = agent.act(states=masked[i], independent = True, deterministic = True)
        elem = agent.tracked_tensors()['agent/policy/action-values']
        softmax = tf.nn.softmax(elem).numpy()
        preds.append(softmax)
    
    preds = np.array(preds)
    

    sal = preds.T.dot(masks.reshape(N, -1)).reshape(-1, dimx,dimy)
    sal = sal / N / p1
    return sal


In [None]:
def saliency(class_idx,img, sal):
    plt.title('Explanation for `{}`'.format(meanings[class_idx]))
    plt.axis('off')
    plt.imshow(img)
    plt.imshow(sal[class_idx], cmap='jet', alpha=0.5)
    plt.show()

## Creación de la explicación

In [None]:
states = environment.reset()
terminal = False
#constantes generacion mascaras
N = 2000
s = 8
p1 = 0.9
masks = generate_masks(N, s, p1,states.shape[0],states.shape[1])

In [None]:
i=0
while(not terminal):
    img = Image.fromarray(states, 'RGB')
    actions = agent.act(states=states, independent = True, deterministic = True)
    if i % 10 == 0:
        sal = explain(states, masks , agent,states.shape[0],states.shape[1])
        saliency(actions,img,sal)
    states, terminal, reward = environment.execute(actions=actions)
    i = i+1

In [None]:
environment.close()