In [13]:
# -*- coding: utf-8 -*-
import random
import gym
import numpy as np
from collections import deque
import tensorflow as tf
from value_function import Qnet
from replay_buffer import Replay_Buffer

EPISODES = 5000

tf.reset_default_graph() # THIS IS NECESSARY BEFORE MAKING NEW SESSION TO STOP IT ERRORING!!
try:
    sess
except:
    pass
else:
    sess.close()
    del sess
sess = tf.InteractiveSession()



class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        
        self.gamma = 0.95    # discount rate
        self.epsilon = 1.0  # exploration rate
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.99
        self.learning_rate = 0.001
        self.tau =0.01
        self.batch_size=32
#         self.memory = deque(maxlen=2000)
        
        self.optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate)
        
        self.observations = tf.placeholder(tf.float32,shape = [None,self.state_size],name = 'observations')
        self.actions = tf.placeholder(tf.int32,shape = [None],name = 'actions')
    
        self.model = Qnet(self.action_size,self.observations,24,2)
        self.m_predict = tf.argmax(self.model.output,axis=1)
        
        self.rb = Replay_Buffer(self.state_size,self.action_size,max_buffer_size = 2000,min_pool_size=self.batch_size,batch_size=self.batch_size)
        
        # Duplicate the Qnet with different variables for the target network
        with tf.variable_scope('qNet_T'):
            self.target_Q_outputs = self.model.make_network(inputs = self.observations,reuse=False) 
            self.target_Q_params = self.model.get_params_internal()
            self.t_predict = tf.argmax(self.target_Q_outputs,axis=1)
        
        self.init_Q_net_training()
        self.update_target_model()

    def init_Q_net_training(self):
        training_variables = self.model.get_params_internal()
        with tf.variable_scope('Q_loss'):
            self.Q_t = tf.placeholder(tf.float32,[None], name = 'Loss')
            self.Q_for_action = tf.reduce_sum(self.model.output*tf.one_hot(self.actions,self.action_size),axis=1)
            self.Q_Loss = tf.reduce_mean(tf.sqrt(1+tf.square(self.Q_for_action - self.Q_t))-1)
        
        self.train_Q = self.optimizer.minimize(self.Q_Loss,var_list = training_variables)

    def update_target_model(self):
        # Pull the qnet params
        qnet_params = self.model.get_params_internal()
        
        with tf.variable_scope('Target_Q_update'):
            self.tQnet_update = []
            for tQ_p in self.target_Q_params:
                # Match each target net param with equiv from vnet
                Q_p = [v for v in qnet_params if tQ_p.name[(tQ_p.name.index('/')+1):] in v.name]
                assert(len(Q_p) == 1) # Check that only found one variable
                Q_p = Q_p[0]
                self.tQnet_update.append(tQ_p.assign((self.tau * Q_p + (1-self.tau)* tQ_p)))
            self.tQnet_update = tf.group(self.tQnet_update)
          

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        return self.m_predict.eval(feed_dict = {self.observations : state})[0]  # returns action

    def replay(self):
        for i in range(self.batch_size):
            minibatch = self.rb.get_samples()
            t = self.t_predict.eval(feed_dict = {self.observations : minibatch['next_observations']})
            loss = minibatch['rewards'] + (1-minibatch['dones'])*self.gamma * t
            sess.run(self.train_Q,feed_dict = {self.Q_t : loss,self.observations : minibatch['observations'],self.actions : minibatch['actions']})
            
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def load(self, name):
        self.model.load_weights(name)

    def save(self, name):
        self.model.save_weights(name)


if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)
    # agent.load("./save/cartpole-ddqn.h5")
    
    tf.global_variables_initializer().run()
    for e in range(EPISODES):
        state = env.reset()
        done = False
        reward_t = 0 
        state = np.reshape(state, [1, state_size])
        while not(done):
            # env.render()
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            reward = reward if not done else -10
            reward_t += reward
            
            next_state = np.reshape(next_state, [1, state_size])
            agent.rb.add_sample(action,state, next_state,reward, done)
            state = next_state
            sess.run(agent.tQnet_update)
        
            if agent.rb.batch_ready():
                agent.replay()
                
        print("episode: {}/{}, score: {}, e: {:.2}"
                      .format(e, EPISODES, reward_t, agent.epsilon))
        # if e % 10 == 0:
        #     agent.save("./save/cartpole-ddqn.h5")

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
episode: 0/5000, score: 8.0, e: 1.0
episode: 1/5000, score: 23.0, e: 0.8
episode: 2/5000, score: 4.0, e: 0.69
episode: 3/5000, score: 13.0, e: 0.54
episode: 4/5000, score: 11.0, e: 0.43
episode: 5/5000, score: 2.0, e: 0.38
episode: 6/5000, score: 15.0, e: 0.29
episode: 7/5000, score: 17.0, e: 0.22
episode: 8/5000, score: 3.0, e: 0.19
episode: 9/5000, score: -2.0, e: 0.18
episode: 10/5000, score: 6.0, e: 0.15
episode: 11/5000, score: 19.0, e: 0.11
episode: 12/5000, score: -1.0, e: 0.099
episode: 13/5000, score: 2.0, e: 0.087
episode: 14/5000, score: -2.0, e: 0.079
episode: 15/5000, score: 7.0, e: 0.066
episode: 16/5000, score: 13.0, e: 0.052
episode: 17/5000, score: -3.0, e: 0.048
episode: 18/5000, score: 0.0, e: 0.043
episode: 19/5000, score: 3.0, e: 0.037
episode: 20/5000, score: 11.0, e: 0.03
episode: 21/5000, score: 1.0, e: 0.027
episode: 22/5000, score: -2.0, e: 0.024
episode

KeyboardInterrupt: 