# Train Cartpole Balancing problem in OpenAI Env using Dueling DQN network along with prioritized memory replay using sum tree

In [80]:
#Initialize
import math, random
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.autograd as Variable
import torch.nn.functional as F
from collections import deque

env = gym.make("CartPole-v0")

epsilon = 1.0
epsilonMin = 0.01
decay = 0.999
episodes = 500
batch_size = 32
gamma = 0.99
goal_steps = 200

In [81]:
# Network
class DuelingDQN(nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(DuelingDQN, self).__init__()
        self.fc1 = nn.Linear(num_inputs, 128)
        
        self.a1 = nn.Linear(128, 128)
        self.a2 = nn.Linear(128, num_outputs)
        
        self.val1 = nn.Linear(128, 128)
        self.val2 = nn.Linear(128, 1)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        
        adv = F.relu(self.a1(x))
        adv = self.a2(adv)

        val = F.relu(self.val1(x))
        val = self.val2(val)
        return val + adv - adv.mean()
    
    def act(self, state, epsilon):
        if random.random() > epsilon:
            state = torch.FloatTensor(state).unsqueeze(0)
            q_value = self.forward(state)
            action = q_value.max(1)[1].data[0].numpy()
        else:
            action = env.action_space.sample()
        return action

In [82]:
model1 = DuelingDQN(env.observation_space.shape[0], env.action_space.n)
model2 = DuelingDQN(env.observation_space.shape[0], env.action_space.n)

optimizer = optim.Adam(model1.parameters())
def sync(model1, model2):
    model2.load_state_dict(model1.state_dict())

sync(model1, model2)

In [83]:
class SumTree:
    write = 0
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2*capacity - 1)
        self.data = np.zeros(capacity, dtype=object)

    def _propagate(self, idx, change):
        parent = (idx-1)//2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    def _retrieve(self, idx, s):
        left = 2*idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s-self.tree[left])

    def total(self):
        return self.tree[0]

    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

    def update(self, idx, p):
        change = p - self.tree[idx]
        self.tree[idx] = p
        self._propagate(idx, change)

    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1
        return (idx, self.tree[idx], self.data[dataIdx])

In [84]:
# Replay memory
class Memory:
    samples = []
    def __init__(self, capacity):
        self.tree = SumTree(capacity)
        self.e = 0.01
        self.a = 0.6

    def _getPriority(self, error):
        return (error + self.e)**self.a

    def add(self, error, sample):
        p = self._getPriority(error)
        self.tree.add(p, sample)

    def sample(self, n):
        batch = []
        segment = self.tree.total()/n

        for i in range(n):
            a = segment*1
            b = segment * (i + 1)
            s = random.uniform(a,b)
            (idx, p, data) = self.tree.get(s)
            batch.append((idx, data))
            
        return batch

    def update(self, idx, error):
        p = self._getPriority(error)
        self.tree.update(idx, p)

In [75]:
# Estimating error
def get_error(state, action, reward, next_state, done):
    state = torch.FloatTensor(np.float32(state))
    next_state = torch.FloatTensor(np.float32(next_state))
    action = torch.LongTensor(action)
#     print(action.unsqueeze(1))
    reward = torch.FloatTensor(reward)
    done = torch.FloatTensor(done)
    q_values = model1(state)
    next_q_values = model1(next_state)
#     print(next_q_values.shape)
    next_q_values2 = model2(next_state)
    q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
    next_q_value = next_q_values2.gather(1, torch.max(next_q_values, 1)[1].unsqueeze(1)).squeeze(1)
    expected_q_value = reward + gamma * next_q_value * (1 - done)
    error = abs(q_value - expected_q_value)
    return error

In [76]:
memory = Memory(10000)

for idx in range(episodes):
    state = env.reset()
    total_reward = 0
    done = False
    while not done:
        action = int(model1.act(state, epsilon))
        next_state, reward, done, _ = env.step(action)
        exp = state, action, reward, next_state, done

        state = torch.FloatTensor(np.float32(state))
        next_state = torch.FloatTensor(np.float32(next_state))
        q_values = model1(state)
        next_q_values = model1(next_state)
        next_q_values2 = model2(next_state)

        q_value = q_values[action].squeeze(0)
        a = int(torch.max(next_q_values, 0)[1].numpy())
        next_q_value = next_q_values2[a]
        
        expected_q_value = reward + gamma * next_q_value * (1 - done)
        error = float(abs(q_value - expected_q_value).detach().numpy())
        memory.add(error, exp)
        state = next_state
        total_reward += reward
        
        if done:
            done = False
            print("Episode = " + str(idx) + " , Score = " + str(total_reward))
            break
    
    if epsilon > epsilonMin:
        epsilon *= decay
    
    if idx % 100 == 0:
        sync(model1, model2)
        
    if idx > 3:
        batch = memory.sample(batch_size)
        state = [np.array(batch[i][1][0]) for i in range(batch_size)]
        action = np.array([o[1][1] for o in batch])
        reward = np.array([o[1][2] for o in batch])
        next_state = np.array([o[1][3] for o in batch])
        done = np.array([o[1][4] for o in batch])
        d = [0]*32
        for i in range(len(d)):
            if done[i]==True:
                d[i] = 1
            else:
                d[i] = 0
        error = get_error(state, action, reward, next_state, d)
        loss = error.pow(2).mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Episode = 0 , Score = 35.0
Episode = 1 , Score = 200.0
Episode = 2 , Score = 187.0
Episode = 3 , Score = 58.0
Episode = 4 , Score = 200.0
Episode = 5 , Score = 131.0
Episode = 6 , Score = 134.0
Episode = 7 , Score = 54.0
Episode = 8 , Score = 115.0
Episode = 9 , Score = 95.0
Episode = 10 , Score = 112.0
Episode = 11 , Score = 11.0
Episode = 12 , Score = 85.0
Episode = 13 , Score = 105.0
Episode = 14 , Score = 27.0
Episode = 15 , Score = 20.0
Episode = 16 , Score = 139.0
Episode = 17 , Score = 97.0
Episode = 18 , Score = 187.0
Episode = 19 , Score = 200.0
Episode = 20 , Score = 118.0
Episode = 21 , Score = 140.0
Episode = 22 , Score = 54.0
Episode = 23 , Score = 112.0
Episode = 24 , Score = 93.0
Episode = 25 , Score = 154.0
Episode = 26 , Score = 138.0
Episode = 27 , Score = 59.0
Episode = 28 , Score = 85.0
Episode = 29 , Score = 124.0
Episode = 30 , Score = 71.0
Episode = 31 , Score = 135.0
Episode = 32 , Score = 83.0
Episode = 33 , Score = 34.0
Episode = 34 , Score = 41.0
Episode = 35

Episode = 279 , Score = 143.0
Episode = 280 , Score = 141.0
Episode = 281 , Score = 123.0
Episode = 282 , Score = 147.0
Episode = 283 , Score = 136.0
Episode = 284 , Score = 17.0
Episode = 285 , Score = 155.0
Episode = 286 , Score = 177.0
Episode = 287 , Score = 189.0
Episode = 288 , Score = 188.0
Episode = 289 , Score = 200.0
Episode = 290 , Score = 200.0
Episode = 291 , Score = 200.0
Episode = 292 , Score = 200.0
Episode = 293 , Score = 200.0
Episode = 294 , Score = 200.0
Episode = 295 , Score = 200.0
Episode = 296 , Score = 200.0
Episode = 297 , Score = 200.0
Episode = 298 , Score = 200.0
Episode = 299 , Score = 200.0
Episode = 300 , Score = 200.0
Episode = 301 , Score = 200.0
Episode = 302 , Score = 57.0
Episode = 303 , Score = 40.0
Episode = 304 , Score = 194.0
Episode = 305 , Score = 161.0
Episode = 306 , Score = 152.0
Episode = 307 , Score = 157.0
Episode = 308 , Score = 157.0
Episode = 309 , Score = 143.0
Episode = 310 , Score = 131.0
Episode = 311 , Score = 23.0
Episode = 312 

Episode = 553 , Score = 200.0
Episode = 554 , Score = 200.0
Episode = 555 , Score = 200.0
Episode = 556 , Score = 200.0
Episode = 557 , Score = 200.0
Episode = 558 , Score = 169.0
Episode = 559 , Score = 197.0
Episode = 560 , Score = 200.0
Episode = 561 , Score = 169.0
Episode = 562 , Score = 156.0
Episode = 563 , Score = 158.0
Episode = 564 , Score = 156.0
Episode = 565 , Score = 98.0
Episode = 566 , Score = 164.0
Episode = 567 , Score = 163.0
Episode = 568 , Score = 200.0


KeyboardInterrupt: 