New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Docs and test code for PrioritizedEpisodeReplayBuffer
.
#43458
Docs and test code for PrioritizedEpisodeReplayBuffer
.
#43458
Conversation
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
PrioritizedEpisodeReplayBuffer
.
rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py
Outdated
Show resolved
Hide resolved
rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py
Outdated
Show resolved
Hide resolved
capacity: The total number of timesteps to be storable in this buffer. | ||
Will start ejecting old episodes once this limit is reached. | ||
batch_size_B: The number of rows in a SampleBatch returned from `sample()`. | ||
batch_length_T: The length of each row in a SampleBatch returned from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we specify here, how this arg is related to n-step? My understanding is: Not at all, correct?
So, batch_length_T is the same as in DreamerV3, meaning if >1, then my returned batch will contain B rows, each row consisting of T consecutive-in-an-episode(!) n-step tuples, correct?
If so, can we give these details here in the description of the arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ar rhis moment batch_öength_T=1
is the only possible option. As this is a method inherited from EpisodeReplayBuffer
this argument has to be there.
@@ -259,8 +371,8 @@ def sample( | |||
next_observations = [[] for _ in range(batch_size_B)] | |||
actions = [[] for _ in range(batch_size_B)] | |||
rewards = [[] for _ in range(batch_size_B)] | |||
is_terminated = [[False] * batch_length_T for _ in range(batch_size_B)] | |||
is_truncated = [[False] * batch_length_T for _ in range(batch_size_B)] | |||
is_terminated = [[False] for _ in range(batch_size_B)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain, why this fix? B/c a returned batch only has one terminated/truncated flag per batch row?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mainly because each batch row has only a single tuple. See my comment above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a few questions and nits to fix before we can merge.
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
…om:simonsays1980/ray into docs-for-prioritized-episode-replay-buffer Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Why are these changes needed?
We move from plain timesteps to episodes in the buffer. This needs some rigorous documentation with appropriate test code to show users how to use a class.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.