Skip to content

Commit

Permalink
Add a SAC-specific tanh-squashed normal projection network, change cr…
Browse files Browse the repository at this point in the history
…itic network initialization to glorot_uniform, change SAC default target entropy, and update other hyper-parameters for reproducing the published results.

PiperOrigin-RevId: 293246271
Change-Id: Id38b12f1f5fca15f16b7f034cf6c80b3bde58dfc
  • Loading branch information
kuanghuei authored and copybara-github committed Feb 4, 2020
1 parent fa43695 commit 9057dd6
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 98 deletions.
25 changes: 17 additions & 8 deletions tf_agents/agents/ddpg/critic_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(self,
joint_dropout_layer_params=None,
activation_fn=tf.nn.relu,
output_activation_fn=None,
kernel_initializer=None,
last_kernel_initializer=None,
name='CriticNetwork'):
"""Creates an instance of `CriticNetwork`.
Expand Down Expand Up @@ -85,6 +87,10 @@ def __init__(self,
used to restrict the range of the output. For example, one can pass
tf.keras.activations.sigmoid here to restrict the output to be bounded
between 0 and 1.
kernel_initializer: kernel initializer for all layers except for the value
regression layer. If None, a VarianceScaling initializer will be used.
last_kernel_initializer: kernel initializer for the value regression
layer. If None, a RandomUniform initializer will be used.
name: A string representing name of the network.
Raises:
Expand All @@ -106,40 +112,43 @@ def __init__(self,
raise ValueError('Only a single action is supported by this network')
self._single_action_spec = flat_action_spec[0]

if kernel_initializer is None:
kernel_initializer = tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform')
if last_kernel_initializer is None:
last_kernel_initializer = tf.keras.initializers.RandomUniform(
minval=-0.003, maxval=0.003)

# TODO(kbanoop): Replace mlp_layers with encoding networks.
self._observation_layers = utils.mlp_layers(
observation_conv_layer_params,
observation_fc_layer_params,
observation_dropout_layer_params,
activation_fn=activation_fn,
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform'),
kernel_initializer=kernel_initializer,
name='observation_encoding')

self._action_layers = utils.mlp_layers(
None,
action_fc_layer_params,
action_dropout_layer_params,
activation_fn=activation_fn,
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform'),
kernel_initializer=kernel_initializer,
name='action_encoding')

self._joint_layers = utils.mlp_layers(
None,
joint_fc_layer_params,
joint_dropout_layer_params,
activation_fn=activation_fn,
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform'),
kernel_initializer=kernel_initializer,
name='joint_mlp')

self._joint_layers.append(
tf.keras.layers.Dense(
1,
activation=output_activation_fn,
kernel_initializer=tf.keras.initializers.RandomUniform(
minval=-0.003, maxval=0.003),
kernel_initializer=last_kernel_initializer,
name='value'))

def call(self, inputs, step_type=(), network_state=(), training=False):
Expand Down
25 changes: 17 additions & 8 deletions tf_agents/agents/ddpg/critic_rnn_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self,
lstm_size=(40,),
output_fc_layer_params=(200, 100),
activation_fn=tf.keras.activations.relu,
kernel_initializer=None,
last_kernel_initializer=None,
name='CriticRnnNetwork'):
"""Creates an instance of `CriticRnnNetwork`.
Expand All @@ -61,6 +63,10 @@ def __init__(self,
each item is the number of units in the layer. This is applied after the
LSTM cell.
activation_fn: Activation function, e.g. tf.nn.relu, slim.leaky_relu, ...
kernel_initializer: kernel initializer for all layers except for the value
regression layer. If None, a VarianceScaling initializer will be used.
last_kernel_initializer: kernel initializer for the value regression layer
. If None, a RandomUniform initializer will be used.
name: A string representing name of the network.
Returns:
Expand All @@ -79,28 +85,32 @@ def __init__(self,
if len(tf.nest.flatten(action_spec)) > 1:
raise ValueError('Only a single action is supported by this network.')

if kernel_initializer is None:
kernel_initializer = tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform')
if last_kernel_initializer is None:
last_kernel_initializer = tf.keras.initializers.RandomUniform(
minval=-0.003, maxval=0.003)

observation_layers = utils.mlp_layers(
observation_conv_layer_params,
observation_fc_layer_params,
activation_fn=activation_fn,
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform'),
kernel_initializer=kernel_initializer,
name='observation_encoding')

action_layers = utils.mlp_layers(
None,
action_fc_layer_params,
activation_fn=activation_fn,
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform'),
kernel_initializer=kernel_initializer,
name='action_encoding')

joint_layers = utils.mlp_layers(
None,
joint_fc_layer_params,
activation_fn=activation_fn,
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform'),
kernel_initializer=kernel_initializer,
name='joint_mlp')

# Create RNN cell
Expand All @@ -126,8 +136,7 @@ def create_spec(size):
tf.keras.layers.Dense(
1,
activation=None,
kernel_initializer=tf.keras.initializers.RandomUniform(
minval=-0.003, maxval=0.003),
kernel_initializer=last_kernel_initializer,
name='value'))

super(CriticRnnNetwork, self).__init__(
Expand Down
40 changes: 18 additions & 22 deletions tf_agents/agents/sac/examples/v1/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.sac import sac_agent
from tf_agents.agents.sac import tanh_normal_projection_network
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_mujoco
from tf_agents.environments import tf_py_environment
Expand All @@ -52,7 +53,6 @@
from tf_agents.metrics import tf_metrics
from tf_agents.metrics import tf_py_metric
from tf_agents.networks import actor_distribution_network
from tf_agents.networks import normal_projection_network
from tf_agents.policies import greedy_policy
from tf_agents.policies import py_tf_policy
from tf_agents.policies import random_tf_policy
Expand All @@ -69,32 +69,26 @@
FLAGS = flags.FLAGS


@gin.configurable
def normal_projection_net(action_spec,
init_action_stddev=0.35,
init_means_output_factor=0.1):
del init_action_stddev
return normal_projection_network.NormalProjectionNetwork(
action_spec,
mean_transform=None,
state_dependent_std=True,
init_means_output_factor=init_means_output_factor,
std_transform=sac_agent.std_clip_transform,
scale_distribution=True)


@gin.configurable
def train_eval(
root_dir,
env_name='HalfCheetah-v2',
eval_env_name=None,
env_load_fn=suite_mujoco.load,
num_iterations=1000000,
# The SAC paper reported:
# Hopper and Cartpole results up to 1000000 iters,
# Humanoid results up to 10000000 iters,
# Other mujoco tasks up to 3000000 iters.
num_iterations=3000000,
actor_fc_layers=(256, 256),
critic_obs_fc_layers=None,
critic_action_fc_layers=None,
critic_joint_fc_layers=(256, 256),
# Params for collect
# Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
# HalfCheetah and Ant take 10000 initial collection steps.
# Other mujoco tasks take 1000.
# Different choices roughly keep the initial episodes about the same.
initial_collect_steps=10000,
collect_steps_per_iteration=1,
replay_buffer_capacity=1000000,
Expand All @@ -109,7 +103,7 @@ def train_eval(
alpha_learning_rate=3e-4,
td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
gamma=0.99,
reward_scale_factor=1.0,
reward_scale_factor=0.1,
gradient_clipping=None,
# Params for eval
num_eval_episodes=30,
Expand Down Expand Up @@ -159,12 +153,15 @@ def train_eval(
observation_spec,
action_spec,
fc_layer_params=actor_fc_layers,
continuous_projection_net=normal_projection_net)
continuous_projection_net=tanh_normal_projection_network
.TanhNormalProjectionNetwork)
critic_net = critic_network.CriticNetwork(
(observation_spec, action_spec),
observation_fc_layer_params=critic_obs_fc_layers,
action_fc_layer_params=critic_action_fc_layers,
joint_fc_layer_params=critic_joint_fc_layers)
joint_fc_layer_params=critic_joint_fc_layers,
kernel_initializer='glorot_uniform',
last_kernel_initializer='glorot_uniform')

tf_agent = sac_agent.SacAgent(
time_step_spec,
Expand Down Expand Up @@ -224,10 +221,9 @@ def train_eval(
def _filter_invalid_transition(trajectories, unused_arg1):
return ~trajectories.is_boundary()[0]
dataset = replay_buffer.as_dataset(
sample_batch_size=5 * batch_size,
sample_batch_size=batch_size,
num_steps=2).unbatch().filter(
_filter_invalid_transition).batch(batch_size).prefetch(
batch_size * 5)
_filter_invalid_transition).batch(batch_size).prefetch(5)
dataset_iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
trajectories, unused_info = dataset_iterator.get_next()
train_op = tf_agent.train(trajectories)
Expand Down
44 changes: 22 additions & 22 deletions tf_agents/agents/sac/examples/v2/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@

from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.sac import sac_agent
from tf_agents.agents.sac import tanh_normal_projection_network
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_mujoco
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.networks import normal_projection_network
from tf_agents.policies import greedy_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
Expand All @@ -68,32 +68,26 @@
FLAGS = flags.FLAGS


@gin.configurable
def normal_projection_net(action_spec,
init_action_stddev=0.35,
init_means_output_factor=0.1):
del init_action_stddev
return normal_projection_network.NormalProjectionNetwork(
action_spec,
mean_transform=None,
state_dependent_std=True,
init_means_output_factor=init_means_output_factor,
std_transform=sac_agent.std_clip_transform,
scale_distribution=True)


@gin.configurable
def train_eval(
root_dir,
env_name='HalfCheetah-v2',
eval_env_name=None,
env_load_fn=suite_mujoco.load,
num_iterations=1000000,
# The SAC paper reported:
# Hopper and Cartpole results up to 1000000 iters,
# Humanoid results up to 10000000 iters,
# Other mujoco tasks up to 3000000 iters.
num_iterations=3000000,
actor_fc_layers=(256, 256),
critic_obs_fc_layers=None,
critic_action_fc_layers=None,
critic_joint_fc_layers=(256, 256),
# Params for collect
# Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
# HalfCheetah and Ant take 10000 initial collection steps.
# Other mujoco tasks take 1000.
# Different choices roughly keep the initial episodes about the same.
initial_collect_steps=10000,
collect_steps_per_iteration=1,
replay_buffer_capacity=1000000,
Expand All @@ -108,7 +102,7 @@ def train_eval(
alpha_learning_rate=3e-4,
td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
gamma=0.99,
reward_scale_factor=1.0,
reward_scale_factor=0.1,
gradient_clipping=None,
use_tf_functions=True,
# Params for eval
Expand Down Expand Up @@ -155,12 +149,15 @@ def train_eval(
observation_spec,
action_spec,
fc_layer_params=actor_fc_layers,
continuous_projection_net=normal_projection_net)
continuous_projection_net=tanh_normal_projection_network
.TanhNormalProjectionNetwork)
critic_net = critic_network.CriticNetwork(
(observation_spec, action_spec),
observation_fc_layer_params=critic_obs_fc_layers,
action_fc_layer_params=critic_action_fc_layers,
joint_fc_layer_params=critic_joint_fc_layers)
joint_fc_layer_params=critic_joint_fc_layers,
kernel_initializer='glorot_uniform',
last_kernel_initializer='glorot_uniform')

tf_agent = sac_agent.SacAgent(
time_step_spec,
Expand Down Expand Up @@ -265,11 +262,14 @@ def train_eval(
timed_at_step = global_step.numpy()
time_acc = 0

# Dataset generates trajectories with shape [Bx2x...]
# Prepare replay buffer as dataset with invalid transitions filtered.
def _filter_invalid_transition(trajectories, unused_arg1):
return ~trajectories.is_boundary()[0]
dataset = replay_buffer.as_dataset(
num_parallel_calls=3,
sample_batch_size=batch_size,
num_steps=2).prefetch(3)
num_steps=2).unbatch().filter(
_filter_invalid_transition).batch(batch_size).prefetch(5)
# Dataset generates trajectories with shape [Bx2x...]
iterator = iter(dataset)

def train_step():
Expand Down
Loading

0 comments on commit 9057dd6

Please sign in to comment.