Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Replay Buffer API and Training Iteration Fn for DQN. #23420

Merged
merged 134 commits into from
Apr 18, 2022

Conversation

ArturNiederfahrenhorst
Copy link
Contributor

@ArturNiederfahrenhorst ArturNiederfahrenhorst commented Mar 23, 2022

Moving DQN into the new training iteration API (from execution_plan). First benchmarks indicate at least equal performance on a Breakout Atari task:

Config:

atari-basic-dqn:
    env: BreakoutNoFrameskip-v4
    run: DQN
    config:
        # Works for both torch and tf.
        framework: tf
        double_q: false
        dueling: false
        num_atoms: 1
        noisy: false
        prioritized_replay: false
        n_step: 1
        target_network_update_freq: 8000
        lr: .0000625
        adam_epsilon: .00015
        hiddens: [512]
        learning_starts: 20000

        replay_buffer_config:
            capacity: 1000000
            prioritized_replay_alpha: 0.5

        rollout_fragment_length: 4
        train_batch_size: 32
        exploration_config:
          epsilon_timesteps: 200000
          final_epsilon: 0.01
        num_gpus: 1
        timesteps_per_iteration: 10000

        _disable_execution_plan_api:
            grid_search: [true, false]

Results:

Current time: 2022-04-08 03:39:52 (running for 01:45:05.95)
Memory usage on this node: 159.2/239.9 GiB
Using FIFO scheduling algorithm.
Resources requested: 2.0/32 CPUs, 2.0/4 GPUs, 0.0/147.65 GiB heap, 0.0/67.27 GiB objects
Result logdir: /home/ray/ray_results/atari-basic-dqn
Number of trials: 2/2 (2 RUNNING)
+----------------------------------------+----------+-------------------+-------------------------------+--------+------------------+--------+----------+----------------------+----------------------+--------------------+
| Trial name                             | status   | loc               | _disable_execution_plan_api   |   iter |   total time (s) |     ts |   reward |   episode_reward_max |   episode_reward_min |   episode_len_mean |
|----------------------------------------+----------+-------------------+-------------------------------+--------+------------------+--------+----------+----------------------+----------------------+--------------------|
| DQN_BreakoutNoFrameskip-v4_8a7c0_00000 | RUNNING  | 10.0.67.156:42378 | True                          |     69 |          6243.36 | 690000 |    53.86 |                  242 |                   17 |            4805.07 |
| DQN_BreakoutNoFrameskip-v4_8a7c0_00001 | RUNNING  | 10.0.67.156:42377 | False                         |     68 |          6255.12 | 680000 |    41.1  |                   71 |                   18 |            4323.95 |
+----------------------------------------+----------+-------------------+-------------------------------+--------+------------------+--------+----------+----------------------+----------------------+--------------------+

Why are these changes needed?

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice PR @ArturNiederfahrenhorst , just a few nits left to fix. The regression tests for DQN vs Breakout, comparing execution_plan vs training_iteration look super solid so far.

@sven1977
Copy link
Contributor

Hey @ArturNiederfahrenhorst , great PR and the benchmarks look really cool!

Let's get this merged, but pull from master first, due to some changes made via the @smorad 's PPO PR.
There are some tests failing. Could be related to this, but not sure (I think they were failing already before).

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Only one nit left.

rllib/agents/dqn/dqn.py Show resolved Hide resolved
rllib/agents/dqn/simple_q_tf_policy.py Show resolved Hide resolved
"no_local_replay_buffer": True,
"replay_buffer_config": {
# For now we don't use the new ReplayBuffer API here
"_enable_replay_buffer_api": False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd really love to use the new replay buffer with APEX. What's the blocker here? Do we just need to rewrite the training_iteration function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As soon as this PR is done I'll chat with Avnish to make sure we are aligned on changes to Ape-X and then do this!

"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# The number of continuous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I correct in that this is analogous to max_sequence_length from policy-gradient based policies? It might make sense to have a full_episode setting for variable-length episodes. The batch can be right zero-padded to the longest episode length in the train batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the replayed sequences is often shorter than replay_sequence_length.
What you are describing can be accomplished by setting storage_unit="episodes" or storage_unit=StorageUnit.EPISODES. Padding of the batch is so far not handled by the buffers, which is open for discussion! I think it should not be handled by buffers, especially so that buffers can be reinstantiated from checkpoints in any setting.

batch_indices = batch_indices.reshape([-1, T])[:, 0]
assert len(batch_indices) == len(td_error)
prio_dict[policy_id] = (batch_indices, td_error)
local_replay_buffer.update_priorities(prio_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does PER sample using recurrent policies? Do you assign a single priority to the entire replay_sequence, or do you sum up the priorities of a replay_sequence?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire sequence!
Today, Priority is is assigned per slot in the buffer.
Depending on the replayed item's unit (timestep, sequence, episode) you choose, one priority applies to one item.
Do you think we should change this? So far this replicates what we have done in the past.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we need to discuss restructuring our agents folder anyways. I'm actually in favor of separating every single algo, like APEX from DQN, and R2D2 from DQN (have them all as separate algos). APEX and R2D2 have no visibility right now, b/c they are buried inside DQN.

@@ -131,14 +127,14 @@ def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["replay_sequence_length"] != -1:
if config["replay_buffer_config"]["replay_sequence_length"] != -1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With your improved ReplayBuffer API I wonder if it makes sense to keep around R2D2 as a separate agent from DQN. AFAIK the only difference is the use of the TD error weighting function h and LSTM burn-in. If you plug these options into the DQN config, would you get distributed R2D2 for free via APEX?

Not worth doing in this PR, but might be worth doing after this is merged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a very similar opinion.
I believe we should keep our algorithms section slimmer and provide R2D2 as an example script.
Same goes for RNNSAC.

"replay_buffer_config": {
# Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
# How many steps of the model to sample before learning starts.
"learning_starts": 1000,
Copy link
Contributor

@smorad smorad Apr 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the units here? Is a model step equivalent to a timestep?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll specify this!
Thanks for all your awesome comments!
Really cool stuff 💯

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that this phrase is all over the library.

@sven1977 sven1977 merged commit e57ce7e into ray-project:master Apr 18, 2022
@ArturNiederfahrenhorst ArturNiederfahrenhorst deleted the ReplayBufferAPI_DQN branch April 24, 2022 16:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants