Skip to content

Commit

Permalink
[rllib] Support parallel, parameterized evaluation (#6981)
Browse files Browse the repository at this point in the history
* eval api

* update

* sync eval filters

* sync fix

* docs

* update

* docs

* update

* link

* nit

* doc updates

* format
  • Loading branch information
ericl committed Feb 2, 2020
1 parent b9ad79d commit fbc545c
Show file tree
Hide file tree
Showing 8 changed files with 381 additions and 46 deletions.
3 changes: 3 additions & 0 deletions ci/jenkins_tests/run_rllib_tests.sh
@@ -1,3 +1,6 @@
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/rllib/examples/custom_eval.py --custom-eval

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/rllib/tests/test_catalog.py

Expand Down
4 changes: 4 additions & 0 deletions doc/source/rllib-models.rst
Expand Up @@ -212,6 +212,10 @@ You can mix supervised losses into any RLlib algorithm through custom models. Fo

**PyTorch**: There is no explicit API for adding losses to custom torch models. However, you can modify the loss in the policy definition directly. Like for TF models, offline datasets can be incorporated by creating an input reader and calling ``reader.next()`` in the loss forward pass.

Self-Supervised Model Losses
----------------------------

You can also use the ``custom_loss()`` API to add in self-supervised losses such as VAE reconstruction loss and L2-regularization.

Variable-length / Parametric Action Spaces
------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion doc/source/rllib-offline.rst
Expand Up @@ -13,7 +13,7 @@ Example: Training on previously saved experiences

.. note::

For custom models and enviroments, you'll need to use the `Python API <rllib-training.html#python-api>`__.
For custom models and enviroments, you'll need to use the `Python API <rllib-training.html#basic-python-api>`__.

In this example, we will save batches of experiences generated during online training to disk, and then leverage this saved data to train a policy offline using DQN. First, we run a simple policy gradient algorithm for 100k steps with ``"output": "/tmp/cartpole-out"`` to tell RLlib to write simulation outputs to the ``/tmp/cartpole-out`` directory.

Expand Down
14 changes: 11 additions & 3 deletions doc/source/rllib-toc.rst
Expand Up @@ -4,6 +4,9 @@ RLlib Table of Contents
Training APIs
-------------
* `Command-line <rllib-training.html>`__

- `Evaluating Trained Policies <rllib-training.html#evaluating-trained-policies>`__

* `Configuration <rllib-training.html#configuration>`__

- `Specifying Parameters <rllib-training.html#specifying-parameters>`__
Expand All @@ -14,20 +17,24 @@ Training APIs

- `Tuned Examples <rllib-training.html#tuned-examples>`__

* `Python API <rllib-training.html#python-api>`__

- `Custom Training Workflows <rllib-training.html#custom-training-workflows>`__
* `Basic Python API <rllib-training.html#basic-python-api>`__

- `Computing Actions <rllib-training.html#computing-actions>`__

- `Accessing Policy State <rllib-training.html#accessing-policy-state>`__

- `Accessing Model State <rllib-training.html#accessing-model-state>`__

* `Advanced Python APIs <rllib-training.html#advanced-python-apis>`__

- `Custom Training Workflows <rllib-training.html#custom-training-workflows>`__

- `Global Coordination <rllib-training.html#global-coordination>`__

- `Callbacks and Custom Metrics <rllib-training.html#callbacks-and-custom-metrics>`__

- `Customized Evaluation During Training <rllib-training.html#customized-evaluation-during-training>`__

- `Rewriting Trajectories <rllib-training.html#rewriting-trajectories>`__

- `Curriculum Learning <rllib-training.html#curriculum-learning>`__
Expand Down Expand Up @@ -64,6 +71,7 @@ Models, Preprocessors, and Action Distributions
* `Custom Preprocessors <rllib-models.html#custom-preprocessors>`__
* `Custom Action Distributions <rllib-models.html#custom-action-distributions>`__
* `Supervised Model Losses <rllib-models.html#supervised-model-losses>`__
* `Self-Supervised Model Losses <rllib-models.html#self-supervised-model-losses>`__
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
* `Autoregressive Action Distributions <rllib-models.html#autoregressive-action-distributions>`__

Expand Down
85 changes: 76 additions & 9 deletions doc/source/rllib-training.rst
Expand Up @@ -61,6 +61,8 @@ and renders its behavior in the environment specified by ``--env``.

(Type ``rllib rollout --help`` to see the available evaluation options.)

For more advanced evaluation functionality, refer to `Customized Evaluation During Training <#customized-evaluation-during-training>`__.

Configuration
-------------

Expand Down Expand Up @@ -107,8 +109,8 @@ You can run these with the ``rllib train`` command as follows:
rllib train -f /path/to/tuned/example.yaml
Python API
----------
Basic Python API
----------------

The Python API provides the needed flexibility for applying RLlib to new problems. You will need to use this API if you wish to use `custom environments, preprocessors, or models <rllib-models.html>`__ with RLlib.

Expand Down Expand Up @@ -177,13 +179,6 @@ Tune will schedule the trials to run in parallel on your Ray cluster:
- PPO_CartPole-v0_0_lr=0.01: RUNNING [pid=21940], 16 s, 4013 ts, 22 rew
- PPO_CartPole-v0_1_lr=0.001: RUNNING [pid=21942], 27 s, 8111 ts, 54.7 rew

Custom Training Workflows
~~~~~~~~~~~~~~~~~~~~~~~~~

In the `basic training example <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your trainer once per training iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions <tune-usage.html#trainable-api>`__ that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_train_fn.py>`__.

For even finer-grained control over training, you can use RLlib's lower-level `building blocks <rllib-concepts.html>`__ directly to implement `fully customized training workflows <https://github.com/ray-project/ray/blob/master/rllib/examples/rollout_worker_custom_workflow.py>`__.

Computing Actions
~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -431,6 +426,16 @@ Similar to accessing policy state, you may want to get a reference to the underl
This is especially useful when used with `custom model classes <rllib-models.html>`__.

Advanced Python APIs
--------------------

Custom Training Workflows
~~~~~~~~~~~~~~~~~~~~~~~~~

In the `basic training example <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your trainer once per training iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions <tune-usage.html#trainable-api>`__ that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_train_fn.py>`__.

For even finer-grained control over training, you can use RLlib's lower-level `building blocks <rllib-concepts.html>`__ directly to implement `fully customized training workflows <https://github.com/ray-project/ray/blob/master/rllib/examples/rollout_worker_custom_workflow.py>`__.

Global Coordination
~~~~~~~~~~~~~~~~~~~
Sometimes, it is necessary to coordinate between pieces of code that live in different processes managed by RLlib. For example, it can be useful to maintain a global average of a certain variable, or centrally control a hyperparameter used by policies. Ray provides a general way to achieve this through *named actors* (learn more about Ray actors `here <actors.html>`__). As an example, consider maintaining a shared global counter that is incremented by environments and read periodically from your driver program:
Expand Down Expand Up @@ -515,6 +520,68 @@ Custom metrics can be accessed and visualized like any other training result:

.. image:: custom_metric.png

Customized Evaluation During Training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

RLlib will report online training rewards, however in some cases you may want to compute
rewards with different settings (e.g., with exploration turned off, or on a specific set
of environment configurations). You can evaluate policies during training by setting one
or more of the ``evaluation_interval``, ``evaluation_num_episodes``, ``evaluation_config``,
``evaluation_num_workers``, and ``custom_eval_function`` configs
(see `trainer.py <https://github.com/ray-project/ray/blob/master/rllib/agents/trainer.py>`__ for further documentation).

There is an end to end example of how to set up custom online evaluation in `custom_eval.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_eval.py>`__. Note that if you only want to eval your policy at the end of training, you can set ``evaluation_interval: N``, where ``N`` is the number of training iterations before stopping.

Below are some examples of how the custom evaluation metrics are reported nested under the ``evaluation`` key of normal training results:

.. code-block:: bash
------------------------------------------------------------------------
Sample output for `python custom_eval.py`
------------------------------------------------------------------------
INFO trainer.py:623 -- Evaluating current policy for 10 episodes.
INFO trainer.py:650 -- Running round 0 of parallel evaluation (2/10 episodes)
INFO trainer.py:650 -- Running round 1 of parallel evaluation (4/10 episodes)
INFO trainer.py:650 -- Running round 2 of parallel evaluation (6/10 episodes)
INFO trainer.py:650 -- Running round 3 of parallel evaluation (8/10 episodes)
INFO trainer.py:650 -- Running round 4 of parallel evaluation (10/10 episodes)
Result for PG_SimpleCorridor_2c6b27dc:
...
evaluation:
custom_metrics: {}
episode_len_mean: 15.864661654135338
episode_reward_max: 1.0
episode_reward_mean: 0.49624060150375937
episode_reward_min: 0.0
episodes_this_iter: 133
.. code-block:: bash
------------------------------------------------------------------------
Sample output for `python custom_eval.py --custom-eval`
------------------------------------------------------------------------
INFO trainer.py:631 -- Running custom eval function <function ...>
Update corridor length to 4
Update corridor length to 7
Custom evaluation round 1
Custom evaluation round 2
Custom evaluation round 3
Custom evaluation round 4
Result for PG_SimpleCorridor_0de4e686:
...
evaluation:
custom_metrics: {}
episode_len_mean: 9.15695067264574
episode_reward_max: 1.0
episode_reward_mean: 0.9596412556053812
episode_reward_min: 0.0
episodes_this_iter: 223
foo: 1
Rewriting Trajectories
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
115 changes: 84 additions & 31 deletions rllib/agents/trainer.py
@@ -1,6 +1,7 @@
from datetime import datetime
import copy
import logging
import math
import os
import pickle
import six
Expand Down Expand Up @@ -170,13 +171,30 @@
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
"evaluation_interval": None,
# Number of episodes to run per evaluation period.
# Number of episodes to run per evaluation period. If using multiple
# evaluation workers, we will run at least this many episodes total.
"evaluation_num_episodes": 10,
# Extra arguments to pass to evaluation workers.
# Typical usage is to pass extra args to evaluation env creator
# and to disable exploration by computing deterministic actions
# TODO(kismuz): implement determ. actions and include relevant keys hints
"evaluation_config": {},
"evaluation_config": {
# Example: overriding env_config, exploration, etc:
# "env_config": {...},
# "exploration_fraction": 0,
# "exploration_final_eps": 0,
},
# Number of parallel workers to use for evaluation. Note that this is set
# to zero by default, which means evaluation will be run in the trainer
# process. If you increase this, it will increase the Ray resource usage
# of the trainer since evaluation workers are created separately from
# rollout workers.
"evaluation_num_workers": 0,
# Customize the evaluation method. This must be a function of signature
# (trainer: Trainer, eval_workers: WorkerSet) -> metrics: dict. See the
# Trainer._evaluate() method to see the default implementation. The
# trainer guarantees all eval workers have the latest policy state before
# this function is called.
"custom_eval_function": None,

# === Advanced Rollout Settings ===
# Use a background thread for sampling (slightly off-policy, usually not
Expand Down Expand Up @@ -408,17 +426,18 @@ def default_logger_creator(config):
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
Trainer._validate_config(cf)
num_workers = cf["num_workers"] + cf["evaluation_num_workers"]
# TODO(ekl): add custom resources here once tune supports them
return Resources(
cpu=cf["num_cpus_for_driver"],
gpu=cf["num_gpus"],
memory=cf["memory"],
object_store_memory=cf["object_store_memory"],
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"],
extra_memory=cf["memory_per_worker"] * cf["num_workers"],
extra_cpu=cf["num_cpus_per_worker"] * num_workers,
extra_gpu=cf["num_gpus_per_worker"] * num_workers,
extra_memory=cf["memory_per_worker"] * num_workers,
extra_object_store_memory=cf["object_store_memory_per_worker"] *
cf["num_workers"])
num_workers)

@override(Trainable)
@PublicAPI
Expand Down Expand Up @@ -456,29 +475,32 @@ def train(self):
if result is None:
raise RuntimeError("Failed to recover from worker crash")

if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
and hasattr(self, "workers")
and isinstance(self.workers, WorkerSet)):
FilterManager.synchronize(
self.workers.local_worker().filters,
self.workers.remote_workers(),
update_remote=self.config["synchronize_filters"])
logger.debug("synchronized filters: {}".format(
self.workers.local_worker().filters))
if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
self._sync_filters_if_needed(self.workers)

if self._has_policy_optimizer():
result["num_healthy_workers"] = len(
self.optimizer.workers.remote_workers())

if self.config["evaluation_interval"]:
if self._iteration % self.config["evaluation_interval"] == 0:
evaluation_metrics = self._evaluate()
assert isinstance(evaluation_metrics, dict), \
"_evaluate() needs to return a dict."
result.update(evaluation_metrics)
if self.config["evaluation_interval"] == 1 or (
self._iteration > 0 and self.config["evaluation_interval"]
and self._iteration % self.config["evaluation_interval"] == 0):
evaluation_metrics = self._evaluate()
assert isinstance(evaluation_metrics, dict), \
"_evaluate() needs to return a dict."
result.update(evaluation_metrics)

return result

def _sync_filters_if_needed(self, workers):
if self.config.get("observation_filter", "NoFilter") != "NoFilter":
FilterManager.synchronize(
workers.local_worker().filters,
workers.remote_workers(),
update_remote=self.config["synchronize_filters"])
logger.debug("synchronized filters: {}".format(
workers.local_worker().filters))

@override(Trainable)
def _log_result(self, result):
if self.config["callbacks"].get("on_train_result"):
Expand Down Expand Up @@ -548,8 +570,8 @@ def get_scope():
self.env_creator,
self._policy,
merge_dicts(self.config, extra_config),
num_workers=0)
self.evaluation_metrics = self._evaluate()
num_workers=self.config["evaluation_num_workers"])
self.evaluation_metrics = {}

@override(Trainable)
def _stop(self):
Expand Down Expand Up @@ -600,15 +622,46 @@ def _evaluate(self):
"overrides, since the results will be the "
"same as reported during normal policy evaluation.")

logger.info("Evaluating current policy for {} episodes".format(
self.config["evaluation_num_episodes"]))
self._before_evaluate()
self.evaluation_workers.local_worker().restore(
self.workers.local_worker().save())
for _ in range(self.config["evaluation_num_episodes"]):
self.evaluation_workers.local_worker().sample()

metrics = collect_metrics(self.evaluation_workers.local_worker())
# Broadcast the new policy weights to all evaluation workers.
logger.info("Synchronizing weights to evaluation workers.")
weights = ray.put(self.workers.local_worker().save())
self.evaluation_workers.foreach_worker(
lambda w: w.restore(ray.get(weights)))
self._sync_filters_if_needed(self.evaluation_workers)

if self.config["custom_eval_function"]:
logger.info("Running custom eval function {}".format(
self.config["custom_eval_function"]))
metrics = self.config["custom_eval_function"](
self, self.evaluation_workers)
if not metrics or not isinstance(metrics, dict):
raise ValueError("Custom eval function must return "
"dict of metrics, got {}.".format(metrics))
else:
logger.info("Evaluating current policy for {} episodes.".format(
self.config["evaluation_num_episodes"]))
if self.config["evaluation_num_workers"] == 0:
for _ in range(self.config["evaluation_num_episodes"]):
self.evaluation_workers.local_worker().sample()
else:
num_rounds = int(
math.ceil(self.config["evaluation_num_episodes"] /
self.config["evaluation_num_workers"]))
num_workers = len(self.evaluation_workers.remote_workers())
num_episodes = num_rounds * num_workers
for i in range(num_rounds):
logger.info("Running round {} of parallel evaluation "
"({}/{} episodes)".format(
i, (i + 1) * num_workers, num_episodes))
ray.get([
w.sample.remote()
for w in self.evaluation_workers.remote_workers()
])

metrics = collect_metrics(self.evaluation_workers.local_worker(),
self.evaluation_workers.remote_workers())
return {"evaluation": metrics}

@DeveloperAPI
Expand Down
8 changes: 6 additions & 2 deletions rllib/evaluation/metrics.py
Expand Up @@ -81,14 +81,18 @@ def collect_episodes(local_worker=None,


@DeveloperAPI
def summarize_episodes(episodes, new_episodes):
def summarize_episodes(episodes, new_episodes=None):
"""Summarizes a set of episode metrics tuples.
Arguments:
episodes: smoothed set of episodes including historical ones
new_episodes: just the new episodes in this iteration
new_episodes: just the new episodes in this iteration. This must be
a subset of `episodes`. If None, assumes all episodes are new.
"""

if new_episodes is None:
new_episodes = episodes

episodes, estimates = _partition(episodes)
new_episodes, _ = _partition(new_episodes)

Expand Down

0 comments on commit fbc545c

Please sign in to comment.