diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 883b0a6bb5266..09c49e2751bf0 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -30,7 +30,7 @@ The following is a list of the built-in model hyperparameters: Custom Models ------------- -Custom models should subclass the common RLlib `model class `__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. A self-supervised loss can be defined via the ``loss`` method. The model can then be registered and used in place of a built-in model: +Custom models should subclass the common RLlib `model class `__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``, ``is_training``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. A self-supervised loss can be defined via the ``loss`` method. The model can then be registered and used in place of a built-in model: .. code-block:: python @@ -44,7 +44,7 @@ Custom models should subclass the common RLlib `model class >> print(input_dict) {'prev_actions': , 'prev_rewards': , + 'is_training': , 'obs': OrderedDict([ ('sensors', OrderedDict([ ('front_cam', [ @@ -115,7 +116,6 @@ Custom Recurrent Models Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. The only difference from a normal custom model is that you have to define ``self.state_init``, ``self.state_in``, and ``self.state_out``. You can refer to the existing `lstm.py `__ model as an example to implement your own model: - .. code-block:: python class MyCustomLSTM(Model): @@ -147,6 +147,11 @@ Instead of using the ``use_lstm: True`` option, it can be preferable use a custo normc_initializer(0.01)) return logits, last_layer +Batch Normalization +~~~~~~~~~~~~~~~~~~~ + +You can use ``tf.layers.batch_normalization(x, training=input_dict["is_training"])`` to add batch norm layers to your custom model: `code example `__. RLlib will automatically run the update ops for the batch norm layers during optimization (see `tf_policy_graph.py `__ and `multi_gpu_impl.py `__ for the exact handling of these updates). + Custom Preprocessors -------------------- @@ -283,7 +288,8 @@ With a custom policy graph, you can also perform model-based rollouts and option def compute_actions(self, obs_batch, state_batches, - is_training=False, + prev_action_batch=None, + prev_reward_batch=None, episodes=None): # compute a batch of actions based on the current obs_batch # and state of each episode (i.e., for multiagent). You can do diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py index 6f079713abaea..8aa60645aaebd 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py @@ -53,7 +53,8 @@ def __init__(self, observation_space, action_space, config): self.model = ModelCatalog.get_model({ "obs": self.observations, "prev_actions": prev_actions, - "prev_rewards": prev_rewards + "prev_rewards": prev_rewards, + "is_training": self._get_is_training_placeholder(), }, observation_space, logit_dim, self.config["model"]) action_dist = dist_class(self.model.outputs) self.vf = self.model.value_function() diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 18adda82d178b..f0d9510756b93 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -385,13 +385,11 @@ def compute_action(self, observation, state=None, policy_id="default"): observation, update=False) if state: return self.local_evaluator.for_policy( - lambda p: p.compute_single_action( - filtered_obs, state, is_training=False), + lambda p: p.compute_single_action(filtered_obs, state), policy_id=policy_id) return self.local_evaluator.for_policy( - lambda p: p.compute_single_action( - filtered_obs, state, is_training=False)[0], - policy_id=policy_id) + lambda p: p.compute_single_action(filtered_obs, state)[0], + policy_id=policy_id) def get_weights(self, policies=None): """Return a dictionary of policy ids to weights. diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py index 738c4e9ac130e..eb5f14c2d1c99 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py @@ -199,7 +199,9 @@ def __init__(self, observation_space, action_space, config): self.stochastic = tf.placeholder(tf.bool, (), name="stochastic") self.eps = tf.placeholder(tf.float32, (), name="eps") self.cur_observations = tf.placeholder( - tf.float32, shape=(None, ) + observation_space.shape) + tf.float32, + shape=(None, ) + observation_space.shape, + name="cur_obs") # Actor: P (policy) network with tf.variable_scope(P_SCOPE) as scope: @@ -236,7 +238,11 @@ def __init__(self, observation_space, action_space, config): # p network evaluation with tf.variable_scope(P_SCOPE, reuse=True) as scope: + prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) self.p_t = self._build_p_network(self.obs_t, observation_space) + p_batchnorm_update_ops = list( + set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - + prev_update_ops) # target p network evaluation with tf.variable_scope(P_TARGET_SCOPE) as scope: @@ -257,6 +263,7 @@ def __init__(self, observation_space, action_space, config): is_target=True) # q network evaluation + prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) with tf.variable_scope(Q_SCOPE) as scope: q_t, model = self._build_q_network(self.obs_t, observation_space, self.act_t) @@ -269,6 +276,8 @@ def __init__(self, observation_space, action_space, config): twin_q_t, twin_model = self._build_q_network( self.obs_t, observation_space, self.act_t) self.twin_q_func_vars = _scope_vars(scope.name) + q_batchnorm_update_ops = list( + set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) # target q network evalution with tf.variable_scope(Q_TARGET_SCOPE) as scope: @@ -345,7 +354,8 @@ def __init__(self, observation_space, action_space, config): obs_input=self.cur_observations, action_sampler=self.output_actions, loss=model.loss() + self.loss.total_loss, - loss_inputs=self.loss_inputs) + loss_inputs=self.loss_inputs, + update_ops=q_batchnorm_update_ops + p_batchnorm_update_ops) self.sess.run(tf.global_variables_initializer()) # Note that this encompasses both the policy and Q-value networks and @@ -359,7 +369,8 @@ def __init__(self, observation_space, action_space, config): def _build_q_network(self, obs, obs_space, actions): q_net = QNetwork( ModelCatalog.get_model({ - "obs": obs + "obs": obs, + "is_training": self._get_is_training_placeholder(), }, obs_space, 1, self.config["model"]), actions, self.config["critic_hiddens"], self.config["critic_hidden_activation"]) @@ -368,7 +379,8 @@ def _build_q_network(self, obs, obs_space, actions): def _build_p_network(self, obs, obs_space): return PNetwork( ModelCatalog.get_model({ - "obs": obs + "obs": obs, + "is_training": self._get_is_training_placeholder(), }, obs_space, 1, self.config["model"]), self.dim_actions, self.config["actor_hiddens"], self.config["actor_hidden_activation"]).action_scores diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index 2bbff99246f3c..c883ef25067dc 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -306,8 +306,12 @@ def __init__(self, observation_space, action_space, config): # q network evaluation with tf.variable_scope(Q_SCOPE, reuse=True): + prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) q_t, q_logits_t, q_dist_t, model = self._build_q_network( self.obs_t, observation_space) + q_batchnorm_update_ops = list( + set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - + prev_update_ops) # target q network evalution with tf.variable_scope(Q_TARGET_SCOPE) as scope: @@ -372,13 +376,15 @@ def __init__(self, observation_space, action_space, config): obs_input=self.cur_observations, action_sampler=self.output_actions, loss=model.loss() + self.loss.loss, - loss_inputs=self.loss_inputs) + loss_inputs=self.loss_inputs, + update_ops=q_batchnorm_update_ops) self.sess.run(tf.global_variables_initializer()) def _build_q_network(self, obs, space): qnet = QNetwork( ModelCatalog.get_model({ - "obs": obs + "obs": obs, + "is_training": self._get_is_training_placeholder(), }, space, self.num_actions, self.config["model"]), self.num_actions, self.config["dueling"], self.config["hiddens"], self.config["noisy"], self.config["num_atoms"], diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index 3d9e4214b7c7a..cfa2f1373aae2 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -133,6 +133,7 @@ def __init__(self, "obs": observations, "prev_actions": prev_actions, "prev_rewards": prev_rewards, + "is_training": self._get_is_training_placeholder(), }, observation_space, logit_dim, diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy_graph.py index 8cbb3a588b491..2a342c117fb3f 100644 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/pg_policy_graph.py @@ -35,7 +35,8 @@ def __init__(self, obs_space, action_space, config): self.model = ModelCatalog.get_model({ "obs": obs, "prev_actions": prev_actions, - "prev_rewards": prev_rewards + "prev_rewards": prev_rewards, + "is_training": self._get_is_training_placeholder(), }, obs_space, self.logit_dim, self.config["model"]) action_dist = dist_class(self.model.outputs) # logit for each action diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index 722f9263d8167..d5e50832f4515 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -24,7 +24,7 @@ "sample_batch_size": 200, # Number of timesteps collected for each SGD round "train_batch_size": 4000, - # Total SGD batch size across all devices for SGD (multi-gpu only) + # Total SGD batch size across all devices for SGD "sgd_minibatch_size": 128, # Number of SGD iterations in each outer loop "num_sgd_iter": 30, @@ -49,7 +49,8 @@ "batch_mode": "truncate_episodes", # Which observation filter to apply to the observation "observation_filter": "MeanStdFilter", - # Use the sync samples optimizer instead of the multi-gpu one + # Uses the sync samples optimizer instead of the multi-gpu one. This does + # not support minibatches. "simple_optimizer": False, }) # __sphinx_doc_end__ diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index f43a336253beb..3762f16f9084e 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -158,7 +158,8 @@ def __init__(self, { "obs": obs_ph, "prev_actions": prev_actions_ph, - "prev_rewards": prev_rewards_ph + "prev_rewards": prev_rewards_ph, + "is_training": self._get_is_training_placeholder(), }, observation_space, logit_dim, @@ -191,7 +192,8 @@ def __init__(self, self.value_function = ModelCatalog.get_model({ "obs": obs_ph, "prev_actions": prev_actions_ph, - "prev_rewards": prev_rewards_ph + "prev_rewards": prev_rewards_ph, + "is_training": self._get_is_training_placeholder(), }, observation_space, 1, vf_config).outputs self.value_function = tf.reshape(self.value_function, [-1]) else: diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index 9de59d269a03c..c19da286b0b9a 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -42,7 +42,6 @@ def compute_actions(self, state_batches, prev_action_batch=None, prev_reward_batch=None, - is_training=False, episodes=None): """Compute actions for the current policy. @@ -51,7 +50,6 @@ def compute_actions(self, state_batches (list): list of RNN state input batches, if any prev_action_batch (np.ndarray): batch of previous action values prev_reward_batch (np.ndarray): batch of previous rewards - is_training (bool): whether we are training the policy episodes (list): MultiAgentEpisode for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multiagent algorithms. @@ -71,7 +69,6 @@ def compute_single_action(self, state, prev_action_batch=None, prev_reward_batch=None, - is_training=False, episode=None): """Unbatched version of compute_actions. @@ -80,7 +77,6 @@ def compute_single_action(self, state_batches (list): list of RNN state inputs, if any prev_action_batch (np.ndarray): batch of previous action values prev_reward_batch (np.ndarray): batch of previous rewards - is_training (bool): whether we are training the policy episode (MultiAgentEpisode): this provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms. @@ -92,7 +88,7 @@ def compute_single_action(self, """ [action], state_out, info = self.compute_actions( - [obs], [[s] for s in state], is_training, episodes=[episode]) + [obs], [[s] for s in state], episodes=[episode]) return action, [s[0] for s in state_out], \ {k: v[0] for k, v in info.items()} diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 2fd2fc4e272ae..2c6411f33510f 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -436,15 +436,13 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): builder, [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], - prev_reward_batch=[t.prev_reward for t in eval_data], - is_training=True) + prev_reward_batch=[t.prev_reward for t in eval_data]) else: eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], - is_training=True, episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: for k, v in pending_fetches.items(): diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index 40e540013fef4..95e7a5d66bcbf 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -30,7 +30,7 @@ class TFPolicyGraph(PolicyGraph): Examples: >>> policy = TFPolicyGraphSubclass( - sess, obs_input, action_sampler, loss, loss_inputs, is_training) + sess, obs_input, action_sampler, loss, loss_inputs) >>> print(policy.compute_actions([1, 0, 2])) (array([0, 1, 1]), [], {}) @@ -53,7 +53,8 @@ def __init__(self, prev_reward_input=None, seq_lens=None, max_seq_len=20, - batch_divisibility_req=1): + batch_divisibility_req=1, + update_ops=None): """Initialize the policy graph. Arguments: @@ -82,6 +83,9 @@ def __init__(self, batch_divisibility_req (int): pad all agent experiences batches to multiples of this value. This only has an effect if not using a LSTM model. + update_ops (list): override the batchnorm update ops to run when + applying gradients. Otherwise we run all update ops found in + the current variable scope. """ self.observation_space = observation_space @@ -94,7 +98,7 @@ def __init__(self, self._loss = loss self._loss_inputs = loss_inputs self._loss_input_dict = dict(self._loss_inputs) - self._is_training = tf.placeholder_with_default(True, ()) + self._is_training = self._get_is_training_placeholder() self._state_inputs = state_inputs or [] self._state_outputs = state_outputs or [] for i, ph in enumerate(self._state_inputs): @@ -108,14 +112,24 @@ def __init__(self, for (g, v) in self.gradients(self._optimizer) if g is not None] self._grads = [g for (g, v) in self._grads_and_vars] - # specify global_step for TD3 which needs to count the num updates - self._apply_op = self._optimizer.apply_gradients( - self._grads_and_vars, - global_step=tf.train.get_or_create_global_step()) - self._variables = ray.experimental.TensorFlowVariables( self._loss, self._sess) + # gather update ops for any batch norm layers + if update_ops: + self._update_ops = update_ops + else: + self._update_ops = tf.get_collection( + tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) + if self._update_ops: + logger.debug("Update ops to run on apply gradient: {}".format( + self._update_ops)) + with tf.control_dependencies(self._update_ops): + # specify global_step for TD3 which needs to count the num updates + self._apply_op = self._optimizer.apply_gradients( + self._grads_and_vars, + global_step=tf.train.get_or_create_global_step()) + if len(self._state_inputs) != len(self._state_outputs): raise ValueError( "Number of state input and output tensors must match, got: " @@ -138,7 +152,6 @@ def build_compute_actions(self, state_batches=None, prev_action_batch=None, prev_reward_batch=None, - is_training=False, episodes=None): state_batches = state_batches or [] assert len(self._state_inputs) == len(state_batches), \ @@ -151,7 +164,7 @@ def build_compute_actions(self, builder.add_feed_dict({self._prev_action_input: prev_action_batch}) if self._prev_reward_input is not None and prev_reward_batch: builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) - builder.add_feed_dict({self._is_training: is_training}) + builder.add_feed_dict({self._is_training: False}) builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) fetches = builder.add_fetches([self._sampler] + self._state_outputs + [self.extra_compute_action_fetches()]) @@ -162,12 +175,11 @@ def compute_actions(self, state_batches=None, prev_action_batch=None, prev_reward_batch=None, - is_training=False, episodes=None): builder = TFRunBuilder(self._sess, "compute_actions") fetches = self.build_compute_actions(builder, obs_batch, state_batches, prev_action_batch, - prev_reward_batch, is_training) + prev_reward_batch) return builder.get(fetches) def _get_loss_inputs_dict(self, batch): @@ -287,6 +299,15 @@ def gradients(self, optimizer): def loss_inputs(self): return self._loss_inputs + def _get_is_training_placeholder(self): + """Get the placeholder for _is_training, i.e., for batch norm layers. + + This can be called safely before __init__ has run. + """ + if not hasattr(self, "_is_training"): + self._is_training = tf.placeholder_with_default(False, ()) + return self._is_training + class LearningRateSchedule(object): """Mixin for TFPolicyGraph that adds a learning rate schedule.""" diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index cb990c36f8bff..a762927bab442 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -72,7 +72,6 @@ def compute_actions(self, state_batches=None, prev_action_batch=None, prev_reward_batch=None, - is_training=False, episodes=None): if state_batches: raise NotImplementedError("Torch RNN support") diff --git a/python/ray/rllib/examples/batch_norm_model.py b/python/ray/rllib/examples/batch_norm_model.py new file mode 100644 index 0000000000000..abd4b53666a2a --- /dev/null +++ b/python/ray/rllib/examples/batch_norm_model.py @@ -0,0 +1,64 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +"""Example of using a custom model with batch norm.""" + +import argparse + +import tensorflow as tf +import tensorflow.contrib.slim as slim + +import ray +from ray.rllib.models import Model, ModelCatalog +from ray.rllib.models.misc import normc_initializer +from ray.tune import run_experiments + +parser = argparse.ArgumentParser() +parser.add_argument("--num-iters", type=int, default=200) +parser.add_argument("--run", type=str, default="PPO") + + +class BatchNormModel(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + last_layer = input_dict["obs"] + hiddens = [256, 256] + for i, size in enumerate(hiddens): + label = "fc{}".format(i) + last_layer = slim.fully_connected( + last_layer, + size, + weights_initializer=normc_initializer(1.0), + activation_fn=tf.nn.tanh, + scope=label) + # Add a batch norm layer + last_layer = tf.layers.batch_normalization( + last_layer, training=input_dict["is_training"]) + output = slim.fully_connected( + last_layer, + num_outputs, + weights_initializer=normc_initializer(0.01), + activation_fn=None, + scope="fc_out") + return output, last_layer + + +if __name__ == "__main__": + args = parser.parse_args() + ray.init() + + ModelCatalog.register_custom_model("bn_model", BatchNormModel) + run_experiments({ + "batch_norm_demo": { + "run": args.run, + "env": "Pendulum-v0" if args.run == "DDPG" else "CartPole-v0", + "stop": { + "training_iteration": args.num_iters + }, + "config": { + "model": { + "custom_model": "bn_model", + }, + "num_workers": 0, + }, + }, + }) diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index d5147168c2fb2..561b636dc863e 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -23,7 +23,7 @@ class Model(object): Attributes: input_dict (dict): Dictionary of input tensors, including "obs", - "prev_action", "prev_reward". + "prev_action", "prev_reward", "is_training". outputs (Tensor): The output vector of this model, of shape [BATCH_SIZE, num_outputs]. last_layer (Tensor): The feature layer right before the model output, @@ -108,7 +108,7 @@ def _build_layers_v2(self, input_dict, num_outputs, options): Arguments: input_dict (dict): Dictionary of input tensors, including "obs", - "prev_action", "prev_reward". + "prev_action", "prev_reward", "is_training". num_outputs (int): Output tensor must be of size [BATCH_SIZE, num_outputs]. options (dict): Model options. @@ -124,6 +124,7 @@ def _build_layers_v2(self, input_dict, num_outputs, options): >>> print(input_dict) {'prev_actions': , 'prev_rewards': , + 'is_training': , 'obs': OrderedDict([ ('sensors', OrderedDict([ ('front_cam', [ diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 1affe8df395e4..c548b20cc022d 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -3,12 +3,15 @@ from __future__ import print_function from collections import namedtuple +import logging import tensorflow as tf # Variable scope in which created variables will be placed under TOWER_SCOPE_NAME = "tower" +logger = logging.getLogger(__name__) + class LocalSyncParallelOptimizer(object): """Optimizer that runs in parallel across multiple local devices. @@ -63,6 +66,8 @@ def __init__(self, # First initialize the shared loss network with tf.name_scope(TOWER_SCOPE_NAME): self._shared_loss = build_graph(self.loss_inputs) + shared_ops = tf.get_collection( + tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) # Then setup the per-device loss graphs that use the shared weights self._batch_index = tf.placeholder(tf.int32, name="batch_index") @@ -95,7 +100,20 @@ def __init__(self, clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping) for i, (grad, var) in enumerate(avg): avg[i] = (clipped[i], var) - self._train_op = self.optimizer.apply_gradients(avg) + + # gather update ops for any batch norm layers. TODO(ekl) here we will + # use all the ops found which won't work for DQN / DDPG, but those + # aren't supported with multi-gpu right now anyways. + self._update_ops = tf.get_collection( + tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) + for op in shared_ops: + self._update_ops.remove(op) # only care about tower update ops + if self._update_ops: + logger.debug("Update ops to run on apply gradient: {}".format( + self._update_ops)) + + with tf.control_dependencies(self._update_ops): + self._train_op = self.optimizer.apply_gradients(avg) def load_data(self, sess, inputs, state_inputs): """Bulk loads the specified inputs into device memory. diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index 5b4099b3c71f3..5712390c05c6e 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -323,7 +323,6 @@ def compute_actions(self, state_batches, prev_action_batch=None, prev_reward_batch=None, - is_training=False, episodes=None): return [0] * len(obs_batch), [[h] * len(obs_batch)], {} @@ -348,7 +347,6 @@ def compute_actions(self, state_batches, prev_action_batch=None, prev_reward_batch=None, - is_training=False, episodes=None): # Pretend we did a model-based rollout and want to return # the extra trajectory. diff --git a/python/ray/rllib/test/test_policy_evaluator.py b/python/ray/rllib/test/test_policy_evaluator.py index 7b4d6c8b5ae09..cf319a7e922b2 100644 --- a/python/ray/rllib/test/test_policy_evaluator.py +++ b/python/ray/rllib/test/test_policy_evaluator.py @@ -25,7 +25,6 @@ def compute_actions(self, state_batches, prev_action_batch=None, prev_reward_batch=None, - is_training=False, episodes=None): return [0] * len(obs_batch), [], {} @@ -43,7 +42,6 @@ def compute_actions(self, state_batches, prev_action_batch=None, prev_reward_batch=None, - is_training=False, episodes=None): raise Exception("intentional error") diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 86fd98af21d09..9b8d9295eae33 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -269,6 +269,18 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_lstm.py +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/batch_norm_model.py --num-iters=1 --run=PPO + +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/batch_norm_model.py --num-iters=1 --run=PG + +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/batch_norm_model.py --num-iters=1 --run=DQN + +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/batch_norm_model.py --num-iters=1 --run=DDPG + docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_multi_agent_env.py