Implementation of "Playing Atari with Reinforcement Learning Mnih et al. arXiv:1312.5602v1
# https://arxiv.org/pdf/1312.5602v1.pdf

In [1]:
import numpy as np
import random
import gym

In [2]:
from numpy import tanh
from scipy.special import softmax
import pdb
from collections import deque

In [3]:
class DQNNModel(object):
    '''
    Two layer neural net
    input X
    Z1 = dot(X, W1) + B1
    Z2 = tanh(Z1)
    Z3 = dot(Z2, W2) + B2
    y_hat = Z3
    
    Loss(L) = 0.5 * (y_hat - y)**2
    
    dL/dy_hat = y_hat-y
    dy_hat/dZ3 = y_hat-y
    dZ3/dW2 = Z2
    dZ3/dB2 = 1
    dZ3/dZ2 = W2
    dZ2/dZ1 = 1 - Z2^2
    dZ1/dW1 = X
    dZ1/dB1 = 1
    
    Using chain rule
    
    dL/dW2 = (y_hat-y)*Z2
    dL/dB2 = (y_hat-y)
    
    dL/dW1 = (y_hat-y)*W2*(1-Z2^2)*X
    (input,hidden) = (1,output)*(hidden,output)*(1,hidden)*(1,input)
    
    dL/dB1 = (y_hat-y)*W2*(1-Z2^2)
    
    
    '''
    def __init__(self, ninput, nhidden, noutput):
        self.ninput_ = ninput
        self.nhidden_ = nhidden
        self.noutput_ = noutput
        #initialize model weights
        #X=(1,input)
        self.W1_ = np.random.randn(ninput, nhidden) / np.sqrt(nhidden)
        self.B1_ = np.random.randn(1, nhidden) / np.sqrt(nhidden)
        self.W2_ = np.random.randn(nhidden, noutput) / np.sqrt(noutput)
        self.B2_ = np.random.randn(1, noutput) / np.sqrt(noutput)
        
        self.dW1_cache_ = 0.
        self.dB1_cache_ = 0.
        self.dW2_cache_ = 0.
        self.dB2_cache_ = 0.
    
    def __str__(self):
        return_str = ""
        for i,j in self.__dict__.items():
            return_str += str(i) + " " + str(j) + "\n"  
        return return_str
    
    def forward(self, x):
        #x=(1,input)
        #W1=(input,hidden)
        z1 = np.dot(x, self.W1_) + self.B1_ #(1,hidden)
        z2 = tanh(z1) #(1,hidden)
        z3 = np.dot(z2, self.W2_) + self.B2_ #(1,output)
        #W2=(hidden, output)
        y_hat = z3 #(1,output)
        return y_hat, z1, z2, z3, self.W1_, self.W2_
    
    def loss(self, y, y_hat):
        loss = 0.5 * (y_hat - y)**2
        return loss
    
    def backward(self, x, y, y_hat, z1, z2, z3, w1, w2):
        #pdb.set_trace()
        dlogp = y_hat-y        
        dB2 = dlogp # (1, output)
        dW2 = np.dot(z2.T, dlogp) #(hidden, output)
        
        dB1 = np.dot(dlogp, w2.T)*(1-(z2*z2)) #(1,hidden)
        dW1 = np.dot(x.T, dB1) #(input,hidden)
        return dW1, dB1, dW2, dB2
            
    
    def update_SGD(self, dW1, dB1, dW2, dB2, lr=1e-3):
        self.W1_ -= lr*dW1
        self.B1_ -= lr*dB1
        self.W2_ -= lr*dW2
        self.B2_ -= lr*dB2
                

In [4]:
env_name = 'CartPole-v1'
env = gym.make(env_name)

In [5]:
state = env.reset()
INPUT_UNITS = state.shape[0]
HIDDEN_UNITS = 100
OUTPUT_UNITS = 1
models = []
#we create one model per action 
for i in range(env.action_space.n):
    model = DQNNModel(INPUT_UNITS, HIDDEN_UNITS, OUTPUT_UNITS)
    models.append(model)

In [None]:
num_episodes = 10000

experience_replay_size = 1000
experience_replay_queue_max = 10000

display_after = 20
display_after_counter = 0

total_steps = 0
total_episode_steps_rs = []
episode_steps_rs = []
batch_loss_rs = []

epsilon = 1
epsilon_min = 0.01
epsilon_decay = 0.995

replay_queue = deque(maxlen=experience_replay_queue_max)

for episode in range(num_episodes):
    episode_steps = 0
    while True:
        state = state.reshape(1,-1)
        
        action_values = []
        for model in models:
            y_hat, _, _, _, _, _ = model.forward(state)
            action_values.append(y_hat)
                            
        actions_sm = softmax(action_values)
        
        #exploration vs exploitation
        if epsilon > random.random():
            action = 0 if 0.5>random.random() else 1
        else:
            action = 0 if actions_sm[0]>random.random() else 1
        
        new_state, reward, done, info = env.step(action)
        episode_steps += 1
        
        replay_queue.append([state,action,reward-done,new_state,done])

        state = new_state
        total_steps += 1

        if (total_steps%experience_replay_size == 0):
            batch_loss = 0.
            batch = random.sample(list(replay_queue), experience_replay_size)
            model_grad = {'dw1':0., 'db1':0., 'dw2':0., 'db2':0.}
            model_grads = []
            for i in range(len(models)):
                model_grads.append(model_grad)
            
            for i,obs in enumerate(batch):
                s, a, r, n_s, d = obs
                n_s = n_s.reshape(1,-1)

                if d:
                    target = r
                else:
                    target_vals = []
                    for model in models:
                        t_val,_,_,_,_,_ = model.forward(n_s)
                        target_vals.append(t_val)
                    model = None
                    target = r + 0.99 * max(target_vals)


                y_hat, z1, z2, z3, w1, w2 = models[a].forward(s)
                y = target
                dw1, db1, dw2, db2 = models[a].backward(s, y, y_hat, z1, z2, z3, w1, w2)
                #accumulate
                model_grads[a]['dw1'] += dw1
                model_grads[a]['db1'] += db1
                model_grads[a]['dw2'] += dw2
                model_grads[a]['db2'] += db2
                
                batch_loss += models[a].loss(y, y_hat)

            #update model
            for i, model in enumerate(models):
                model.update_SGD(model_grads[i]['dw1'], model_grads[i]['db1'], model_grads[i]['dw2'], model_grads[i]['db2'], lr=1e-4)
                
            average_loss = batch_loss / len(batch)                         
            batch_loss_rs.append(average_loss)
            display_after_counter += 1
            
            if epsilon > epsilon_min:
                epsilon *= epsilon_decay
            
            if (display_after_counter%display_after == 0):
                episode_steps_buffer = episode_steps_rs
                episode_loss_buffer = batch_loss_rs[-display_after:]
                #print(episode_loss_buffer)
                min_steps = np.min(episode_steps_buffer)
                max_steps = np.max(episode_steps_buffer)
                avg_steps = np.mean(episode_steps_buffer)
                print("Episode %d, batch episodes %d, max %d, min %d, avg %d steps, avg loss %f"%(episode+1, len(episode_steps_rs),max_steps, 
                                                                                                  min_steps, avg_steps, np.mean(episode_loss_buffer)))
                episode_steps_rs = []
                                
        if done:
            state = env.reset()
            episode_steps_rs.append(episode_steps)
            total_episode_steps_rs.append(episode_steps)
            break
            


Episode 844, batch episodes 843, max 86, min 8, avg 23 steps, avg loss 135.153612
Episode 1524, batch episodes 680, max 185, min 8, avg 29 steps, avg loss 28.420147
Episode 2112, batch episodes 588, max 132, min 8, avg 34 steps, avg loss 30.206123
Episode 2587, batch episodes 475, max 216, min 9, avg 42 steps, avg loss 34.429649
Episode 2938, batch episodes 351, max 191, min 9, avg 56 steps, avg loss 32.358351
Episode 3178, batch episodes 240, max 244, min 10, avg 83 steps, avg loss 40.350849
Episode 3380, batch episodes 202, max 390, min 12, avg 99 steps, avg loss 37.418711
Episode 3532, batch episodes 152, max 303, min 9, avg 131 steps, avg loss 30.819214
Episode 3668, batch episodes 136, max 399, min 11, avg 146 steps, avg loss 22.702351
Episode 3798, batch episodes 130, max 363, min 17, avg 153 steps, avg loss 14.583669
Episode 3893, batch episodes 95, max 363, min 47, avg 208 steps, avg loss 9.643365
Episode 4004, batch episodes 111, max 394, min 26, avg 183 steps, avg loss 7.3571

In [None]:
from matplotlib import pyplot as plt

In [None]:
plt.scatter(np.array(range(len(total_episode_steps_rs))), total_episode_steps_rs)

In [None]:
import pandas as pd
s1 = pd.Series(total_episode_steps_rs).rolling(20).mean()
s1.dropna(inplace=True)
plt.plot(s1)

In [None]:
import pandas as pd
s2 = pd.Series(total_episode_steps_rs).rolling(50).mean()
s2.dropna(inplace=True)
plt.plot(s2)