Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] R2D2 Implementation. #13933

Merged
merged 34 commits into from
Feb 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5e7b0cf
wip.
sven1977 Feb 4, 2021
0ae233e
wip.
sven1977 Feb 5, 2021
554fa79
wip.
sven1977 Feb 5, 2021
0bde31a
Merge branch 'master' of https://github.com/ray-project/ray into r2d2…
sven1977 Feb 9, 2021
863eda9
wip.
sven1977 Feb 9, 2021
7762d4c
wip.
sven1977 Feb 9, 2021
b86e01f
wip.
sven1977 Feb 10, 2021
dde42b3
Merge branch 'master' of https://github.com/ray-project/ray into r2d2…
sven1977 Feb 10, 2021
21e1679
Merge branch 'master' of https://github.com/ray-project/ray into r2d2…
sven1977 Feb 10, 2021
6762b62
wip.
sven1977 Feb 11, 2021
6cd1d67
wip.
sven1977 Feb 11, 2021
3feaff7
wip.
sven1977 Feb 12, 2021
e2e9784
Test case is learning CartPole-v0.
sven1977 Feb 16, 2021
352454a
Learning tf CartPole!
sven1977 Feb 16, 2021
0ceb56c
Merge branch 'master' of https://github.com/ray-project/ray into r2d2…
sven1977 Feb 16, 2021
e073e69
Compilation test case passing also for tf-eager.
sven1977 Feb 16, 2021
6a85845
wip.
sven1977 Feb 16, 2021
51c6d38
Merge branch 'master' of https://github.com/ray-project/ray into r2d2…
sven1977 Feb 17, 2021
fbf63ae
fix
sven1977 Feb 17, 2021
29e9fea
Merge branch 'master' of https://github.com/ray-project/ray into r2d2…
sven1977 Feb 21, 2021
ad51946
wip.
sven1977 Feb 21, 2021
b68da54
wip.
sven1977 Feb 21, 2021
88ad777
LINT.
sven1977 Feb 22, 2021
6e8cb0f
LINT.
sven1977 Feb 22, 2021
3fa6335
Merge branch 'master' of https://github.com/ray-project/ray into r2d2…
sven1977 Feb 22, 2021
5c46b19
wip.
sven1977 Feb 22, 2021
44f3ae0
wip.
sven1977 Feb 22, 2021
8cfd107
wip.
sven1977 Feb 22, 2021
643d719
wip.
sven1977 Feb 22, 2021
2dcb5e2
Merge branch 'master' of https://github.com/ray-project/ray into r2d2…
sven1977 Feb 24, 2021
3774987
wip and LINT.
sven1977 Feb 24, 2021
982df1f
wip and LINT.
sven1977 Feb 24, 2021
7c2de5f
fix.
sven1977 Feb 25, 2021
f8c13c0
fix.
sven1977 Feb 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi-
`MBMPO`_ torch No **Yes** No
`PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_
`PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_
`R2D2`_ tf + torch **Yes** `+parametric`_ No **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+autoreg`_
`SAC`_ tf + torch **Yes** **Yes** **Yes**
`SlateQ`_ torch **Yes** No No
`LinUCB`_, `LinTS`_ torch **Yes** `+parametric`_ No **Yes**
Expand Down Expand Up @@ -323,6 +324,18 @@ SpaceInvaders 650 1001 1025
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__


.. _r2d2:

Recurrent Replay Distributed DQN (R2D2)
---------------------------------------
|pytorch| |tensorflow|
`[paper] <https://openreview.net/pdf?id=r1lyTjAqYX>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/r2d2.py>`__
R2D2 can be scaled by increasing the number of workers. All of the DQN improvements evaluated in `Rainbow <https://arxiv.org/abs/1710.02298>`__ are available, though not all are enabled by default.

Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/dqn/cartpole-r2d2.yaml>`__
michaelzhiluo marked this conversation as resolved.
Show resolved Hide resolved


.. _pg:

Policy Gradients
Expand Down
4 changes: 2 additions & 2 deletions doc/source/rllib-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ Now let's look at each PPO policy definition:
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
loss_fn=ppo_surrogate_loss,
stats_fn=kl_and_loss_stats,
extra_action_fetches_fn=vf_preds_and_logits_fetches,
extra_action_out_fn=vf_preds_and_logits_fetches,
postprocess_fn=postprocess_ppo_gae,
gradients_fn=clip_gradients,
before_loss_init=setup_mixins,
Expand Down Expand Up @@ -363,7 +363,7 @@ Let's look at how to implement a different family of policies, by looking at the
action_sampler_fn=build_action_sampler,
loss_fn=build_q_losses,
extra_action_feed_fn=exploration_setting_inputs,
extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values},
extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
before_init=setup_early_mixins,
after_init=setup_late_mixins,
Expand Down
4 changes: 3 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ py_test(
args = ["--yaml-dir=tuned_examples/ppo", "--framework=torch"]
)

# DQN/Simple-Q
# DQN/Simple-Q/R2D2
py_test(
name = "run_regression_tests_cartpole_dqn_tf",
main = "tests/run_regression_tests.py",
Expand All @@ -174,6 +174,7 @@ py_test(
"tuned_examples/dqn/cartpole-dqn.yaml",
"tuned_examples/dqn/cartpole-dqn-softq.yaml",
"tuned_examples/dqn/cartpole-dqn-param-noise.yaml",
"tuned_examples/dqn/cartpole-r2d2.yaml",
],
args = ["--yaml-dir=tuned_examples/dqn"]
)
Expand All @@ -189,6 +190,7 @@ py_test(
"tuned_examples/dqn/cartpole-dqn.yaml",
"tuned_examples/dqn/cartpole-dqn-softq.yaml",
"tuned_examples/dqn/cartpole-dqn-param-noise.yaml",
"tuned_examples/dqn/cartpole-r2d2.yaml",
],
args = ["--yaml-dir=tuned_examples/dqn", "--framework=torch"]
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/a3c/a3c_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,6 @@ def setup_mixins(policy, obs_space, action_space, config):
grad_stats_fn=grad_stats,
gradients_fn=clip_gradients,
postprocess_fn=compute_gae_for_sample_batch,
extra_action_fetches_fn=add_value_function_fetch,
extra_action_out_fn=add_value_function_fetch,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin, LearningRateSchedule])
6 changes: 6 additions & 0 deletions rllib/agents/dqn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.agents.dqn.r2d2 import R2D2Trainer, DEFAULT_CONFIG as \
R2D2_DEFAULT_CONFIG
from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy
from ray.rllib.agents.dqn.simple_q import SimpleQTrainer, \
DEFAULT_CONFIG as SIMPLE_Q_DEFAULT_CONFIG
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
Expand All @@ -13,6 +16,9 @@
"DQNTorchPolicy",
"DQNTrainer",
"DEFAULT_CONFIG",
"R2D2TorchPolicy",
"R2D2Trainer",
"R2D2_DEFAULT_CONFIG",
"SIMPLE_Q_DEFAULT_CONFIG",
"SimpleQTFPolicy",
"SimpleQTorchPolicy",
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def apex_execution_plan(workers: WorkerSet,
config["prioritized_replay_beta"],
config["prioritized_replay_eps"],
config["multiagent"]["replay_mode"],
config["replay_sequence_length"],
config.get("replay_sequence_length", 1),
], num_replay_buffer_shards)

# Start the learner thread.
Expand Down
8 changes: 7 additions & 1 deletion rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": 50000,
# The number of contiguous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
# If True prioritized replay buffer will be used.
"prioritized_replay": True,
# Alpha parameter for prioritized replay buffer.
Expand All @@ -94,6 +97,7 @@
"prioritized_replay_beta_annealing_timesteps": 20000,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,

# Whether to LZ4 compress observations
"compress_observations": False,
# Callback to run before learning on a multi-agent batch of experiences.
Expand Down Expand Up @@ -194,7 +198,9 @@ def execution_plan(workers: WorkerSet,
buffer_size=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config["replay_sequence_length"],
replay_sequence_length=config.get("replay_sequence_length", 1),
replay_burn_in=config.get("burn_in", 0),
replay_zero_init_states=config.get("zero_init_states", True),
**prio_args)

rollouts = ParallelRollouts(workers, mode="bulk_sync")
Expand Down
44 changes: 26 additions & 18 deletions rllib/agents/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tensorflow policy class used for DQN"""
"""TensorFlow policy class used for DQN"""

from typing import Dict

Expand Down Expand Up @@ -215,7 +215,8 @@ def get_distribution_inputs_and_class(policy: Policy,
*,
explore=True,
**kwargs):
q_vals = compute_q_values(policy, model, obs_batch, explore)
q_vals = compute_q_values(
policy, model, {"obs": obs_batch}, state_batches=None, explore=explore)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

policy.q_values = q_vals
Expand All @@ -237,19 +238,20 @@ def build_q_losses(policy: Policy, model, _,
"""
config = policy.config
# q network evaluation
q_t, q_logits_t, q_dist_t = compute_q_values(
q_t, q_logits_t, q_dist_t, _ = compute_q_values(
policy,
policy.q_model,
train_batch[SampleBatch.CUR_OBS],
policy.q_model, {"obs": train_batch[SampleBatch.CUR_OBS]},
state_batches=None,
explore=False)

# target q network evalution
q_tp1, q_logits_tp1, q_dist_tp1 = compute_q_values(
q_tp1, q_logits_tp1, q_dist_tp1, _ = compute_q_values(
policy,
policy.target_q_model,
train_batch[SampleBatch.NEXT_OBS],
policy.target_q_model, {"obs": train_batch[SampleBatch.NEXT_OBS]},
state_batches=None,
explore=False)
policy.target_q_func_vars = policy.target_q_model.variables()
if not hasattr(policy, "target_q_func_vars"):
policy.target_q_func_vars = policy.target_q_model.variables()

# q scores for actions which we know were selected in the given state.
one_hot_selection = tf.one_hot(
Expand All @@ -262,9 +264,10 @@ def build_q_losses(policy: Policy, model, _,
# compute estimate of best possible value starting from state at t + 1
if config["double_q"]:
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
q_dist_tp1_using_online_net = compute_q_values(
q_dist_tp1_using_online_net, _ = compute_q_values(
policy, policy.q_model,
train_batch[SampleBatch.NEXT_OBS],
{"obs": train_batch[SampleBatch.NEXT_OBS]},
state_batches=None,
explore=False)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net,
Expand Down Expand Up @@ -329,13 +332,18 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)


def compute_q_values(policy: Policy, model: ModelV2, obs: TensorType, explore):
def compute_q_values(policy: Policy,
model: ModelV2,
input_dict,
state_batches=None,
seq_lens=None,
explore=None,
is_training: bool = False):

config = policy.config

model_out, state = model({
SampleBatch.CUR_OBS: obs,
"is_training": policy._get_is_training_placeholder(),
}, [], None)
input_dict["is_training"] = policy._get_is_training_placeholder()
model_out, state = model(input_dict, state_batches or [], seq_lens)

if config["num_atoms"] > 1:
(action_scores, z, support_logits_per_action, logits,
Expand Down Expand Up @@ -368,7 +376,7 @@ def compute_q_values(policy: Policy, model: ModelV2, obs: TensorType, explore):
else:
value = action_scores

return value, logits, dist
return value, logits, dist, state


def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
Expand Down Expand Up @@ -433,7 +441,7 @@ def postprocess_nstep_and_prio(policy: Policy,
postprocess_fn=postprocess_nstep_and_prio,
optimizer_fn=adam_optimizer,
gradients_fn=clip_gradients,
extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values},
extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
michaelzhiluo marked this conversation as resolved.
Show resolved Hide resolved
extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
before_init=setup_early_mixins,
before_loss_init=setup_mid_mixins,
Expand Down
34 changes: 18 additions & 16 deletions rllib/agents/dqn/dqn_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,11 @@ def get_distribution_inputs_and_class(
explore: bool = True,
is_training: bool = False,
**kwargs) -> Tuple[TensorType, type, List[TensorType]]:
q_vals = compute_q_values(policy, model, obs_batch, explore, is_training)
q_vals = compute_q_values(
policy,
model, {"obs": obs_batch},
explore=explore,
is_training=is_training)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

policy.q_values = q_vals
Expand All @@ -235,18 +239,16 @@ def build_q_losses(policy: Policy, model, _,
"""
config = policy.config
# Q-network evaluation.
q_t, q_logits_t, q_probs_t = compute_q_values(
q_t, q_logits_t, q_probs_t, _ = compute_q_values(
policy,
policy.q_model,
train_batch[SampleBatch.CUR_OBS],
policy.q_model, {"obs": train_batch[SampleBatch.CUR_OBS]},
explore=False,
is_training=True)

# Target Q-network evaluation.
q_tp1, q_logits_tp1, q_probs_tp1 = compute_q_values(
q_tp1, q_logits_tp1, q_probs_tp1, _ = compute_q_values(
policy,
policy.target_q_model,
train_batch[SampleBatch.NEXT_OBS],
policy.target_q_model, {"obs": train_batch[SampleBatch.NEXT_OBS]},
explore=False,
is_training=True)

Expand All @@ -263,10 +265,10 @@ def build_q_losses(policy: Policy, model, _,
# compute estimate of best possible value starting from state at t + 1
if config["double_q"]:
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
q_dist_tp1_using_online_net = compute_q_values(
q_dist_tp1_using_online_net, _ = compute_q_values(
policy,
policy.q_model,
train_batch[SampleBatch.NEXT_OBS],
{"obs": train_batch[SampleBatch.NEXT_OBS]},
explore=False,
is_training=True)
q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1)
Expand Down Expand Up @@ -327,15 +329,15 @@ def before_loss_init(policy: Policy, obs_space: gym.spaces.Space,

def compute_q_values(policy: Policy,
model: ModelV2,
obs: TensorType,
explore,
input_dict,
state_batches=None,
seq_lens=None,
explore=None,
is_training: bool = False):
config = policy.config

model_out, state = model({
SampleBatch.CUR_OBS: obs,
"is_training": is_training,
}, [], None)
input_dict["is_training"] = is_training
model_out, state = model(input_dict, state_batches or [], seq_lens)

if config["num_atoms"] > 1:
(action_scores, z, support_logits_per_action, logits,
Expand Down Expand Up @@ -367,7 +369,7 @@ def compute_q_values(policy: Policy,
else:
value = action_scores

return value, logits, probs_or_logits
return value, logits, probs_or_logits, state


def grad_process_and_td_error_fn(policy: Policy,
Expand Down
Loading