Skip to content

Commit

Permalink
[rllib] example and docs on how to use parametric actions with DQN / …
Browse files Browse the repository at this point in the history
…PG algorithms (#3384)
  • Loading branch information
ericl committed Nov 28, 2018
1 parent c2108ca commit f0df97d
Show file tree
Hide file tree
Showing 15 changed files with 366 additions and 45 deletions.
6 changes: 3 additions & 3 deletions doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/pyt
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__

Deep Q Networks (DQN, Rainbow)
------------------------------
Deep Q Networks (DQN, Rainbow, Parametric DQN)
----------------------------------------------
`[paper] <https://arxiv.org/abs/1312.5602>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/dqn/dqn.py>`__
RLlib DQN is implemented using the SyncReplayOptimizer. The algorithm can be scaled by increasing the number of workers, using the AsyncGradientsOptimizer for async DQN, or using Ape-X. Memory usage is reduced by compressing samples in the replay buffer with LZ4. All of the DQN improvements evaluated in `Rainbow <https://arxiv.org/abs/1710.02298>`__ are available, though not all are enabled by default.
RLlib DQN is implemented using the SyncReplayOptimizer. The algorithm can be scaled by increasing the number of workers, using the AsyncGradientsOptimizer for async DQN, or using Ape-X. Memory usage is reduced by compressing samples in the replay buffer with LZ4. All of the DQN improvements evaluated in `Rainbow <https://arxiv.org/abs/1710.02298>`__ are available, though not all are enabled by default. See also how to use `parametric-actions in DQN <rllib-models.html#variable-length-parametric-action-spaces>`__.

Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-dqn.yaml>`__, `Rainbow configuration <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-rainbow.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/atari-basic-dqn.yaml>`__, `with Dueling and Double-Q <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/atari-duel-ddqn.yaml>`__, `with Distributional DQN <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/atari-dist-dqn.yaml>`__.

Expand Down
30 changes: 16 additions & 14 deletions doc/source/rllib-env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,22 @@ RLlib works with several different types of environments, including `OpenAI Gym

**Compatibility matrix**:

============= ================ ================== =========== ==================
Algorithm Discrete Actions Continuous Actions Multi-Agent Recurrent Policies
============= ================ ================== =========== ==================
A2C, A3C **Yes** **Yes** **Yes** **Yes**
PPO **Yes** **Yes** **Yes** **Yes**
PG **Yes** **Yes** **Yes** **Yes**
IMPALA **Yes** No **Yes** **Yes**
DQN, Rainbow **Yes** No **Yes** No
DDPG, TD3 No **Yes** **Yes** No
APEX-DQN **Yes** No **Yes** No
APEX-DDPG No **Yes** **Yes** No
ES **Yes** **Yes** No No
ARS **Yes** **Yes** No No
============= ================ ================== =========== ==================
============= ======================= ================== =========== ==================
Algorithm Discrete Actions Continuous Actions Multi-Agent Recurrent Policies
============= ======================= ================== =========== ==================
A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** **Yes**
PPO **Yes** `+parametric`_ **Yes** **Yes** **Yes**
PG **Yes** `+parametric`_ **Yes** **Yes** **Yes**
IMPALA **Yes** `+parametric`_ No **Yes** **Yes**
DQN, Rainbow **Yes** `+parametric`_ No **Yes** No
DDPG, TD3 No **Yes** **Yes** No
APEX-DQN **Yes** `+parametric`_ No **Yes** No
APEX-DDPG No **Yes** **Yes** No
ES **Yes** **Yes** No No
ARS **Yes** **Yes** No No
============= ======================= ================== =========== ==================

.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces

In the high-level agent APIs, environments are identified with string names. By default, the string will be interpreted as a gym `environment name <https://gym.openai.com/envs>`__, however you can also register custom environments by name:

Expand Down
84 changes: 84 additions & 0 deletions doc/source/rllib-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,43 @@ Custom models should subclass the common RLlib `model class <https://github.com/
For a full example of a custom model in code, see the `Carla RLlib model <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/carla/models.py>`__ and associated `training scripts <https://github.com/ray-project/ray/tree/master/python/ray/rllib/examples/carla>`__. You can also reference the `unit tests <https://github.com/ray-project/ray/blob/master/python/ray/rllib/test/test_nested_spaces.py>`__ for Tuple and Dict spaces, which show how to access nested observation fields.

Custom Recurrent Models
~~~~~~~~~~~~~~~~~~~~~~~

Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. The only difference from a normal custom model is that you have to define ``self.state_init``, ``self.state_in``, and ``self.state_out``. You can refer to the existing `lstm.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/lstm.py>`__ model as an example to implement your own model:


.. code-block:: python
class MyCustomLSTM(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
# Some initial layers to process inputs, shape [BATCH, OBS...].
features = some_hidden_layers(input_dict["obs"])
# Add back the nested time dimension for tf.dynamic_rnn, new shape
# will be [BATCH, MAX_SEQ_LEN, OBS...].
last_layer = add_time_dimension(features, self.seq_lens)
# Setup the LSTM cell (see lstm.py for an example)
lstm = rnn.BasicLSTMCell(256, state_is_tuple=True)
self.state_init = ...
self.state_in = ...
lstm_out, lstm_state = tf.nn.dynamic_rnn(
lstm,
last_layer,
initial_state=...,
sequence_length=self.seq_lens,
time_major=False,
dtype=tf.float32)
self.state_out = list(lstm_state)
# Drop the time dimension again so back to shape [BATCH, OBS...].
# Note that we retain the zero padding (see issue #2992).
last_layer = tf.reshape(lstm_out, [-1, cell_size])
logits = linear(last_layer, num_outputs, "action",
normc_initializer(0.01))
return logits, last_layer
Custom Preprocessors
--------------------

Expand Down Expand Up @@ -188,6 +225,53 @@ Then, you can create an agent with your custom policy graph by:
In this example we overrode existing methods of the existing DDPG policy graph, i.e., `_build_q_network`, `_build_p_network`, `_build_action_network`, `_build_actor_critic_loss`, but you can also replace the entire graph class entirely.

Variable-length / Parametric Action Spaces
------------------------------------------

Custom models can be used to work with environments where (1) the set of valid actions varies per step, and/or (2) the number of valid actions is very large, as in `OpenAI Five <https://neuro.cs.ut.ee/the-use-of-embeddings-in-openai-five/>`__ and `Horizon <https://arxiv.org/abs/1811.00260>`__. The general idea is that the meaning of actions can be completely conditioned on the observation, that is, the ``a`` in ``Q(s, a)`` is just a token in ``[0, MAX_AVAIL_ACTIONS)`` that only has meaning in the context of ``s``. This works with algorithms in the `DQN and policy-gradient families <rllib-env.html>`__ and can be implemented as follows:

1. The environment should return a mask and/or list of valid action embeddings as part of the observation for each step. To enable batching, the number of actions can be allowed to vary from 1 to some max number:

.. code-block:: python
class MyParamActionEnv(gym.Env):
def __init__(self, max_avail_actions):
self.action_space = Discrete(max_avail_actions)
self.observation_space = Dict({
"action_mask": Box(0, 1, shape=(max_avail_actions, )),
"avail_actions": Box(-1, 1, shape=(max_avail_actions, action_embedding_sz)),
"real_obs": ...,
})
2. A custom model can be defined that can interpret the ``action_mask`` and ``avail_actions`` portions of the observation. Here the model computes the action logits via the dot product of some network output and each action embedding. Invalid actions can be masked out of the softmax by scaling the probability to zero:

.. code-block:: python
class MyParamActionModel(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
avail_actions = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
output = FullyConnectedNetwork(
input_dict["obs"]["real_obs"], num_outputs=action_embedding_sz)
# Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
# avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
intent_vector = tf.expand_dims(output, 1)
# Shape of logits is [BATCH, MAX_ACTIONS].
action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2)
# Mask out invalid actions (use tf.float32.min for stability)
inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
masked_logits = inf_mask + action_logits
return masked_logits, last_layer
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/parametric_action_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN and several policy gradient algorithms.


Model-Based Rollouts
--------------------

Expand Down
4 changes: 2 additions & 2 deletions doc/source/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ In an example below, we train A2C by specifying 8 workers through the config fla
python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \
--run=A2C --config '{"num_workers": 8}'
.. image:: rllib-config.svg

Specifying Resources
~~~~~~~~~~~~~~~~~~~~

You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five agents onto one GPU by setting ``num_gpus: 0.2``. Note that in Ray < 0.6.0 fractional GPU support requires setting the environment variable ``RAY_USE_XRAY=1``.

.. image:: rllib-config.svg

Common Parameters
~~~~~~~~~~~~~~~~~

Expand Down
3 changes: 2 additions & 1 deletion doc/source/rllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Algorithms

- `Deep Deterministic Policy Gradients (DDPG, TD3) <rllib-algorithms.html#deep-deterministic-policy-gradients-ddpg-td3>`__

- `Deep Q Networks (DQN, Rainbow) <rllib-algorithms.html#deep-q-networks-dqn-rainbow>`__
- `Deep Q Networks (DQN, Rainbow, Parametric DQN) <rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn>`__

- `Policy Gradients <rllib-algorithms.html#policy-gradients>`__

Expand All @@ -75,6 +75,7 @@ Models and Preprocessors
* `Custom Models <rllib-models.html#custom-models>`__
* `Custom Preprocessors <rllib-models.html#custom-preprocessors>`__
* `Customizing Policy Graphs <rllib-models.html#customizing-policy-graphs>`__
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
* `Model-Based Rollouts <rllib-models.html#model-based-rollouts>`__

RLlib Concepts
Expand Down
57 changes: 38 additions & 19 deletions python/ray/rllib/agents/dqn/dqn_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,35 @@ def __init__(self,
sigma0=0.5):
self.model = model
with tf.variable_scope("action_value"):
action_out = model.last_layer
for i in range(len(hiddens)):
if use_noisy:
action_out = self.noisy_layer("hidden_%d" % i, action_out,
hiddens[i], sigma0)
else:
action_out = layers.fully_connected(
action_out,
num_outputs=hiddens[i],
activation_fn=tf.nn.relu)
if hiddens:
action_out = model.last_layer
for i in range(len(hiddens)):
if use_noisy:
action_out = self.noisy_layer(
"hidden_%d" % i, action_out, hiddens[i], sigma0)
else:
action_out = layers.fully_connected(
action_out,
num_outputs=hiddens[i],
activation_fn=tf.nn.relu)
else:
# Avoid postprocessing the outputs. This enables custom models
# to be used for parametric action DQN.
action_out = model.outputs
if use_noisy:
action_scores = self.noisy_layer(
"output",
action_out,
num_actions * num_atoms,
sigma0,
non_linear=False)
else:
elif hiddens:
action_scores = layers.fully_connected(
action_out,
num_outputs=num_actions * num_atoms,
activation_fn=None)
else:
action_scores = model.outputs
if num_atoms > 1:
# Distributional Q-learning uses a discrete support z
# to represent the action value distribution
Expand Down Expand Up @@ -107,7 +114,7 @@ def __init__(self,
self.logits = support_logits_per_action
self.dist = support_prob_per_action
else:
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores_mean = _reduce_mean_ignore_inf(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(
action_scores_mean, 1)
self.value = state_score + action_scores_centered
Expand Down Expand Up @@ -176,11 +183,15 @@ class QValuePolicy(object):
def __init__(self, q_values, observations, num_actions, stochastic, eps):
deterministic_actions = tf.argmax(q_values, axis=1)
batch_size = tf.shape(observations)[0]
random_actions = tf.random_uniform(
tf.stack([batch_size]),
minval=0,
maxval=num_actions,
dtype=tf.int64)

# Special case masked out actions (q_value ~= -inf) so that we don't
# even consider them for exploration.
random_valid_action_logits = tf.where(
tf.equal(q_values, tf.float32.min),
tf.ones_like(q_values) * tf.float32.min, tf.ones_like(q_values))
random_actions = tf.squeeze(
tf.multinomial(random_valid_action_logits, 1), axis=1)

chose_random = tf.random_uniform(
tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
stochastic_actions = tf.where(chose_random, random_actions,
Expand Down Expand Up @@ -368,8 +379,8 @@ def _build_q_network(self, obs, space):
qnet = QNetwork(
ModelCatalog.get_model({
"obs": obs
}, space, 1, self.config["model"]), self.num_actions,
self.config["dueling"], self.config["hiddens"],
}, space, self.num_actions, self.config["model"]),
self.num_actions, self.config["dueling"], self.config["hiddens"],
self.config["noisy"], self.config["num_atoms"],
self.config["v_min"], self.config["v_max"], self.config["sigma0"])
return qnet.value, qnet.logits, qnet.dist, qnet.model
Expand Down Expand Up @@ -507,6 +518,14 @@ def _postprocess_dqn(policy_graph, sample_batch):
return batch


def _reduce_mean_ignore_inf(x, axis):
"""Same as tf.reduce_mean() but ignores -inf values."""
mask = tf.not_equal(x, tf.float32.min)
x_zeroed = tf.where(mask, x, tf.zeros_like(x))
return (tf.reduce_sum(x_zeroed, axis) / tf.reduce_sum(
tf.cast(mask, tf.float32), axis))


def _huber_loss(x, delta=1.0):
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
return tf.where(
Expand Down
5 changes: 5 additions & 0 deletions python/ray/rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def _validate_config(self):
and not self.config["simple_optimizer"]):
logger.warn("forcing simple_optimizer=True in multi-agent mode")
self.config["simple_optimizer"] = True
if self.config["observation_filter"] != "NoFilter":
# TODO(ekl): consider setting the default to be NoFilter
logger.warn(
"By default, observations will be normalized with {}".format(
self.config["observation_filter"]))

def _train(self):
prev_steps = self.optimizer.num_steps_sampled
Expand Down
Loading

0 comments on commit f0df97d

Please sign in to comment.