In [1]:
import math, random

import gym
import numpy as np

import mxnet as mx
from mxnet import gluon, autograd, nd
from mxnet.gluon import nn

from tqdm import tqdm, trange
from mxboard import SummaryWriter

In [2]:

# from IPython.display import clear_output
# import matplotlib.pyplot as plt
# %matplotlib inline

In [3]:
from collections import deque

class ReplayBuffer(object):
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        state      = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)
            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.concatenate(state), action, reward, np.concatenate(next_state), done
    
    def __len__(self):
        return len(self.buffer)


In [4]:
from wrappers import make_atari, wrap_deepmind, wrap_mxnet

In [5]:
env_id = "PongNoFrameskip-v4"
env    = make_atari(env_id)
env    = wrap_deepmind(env,frame_stack=True)
env    = wrap_mxnet(env)

In [6]:
class DQN(nn.Block):
    def __init__(self, input_shape, n_actions, **kwargs):
        super(DQN, self).__init__(**kwargs)
        
        with self.name_scope():
            self.conv1 = nn.Conv2D(32, 8, 4, in_channels=input_shape[0])
            self.bn1 = nn.BatchNorm()
            self.conv2 = nn.Conv2D(64, 4, 2, in_channels=32)
            self.bn2 = nn.BatchNorm()
            self.conv3 = nn.Conv2D(64, 3, 1, in_channels=64)
            self.bn3 = nn.BatchNorm()
            self.fc1 = nn.Dense(512)
            self.fc2 = nn.Dense(n_actions, in_units=512)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = nd.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = nd.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out = nd.relu(out)
        out = nd.reshape(out, shape=(x.shape[0],-1))
        out = self.fc1(out)
        out = nd.relu(out)
        out = self.fc2(out)
        return out
    
    def act(self, state, epsilon, ctx):
        if random.random() > epsilon:
            state = nd.array(np.float32(state), ctx=ctx).expand_dims(0) / 255.0
            q_value = self.forward(state)
            action = nd.argmax(q_value, axis=1)
            action = int(action.asnumpy())
        else:
            action = random.randrange(env.action_space.n)
        return action

In [13]:
def compute_td_loss(batch_size, net, loss_fn, ctx):
    state, action, reward, next_state, done = replay_buffer.sample(batch_size)

    state      = nd.array((np.float32(state)), ctx=ctx) / 255.0
    next_state = nd.array(np.float32(next_state), ctx=ctx)
    action     = nd.array((action), ctx=ctx)
    reward     = nd.array((reward), ctx=ctx)
    done       = nd.array((done), ctx=ctx)

    q_values      = net(state)
    next_q_values = net(next_state)
    
    q_values = nd.gather_nd(q_values, nd.stack(nd.arange(action.shape[0], ctx=ctx).expand_dims(-1),action.expand_dims(-1), axis=0))
    next_q_value     = next_q_values.max(1)
    
    q_values = q_values.squeeze()
    next_q_value = next_q_value.squeeze()
    expected_q_value = reward + gamma * next_q_value * (1 - done)
    print(q_values.shape)
    print(expected_q_value.shape)
    loss = loss_fn(q_values, expected_q_value)
#     loss = nd.power(q_values - expected_q_value,2).mean()
        
    return loss

In [14]:
def plot(frame_idx, rewards, losses):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('frame %s. reward: %s' % (frame_idx, np.mean(rewards[-10:])))
    plt.plot(rewards)
    plt.subplot(132)
    plt.title('loss')
    plt.plot(losses)
    plt.show()

In [15]:
ctx = mx.gpu()

In [16]:
replay_initial = 10000
replay_buffer = ReplayBuffer(100000)

net = DQN(env.observation_space.shape, env.action_space.n)
net.initialize(ctx=ctx)
loss_fn = gluon.loss.L2Loss()
trainer = gluon.Trainer(net.collect_params(), optimizer='adam', optimizer_params={'learning_rate':0.0001})

In [17]:
epsilon_start = 1.0
epsilon_final = 0.02
epsilon_decay = 30000

epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * math.exp(-1. * frame_idx / epsilon_decay)

In [18]:
num_frames = 1400000
batch_size = 32
gamma      = 0.99

losses = []
all_rewards = []
episode_reward = 0

state = env.reset()
writer = SummaryWriter(logdir='./logs',filename_suffix="_DQN")

In [19]:
state = env.reset()
current_best = 0.0
for frame_idx in range(1, num_frames + 1):
    epsilon = epsilon_by_frame(frame_idx)
    action = net.act(state, epsilon, ctx)
    
    next_state, reward, done, _ = env.step(action)
    replay_buffer.push(state, action, reward, next_state, done)
    
    state = next_state
    episode_reward += reward
    
    if done:
        state = env.reset()
        all_rewards.append(episode_reward)
        writer.add_scalar("reward", episode_reward, frame_idx)  
        mean_reward = np.mean(all_rewards[-100:])
        print("%d: done %d games, mean reward %.3f, reward %.3f, eps %.2f" % (
                frame_idx, len(all_rewards), mean_reward, episode_reward, epsilon,
            ))
        if current_best < mean_reward:
            print("save current best model")
            net.save_parameters('./models/cartpole_best_model')
            current_best = mean_reward
        episode_reward = 0
        writer.add_scalar("epsilon", epsilon, frame_idx)
        writer.add_scalar("mean_reward", mean_reward, frame_idx)  
        
        
    if len(replay_buffer) > replay_initial:
        with autograd.record():
#             print("compute loss")
            loss = compute_td_loss(batch_size, net, loss_fn, ctx)
            loss.backward()
        trainer.step(batch_size)
        losses.append(loss.sum().asscalar())
        writer.add_scalar("loss", loss.mean().asscalar(), frame_idx) 
#     if frame_idx % 10000 == 0:
#         plot(frame_idx, all_rewards, losses)

870: done 1 games, mean reward -21.000, reward -21.000, eps 0.97
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)
(32,)

KeyboardInterrupt: 

In [15]:
    state, action, reward, next_state, done = replay_buffer.sample(batch_size)

    state      = nd.array((np.float32(state)), ctx=ctx) / 255.0
    next_state = nd.array(np.float32(next_state), ctx=ctx)
    action     = nd.array((action), ctx=ctx)
    reward     = nd.array((reward), ctx=ctx)
    done       = nd.array((done), ctx=ctx)

In [16]:
state


[[[[0.20392157 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]
   [0.20392157 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]
   [0.20392157 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]
   ...
   [0.34117648 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]
   [0.34117648 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]
   [0.34117648 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]]]


 [[[0.20392157 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]
   [0.20392157 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]
   [0.20392157 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]
   ...
   [0.34117648 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]
   [0.34117648 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]
   [0.34117648 0.34117648 0.34117648 ... 0.9254902  0.9254902
    0.9254902 ]]]


 [[[0.20392157 0.34117648 0.34117648 ... 

In [18]:
    q_values      = net(state)
    next_q_values = net(next_state)
    
#     q_values = nd.gather_nd(q_values, nd.stack(nd.arange(action.shape[0], ctx=ctx).expand_dims(-1),action.expand_dims(-1), axis=0))
#     next_q_value     = next_q_values.max(1)
#     expected_q_value = reward + gamma * next_q_value * (1 - done)

In [19]:
q_values


[[-0.01308831  0.02949993  0.13193344 -0.09351277  0.46219692  0.07861875]
 [-0.01264067  0.02947327  0.13350035 -0.09446591  0.46111673  0.0793374 ]
 [-0.01275302  0.02911252  0.13252924 -0.09362247  0.46371642  0.07859968]
 [-0.01360997  0.02901225  0.13291612 -0.09286225  0.46156415  0.07799529]
 [-0.01271808  0.02985693  0.13179749 -0.09364564  0.46239865  0.07754089]
 [-0.0155503   0.03114699  0.13334677 -0.09410661  0.46289864  0.08036682]
 [-0.01273784  0.02844112  0.1362626  -0.09261388  0.46247298  0.07853465]
 [-0.01379248  0.02906224  0.13446493 -0.09217975  0.4618057   0.07770078]
 [-0.01445401  0.02872759  0.13348524 -0.0936028   0.461296    0.08012594]
 [-0.01443531  0.02836615  0.13265958 -0.09343564  0.4619497   0.07793523]
 [-0.01370101  0.02943321  0.13264008 -0.09226813  0.4635886   0.08097759]
 [-0.01408887  0.02825195  0.13299163 -0.09351309  0.46188277  0.07867248]
 [-0.01215731  0.03016526  0.13175145 -0.09391582  0.46169758  0.07704474]
 [-0.01395422  0.0319978

In [20]:
    q_values = nd.gather_nd(q_values, nd.stack(nd.arange(action.shape[0], ctx=ctx).expand_dims(-1),action.expand_dims(-1), axis=0))
    next_q_value     = next_q_values.max(1)
    expected_q_value = reward + gamma * next_q_value * (1 - done)

In [21]:
q_values


[[ 0.07861875]
 [ 0.0793374 ]
 [ 0.46371642]
 [ 0.02901225]
 [ 0.46239865]
 [-0.0155503 ]
 [-0.09261388]
 [-0.01379248]
 [ 0.461296  ]
 [-0.01443531]
 [-0.01370101]
 [-0.01408887]
 [-0.09391582]
 [ 0.462777  ]
 [ 0.46184373]
 [ 0.46171343]
 [ 0.02894381]
 [ 0.13368975]
 [ 0.13274193]
 [ 0.0298419 ]
 [ 0.02755402]
 [ 0.13283585]
 [-0.09398049]
 [ 0.07823001]
 [-0.01322091]
 [ 0.03011356]
 [-0.01288527]
 [-0.09325534]
 [ 0.02965342]
 [-0.09314058]
 [-0.01386082]
 [-0.09322649]
 [ 0.13280994]
 [-0.09297051]
 [-0.01229654]
 [ 0.46115208]
 [-0.09481248]
 [ 0.13333605]
 [ 0.07782365]
 [-0.01369789]
 [ 0.1327418 ]
 [ 0.46162713]
 [-0.09172933]
 [ 0.46207103]
 [ 0.07731377]
 [ 0.07825123]
 [ 0.07780345]
 [-0.01338041]
 [ 0.02887231]
 [ 0.13418731]
 [ 0.13202138]
 [ 0.463687  ]
 [ 0.07862045]
 [ 0.07840838]
 [ 0.02867521]
 [ 0.02901225]
 [ 0.46290314]
 [-0.01253502]
 [ 0.46218073]
 [ 0.46250886]
 [-0.09293008]
 [ 0.07842351]
 [-0.01323756]
 [ 0.13251422]]
<NDArray 64x1 @gpu(0)>

In [25]:
l = gluon.loss.L2Loss()

In [25]:
l(q_values, expected_q_value).sum()


[13.505692]
<NDArray 1 @gpu(0)>

In [26]:
nd.power(q_values - expected_q_value,2).mean()


[0.42250335]
<NDArray 1 @gpu(0)>

In [23]:
loss = nd.power(q_values - expected_q_value,2).mean()

In [27]:
l(q_values, expected_q_value).sum()


[128.41974]
<NDArray 1 @gpu(0)>

In [33]:
net.conv1.weight.grad()


[[[[ 3.32642882e-03  1.02666570e-02  2.26427913e-02 ...  4.60287696e-03
    -1.99091937e-02 -1.99267287e-02]
   [ 3.27065517e-03  1.06781907e-02  2.27930583e-02 ...  4.76535875e-03
    -1.98565964e-02 -1.94186009e-02]
   [ 2.77602836e-03  1.04552601e-02  2.26033293e-02 ...  4.75762039e-03
    -2.00518426e-02 -1.96249709e-02]
   ...
   [ 7.85774842e-04  1.06070992e-02  2.29748543e-02 ...  4.91922488e-03
    -1.95663646e-02 -1.95613652e-02]
   [ 9.90324887e-04  1.05871679e-02  2.30480842e-02 ...  4.78323922e-03
    -1.93662345e-02 -1.94529854e-02]
   [-2.71517783e-05  8.45097844e-03  2.14985535e-02 ...  4.71983105e-03
    -1.95656475e-02 -1.99426636e-02]]]


 [[[ 1.99049129e-03  7.39674037e-03  3.48120555e-03 ... -8.61275941e-04
    -2.10155174e-03  3.75435222e-04]
   [-4.91639785e-03  9.54594929e-04 -2.20501982e-03 ...  1.37664750e-03
     6.42713159e-04  1.40175736e-03]
   [-6.87667681e-03 -6.84650615e-04 -2.79970095e-03 ...  2.07509287e-03
     1.74485496e-03  2.10008491e-03]
   ...
