Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Support batch norm layers #3369

Merged
merged 10 commits into from
Nov 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions doc/source/rllib-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ The following is a list of the built-in model hyperparameters:
Custom Models
-------------

Custom models should subclass the common RLlib `model class <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/model.py>`__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. A self-supervised loss can be defined via the ``loss`` method. The model can then be registered and used in place of a built-in model:
Custom models should subclass the common RLlib `model class <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/model.py>`__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``, ``is_training``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. A self-supervised loss can be defined via the ``loss`` method. The model can then be registered and used in place of a built-in model:

.. code-block:: python

Expand All @@ -44,7 +44,7 @@ Custom models should subclass the common RLlib `model class <https://github.com/

Arguments:
input_dict (dict): Dictionary of input tensors, including "obs",
"prev_action", "prev_reward".
"prev_action", "prev_reward", "is_training".
num_outputs (int): Output tensor must be of size
[BATCH_SIZE, num_outputs].
options (dict): Model options.
Expand All @@ -60,6 +60,7 @@ Custom models should subclass the common RLlib `model class <https://github.com/
>>> print(input_dict)
{'prev_actions': <tf.Tensor shape=(?,) dtype=int64>,
'prev_rewards': <tf.Tensor shape=(?,) dtype=float32>,
'is_training': <tf.Tensor shape=(), dtype=bool>,
'obs': OrderedDict([
('sensors', OrderedDict([
('front_cam', [
Expand Down Expand Up @@ -115,7 +116,6 @@ Custom Recurrent Models

Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. The only difference from a normal custom model is that you have to define ``self.state_init``, ``self.state_in``, and ``self.state_out``. You can refer to the existing `lstm.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/lstm.py>`__ model as an example to implement your own model:


.. code-block:: python

class MyCustomLSTM(Model):
Expand Down Expand Up @@ -147,6 +147,11 @@ Instead of using the ``use_lstm: True`` option, it can be preferable use a custo
normc_initializer(0.01))
return logits, last_layer

Batch Normalization
~~~~~~~~~~~~~~~~~~~

You can use ``tf.layers.batch_normalization(x, training=input_dict["is_training"])`` to add batch norm layers to your custom model: `code example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/batch_norm_model.py>`__. RLlib will automatically run the update ops for the batch norm layers during optimization (see `tf_policy_graph.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/evaluation/tf_policy_graph.py>`__ and `multi_gpu_impl.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/optimizers/multi_gpu_impl.py>`__ for the exact handling of these updates).

Custom Preprocessors
--------------------

Expand Down Expand Up @@ -283,7 +288,8 @@ With a custom policy graph, you can also perform model-based rollouts and option
def compute_actions(self,
obs_batch,
state_batches,
is_training=False,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None):
# compute a batch of actions based on the current obs_batch
# and state of each episode (i.e., for multiagent). You can do
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def __init__(self, observation_space, action_space, config):
self.model = ModelCatalog.get_model({
"obs": self.observations,
"prev_actions": prev_actions,
"prev_rewards": prev_rewards
"prev_rewards": prev_rewards,
"is_training": self._get_is_training_placeholder(),
}, observation_space, logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs)
self.vf = self.model.value_function()
Expand Down
8 changes: 3 additions & 5 deletions python/ray/rllib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,13 +385,11 @@ def compute_action(self, observation, state=None, policy_id="default"):
observation, update=False)
if state:
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(
filtered_obs, state, is_training=False),
lambda p: p.compute_single_action(filtered_obs, state),
policy_id=policy_id)
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(
filtered_obs, state, is_training=False)[0],
policy_id=policy_id)
lambda p: p.compute_single_action(filtered_obs, state)[0],
policy_id=policy_id)

def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.
Expand Down
20 changes: 16 additions & 4 deletions python/ray/rllib/agents/ddpg/ddpg_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def __init__(self, observation_space, action_space, config):
self.stochastic = tf.placeholder(tf.bool, (), name="stochastic")
self.eps = tf.placeholder(tf.float32, (), name="eps")
self.cur_observations = tf.placeholder(
tf.float32, shape=(None, ) + observation_space.shape)
tf.float32,
shape=(None, ) + observation_space.shape,
name="cur_obs")

# Actor: P (policy) network
with tf.variable_scope(P_SCOPE) as scope:
Expand Down Expand Up @@ -236,7 +238,11 @@ def __init__(self, observation_space, action_space, config):

# p network evaluation
with tf.variable_scope(P_SCOPE, reuse=True) as scope:
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
self.p_t = self._build_p_network(self.obs_t, observation_space)
p_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
prev_update_ops)

# target p network evaluation
with tf.variable_scope(P_TARGET_SCOPE) as scope:
Expand All @@ -257,6 +263,7 @@ def __init__(self, observation_space, action_space, config):
is_target=True)

# q network evaluation
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
with tf.variable_scope(Q_SCOPE) as scope:
q_t, model = self._build_q_network(self.obs_t, observation_space,
self.act_t)
Expand All @@ -269,6 +276,8 @@ def __init__(self, observation_space, action_space, config):
twin_q_t, twin_model = self._build_q_network(
self.obs_t, observation_space, self.act_t)
self.twin_q_func_vars = _scope_vars(scope.name)
q_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)

# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
Expand Down Expand Up @@ -345,7 +354,8 @@ def __init__(self, observation_space, action_space, config):
obs_input=self.cur_observations,
action_sampler=self.output_actions,
loss=model.loss() + self.loss.total_loss,
loss_inputs=self.loss_inputs)
loss_inputs=self.loss_inputs,
update_ops=q_batchnorm_update_ops + p_batchnorm_update_ops)
self.sess.run(tf.global_variables_initializer())

# Note that this encompasses both the policy and Q-value networks and
Expand All @@ -359,7 +369,8 @@ def __init__(self, observation_space, action_space, config):
def _build_q_network(self, obs, obs_space, actions):
q_net = QNetwork(
ModelCatalog.get_model({
"obs": obs
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, obs_space, 1, self.config["model"]), actions,
self.config["critic_hiddens"],
self.config["critic_hidden_activation"])
Expand All @@ -368,7 +379,8 @@ def _build_q_network(self, obs, obs_space, actions):
def _build_p_network(self, obs, obs_space):
return PNetwork(
ModelCatalog.get_model({
"obs": obs
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, obs_space, 1, self.config["model"]), self.dim_actions,
self.config["actor_hiddens"],
self.config["actor_hidden_activation"]).action_scores
Expand Down
10 changes: 8 additions & 2 deletions python/ray/rllib/agents/dqn/dqn_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,12 @@ def __init__(self, observation_space, action_space, config):

# q network evaluation
with tf.variable_scope(Q_SCOPE, reuse=True):
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
q_t, q_logits_t, q_dist_t, model = self._build_q_network(
self.obs_t, observation_space)
q_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
prev_update_ops)

# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
Expand Down Expand Up @@ -372,13 +376,15 @@ def __init__(self, observation_space, action_space, config):
obs_input=self.cur_observations,
action_sampler=self.output_actions,
loss=model.loss() + self.loss.loss,
loss_inputs=self.loss_inputs)
loss_inputs=self.loss_inputs,
update_ops=q_batchnorm_update_ops)
self.sess.run(tf.global_variables_initializer())

def _build_q_network(self, obs, space):
qnet = QNetwork(
ModelCatalog.get_model({
"obs": obs
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, space, self.num_actions, self.config["model"]),
self.num_actions, self.config["dueling"], self.config["hiddens"],
self.config["noisy"], self.config["num_atoms"],
Expand Down
1 change: 1 addition & 0 deletions python/ray/rllib/agents/impala/vtrace_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(self,
"obs": observations,
"prev_actions": prev_actions,
"prev_rewards": prev_rewards,
"is_training": self._get_is_training_placeholder(),
},
observation_space,
logit_dim,
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/agents/pg/pg_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(self, obs_space, action_space, config):
self.model = ModelCatalog.get_model({
"obs": obs,
"prev_actions": prev_actions,
"prev_rewards": prev_rewards
"prev_rewards": prev_rewards,
"is_training": self._get_is_training_placeholder(),
}, obs_space, self.logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs) # logit for each action

Expand Down
5 changes: 3 additions & 2 deletions python/ray/rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"sample_batch_size": 200,
# Number of timesteps collected for each SGD round
"train_batch_size": 4000,
# Total SGD batch size across all devices for SGD (multi-gpu only)
# Total SGD batch size across all devices for SGD
"sgd_minibatch_size": 128,
# Number of SGD iterations in each outer loop
"num_sgd_iter": 30,
Expand All @@ -49,7 +49,8 @@
"batch_mode": "truncate_episodes",
# Which observation filter to apply to the observation
"observation_filter": "MeanStdFilter",
# Use the sync samples optimizer instead of the multi-gpu one
# Uses the sync samples optimizer instead of the multi-gpu one. This does
# not support minibatches.
"simple_optimizer": False,
})
# __sphinx_doc_end__
Expand Down
6 changes: 4 additions & 2 deletions python/ray/rllib/agents/ppo/ppo_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def __init__(self,
{
"obs": obs_ph,
"prev_actions": prev_actions_ph,
"prev_rewards": prev_rewards_ph
"prev_rewards": prev_rewards_ph,
"is_training": self._get_is_training_placeholder(),
},
observation_space,
logit_dim,
Expand Down Expand Up @@ -191,7 +192,8 @@ def __init__(self,
self.value_function = ModelCatalog.get_model({
"obs": obs_ph,
"prev_actions": prev_actions_ph,
"prev_rewards": prev_rewards_ph
"prev_rewards": prev_rewards_ph,
"is_training": self._get_is_training_placeholder(),
}, observation_space, 1, vf_config).outputs
self.value_function = tf.reshape(self.value_function, [-1])
else:
Expand Down
6 changes: 1 addition & 5 deletions python/ray/rllib/evaluation/policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def compute_actions(self,
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
is_training=False,
episodes=None):
"""Compute actions for the current policy.

Expand All @@ -51,7 +50,6 @@ def compute_actions(self,
state_batches (list): list of RNN state input batches, if any
prev_action_batch (np.ndarray): batch of previous action values
prev_reward_batch (np.ndarray): batch of previous rewards
is_training (bool): whether we are training the policy
episodes (list): MultiAgentEpisode for each obs in obs_batch.
This provides access to all of the internal episode state,
which may be useful for model-based or multiagent algorithms.
Expand All @@ -71,7 +69,6 @@ def compute_single_action(self,
state,
prev_action_batch=None,
prev_reward_batch=None,
is_training=False,
episode=None):
"""Unbatched version of compute_actions.

Expand All @@ -80,7 +77,6 @@ def compute_single_action(self,
state_batches (list): list of RNN state inputs, if any
prev_action_batch (np.ndarray): batch of previous action values
prev_reward_batch (np.ndarray): batch of previous rewards
is_training (bool): whether we are training the policy
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
Expand All @@ -92,7 +88,7 @@ def compute_single_action(self,
"""

[action], state_out, info = self.compute_actions(
[obs], [[s] for s in state], is_training, episodes=[episode])
[obs], [[s] for s in state], episodes=[episode])
return action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}

Expand Down
4 changes: 1 addition & 3 deletions python/ray/rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,15 +436,13 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
builder, [t.obs for t in eval_data],
rnn_in_cols,
prev_action_batch=[t.prev_action for t in eval_data],
prev_reward_batch=[t.prev_reward for t in eval_data],
is_training=True)
prev_reward_batch=[t.prev_reward for t in eval_data])
else:
eval_results[policy_id] = policy.compute_actions(
[t.obs for t in eval_data],
rnn_in_cols,
prev_action_batch=[t.prev_action for t in eval_data],
prev_reward_batch=[t.prev_reward for t in eval_data],
is_training=True,
episodes=[active_episodes[t.env_id] for t in eval_data])
if builder:
for k, v in pending_fetches.items():
Expand Down
Loading