Skip to content

Commit

Permalink
work for tl2 tf2
Browse files Browse the repository at this point in the history
  • Loading branch information
quantumiracle committed May 16, 2019
1 parent f88f23e commit a6652b0
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 272 deletions.
38 changes: 6 additions & 32 deletions examples/reinforcement_learning/tutorial_atari_pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@
import gym
import tensorlayer as tl

## enable eager mode
tf.enable_eager_execution()


tf.logging.set_verbosity(tf.logging.DEBUG) # enable logging
tl.logging.set_verbosity(tl.logging.DEBUG)

# hyper-parameters
Expand All @@ -52,7 +48,7 @@
render = False # display the game environment
# resume = True # load existing policy network
model_file_name = "model_pong"
np.set_printoptions(threshold=np.nan)
np.set_printoptions(threshold=np.inf)


def prepro(I):
Expand All @@ -73,10 +69,7 @@ def prepro(I):
episode_number = 0

xs, ys, rs = [], [], []
# observation for training and inference
# t_states = tf.placeholder(tf.float32, shape=[None, D])
# policy network

def get_model(inputs_shape):
ni = tl.layers.Input(inputs_shape)
nn = tl.layers.Dense(n_units=H, act=tf.nn.relu, name='hidden')(ni)
Expand All @@ -85,22 +78,9 @@ def get_model(inputs_shape):
return M
model = get_model([None, D])
train_weights = model.trainable_weights
# probs = model(t_states, is_train=True).outputs
# sampling_prob = tf.nn.softmax(probs)

# t_actions = tf.placeholder(tf.int32, shape=[None])
# t_discount_rewards = tf.placeholder(tf.float32, shape=[None])
# loss = tl.rein.cross_entropy_reward_loss(probs, t_actions, t_discount_rewards)
optimizer = tf.train.RMSPropOptimizer(learning_rate, decay_rate)#.minimize(loss)

# with tf.Session() as sess:
# sess.run(tf.global_variables_initializer())
# if resume: TODO
# load_params = tl.files.load_npz(name=model_file_name+'.npz')
# tl.files.assign_params(sess, load_params, network)
# tl.files.load_and_assign_npz(sess, model_file_name + '.npz', network)
# network.print_params()
# network.print_layers()

optimizer = tf.optimizers.RMSprop(lr=learning_rate, decay=decay_rate)

model.train() # set model to train mode (in case you add dropout into the model)

start_time = time.time()
Expand All @@ -114,14 +94,12 @@ def get_model(inputs_shape):
x = x.reshape(1, D)
prev_x = cur_x

# prob = sess.run(sampling_prob, feed_dict={t_states: x})
_prob = model(x).outputs
_prob = model(x)
prob = tf.nn.softmax(_prob)

# action. 1: STOP 2: UP 3: DOWN
# action = np.random.choice([1,2,3], p=prob.flatten())
# action = tl.rein.choice_action_by_probs(prob.flatten(), [1, 2, 3])
# action = np.random.choice([1,2,3], p=prob.numpy())
action = tl.rein.choice_action_by_probs(prob[0].numpy(), [1, 2, 3])

observation, reward, done, _ = env.step(action)
Expand All @@ -145,12 +123,8 @@ def get_model(inputs_shape):

xs, ys, rs = [], [], []

# sess.run(train_op, feed_dict={t_states: epx, t_actions: epy, t_discount_rewards: disR})
# t_actions = tf.placeholder(tf.int32, shape=[None])
# t_discount_rewards = tf.placeholder(tf.float32, shape=[None])
# loss = tl.rein.cross_entropy_reward_loss(probs, t_actions, t_discount_rewards)
with tf.GradientTape() as tape:
_prob = model(epx).outputs
_prob = model(epx)
_loss = tl.rein.cross_entropy_reward_loss(_prob, epy, disR)
grad = tape.gradient(_loss, train_weights)
optimizer.apply_gradients(zip(grad, train_weights))
Expand Down

0 comments on commit a6652b0

Please sign in to comment.