-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] ReplayBuffer's storage now signal back when changes happen #614
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome and really quick!
Thanks!
Can we discuss about ReplayBuffer.update_priority
which seems a bit too customize for prioritized sampling?
@@ -28,7 +28,7 @@ def extend(self, index: torch.Tensor) -> None: | |||
pass | |||
|
|||
def update_priority( | |||
self, index: Union[int, torch.Tensor], priority: Union[int, torch.Tensor] | |||
self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx for that :)
@@ -29,6 +29,9 @@ class Storage: | |||
|
|||
def __init__(self, max_size: int) -> None: | |||
self.max_size = int(max_size) | |||
# Prototype feature. RBs that use a given instance of Storage should add | |||
# themselves to this set. | |||
self.attached_entities = set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want this to be public? Is any other object accessing that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. I can rename it.
def __getitem__(self, item): | ||
return self.get(item) | ||
|
||
def __setitem__(self, index, value): | ||
return self.set(index, value) | ||
ret = self.set(index, value) | ||
for i in self.attached_entities: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually (at least in our code base) i
points to an integer.
Can we rename this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure.
Args: | ||
index: The modified index from storage. | ||
""" | ||
return self.update_priority(index, self._sampler.default_priority) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's cool
In the future we will probably want the user to update the priority of the sampler directly, without passing through the replay buffer (otherwise the replay buffer will need to implement custom update methods for all the modules e put in it).
One fix for now could be that the sampler has a mark_update
method that falls back on update_priority
for prioritized sampling. If sampling is uniform, mark_update
is a no-op.
Like this we could remove the ReplayBuffer.update_priority
method altogether (in a future PR).
Would that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense, yes.
I debated with myself for a while if we should attach samplers to storage or if we should attach RBs to storage, and decided for the latter because it is more extensible: It's possible that we need components other than the sampler to be aware of changes, and it's the RB's duty to coordinate the different moving parts.
The counterpoint to this would be YAGNI. It's simpler if we attach the samplers directly. I can definitely make that change.
By the way, another advantage of having the mark_update
method instead of directly calling update_priority
is that samplers can implement the update whenever they feel like it. We could, for example, simply store the modified indexes inside the sampler until the next call to sample, then update everything lazily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I debated with myself for a while if we should attach samplers to storage or if we should attach RBs to storage, and decided for the latter because it is more extensible: It's possible that we need components other than the sampler to be aware of changes, and it's the RB's duty to coordinate the different moving parts.
I fully agree. I still think we should attach the RB to the storage, not all other children (the parent is the minimal sufficient information).
My point is mostly about generality of the name update_priority
in the replay buffer. I would rather name that mark_update
:
class ReplayBuffer:
def mark_update(self, ...):
self.sampler.mark_update(...)
class PrioritizedSampler:
def mark_update(self, ...):
self.update_priority(...)
def update_priority(self, ...):
foo(...)
By the way, another advantage of having the mark_update method instead of directly calling update_priority is that samplers can implement the update whenever they feel like it. We could, for example, simply store the modified indexes inside the sampler until the next call to sample, then update everything lazily.
Yes that would be cool! That would definitely be an advantage (we could also update things in batch when needed, which I would expect to cut the compute time of those ops by a bit)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes. Just made that change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the high-quality work
Can I ask you to create an issue for the lazy implementation of mark_update
-> update_priority
we talked about?
I can assign someone else to take care of it
I think not all tests are running now, probably because you create the PR from your main branch an not a separate branch. Not sure what to do in that case. I will run the tests locally and see if it's all working properly. |
Description
The prototype modular Replay Buffer has a bug (#606) when storage is shared across buffers. This happens because the different parts implicitly assume to always be aware of changes in storage, which was no true. This change introduces an API for storage to signal back to the buffers that changes were made. Buffers can then handle this signal however and whenever they see fit.
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213close #606
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!