In [1]:
import math, random

import gym
import numpy as np

import mxnet as mx
from mxnet import gluon, autograd, nd
from mxnet.gluon import nn

from tqdm import tqdm, trange
from mxboard import SummaryWriter

In [2]:
from collections import deque

class ReplayBuffer(object):
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        state      = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)
            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.concatenate(state), action, reward, np.concatenate(next_state), done
    
    def __len__(self):
        return len(self.buffer)

In [3]:
from wrappers import make_atari, wrap_deepmind, wrap_mxnet

In [83]:
env_id = "PongNoFrameskip-v4"
env    = make_atari(env_id)
env    = wrap_deepmind(env)
env    = wrap_mxnet(env)

In [5]:
class DQN(nn.Block):
    def __init__(self, input_shape, n_actions, **kwargs):
        super(DQN, self).__init__(**kwargs)
        
        with self.name_scope():
            self.conv1 = nn.Conv2D(32, 8, 4, in_channels=input_shape[0])
            self.bn1 = nn.BatchNorm()
            self.conv2 = nn.Conv2D(64, 4, 2, in_channels=32)
            self.bn2 = nn.BatchNorm()
            self.conv3 = nn.Conv2D(64, 3, 1, in_channels=64)
            self.bn3 = nn.BatchNorm()
            self.fc1 = nn.Dense(512)
            self.fc2 = nn.Dense(n_actions, in_units=512)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = nd.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = nd.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out = nd.relu(out)
        out = nd.reshape(out, shape=(x.shape[0],-1))
        out = self.fc1(out)
        out = nd.relu(out)
        out = self.fc2(out)
        return out
    
    def act(self, state, epsilon, ctx):
        if random.random() > epsilon:
            state = nd.array(np.float32(state), ctx=ctx).expand_dims(0)
            q_value = self.forward(state)
            action = nd.argmax(q_value, axis=1)
            action = int(action.asnumpy())
        else:
            action = random.randrange(env.action_space.n)
        return action

In [6]:
ctx = mx.gpu()

In [63]:
model = DQN(env.observation_space.shape, env.action_space.n)
model.load_parameters('./models/double_dqn_best_model', ctx=ctx)

In [82]:
%matplotlib inline
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
import matplotlib.pyplot as plt
from IPython.display import display

In [84]:
state = env.reset()

In [85]:
frames = []
while True:
    frames.append(env.render(mode = 'rgb_array'))
    action = model.act(state, 0.0, ctx)
    state, reward, done, _ = env.step(action)
    if done:
        break
env.close()

TypeError: render() got an unexpected keyword argument 'close'

In [86]:
len(frames)

2202

In [87]:
def display_frames_as_gif(frames):
    """
    Displays a list of frames as a gif, with controls
    """
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi = 72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)
    display(display_animation(anim, default_mode='loop'))

In [88]:
display_frames_as_gif(frames)

In [None]:
env.close()