-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Replay Buffer API and Training Iteration Fn for DQN. #23420
[RLlib] Replay Buffer API and Training Iteration Fn for DQN. #23420
Conversation
…methods, docstrings
minors in other buffer classes
…ayBufferAPI_tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice PR @ArturNiederfahrenhorst , just a few nits left to fix. The regression tests for DQN vs Breakout, comparing execution_plan vs training_iteration look super solid so far.
Co-authored-by: Sven Mika <sven@anyscale.io>
Hey @ArturNiederfahrenhorst , great PR and the benchmarks look really cool! Let's get this merged, but pull from master first, due to some changes made via the @smorad 's PPO PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Only one nit left.
"no_local_replay_buffer": True, | ||
"replay_buffer_config": { | ||
# For now we don't use the new ReplayBuffer API here | ||
"_enable_replay_buffer_api": False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd really love to use the new replay buffer with APEX. What's the blocker here? Do we just need to rewrite the training_iteration
function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As soon as this PR is done I'll chat with Avnish to make sure we are aligned on changes to Ape-X and then do this!
"prioritized_replay_alpha": 0.6, | ||
# Beta parameter for sampling from prioritized replay buffer. | ||
"prioritized_replay_beta": 0.4, | ||
# Epsilon to add to the TD errors when updating priorities. | ||
"prioritized_replay_eps": 1e-6, | ||
# The number of continuous environment steps to replay at once. This may | ||
# be set to greater than 1 to support recurrent models. | ||
"replay_sequence_length": 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I correct in that this is analogous to max_sequence_length
from policy-gradient based policies? It might make sense to have a full_episode
setting for variable-length episodes. The batch can be right zero-padded to the longest episode length in the train batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the replayed sequences is often shorter than replay_sequence_length
.
What you are describing can be accomplished by setting storage_unit="episodes"
or storage_unit=StorageUnit.EPISODES
. Padding of the batch is so far not handled by the buffers, which is open for discussion! I think it should not be handled by buffers, especially so that buffers can be reinstantiated from checkpoints in any setting.
batch_indices = batch_indices.reshape([-1, T])[:, 0] | ||
assert len(batch_indices) == len(td_error) | ||
prio_dict[policy_id] = (batch_indices, td_error) | ||
local_replay_buffer.update_priorities(prio_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does PER sample using recurrent policies? Do you assign a single priority to the entire replay_sequence
, or do you sum up the priorities of a replay_sequence
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The entire sequence!
Today, Priority is is assigned per slot in the buffer.
Depending on the replayed item's unit (timestep, sequence, episode) you choose, one priority applies to one item.
Do you think we should change this? So far this replicates what we have done in the past.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we need to discuss restructuring our agents
folder anyways. I'm actually in favor of separating every single algo, like APEX from DQN, and R2D2 from DQN (have them all as separate algos). APEX and R2D2 have no visibility right now, b/c they are buried inside DQN.
@@ -131,14 +127,14 @@ def validate_config(self, config: TrainerConfigDict) -> None: | |||
# Call super's validation method. | |||
super().validate_config(config) | |||
|
|||
if config["replay_sequence_length"] != -1: | |||
if config["replay_buffer_config"]["replay_sequence_length"] != -1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With your improved ReplayBuffer API I wonder if it makes sense to keep around R2D2
as a separate agent from DQN. AFAIK the only difference is the use of the TD error weighting function h
and LSTM burn-in. If you plug these options into the DQN
config, would you get distributed R2D2 for free via APEX
?
Not worth doing in this PR, but might be worth doing after this is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a very similar opinion.
I believe we should keep our algorithms section slimmer and provide R2D2 as an example script.
Same goes for RNNSAC.
"replay_buffer_config": { | ||
# Use the new ReplayBuffer API here | ||
"_enable_replay_buffer_api": True, | ||
# How many steps of the model to sample before learning starts. | ||
"learning_starts": 1000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the units here? Is a model step equivalent to a timestep?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll specify this!
Thanks for all your awesome comments!
Really cool stuff 💯
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized that this phrase is all over the library.
Moving DQN into the new training iteration API (from execution_plan). First benchmarks indicate at least equal performance on a Breakout Atari task:
Config:
Results:
Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.