In [None]:
import numpy as np
import tensorflow as tf
import random
import dqn
from collections import deque

import gym
env = gym.make('CartPole-v0')

# Constants defining our neural network
input_size = env.observation_space.shape[0]
output_size = env.action_space.n

dis = 0.9
REPLAY_MEMORY = 50000

[2017-02-18 12:30:52,483] Making new env: CartPole-v0


In [None]:
def simple_replay_train(DQN, train_batch):
    x_stack = np.empty(0).reshape(0, DQN.input_size)
    y_stack = np.empty(0).reshape(0, DQN.output_size)
    
    #Get stored information from the buffer
    for state, action, reward, next_state, done in train_batch:
        Q = DQN.predict(state)
        
        # terminal?
        if done:
            Q[0, action] = reward
        else:
            # Obtain the Q'
            Q[0, action] = reward + dis * np.max(DQN.predict(next_state))

        y_stack = np.vstack([y_stack, Q])
        x_stack = np.vstack([x_stack, state])
        
    #Train our network using target and predicted Q values on each episode
    return DQN.update(x_stack, y_stack)

In [None]:
def bot_play(mainDQN):
    # See our trained network in action
    s = env.reset()
    reward_sum = 0
    while True:
        env.render()
        a = np.argmax(mainDQN.predict(s))
        s, reward, done, _ = env.step(a)
        reward_sum += reward
        if done:
            print("Toral score: {}".format(reward_sum))
            break

In [None]:
def main():
    max_episodes = 5000
    
    # store the previous observations in replay memory
    replay_buffer = deque()

    with tf.Session() as sess:
        mainDQN = dqn.DQN(sess, input_size, output_size)
        tf.global_variables_initializer().run()

        for episode in range(max_episodes):
            e = 1. / ((episode / 10) + 1)
            done = False
            step_count = 0

            state = env.reset()

            while not done:
                if np.random.rand(1) < e:
                    action = env.action_space.sample()
                else:
                    # Choose an action by greedily from the Q-network
                    action = np.argmax(mainDQN.predict(state))

                # Get new state and reward from environment
                next_state, reward, done, _ = env.step(action)
                if done:    # big penalty
                    reward = -100

                # Save the experience to our buffer
                replay_buffer.append((state, action, reward, next_state, done))
                if len(replay_buffer) > REPLAY_MEMORY:
                    replay_buffer.popleft()

                state = next_state
                step_count += 1
                if step_count > 10000:    # Good enough
                    break

            print("Episode: {}    steps: {}".format(episode, step_count))
            if step_count > 10000:
                pass
                # break

            if episode % 10 ==1:
                # Get a random batch of experiences.
                for _ in range(50):
                    # Minibatch works better
                    minibatch = random.sample(replay_buffer, 10)
                    loss, _ = simple_replay_train(mainDQN, minibatch)
                print("Loss: ", loss)

        bot_play(mainDQN)
    
if __name__ == "__main__":
    main()

Episode: 0    steps: 14
Episode: 1    steps: 10
Loss:  852.971
Episode: 2    steps: 24
Episode: 3    steps: 17
Episode: 4    steps: 58
Episode: 5    steps: 37
Episode: 6    steps: 57
Episode: 7    steps: 51
Episode: 8    steps: 32
Episode: 9    steps: 44
Episode: 10    steps: 48
Episode: 11    steps: 77
Loss:  3.84558
Episode: 12    steps: 39
Episode: 13    steps: 33
Episode: 14    steps: 113
Episode: 15    steps: 75
Episode: 16    steps: 58
Episode: 17    steps: 55
Episode: 18    steps: 62
Episode: 19    steps: 139
Episode: 20    steps: 67
Episode: 21    steps: 80
Loss:  2.28475
Episode: 22    steps: 14
Episode: 23    steps: 96
Episode: 24    steps: 42
Episode: 25    steps: 64
Episode: 26    steps: 82
Episode: 27    steps: 40
Episode: 28    steps: 15
Episode: 29    steps: 69
Episode: 30    steps: 70
Episode: 31    steps: 88
Loss:  4.02009
Episode: 32    steps: 101
Episode: 33    steps: 79
Episode: 34    steps: 85
Episode: 35    steps: 43
Episode: 36    steps: 103
Episode: 37    steps:

KeyboardInterrupt: 