In [1]:
import numpy as np
import gym

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory

Using TensorFlow backend.


In [2]:
ENV_NAME = 'CartPole-v0'
env = gym.make(ENV_NAME)
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n

In [3]:
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 4)                 0         
_________________________________________________________________
dense_1 (Dense)              (None, 16)                80        
_________________________________________________________________
activation_1 (Activation)    (None, 16)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 34        
_________________________________________________________________
activation_2 (Activation)    (None, 2)                 0         
Total params: 114
Trainable params: 114
Non-trainable params: 0
_________________________________________________________________
None


In [4]:
policy = EpsGreedyQPolicy()
memory = SequentialMemory(limit=50000, window_length=1)
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10,
target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])

# Okay, now it's time to learn something! We visualize the training here for show, but this slows down training quite a lot. 
dqn.fit(env, nb_steps=5000, verbose=2, visualize=True)

Training for 5000 steps ...




   79/5000: episode: 1, duration: 7.909s, episode steps: 79, steps per second: 10, episode reward: 79.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.519 [0.000, 1.000], mean observation: 0.060 [-0.402, 0.722], loss: 0.427551, mean_absolute_error: 0.495523, mean_q: 0.052745
  113/5000: episode: 2, duration: 0.565s, episode steps: 34, steps per second: 60, episode reward: 34.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.529 [0.000, 1.000], mean observation: 0.151 [-0.159, 0.753], loss: 0.351142, mean_absolute_error: 0.445039, mean_q: 0.192280
  163/5000: episode: 3, duration: 0.832s, episode steps: 50, steps per second: 60, episode reward: 50.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.520 [0.000, 1.000], mean observation: 0.082 [-0.295, 0.778], loss: 0.313605, mean_absolute_error: 0.464687, mean_q: 0.319058
  197/5000: episode: 4, duration: 0.567s, episode steps: 34, steps per second: 60, episode reward: 34.000, mean reward: 1.000 [1.000, 1.000], mean action:

  695/5000: episode: 32, duration: 0.149s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.111 [0.000, 1.000], mean observation: 0.130 [-1.418, 2.303], loss: 0.444906, mean_absolute_error: 2.360925, mean_q: 4.503819
  705/5000: episode: 33, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.200 [0.000, 1.000], mean observation: 0.139 [-1.345, 2.096], loss: 0.545148, mean_absolute_error: 2.386504, mean_q: 4.583114
  714/5000: episode: 34, duration: 0.151s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.111 [0.000, 1.000], mean observation: 0.151 [-1.516, 2.495], loss: 0.387792, mean_absolute_error: 2.420679, mean_q: 4.665392
  722/5000: episode: 35, duration: 0.141s, episode steps: 8, steps per second: 57, episode reward: 8.000, mean reward: 1.000 [1.000, 1.000], mean action: 0

  995/5000: episode: 62, duration: 0.200s, episode steps: 12, steps per second: 60, episode reward: 12.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.333 [0.000, 1.000], mean observation: 0.107 [-1.166, 1.775], loss: 0.637031, mean_absolute_error: 3.434582, mean_q: 6.499945
 1010/5000: episode: 63, duration: 0.250s, episode steps: 15, steps per second: 60, episode reward: 15.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.333 [0.000, 1.000], mean observation: 0.092 [-1.134, 1.835], loss: 0.963911, mean_absolute_error: 3.540878, mean_q: 6.652359
 1022/5000: episode: 64, duration: 0.199s, episode steps: 12, steps per second: 60, episode reward: 12.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.250 [0.000, 1.000], mean observation: 0.119 [-1.137, 1.980], loss: 0.554560, mean_absolute_error: 3.511672, mean_q: 6.700005
 1035/5000: episode: 65, duration: 0.216s, episode steps: 13, steps per second: 60, episode reward: 13.000, mean reward: 1.000 [1.000, 1.000], mean act

 1516/5000: episode: 91, duration: 1.082s, episode steps: 65, steps per second: 60, episode reward: 65.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.477 [0.000, 1.000], mean observation: -0.113 [-0.878, 0.425], loss: 1.006805, mean_absolute_error: 4.608014, mean_q: 8.677599
 1573/5000: episode: 92, duration: 0.949s, episode steps: 57, steps per second: 60, episode reward: 57.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.491 [0.000, 1.000], mean observation: -0.093 [-0.943, 0.452], loss: 1.267358, mean_absolute_error: 4.781363, mean_q: 8.922915
 1599/5000: episode: 93, duration: 0.435s, episode steps: 26, steps per second: 60, episode reward: 26.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: -0.099 [-0.865, 0.390], loss: 1.179109, mean_absolute_error: 4.895144, mean_q: 9.212114
 1616/5000: episode: 94, duration: 0.283s, episode steps: 17, steps per second: 60, episode reward: 17.000, mean reward: 1.000 [1.000, 1.000], mean 

 2191/5000: episode: 120, duration: 0.314s, episode steps: 19, steps per second: 61, episode reward: 19.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.526 [0.000, 1.000], mean observation: -0.095 [-1.151, 0.458], loss: 2.999661, mean_absolute_error: 6.583381, mean_q: 12.291676
 2214/5000: episode: 121, duration: 0.383s, episode steps: 23, steps per second: 60, episode reward: 23.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.565 [0.000, 1.000], mean observation: -0.110 [-1.463, 0.573], loss: 2.673133, mean_absolute_error: 6.530562, mean_q: 12.196841
 2282/5000: episode: 122, duration: 1.131s, episode steps: 68, steps per second: 60, episode reward: 68.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.087 [-0.290, 1.172], loss: 2.261793, mean_absolute_error: 6.642526, mean_q: 12.531883
 2304/5000: episode: 123, duration: 0.368s, episode steps: 22, steps per second: 60, episode reward: 22.000, mean reward: 1.000 [1.000, 1.000],

 3418/5000: episode: 149, duration: 0.686s, episode steps: 41, steps per second: 60, episode reward: 41.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.488 [0.000, 1.000], mean observation: -0.112 [-0.870, 0.590], loss: 3.350235, mean_absolute_error: 8.724845, mean_q: 16.707369
 3466/5000: episode: 150, duration: 0.796s, episode steps: 48, steps per second: 60, episode reward: 48.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.049 [-0.414, 0.992], loss: 3.363184, mean_absolute_error: 8.742322, mean_q: 16.696054
 3517/5000: episode: 151, duration: 0.849s, episode steps: 51, steps per second: 60, episode reward: 51.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.510 [0.000, 1.000], mean observation: 0.089 [-0.274, 0.946], loss: 3.819530, mean_absolute_error: 8.809394, mean_q: 16.735388
 3565/5000: episode: 152, duration: 0.804s, episode steps: 48, steps per second: 60, episode reward: 48.000, mean reward: 1.000 [1.000, 1.000], 

 4937/5000: episode: 178, duration: 1.035s, episode steps: 62, steps per second: 60, episode reward: 62.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.532 [0.000, 1.000], mean observation: 0.105 [-0.267, 0.707], loss: 3.697470, mean_absolute_error: 11.146082, mean_q: 21.668655
 4997/5000: episode: 179, duration: 1.013s, episode steps: 60, steps per second: 59, episode reward: 60.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.130 [-0.268, 1.413], loss: 3.757988, mean_absolute_error: 11.263992, mean_q: 21.920809
done, took 90.059 seconds


<keras.callbacks.History at 0x205158d89b0>

In [5]:
dqn.test(env, nb_episodes=5, visualize=True)

Testing for 5 episodes ...
Episode 1: reward: 54.000, steps: 54
Episode 2: reward: 43.000, steps: 43
Episode 3: reward: 39.000, steps: 39
Episode 4: reward: 37.000, steps: 37
Episode 5: reward: 66.000, steps: 66


<keras.callbacks.History at 0x2051577bc88>