In [1]:
import numpy as np
import gym

In [2]:
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory

In [4]:
def train_cartpole_nnet():
    #from test import CartPoleContEnv

    ENV_NAME = 'CartPole-v0'
    gym.undo_logger_setup()

    # Get the environment and extract the number of actions.
    env = gym.make(ENV_NAME)

    np.random.seed(123)
    env.seed(123)
    nb_actions = env.action_space.n

    # Next, we build a very simple model.
    model = Sequential()
    model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
    model.add(Dense(16))
    model.add(Activation('relu'))
    model.add(Dense(16))
    model.add(Activation('relu'))
    model.add(Dense(16))
    model.add(Activation('relu'))
    model.add(Dense(nb_actions))

    # Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
    # even the metrics!
    memory = SequentialMemory(limit=60000, window_length=1)
    policy = BoltzmannQPolicy()
    dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=100,
          target_model_update=1e-2, policy=policy)
    dqn.compile(Adam(lr=1e-3), metrics=['mae'])

    # Okay, now it's time to learn something! We visualize the training here for show, but this
    # slows down training quite a lot. You can always safely abort the training prematurely using
    # Ctrl + C.
    dqn.fit(env, nb_steps=60000, visualize=False, verbose=2)

    # get model weights
    weights = model.get_weights()

    # Finally, evaluate our algorithm for 5 episodes.
    dqn.test(env, nb_episodes=5, visualize=True)
    return weights


In [5]:
def weights_to_txt(weights, f):
    shape = weights.shape
    n_rows = shape[0]
    n_cols = shape[1]
    weights = np.reshape((n_cols, n_rows)) # do this in a prettier way w/transpose
    for r in range(0,n_cols):
            for c in range(0,n_rows - 1):
                    f.write(str(weights[r,c])+", ")
            f.write(str(weights[r,c])+"\n")

In [6]:
def bias_to_txt(bias, f):
    for i in range(0,len(bias)):
        f.write(str(bias[i])+"\n")

In [7]:
'''
Write cartpole nnet to text so that its format is the same as that of small_nnet
'''
def cartpole_nnet_to_txt():
    layers = train_cartpole_nnet()
    f = open("cartpole_nnet.txt", "w+")
    f.write("3\n")
    f.write("4, 16, 16, 16, 2\n")
    for i in range(5):
        f.write("0\n")
    for i,layer in enumerate(layers):
        if i % 2 == 0:
            weights_to_txt(layer, f)
        else:
            bias_to_txt(layer, f)

In [10]:
weights = cartpole_nnet_to_txt()



[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
Training for 60000 steps ...
    31/60000: episode: 1, duration: 0.096s, episode steps: 31, steps per second: 322, episode reward: 31.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.419 [0.000, 1.000], mean observation: 0.015 [-1.187, 1.822], loss: --, mean_absolute_error: --, mean_q: --
    54/60000: episode: 2, duration: 0.020s, episode steps: 23, steps per second: 1144, episode reward: 23.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.391 [0.000, 1.000], mean observation: 0.099 [-0.940, 1.811], loss: --, mean_absolute_error: --, mean_q: --
    79/60000: episode: 3, duration: 0.023s, episode steps: 25, steps per second: 1069, episode reward: 25.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.480 [0.000, 1.000], mean observation: 0.069 [-0.545, 1.124], loss: --, mean_absolute_error: --, mean_q: --
    99/60000: episode: 4, duration: 0.018s, episode steps: 2

   626/60000: episode: 31, duration: 0.201s, episode steps: 38, steps per second: 189, episode reward: 38.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: -0.158 [-1.354, 0.870], loss: 0.141783, mean_absolute_error: 2.356678, mean_q: 4.465603
   661/60000: episode: 32, duration: 0.159s, episode steps: 35, steps per second: 220, episode reward: 35.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.429 [0.000, 1.000], mean observation: -0.004 [-1.208, 1.627], loss: 0.182793, mean_absolute_error: 2.472494, mean_q: 4.678588
   679/60000: episode: 33, duration: 0.108s, episode steps: 18, steps per second: 166, episode reward: 18.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.611 [0.000, 1.000], mean observation: -0.074 [-1.769, 1.011], loss: 0.083976, mean_absolute_error: 2.596426, mean_q: 5.038244
   722/60000: episode: 34, duration: 0.190s, episode steps: 43, steps per second: 226, episode reward: 43.000, mean reward: 1.000 [1.000, 1

  2768/60000: episode: 60, duration: 0.801s, episode steps: 187, steps per second: 233, episode reward: 187.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.471 [0.000, 1.000], mean observation: -0.348 [-2.416, 1.255], loss: 0.952939, mean_absolute_error: 11.206593, mean_q: 22.859228
  2968/60000: episode: 61, duration: 0.857s, episode steps: 200, steps per second: 233, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.485 [0.000, 1.000], mean observation: -0.271 [-1.909, 1.013], loss: 1.110746, mean_absolute_error: 12.104518, mean_q: 24.673775
  3168/60000: episode: 62, duration: 0.842s, episode steps: 200, steps per second: 238, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.042 [-1.278, 1.197], loss: 1.375261, mean_absolute_error: 13.061201, mean_q: 26.563278
  3368/60000: episode: 63, duration: 0.871s, episode steps: 200, steps per second: 229, episode reward: 200.000, mean reward: 1.

  8428/60000: episode: 89, duration: 0.879s, episode steps: 200, steps per second: 228, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.485 [0.000, 1.000], mean observation: -0.150 [-1.342, 1.043], loss: 4.155929, mean_absolute_error: 30.800024, mean_q: 62.279736
  8628/60000: episode: 90, duration: 0.846s, episode steps: 200, steps per second: 236, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.480 [0.000, 1.000], mean observation: -0.292 [-2.035, 1.067], loss: 4.065501, mean_absolute_error: 31.096691, mean_q: 62.856640
  8828/60000: episode: 91, duration: 0.874s, episode steps: 200, steps per second: 229, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.485 [0.000, 1.000], mean observation: -0.143 [-1.258, 1.074], loss: 4.669473, mean_absolute_error: 31.580479, mean_q: 63.903099
  9028/60000: episode: 92, duration: 0.879s, episode steps: 200, steps per second: 228, episode reward: 200.000, mean reward: 1

 14047/60000: episode: 118, duration: 0.874s, episode steps: 200, steps per second: 229, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.480 [0.000, 1.000], mean observation: -0.200 [-1.714, 1.526], loss: 7.375447, mean_absolute_error: 38.528652, mean_q: 77.552094
 14247/60000: episode: 119, duration: 0.857s, episode steps: 200, steps per second: 234, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.485 [0.000, 1.000], mean observation: -0.153 [-1.675, 1.592], loss: 4.230103, mean_absolute_error: 38.705849, mean_q: 78.191727
 14447/60000: episode: 120, duration: 0.859s, episode steps: 200, steps per second: 233, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.525 [0.000, 1.000], mean observation: 0.283 [-1.198, 2.228], loss: 5.166572, mean_absolute_error: 38.793324, mean_q: 78.362175
 14616/60000: episode: 121, duration: 0.729s, episode steps: 169, steps per second: 232, episode reward: 169.000, mean reward

 19307/60000: episode: 147, duration: 0.845s, episode steps: 200, steps per second: 237, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.490 [0.000, 1.000], mean observation: -0.197 [-1.534, 1.238], loss: 8.354635, mean_absolute_error: 39.751591, mean_q: 80.115555
 19454/60000: episode: 148, duration: 0.623s, episode steps: 147, steps per second: 236, episode reward: 147.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.531 [0.000, 1.000], mean observation: 0.334 [-0.896, 2.022], loss: 5.478581, mean_absolute_error: 40.205608, mean_q: 81.094505
 19641/60000: episode: 149, duration: 0.801s, episode steps: 187, steps per second: 234, episode reward: 187.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.524 [0.000, 1.000], mean observation: 0.317 [-1.489, 2.380], loss: 5.708466, mean_absolute_error: 39.826935, mean_q: 80.330154
 19841/60000: episode: 150, duration: 0.885s, episode steps: 200, steps per second: 226, episode reward: 200.000, mean reward:

 24419/60000: episode: 176, duration: 0.944s, episode steps: 200, steps per second: 212, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: -0.063 [-1.179, 1.358], loss: 5.261411, mean_absolute_error: 40.349491, mean_q: 81.347595
 24616/60000: episode: 177, duration: 1.057s, episode steps: 197, steps per second: 186, episode reward: 197.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.538 [0.000, 1.000], mean observation: 0.319 [-1.005, 2.774], loss: 3.008866, mean_absolute_error: 40.700863, mean_q: 82.152710
 24804/60000: episode: 178, duration: 1.243s, episode steps: 188, steps per second: 151, episode reward: 188.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.516 [0.000, 1.000], mean observation: 0.328 [-0.832, 2.356], loss: 5.785369, mean_absolute_error: 40.415112, mean_q: 81.433304
 24951/60000: episode: 179, duration: 0.692s, episode steps: 147, steps per second: 212, episode reward: 147.000, mean reward:

 29464/60000: episode: 205, duration: 1.003s, episode steps: 168, steps per second: 167, episode reward: 168.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.536 [0.000, 1.000], mean observation: 0.358 [-0.868, 2.283], loss: 3.420611, mean_absolute_error: 41.967278, mean_q: 84.785561
 29633/60000: episode: 206, duration: 1.244s, episode steps: 169, steps per second: 136, episode reward: 169.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.527 [0.000, 1.000], mean observation: 0.363 [-0.767, 2.286], loss: 4.884663, mean_absolute_error: 41.643204, mean_q: 84.280647
 29820/60000: episode: 207, duration: 0.996s, episode steps: 187, steps per second: 188, episode reward: 187.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.519 [0.000, 1.000], mean observation: 0.327 [-0.992, 2.282], loss: 3.025326, mean_absolute_error: 41.732155, mean_q: 84.553238
 29996/60000: episode: 208, duration: 0.965s, episode steps: 176, steps per second: 182, episode reward: 176.000, mean reward: 

 35043/60000: episode: 234, duration: 0.896s, episode steps: 196, steps per second: 219, episode reward: 196.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.520 [0.000, 1.000], mean observation: 0.280 [-0.909, 2.029], loss: 4.045689, mean_absolute_error: 45.252659, mean_q: 91.561195
 35243/60000: episode: 235, duration: 0.926s, episode steps: 200, steps per second: 216, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.515 [0.000, 1.000], mean observation: 0.101 [-0.948, 1.321], loss: 4.366701, mean_absolute_error: 44.964787, mean_q: 90.948341
 35443/60000: episode: 236, duration: 0.860s, episode steps: 200, steps per second: 232, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.515 [0.000, 1.000], mean observation: 0.113 [-0.942, 1.064], loss: 4.367002, mean_absolute_error: 44.945084, mean_q: 91.079857
 35643/60000: episode: 237, duration: 0.881s, episode steps: 200, steps per second: 227, episode reward: 200.000, mean reward: 

 40843/60000: episode: 263, duration: 0.858s, episode steps: 200, steps per second: 233, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: -0.028 [-1.120, 1.050], loss: 13.727439, mean_absolute_error: 48.556984, mean_q: 98.146965
 41043/60000: episode: 264, duration: 0.853s, episode steps: 200, steps per second: 235, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: -0.022 [-1.113, 0.823], loss: 8.249223, mean_absolute_error: 49.431446, mean_q: 99.936668
 41243/60000: episode: 265, duration: 0.865s, episode steps: 200, steps per second: 231, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: -0.029 [-1.122, 1.289], loss: 8.964634, mean_absolute_error: 48.968945, mean_q: 99.051735
 41443/60000: episode: 266, duration: 0.848s, episode steps: 200, steps per second: 236, episode reward: 200.000, mean rewa

 46443/60000: episode: 291, duration: 0.857s, episode steps: 200, steps per second: 233, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: -0.024 [-0.909, 0.941], loss: 17.694050, mean_absolute_error: 53.332245, mean_q: 107.383575
 46643/60000: episode: 292, duration: 1.117s, episode steps: 200, steps per second: 179, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: -0.025 [-1.004, 0.804], loss: 9.503999, mean_absolute_error: 53.446495, mean_q: 107.792648
 46843/60000: episode: 293, duration: 1.240s, episode steps: 200, steps per second: 161, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: -0.026 [-1.039, 0.849], loss: 9.999101, mean_absolute_error: 53.676289, mean_q: 108.291580
 47043/60000: episode: 294, duration: 1.175s, episode steps: 200, steps per second: 170, episode reward: 200.000, mean r

 52043/60000: episode: 319, duration: 0.865s, episode steps: 200, steps per second: 231, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: 0.065 [-1.136, 1.222], loss: 21.559299, mean_absolute_error: 54.314842, mean_q: 108.723679
 52243/60000: episode: 320, duration: 0.858s, episode steps: 200, steps per second: 233, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.041 [-1.306, 1.338], loss: 25.911543, mean_absolute_error: 54.347725, mean_q: 108.699356
 52443/60000: episode: 321, duration: 0.851s, episode steps: 200, steps per second: 235, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.510 [0.000, 1.000], mean observation: 0.085 [-1.261, 1.184], loss: 20.164495, mean_absolute_error: 53.578178, mean_q: 107.162483
 52643/60000: episode: 322, duration: 0.870s, episode steps: 200, steps per second: 230, episode reward: 200.000, mean re

 57800/60000: episode: 348, duration: 0.901s, episode steps: 200, steps per second: 222, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.010 [-1.486, 1.456], loss: 19.606665, mean_absolute_error: 47.922676, mean_q: 95.683655
 58000/60000: episode: 349, duration: 0.884s, episode steps: 200, steps per second: 226, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.495 [0.000, 1.000], mean observation: 0.020 [-1.324, 1.186], loss: 9.468140, mean_absolute_error: 48.072128, mean_q: 96.391983
 58200/60000: episode: 350, duration: 0.882s, episode steps: 200, steps per second: 227, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.035 [-1.382, 1.339], loss: 11.498200, mean_absolute_error: 47.957520, mean_q: 96.162148
 58400/60000: episode: 351, duration: 0.863s, episode steps: 200, steps per second: 232, episode reward: 200.000, mean reward

NotImplementedError: abstract