In [None]:
%%HTML
<!-- Mejorar visualización en proyector -->
<style>
.rendered_html {font-size: 1.2em; line-height: 150%;}
div.prompt {min-width: 0ex; padding: 0px;}
.container {width:95% !important;}
</style>

In [None]:
%autosave 0
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import torch

# DQN a partir de píxeles

A continuación entrenaremos una Deep Q network para jugar *Space invaders* a partir de imágenes

El ambiante de Open AI es

In [None]:
import gym

#env = gym.make("SpaceInvaders-v0") 
env = gym.make("PongNoFrameskip-v4")
state = env.reset()

display(state.shape)

fig, ax = plt.subplots(figsize=(3, 4), tight_layout=True)
ax.axis('off')
ax.imshow(state);

Notemos que el estado es una imagen de 210 x 160 x 3 pixeles

En primer lugar haremos un preprocesamiento simple 
1. Combinar los canales en una imagen en blanco y negro
1. Reescalar a 110x84 píxeles
1. Convertir los pixeles a float y normalizar al rango [0, 1]
1. Crear un stack de cuatro frames

Para esto usaremos la librería *torchvision*

In [None]:
import gym
import torchvision

#env = gym.make("SpaceInvaders-v0") 
env = gym.make("PongNoFrameskip-v4")
# Acumular 4 frames
states = []
states.append(env.reset())
for k in range(3):    
    a = env.action_space.sample()
    s, r, end, info = env.step(a)
    states.append(s)

# Crear composición de transformaciones
transforms = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
                                             torchvision.transforms.Grayscale(),
                                             torchvision.transforms.Resize(size=(110, 84)),
                                             torchvision.transforms.CenterCrop((84,84)),
                                             torchvision.transforms.ToTensor()])
# Función de para preprocesar
def preprocess(states):
    tmp = []
    for state in states:
        tmp.append(transforms(state))
    return torch.cat(tmp) # Esto es un tensor de 1x4x210x160

transformed_state = preprocess(states)
display("Tamaño del tensor transformado:", transformed_state.shape)

fig, ax = plt.subplots(1, 4, figsize=(8, 2), tight_layout=True)
for k in range(4):
    ax[k].matshow(transformed_state[k, :, :].numpy(), cmap=plt.cm.Greys_r);
    ax[k].axis('off')

##### Frame skipping

Usaremos la técnica propuesta en (Minh et al 2013) conocida como *frame skipping*. Utilizamos 4 frames como estado pero ...

In [None]:
class ConvolutionalNeuralNetwork(torch.nn.Module):    
    def __init__(self, n_input, n_output, n_filters=32, n_hidden=256):
        super(type(self), self).__init__()
        self.conv1 = torch.nn.Conv2d(n_input, n_filters, kernel_size=8, stride=4)
        self.conv2 = torch.nn.Conv2d(n_filters, n_filters, kernel_size=4, stride=2)
        self.conv3 = torch.nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=1)
        self.linear1 = torch.nn.Linear(7 * 7 * n_filters, n_hidden)
        self.output = torch.nn.Linear(n_hidden, n_output)
        self.activation = torch.nn.ReLU()
        
    def forward(self, x):
        h = self.activation(self.conv1(x))
        h = self.activation(self.conv2(h))
        h = self.activation(self.conv3(h))
        h = h.view(-1, 7*7*32)
        h = self.activation(self.linear1(h))
        return  self.output(h)

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
from scipy.signal import convolve

fig, ax = plt.subplots(4, figsize=(6, 5), sharex=True, tight_layout=True)

def smooth_data(x, window_length=10):
    return convolve(x, np.ones(window_length)/window_length, mode='valid')

def update_plot(step, episode, smooth_window=10, target=195, target_update=500):
    for ax_ in ax:
        ax_.cla()
    episodes = np.arange((episode))
    ax[0].scatter(episodes, diagnostics['rewards'], s=1)      
    if episode > smooth_window:
        ax[0].plot(episodes[:-smooth_window+1], 
                   smooth_data(diagnostics['rewards']), alpha=0.5, lw=2)        
    ax[1].plot(episodes, diagnostics['loss'])
    ax[2].plot(episodes, np.array(diagnostics['q_sum'])/(np.array(diagnostics['q_N'])+1e-4))
                   
    #ax[0].plot(episodes, [target]*len(episodes), 'k--')
    ax[0].set_ylabel('Recompensa');
    ax[1].set_ylabel('Loss')
    ax[2].set_ylabel('Q promedio')
    ax[3].plot(episodes, epsilon(episodes))
    ax[3].set_ylabel('Epsilon')
    ax[3].set_xlabel('Episodios')
    ax[0].set_title("Paso %d" % (step))
    fig.canvas.draw()

In [None]:
import numpy as np
import gym
from tqdm.notebook import tqdm

torch.manual_seed(123)

env = gym.make("SpaceInvaders-v0") 
n_state = (4, 110, 84)
n_action = env.action_space.n 

dqn_model = DeepQNetwork(q_model=ConvolutionalNeuralNetwork(n_state[0], n_action),
                         gamma = 0.99,
                         double_dqn=True,
                         target_update_freq=100,
                         learning_rate=1e-4)

def epsilon(episode, epsilon_init=1., epsilon_end=0.1, epsilon_rate=1e-2):
    return epsilon_end + (epsilon_init - epsilon_end) * np.exp(-epsilon_rate*episode) 

memory = ReplayMemory(n_state, memory_length=10000)        

diagnostics = {'rewards': [0], 'loss': [0],
               'q_sum': [0], 'q_N': [0]}

episode = 1
end = False
stacked_states = []
stacked_states.append(env.reset())
for k in range(3):
    s, r, end, info = env.step(0)  
    stacked_states.append(s)

for step in tqdm(range(100000)):    
    # Escoger acción
    state = preprocess(stacked_states)
    a, q = dqn_model.select_action(state.unsqueeze(0), epsilon(episode))
    if q is not None:
        diagnostics['q_sum'][-1] += q
        diagnostics['q_N'][-1] += 1
    
    # Aplicar la acción y guardar 4 fames
    stacked_states_next = []
    r = 0.0
    end = False
    for k in range(4):
        s, r_tmp, end_tmp, info = env.step(a)        
        stacked_states_next.append(s)
        end  = end | end_tmp
        r += r_tmp
    diagnostics['rewards'][-1] += r
               
    # Guardar en memoria
    memory.push(state, preprocess(stacked_states_next), 
                a, torch.tensor(r), end)
    
    stacked_states = stacked_states_next
    
    # Actualizar modelo    
    mini_batch = memory.sample(32)
    if not mini_batch is None:
        diagnostics['loss'][-1] += dqn_model.update(mini_batch)            
    
    # Preparar siguiente episodio
    if end:
        if episode % 5 == 0:
            update_plot(step, episode)
        episode += 1   
        end = False
        stacked_states = []
        stacked_states.append(env.reset())
        for k in range(3):
            s, r, end, info = env.step(0)  
            stacked_states.append(s)
        diagnostics['rewards'].append(0)
        diagnostics['loss'].append(0)
        diagnostics['q_sum'].append(0)
        diagnostics['q_N'].append(0)

In [None]:
import gym
from time import sleep

env = gym.make("SpaceInvaders-v0") 
env.reset()
end = False

stacked_states = []
stacked_states.append(env.reset())
for k in range(3):
    s, r, end, info = env.step(0)  
    stacked_states.append(s)

while not end:
    state = preprocess(stacked_states)
    a, q = dqn_model.select_action(state.unsqueeze(0))
    print(a)
    #stacked_states = []
    #for k in range(4):
    #    s, r, end, info = env.step(a)  
    #    stacked_states.append(s)
    s, r, end, info = env.step(a)
    stacked_states = stacked_states[1:]
    stacked_states.append(s)
    #env.render() 
    sleep(.02)     

In [None]:
env.close()

In [None]:
import gym
from time import sleep
from IPython import display

env = gym.make("SpaceInvaders-v0") 
env.reset()
end = False

fig, ax = plt.subplots(figsize=(7, 4))
img = plt.imshow(env.render(mode='rgb_array'))

stacked_states = []
stacked_states.append(env.reset())
for k in range(3):
    s, r, end, info = env.step(0)  
    stacked_states.append(s)

while not end:
    
    state = preprocess(stacked_states)
    a, q = dqn_model.select_action(state.unsqueeze(0))
    s, r, end, info = env.step(a)
    stacked_states = stacked_states[1:]
    stacked_states.append(s)
    
    img.set_data(env.render(mode='rgb_array'))
    display.display(plt.gcf())
    display.clear_output(wait=True)
    
    
    #sleep(.02)     