Skip to content

Commit

Permalink
clean up a3c code - pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
yrlu committed Apr 29, 2017
1 parent b6505e9 commit b4db54c
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 191 deletions.
65 changes: 38 additions & 27 deletions A3C/ac_net.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
'''Actor-critic network class for a3c'''
import numpy as np
import tensorflow as tf

import tf_utils
import numpy as np


class AC_Net(object):
'''Actor-critic network class for a3c'''


def __init__(self, state_size, action_size, lr,
name, n_h1=400, n_h2=300, global_name='global'):
def __init__(self, state_size, action_size, lr,
name, n_h1=400, n_h2=300, global_name='global'):

self.state_size = state_size
self.action_size = action_size
Expand All @@ -16,13 +18,14 @@ def __init__(self, state_size, action_size, lr,
self.n_h2 = n_h2

self.optimizer = tf.train.AdamOptimizer(lr)
self.input_s, self.input_a, self.advantage, self.target_v, self.policy, self.value, self.action_est, self.model_variables = self._build_network(name)
self.input_s, self.input_a, self.advantage, self.target_v, self.policy, self.value, self.action_est, self.model_variables = self._build_network(
name)

# 0.5, 0.2, 1.0
self.value_loss = 0.5 * tf.reduce_sum(tf.square(self.target_v - tf.reshape(self.value,[-1])))
self.value_loss = 0.5 * tf.reduce_sum(tf.square(self.target_v - tf.reshape(self.value, [-1])))
self.entropy_loss = 1.0 * tf.reduce_sum(self.policy * tf.log(self.policy))
self.policy_loss = 1.0 * tf.reduce_sum(-tf.log(self.action_est) * self.advantage)
self.l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in self.model_variables])
self.l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in self.model_variables])
# self.loss = 0.5 * self.value_loss + self.policy_loss + 0.2 * self.entropy_loss
self.loss = self.value_loss + self.policy_loss + self.entropy_loss
self.gradients = tf.gradients(self.loss, self.model_variables)
Expand All @@ -31,44 +34,52 @@ def __init__(self, state_size, action_size, lr,
global_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, global_name)
self.apply_gradients = self.optimizer.apply_gradients(zip(self.gradients, global_variables))


def _build_network(self, name):
input_s = tf.placeholder(tf.float32, [None, self.state_size])
input_a = tf.placeholder(tf.int32, [None])
advantage = tf.placeholder(tf.float32, [None])
target_v = tf.placeholder(tf.float32, [None])

with tf.variable_scope(name):
layer_1 = tf_utils.fc(input_s, self.n_h1, scope="fc1", activation_fn=tf.nn.relu,
initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_IN"))
layer_2 = tf_utils.fc(layer_1, self.n_h2, scope="fc2", activation_fn=tf.nn.relu,
initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_IN"))
policy = tf_utils.fc(layer_2, self.action_size, activation_fn=tf.nn.softmax,
scope="policy", initializer=tf_utils.normalized_columns_initializer(0.01))
layer_1 = tf_utils.fc(
input_s,
self.n_h1,
scope="fc1",
activation_fn=tf.nn.relu,
initializer=tf.contrib.layers.variance_scaling_initializer(
mode="FAN_IN"))
layer_2 = tf_utils.fc(
layer_1,
self.n_h2,
scope="fc2",
activation_fn=tf.nn.relu,
initializer=tf.contrib.layers.variance_scaling_initializer(
mode="FAN_IN"))
policy = tf_utils.fc(
layer_2,
self.action_size,
activation_fn=tf.nn.softmax,
scope="policy",
initializer=tf_utils.normalized_columns_initializer(0.01))
value = tf_utils.fc(layer_2, 1, activation_fn=None,
scope="value", initializer=tf_utils.normalized_columns_initializer(1.0))
scope="value", initializer=tf_utils.normalized_columns_initializer(1.0))

action_mask = tf.one_hot(input_a, self.action_size, 1.0, 0.0)
action_est = tf.reduce_sum(policy * action_mask, 1)

model_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=name)
return input_s, input_a, advantage, target_v, policy, value, action_est, model_variables


def get_action(self, state, sess):
state = np.reshape(state,[-1, self.state_size])
pi = sess.run(self.policy, feed_dict={self.input_s: state})
return np.random.choice(range(self.action_size), p=pi[0])

state = np.reshape(state, [-1, self.state_size])
policy = sess.run(self.policy, feed_dict={self.input_s: state})
return np.random.choice(range(self.action_size), p=policy[0])

def predict_policy(self, state, sess):
state = np.reshape(state,[-1, self.state_size])
pi = sess.run(self.policy, feed_dict={self.input_s: state})
return pi[0]

state = np.reshape(state, [-1, self.state_size])
policy = sess.run(self.policy, feed_dict={self.input_s: state})
return policy[0]

def predict_value(self, state, sess):
state = np.reshape(state,[-1, self.state_size])
state = np.reshape(state, [-1, self.state_size])
return sess.run(self.value, feed_dict={self.input_s: state})


84 changes: 46 additions & 38 deletions A3C/acrobot_a3c.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,70 @@
import tensorflow as tf
import gym
'''Example of A3C running on Acrobot environment
'''
import argparse
import time
import threading
import tensorflow as tf
import gym
# import multiprocessing
import argparse

import ac_net
import worker

parser = argparse.ArgumentParser(description=None)
parser.add_argument('-d', '--device', default='cpu', type=str, help='choose device: cpu/gpu')
parser.add_argument('-e', '--episodes', default=500, type=int, help='number of episodes')
parser.add_argument('-w', '--workers', default=4, type=int, help='number of workers')
parser.add_argument('-l', '--log_dir', default='acrobot_logs', type=str, help='log directory')
args = parser.parse_args()
print(args)

PARSER = argparse.ArgumentParser(description=None)
PARSER.add_argument('-d', '--device', default='cpu', type=str, help='choose device: cpu/gpu')
PARSER.add_argument('-e', '--episodes', default=500, type=int, help='number of episodes')
PARSER.add_argument('-w', '--workers', default=4, type=int, help='number of workers')
PARSER.add_argument('-l', '--log_dir', default='acrobot_logs', type=str, help='log directory')
ARGS = PARSER.parse_args()
print ARGS

DEVICE = args.device
DEVICE = ARGS.device
STATE_SIZE = 6
ACTION_SIZE = 3
LEARNING_RATE = 0.0001
GAMMA = 0.99
T_MAX = 5
# NUM_WORKERS = multiprocessing.cpu_count()
NUM_WORKERS = args.workers
NUM_EPISODES = args.episodes
LOG_DIR = args.log_dir
NUM_WORKERS = ARGS.workers
NUM_EPISODES = ARGS.episodes
LOG_DIR = ARGS.log_dir

N_H1 = 300
N_H2 = 300

tf.reset_default_graph()

history = []

with tf.device('/{}:0'.format(DEVICE)):
sess = tf.Session()
global_model = ac_net.AC_Net(STATE_SIZE, ACTION_SIZE, LEARNING_RATE, 'global', n_h1=N_H1, n_h2=N_H2)
workers = []
for i in xrange(NUM_WORKERS):
env = gym.make('Acrobot-v1')
env._max_episode_steps = 3000
workers.append(worker.Worker(env,
state_size=STATE_SIZE, action_size=ACTION_SIZE,
worker_name='worker_{}'.format(i), global_name='global',
lr=LEARNING_RATE, gamma=GAMMA, t_max=T_MAX, sess=sess,
history=history, n_h1=N_H1, n_h2=N_H2, logdir=LOG_DIR))

sess.run(tf.global_variables_initializer())

for worker in workers:
worker_work = lambda: worker.work(NUM_EPISODES)
t = threading.Thread(target=worker_work)
t.start()
def main():
'''Example of A3C running on Acrobot environment'''
tf.reset_default_graph()

history = []

with tf.device('/{}:0'.format(DEVICE)):
sess = tf.Session()
global_model = ac_net.AC_Net(
STATE_SIZE,
ACTION_SIZE,
LEARNING_RATE,
'global',
n_h1=N_H1,
n_h2=N_H2)
workers = []
for i in xrange(NUM_WORKERS):
env = gym.make('Acrobot-v1')
env._max_episode_steps = 3000
workers.append(worker.Worker(env,
state_size=STATE_SIZE, action_size=ACTION_SIZE,
worker_name='worker_{}'.format(i), global_name='global',
lr=LEARNING_RATE, gamma=GAMMA, t_max=T_MAX, sess=sess,
history=history, n_h1=N_H1, n_h2=N_H2, logdir=LOG_DIR))

sess.run(tf.global_variables_initializer())

for workeri in workers:
worker_work = lambda: workeri.work(NUM_EPISODES)
thread = threading.Thread(target=worker_work)
thread.start()



if __name__ == "__main__":
main()
103 changes: 51 additions & 52 deletions A3C/cartpole_a3c.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,71 @@
import tensorflow as tf
import gym
'''Example of A3C running on Cartpole environment'''
import argparse
import time
import threading
import tensorflow as tf
import gym
# import multiprocessing
import argparse

import ac_net
import worker

parser = argparse.ArgumentParser(description=None)
parser.add_argument('-d', '--device', default='cpu', type=str, help='choose device: cpu/gpu')
parser.add_argument('-e', '--episodes', default=1000, type=int, help='number of episodes')
parser.add_argument('-w', '--workers', default=4, type=int, help='number of workers')
parser.add_argument('-l', '--log_dir', default='cartpole_logs', type=str, help='log directory')
args = parser.parse_args()
print(args)
PARSER = argparse.ArgumentParser(description=None)
PARSER.add_argument('-d', '--device', default='cpu', type=str, help='choose device: cpu/gpu')
PARSER.add_argument('-e', '--episodes', default=1000, type=int, help='number of episodes')
PARSER.add_argument('-w', '--workers', default=4, type=int, help='number of workers')
PARSER.add_argument('-l', '--log_dir', default='cartpole_logs', type=str, help='log directory')
ARGS = PARSER.parse_args()
print ARGS



DEVICE = args.device
DEVICE = ARGS.device
STATE_SIZE = 4
ACTION_SIZE = 2
LEARNING_RATE = 0.0001
GAMMA = 0.99
T_MAX = 5
# NUM_WORKERS = multiprocessing.cpu_count()
NUM_WORKERS = args.workers
NUM_EPISODES = args.episodes
LOG_DIR = args.log_dir
NUM_WORKERS = ARGS.workers
NUM_EPISODES = ARGS.episodes
LOG_DIR = ARGS.log_dir


N_H1 = 300
N_H2 = 300

tf.reset_default_graph()

history = []

with tf.device('/{}:0'.format(DEVICE)):
sess = tf.Session()
global_model = ac_net.AC_Net(STATE_SIZE, ACTION_SIZE, LEARNING_RATE, 'global', n_h1=N_H1, n_h2=N_H2)
workers = []
for i in xrange(NUM_WORKERS):
env = gym.make('CartPole-v0')
env._max_episode_steps = 200
workers.append(worker.Worker(env,
state_size=STATE_SIZE, action_size=ACTION_SIZE,
worker_name='worker_{}'.format(i), global_name='global',
lr=LEARNING_RATE, gamma=GAMMA, t_max=T_MAX, sess=sess,
history=history, n_h1=N_H1, n_h2=N_H2, logdir=LOG_DIR))

sess.run(tf.global_variables_initializer())

for worker in workers:
worker_work = lambda: worker.work(NUM_EPISODES)
t = threading.Thread(target=worker_work)
t.start()















def main():
'''Example of A3C running on Cartpole environment'''
tf.reset_default_graph()

history = []

with tf.device('/{}:0'.format(DEVICE)):
sess = tf.Session()
global_model = ac_net.AC_Net(
STATE_SIZE,
ACTION_SIZE,
LEARNING_RATE,
'global',
n_h1=N_H1,
n_h2=N_H2)
workers = []
for i in xrange(NUM_WORKERS):
env = gym.make('CartPole-v0')
env._max_episode_steps = 200
workers.append(worker.Worker(env,
state_size=STATE_SIZE, action_size=ACTION_SIZE,
worker_name='worker_{}'.format(i), global_name='global',
lr=LEARNING_RATE, gamma=GAMMA, t_max=T_MAX, sess=sess,
history=history, n_h1=N_H1, n_h2=N_H2, logdir=LOG_DIR))

sess.run(tf.global_variables_initializer())

for workeri in workers:
worker_work = lambda: workeri.work(NUM_EPISODES)
thread = threading.Thread(target=worker_work)
thread.start()


if __name__ == "__main__":
main()
Loading

0 comments on commit b4db54c

Please sign in to comment.