-
Notifications
You must be signed in to change notification settings - Fork 260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] ReplayBuffer's storage now signal back when changes happen #614
Changes from 1 commit
64c9af0
fc5552d
47c36a3
58f9049
8f375e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ def extend(self, index: torch.Tensor) -> None: | |
pass | ||
|
||
def update_priority( | ||
self, index: Union[int, torch.Tensor], priority: Union[int, torch.Tensor] | ||
self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thx for that :) |
||
) -> dict: | ||
pass | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,9 @@ class Storage: | |
|
||
def __init__(self, max_size: int) -> None: | ||
self.max_size = int(max_size) | ||
# Prototype feature. RBs that use a given instance of Storage should add | ||
# themselves to this set. | ||
self.attached_entities = set() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want this to be public? Is any other object accessing that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope. I can rename it. |
||
|
||
@abc.abstractmethod | ||
def set(self, cursor: int, data: Any): | ||
|
@@ -38,11 +41,28 @@ def set(self, cursor: int, data: Any): | |
def get(self, index: int) -> Any: | ||
raise NotImplementedError | ||
|
||
def attach(self, buffer: Any) -> None: | ||
"""This function attaches a buffer to this storage. | ||
|
||
Replay Buffers that read from this storage must call this | ||
method to attach themselves. This guarantees that when data | ||
in the storage changes, all relevant pieces of the buffer are | ||
aware of it even if the storage is shared with another buffer | ||
(eg. Priority Samplers). | ||
|
||
Args: | ||
buffer: the object that reads from this storage. | ||
""" | ||
self.attached_entities.add(buffer) | ||
|
||
def __getitem__(self, item): | ||
return self.get(item) | ||
|
||
def __setitem__(self, index, value): | ||
return self.set(index, value) | ||
ret = self.set(index, value) | ||
for i in self.attached_entities: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. usually (at least in our code base) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure. |
||
i.mark_update(index) | ||
return ret | ||
|
||
def __iter__(self): | ||
for i in range(len(self)): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's cool
In the future we will probably want the user to update the priority of the sampler directly, without passing through the replay buffer (otherwise the replay buffer will need to implement custom update methods for all the modules e put in it).
One fix for now could be that the sampler has a
mark_update
method that falls back onupdate_priority
for prioritized sampling. If sampling is uniform,mark_update
is a no-op.Like this we could remove the
ReplayBuffer.update_priority
method altogether (in a future PR).Would that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense, yes.
I debated with myself for a while if we should attach samplers to storage or if we should attach RBs to storage, and decided for the latter because it is more extensible: It's possible that we need components other than the sampler to be aware of changes, and it's the RB's duty to coordinate the different moving parts.
The counterpoint to this would be YAGNI. It's simpler if we attach the samplers directly. I can definitely make that change.
By the way, another advantage of having the
mark_update
method instead of directly callingupdate_priority
is that samplers can implement the update whenever they feel like it. We could, for example, simply store the modified indexes inside the sampler until the next call to sample, then update everything lazily.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fully agree. I still think we should attach the RB to the storage, not all other children (the parent is the minimal sufficient information).
My point is mostly about generality of the name
update_priority
in the replay buffer. I would rather name thatmark_update
:Yes that would be cool! That would definitely be an advantage (we could also update things in batch when needed, which I would expect to cut the compute time of those ops by a bit)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes. Just made that change.