diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index a4a659b8e00d5..ea91c2e30d97a 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -5,6 +5,23 @@ RLlib works with several different types of environments, including `OpenAI Gym .. image:: rllib-envs.svg +**Compatibility matrix**: + +============= ================ ================== =========== ================== +Algorithm Discrete Actions Continuous Actions Multi-Agent Recurrent Policies +============= ================ ================== =========== ================== +A2C, A3C **Yes** **Yes** **Yes** **Yes** +PPO **Yes** **Yes** **Yes** **Yes** +PG **Yes** **Yes** **Yes** **Yes** +IMPALA **Yes** No **Yes** **Yes** +DQN, Rainbow **Yes** No **Yes** No +DDPG No **Yes** **Yes** No +APEX-DQN **Yes** No **Yes** No +APEX-DDPG No **Yes** **Yes** No +ES **Yes** **Yes** No No +ARS **Yes** **Yes** No No +============= ================ ================== =========== ================== + In the high-level agent APIs, environments are identified with string names. By default, the string will be interpreted as a gym `environment name `__, however you can also register custom environments by name: .. code-block:: python @@ -132,7 +149,35 @@ If all the agents will be using the same algorithm class to train, then you can RLlib will create three distinct policies and route agent decisions to its bound policy. When an agent first appears in the env, ``policy_mapping_fn`` will be called to determine which policy it is bound to. RLlib reports separate training statistics for each policy in the return from ``train()``, along with the combined reward. -Here is a simple `example training script `__ in which you can vary the number of agents and policies in the environment. For how to use multiple training methods at once (here DQN and PPO), see the `two-trainer example `__. +Here is a simple `example training script `__ in which you can vary the number of agents and policies in the environment. For how to use multiple training methods at once (here DQN and PPO), see the `two-trainer example `__. Metrics are reported for each policy separately, for example: + +.. code-block:: bash + :emphasize-lines: 6,14,22 + + Result for PPO_multi_cartpole_0: + episode_len_mean: 34.025862068965516 + episode_reward_max: 159.0 + episode_reward_mean: 86.06896551724138 + info: + policy_0: + cur_lr: 4.999999873689376e-05 + entropy: 0.6833480000495911 + kl: 0.010264254175126553 + policy_loss: -11.95590591430664 + total_loss: 197.7039794921875 + vf_explained_var: 0.0010995268821716309 + vf_loss: 209.6578826904297 + policy_1: + cur_lr: 4.999999873689376e-05 + entropy: 0.6827034950256348 + kl: 0.01119876280426979 + policy_loss: -8.787769317626953 + total_loss: 88.26161193847656 + vf_explained_var: 0.0005457401275634766 + vf_loss: 97.0471420288086 + policy_reward_mean: + policy_0: 21.194444444444443 + policy_1: 21.798387096774192 To scale to hundreds of agents, MultiAgentEnv batches policy evaluations across multiple agents internally. It can also be auto-vectorized by setting ``num_envs_per_worker > 1``. diff --git a/python/ray/rllib/agents/ddpg/apex.py b/python/ray/rllib/agents/ddpg/apex.py index c9053ca8a00a0..e809ac8650a05 100644 --- a/python/ray/rllib/agents/ddpg/apex.py +++ b/python/ray/rllib/agents/ddpg/apex.py @@ -47,7 +47,7 @@ class ApexDDPGAgent(DDPGAgent): def default_resource_request(cls, config): cf = merge_dicts(cls._default_config, config) return Resources( - cpu=1 + cf["optimizer"]["num_replay_buffer_shards"], + cpu=1, gpu=cf["gpu"] and cf["gpu_fraction"] or 0, extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index ac8ec4490e609..585bdeea89667 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -50,7 +50,7 @@ class ApexAgent(DQNAgent): def default_resource_request(cls, config): cf = merge_dicts(cls._default_config, config) return Resources( - cpu=1 + cf["optimizer"]["num_replay_buffer_shards"], + cpu=1, gpu=cf["gpu"] and cf["gpu_fraction"] or 0, extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) diff --git a/python/ray/rllib/examples/multiagent_cartpole.py b/python/ray/rllib/examples/multiagent_cartpole.py index fff7066390cb1..002e1fa98db11 100644 --- a/python/ray/rllib/examples/multiagent_cartpole.py +++ b/python/ray/rllib/examples/multiagent_cartpole.py @@ -85,7 +85,6 @@ def gen_policy(i): "custom_model": ["model1", "model2"][i % 2], }, "gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]), - "n_step": random.choice([1, 2, 3, 4, 5]), } return (PPOPolicyGraph, obs_space, act_space, config) @@ -98,12 +97,13 @@ def gen_policy(i): run_experiments({ "test": { - "run": "PG", + "run": "PPO", "env": "multi_cartpole", "stop": { "training_iteration": args.num_iters }, "config": { + "simple_optimizer": True, "multiagent": { "policy_graphs": policy_graphs, "policy_mapping_fn": tune.function( diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index c48fd18601974..bacec3675c401 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -6,6 +6,7 @@ from __future__ import division from __future__ import print_function +import collections import os import random import time @@ -15,9 +16,10 @@ from six.moves import queue import ray +from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ + MultiAgentBatch from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer -from ray.rllib.evaluation.sample_batch import SampleBatch from ray.rllib.utils.actors import TaskPool, create_colocated from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat @@ -43,49 +45,61 @@ def __init__(self, num_shards, learning_starts, buffer_size, self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps - self.replay_buffer = PrioritizedReplayBuffer( - self.buffer_size, alpha=prioritized_replay_alpha) + def new_buffer(): + return PrioritizedReplayBuffer( + self.buffer_size, alpha=prioritized_replay_alpha) + + self.replay_buffers = collections.defaultdict(new_buffer) # Metrics self.add_batch_timer = TimerStat() self.replay_timer = TimerStat() self.update_priorities_timer = TimerStat() + self.num_added = 0 def get_host(self): return os.uname()[1] def add_batch(self, batch): - PolicyOptimizer._check_not_multiagent(batch) + # Handle everything as if multiagent + if isinstance(batch, SampleBatch): + batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) with self.add_batch_timer: - for row in batch.rows(): - self.replay_buffer.add(row["obs"], row["actions"], - row["rewards"], row["new_obs"], - row["dones"], row["weights"]) + for policy_id, s in batch.policy_batches.items(): + for row in s.rows(): + self.replay_buffers[policy_id].add( + row["obs"], row["actions"], row["rewards"], + row["new_obs"], row["dones"], row["weights"]) + self.num_added += batch.count def replay(self): - with self.replay_timer: - if len(self.replay_buffer) < self.replay_starts: - return None - - (obses_t, actions, rewards, obses_tp1, dones, weights, - batch_indexes) = self.replay_buffer.sample( - self.train_batch_size, beta=self.prioritized_replay_beta) - - batch = SampleBatch({ - "obs": obses_t, - "actions": actions, - "rewards": rewards, - "new_obs": obses_tp1, - "dones": dones, - "weights": weights, - "batch_indexes": batch_indexes - }) - return batch + if self.num_added < self.replay_starts: + return None - def update_priorities(self, batch_indexes, td_errors): + with self.replay_timer: + samples = {} + for policy_id, replay_buffer in self.replay_buffers.items(): + (obses_t, actions, rewards, obses_tp1, dones, weights, + batch_indexes) = replay_buffer.sample( + self.train_batch_size, beta=self.prioritized_replay_beta) + samples[policy_id] = SampleBatch({ + "obs": obses_t, + "actions": actions, + "rewards": rewards, + "new_obs": obses_tp1, + "dones": dones, + "weights": weights, + "batch_indexes": batch_indexes + }) + return MultiAgentBatch(samples, self.train_batch_size) + + def update_priorities(self, prio_dict): with self.update_priorities_timer: - new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps) - self.replay_buffer.update_priorities(batch_indexes, new_priorities) + for policy_id, (batch_indexes, td_errors) in prio_dict.items(): + new_priorities = ( + np.abs(td_errors) + self.prioritized_replay_eps) + self.replay_buffers[policy_id].update_priorities( + batch_indexes, new_priorities) def stats(self, debug=False): stat = { @@ -94,7 +108,10 @@ def stats(self, debug=False): "update_priorities_time_ms": round( 1000 * self.update_priorities_timer.mean, 3), } - stat.update(self.replay_buffer.stats(debug=debug)) + for policy_id, replay_buffer in self.replay_buffers.items(): + stat.update({ + "policy_{}".format(policy_id): replay_buffer.stats(debug=debug) + }) return stat @@ -126,10 +143,16 @@ def step(self): with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: + prio_dict = {} with self.grad_timer: - td_error = self.local_evaluator.compute_apply(replay)[ - "td_error"] - self.outqueue.put((ra, replay, td_error, replay.count)) + grad_out = self.local_evaluator.compute_apply(replay) + for pid, info in grad_out.items(): + prio_dict[pid] = ( + replay.policy_batches[pid]["batch_indexes"], + info["td_error"]) + # send `replay` back also so that it gets released by the original + # thread: https://github.com/ray-project/ray/issues/2610 + self.outqueue.put((ra, replay, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True @@ -267,8 +290,8 @@ def _step(self): with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): - ra, replay, td_error, count = self.learner.outqueue.get() - ra.update_priorities.remote(replay["batch_indexes"], td_error) + ra, _, prio_dict, count = self.learner.outqueue.get() + ra.update_priorities.remote(prio_dict) train_timesteps += count return sample_timesteps, train_timesteps