In [1]:
import numpy as np
import gym

In [2]:
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory

In [4]:
def train_cartpole_nnet():
    #from test import CartPoleContEnv

    ENV_NAME = 'CartPole-v0'
    gym.undo_logger_setup()

    # Get the environment and extract the number of actions.
    env = gym.make(ENV_NAME)

    np.random.seed(123)
    env.seed(123)
    nb_actions = env.action_space.n

    # Next, we build a very simple model.
    model = Sequential()
    model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
    model.add(Dense(16))
    model.add(Activation('relu'))
    model.add(Dense(16))
    model.add(Activation('relu'))
    model.add(Dense(16))
    model.add(Activation('relu'))
    model.add(Dense(nb_actions))

    # Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
    # even the metrics!
    memory = SequentialMemory(limit=60000, window_length=1)
    policy = BoltzmannQPolicy()
    dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=100,
          target_model_update=1e-2, policy=policy)
    dqn.compile(Adam(lr=1e-3), metrics=['mae'])

    # Okay, now it's time to learn something! We visualize the training here for show, but this
    # slows down training quite a lot. You can always safely abort the training prematurely using
    # Ctrl + C.
    dqn.fit(env, nb_steps=60000, visualize=False, verbose=2)

    # get model weights
    weights = model.get_weights()

    # Finally, evaluate our algorithm for 5 episodes.
    dqn.test(env, nb_episodes=5, visualize=True)
    return weights


In [5]:
def weights_to_txt(weights, f):
    shape = weights.shape
    n_rows = shape[0]
    n_cols = shape[1]
    weights = np.reshape((n_cols, n_rows))
    for r in range(0,n_cols):
            for c in range(0,n_rows - 1):
                    f.write(str(weights[r,c])+", ")
            f.write(str(weights[r,c])+"\n")

In [6]:
def bias_to_txt(bias, f):
    for i in range(0,len(bias)):
        f.write(str(bias[i])+"\n")

In [8]:
'''
Write cartpole nnet to text so that its format is the same as that of small_nnet
'''
def cartpole_nnet_to_txt():
    layers = train_cartpole_nnet()
    f = open("cartpole_nnet.txt", "w+")
    f.write("4\n")
    f.write("4, 16, 16, 16, 2\n")
    for i in range(5):
        f.write("0\n")
    for i,layer in enumerate(layers):
        if i % 2 == 0:
            weights_to_txt(layer, f)
        else:
            bias_to_txt(layer, f)

In [9]:
weights = cartpole_nnet_to_txt()



[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
Training for 60000 steps ...
    31/60000: episode: 1, duration: 0.062s, episode steps: 31, steps per second: 502, episode reward: 31.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.419 [0.000, 1.000], mean observation: 0.015 [-1.187, 1.822], loss: --, mean_absolute_error: --, mean_q: --
    54/60000: episode: 2, duration: 0.018s, episode steps: 23, steps per second: 1282, episode reward: 23.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.391 [0.000, 1.000], mean observation: 0.099 [-0.940, 1.811], loss: --, mean_absolute_error: --, mean_q: --
    79/60000: episode: 3, duration: 0.020s, episode steps: 25, steps per second: 1234, episode reward: 25.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.480 [0.000, 1.000], mean observation: 0.069 [-0.545, 1.124], loss: --, mean_absolute_error: --, mean_q: --
    99/60000: episode: 4, duration: 0.017s, episode steps: 2

   677/60000: episode: 32, duration: 0.104s, episode steps: 17, steps per second: 164, episode reward: 17.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.647 [0.000, 1.000], mean observation: -0.090 [-1.678, 0.941], loss: 0.129312, mean_absolute_error: 2.513361, mean_q: 4.780386
   698/60000: episode: 33, duration: 0.090s, episode steps: 21, steps per second: 232, episode reward: 21.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.524 [0.000, 1.000], mean observation: -0.079 [-1.221, 0.805], loss: 0.123624, mean_absolute_error: 2.586859, mean_q: 4.962061
   719/60000: episode: 34, duration: 0.094s, episode steps: 21, steps per second: 222, episode reward: 21.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.429 [0.000, 1.000], mean observation: 0.092 [-0.746, 1.574], loss: 0.138616, mean_absolute_error: 2.665992, mean_q: 5.126945
   762/60000: episode: 35, duration: 0.238s, episode steps: 43, steps per second: 181, episode reward: 43.000, mean reward: 1.000 [1.000, 1.

  3642/60000: episode: 61, duration: 1.338s, episode steps: 200, steps per second: 149, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.495 [0.000, 1.000], mean observation: -0.017 [-1.258, 1.152], loss: 1.430736, mean_absolute_error: 15.365068, mean_q: 31.303343
  3798/60000: episode: 62, duration: 1.100s, episode steps: 156, steps per second: 142, episode reward: 156.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.462 [0.000, 1.000], mean observation: -0.383 [-2.407, 0.970], loss: 1.993280, mean_absolute_error: 16.192730, mean_q: 32.923141
  3978/60000: episode: 63, duration: 1.431s, episode steps: 180, steps per second: 126, episode reward: 180.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.467 [0.000, 1.000], mean observation: -0.355 [-2.420, 0.900], loss: 1.808962, mean_absolute_error: 16.795731, mean_q: 34.183064
  4147/60000: episode: 64, duration: 1.172s, episode steps: 169, steps per second: 144, episode reward: 169.000, mean reward: 1

  8819/60000: episode: 90, duration: 1.002s, episode steps: 169, steps per second: 169, episode reward: 169.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.462 [0.000, 1.000], mean observation: -0.386 [-2.531, 1.179], loss: 2.951592, mean_absolute_error: 30.318909, mean_q: 61.339157
  8992/60000: episode: 91, duration: 0.822s, episode steps: 173, steps per second: 211, episode reward: 173.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.462 [0.000, 1.000], mean observation: -0.365 [-2.535, 0.804], loss: 1.652319, mean_absolute_error: 30.429821, mean_q: 61.692703
  9192/60000: episode: 92, duration: 1.120s, episode steps: 200, steps per second: 179, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.470 [0.000, 1.000], mean observation: -0.320 [-2.309, 1.175], loss: 3.946754, mean_absolute_error: 30.645880, mean_q: 62.004505
  9359/60000: episode: 93, duration: 1.545s, episode steps: 167, steps per second: 108, episode reward: 167.000, mean reward: 1

 14257/60000: episode: 119, duration: 0.883s, episode steps: 200, steps per second: 226, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.475 [0.000, 1.000], mean observation: -0.303 [-2.119, 1.174], loss: 2.572878, mean_absolute_error: 35.663273, mean_q: 72.044075
 14446/60000: episode: 120, duration: 1.146s, episode steps: 189, steps per second: 165, episode reward: 189.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.471 [0.000, 1.000], mean observation: -0.342 [-2.314, 1.558], loss: 2.897006, mean_absolute_error: 35.887707, mean_q: 72.354416
 14639/60000: episode: 121, duration: 1.182s, episode steps: 193, steps per second: 163, episode reward: 193.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.477 [0.000, 1.000], mean observation: -0.348 [-2.433, 1.445], loss: 1.997161, mean_absolute_error: 35.568996, mean_q: 71.852516
 14839/60000: episode: 122, duration: 1.773s, episode steps: 200, steps per second: 113, episode reward: 200.000, mean rewar

 19871/60000: episode: 148, duration: 1.038s, episode steps: 200, steps per second: 193, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.029 [-1.199, 1.510], loss: 3.443650, mean_absolute_error: 38.714245, mean_q: 77.978409
 20052/60000: episode: 149, duration: 0.939s, episode steps: 181, steps per second: 193, episode reward: 181.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.541 [0.000, 1.000], mean observation: 0.348 [-1.436, 2.738], loss: 4.629624, mean_absolute_error: 38.132523, mean_q: 76.908417
 20252/60000: episode: 150, duration: 1.198s, episode steps: 200, steps per second: 167, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: -0.031 [-1.159, 1.269], loss: 5.357780, mean_absolute_error: 38.177036, mean_q: 76.817322
 20424/60000: episode: 151, duration: 0.807s, episode steps: 172, steps per second: 213, episode reward: 172.000, mean reward:

 25086/60000: episode: 177, duration: 0.985s, episode steps: 167, steps per second: 169, episode reward: 167.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.539 [0.000, 1.000], mean observation: 0.362 [-1.075, 2.788], loss: 5.215618, mean_absolute_error: 38.766438, mean_q: 78.074104
 25286/60000: episode: 178, duration: 0.942s, episode steps: 200, steps per second: 212, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: -0.019 [-1.442, 1.325], loss: 3.559634, mean_absolute_error: 38.894726, mean_q: 78.441055
 25466/60000: episode: 179, duration: 0.995s, episode steps: 180, steps per second: 181, episode reward: 180.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.539 [0.000, 1.000], mean observation: 0.347 [-0.881, 2.744], loss: 4.637564, mean_absolute_error: 39.207375, mean_q: 78.983917
 25652/60000: episode: 180, duration: 0.943s, episode steps: 186, steps per second: 197, episode reward: 186.000, mean reward:

 30532/60000: episode: 206, duration: 0.777s, episode steps: 188, steps per second: 242, episode reward: 188.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.516 [0.000, 1.000], mean observation: 0.208 [-1.603, 2.302], loss: 2.303527, mean_absolute_error: 40.188740, mean_q: 81.286888
 30732/60000: episode: 207, duration: 0.858s, episode steps: 200, steps per second: 233, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.510 [0.000, 1.000], mean observation: 0.202 [-1.372, 2.405], loss: 3.803193, mean_absolute_error: 39.853279, mean_q: 80.527756
 30932/60000: episode: 208, duration: 0.832s, episode steps: 200, steps per second: 240, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.011 [-1.163, 1.453], loss: 3.470919, mean_absolute_error: 40.239456, mean_q: 81.375923
 31132/60000: episode: 209, duration: 0.813s, episode steps: 200, steps per second: 246, episode reward: 200.000, mean reward: 

 36332/60000: episode: 235, duration: 0.838s, episode steps: 200, steps per second: 239, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.029 [-0.817, 0.799], loss: 8.699410, mean_absolute_error: 47.563148, mean_q: 96.235138
 36532/60000: episode: 236, duration: 1.119s, episode steps: 200, steps per second: 179, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.028 [-0.716, 0.667], loss: 10.286829, mean_absolute_error: 48.480637, mean_q: 98.071312
 36732/60000: episode: 237, duration: 1.270s, episode steps: 200, steps per second: 157, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.025 [-0.791, 0.754], loss: 13.554796, mean_absolute_error: 48.739403, mean_q: 98.399002
 36932/60000: episode: 238, duration: 0.829s, episode steps: 200, steps per second: 241, episode reward: 200.000, mean reward

 41932/60000: episode: 263, duration: 0.841s, episode steps: 200, steps per second: 238, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.084 [-1.132, 1.004], loss: 24.184238, mean_absolute_error: 55.820736, mean_q: 112.449219
 42132/60000: episode: 264, duration: 0.842s, episode steps: 200, steps per second: 237, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.495 [0.000, 1.000], mean observation: 0.071 [-0.949, 1.089], loss: 14.577908, mean_absolute_error: 56.100372, mean_q: 113.000725
 42332/60000: episode: 265, duration: 0.835s, episode steps: 200, steps per second: 239, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.086 [-0.827, 1.118], loss: 14.016582, mean_absolute_error: 56.432686, mean_q: 113.692406
 42532/60000: episode: 266, duration: 0.845s, episode steps: 200, steps per second: 237, episode reward: 200.000, mean re

 47169/60000: episode: 291, duration: 1.031s, episode steps: 172, steps per second: 167, episode reward: 172.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.512 [0.000, 1.000], mean observation: 0.310 [-1.336, 1.594], loss: 24.552502, mean_absolute_error: 58.331913, mean_q: 117.064224
 47353/60000: episode: 292, duration: 2.920s, episode steps: 184, steps per second: 63, episode reward: 184.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: 0.197 [-0.928, 1.291], loss: 17.603958, mean_absolute_error: 57.882122, mean_q: 116.400749
 47486/60000: episode: 293, duration: 0.791s, episode steps: 133, steps per second: 168, episode reward: 133.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.511 [0.000, 1.000], mean observation: 0.261 [-1.229, 1.360], loss: 17.591984, mean_absolute_error: 58.563999, mean_q: 117.873932
 47669/60000: episode: 294, duration: 1.013s, episode steps: 183, steps per second: 181, episode reward: 183.000, mean rew

 51944/60000: episode: 319, duration: 0.883s, episode steps: 200, steps per second: 227, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.525 [0.000, 1.000], mean observation: 0.332 [-0.861, 2.315], loss: 12.244852, mean_absolute_error: 53.340523, mean_q: 106.882721
 52144/60000: episode: 320, duration: 0.894s, episode steps: 200, steps per second: 224, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.485 [0.000, 1.000], mean observation: -0.028 [-1.474, 1.430], loss: 21.545027, mean_absolute_error: 52.900059, mean_q: 105.667992
 52264/60000: episode: 321, duration: 0.615s, episode steps: 120, steps per second: 195, episode reward: 120.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.525 [0.000, 1.000], mean observation: 0.428 [-1.113, 1.910], loss: 7.633661, mean_absolute_error: 52.842766, mean_q: 106.182854
 52448/60000: episode: 322, duration: 1.151s, episode steps: 184, steps per second: 160, episode reward: 184.000, mean re

 56903/60000: episode: 347, duration: 0.596s, episode steps: 144, steps per second: 242, episode reward: 144.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.507 [0.000, 1.000], mean observation: 0.229 [-1.379, 1.369], loss: 9.912623, mean_absolute_error: 46.189026, mean_q: 92.241035
 57103/60000: episode: 348, duration: 0.837s, episode steps: 200, steps per second: 239, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.495 [0.000, 1.000], mean observation: -0.059 [-1.479, 1.578], loss: 15.702417, mean_absolute_error: 45.211514, mean_q: 90.110977
 57303/60000: episode: 349, duration: 0.835s, episode steps: 200, steps per second: 240, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.495 [0.000, 1.000], mean observation: -0.048 [-1.495, 1.664], loss: 12.863317, mean_absolute_error: 45.199547, mean_q: 90.285645
 57503/60000: episode: 350, duration: 0.835s, episode steps: 200, steps per second: 239, episode reward: 200.000, mean rewa

NotImplementedError: abstract