Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] ReplayBuffer's storage now signal back when changes happen #614

Merged
merged 5 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 49 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
TensorDictPrioritizedReplayBuffer,
writers,
)
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.storages import (
LazyMemmapStorage,
LazyTensorStorage,
ListStorage,
)
from torchrl.data.replay_buffers.writers import RoundRobinWriter
from torchrl.data.tensordict.tensordict import assert_allclose_td, TensorDictBase


Expand Down Expand Up @@ -542,6 +544,53 @@ def test_rb_trajectories(stack):
sampled_td_filtered.batch_size = [3, 4]


def test_shared_storage_prioritized_sampler():

n = 100

storage = LazyMemmapStorage(n)
writer = RoundRobinWriter()
sampler0 = RandomSampler()
sampler1 = PrioritizedSampler(max_capacity=n, alpha=0.7, beta=1.1)

rb0 = rb_prototype.ReplayBuffer(
storage=storage, writer=writer, sampler=sampler0, collate_fn=lambda x: x
)
rb1 = rb_prototype.ReplayBuffer(
storage=storage, writer=writer, sampler=sampler1, collate_fn=lambda x: x
)

data = TensorDict({"a": torch.arange(50)}, [50])

# Extend rb0. rb1 should be aware of changes to storage.
rb0.extend(data)

assert len(rb0) == 50
assert len(storage) == 50
assert len(rb1) == 50

rb0.sample(10)
rb1.sample(10)

assert rb1._sampler._sum_tree.query(0, 10) == 10
assert rb1._sampler._sum_tree.query(0, 50) == 50
assert rb1._sampler._sum_tree.query(0, 70) == 50


def test_legacy_rb_does_not_attach():
n = 10
storage = LazyMemmapStorage(n)
writer = RoundRobinWriter()
sampler = RandomSampler()
rb = ReplayBuffer(storage=storage, size=n, prefetch=0, collate_fn=lambda x: x)
prb = rb_prototype.ReplayBuffer(
storage=storage, writer=writer, sampler=sampler, collate_fn=lambda x: x
)

assert rb not in storage.attached_entities
assert prb in storage.attached_entities


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
12 changes: 12 additions & 0 deletions torchrl/data/replay_buffers/rb_prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
prefetch: Optional[int] = None,
) -> None:
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
self._storage.attach(self)
self._sampler = sampler if sampler is not None else RandomSampler()
self._writer = writer if writer is not None else RoundRobinWriter()
self._writer.register_storage(self._storage)
Expand Down Expand Up @@ -155,6 +156,17 @@ def sample(self, batch_size: int) -> Tuple[Any, dict]:

return ret

def mark_update(self, index) -> None:
"""Marks a given storage index as having changed.

Derived classes can deal with this however appropriate,
forwarding this call to whichever parts are needed.

Args:
index: The modified index from storage.
"""
return self.update_priority(index, self._sampler.default_priority)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's cool
In the future we will probably want the user to update the priority of the sampler directly, without passing through the replay buffer (otherwise the replay buffer will need to implement custom update methods for all the modules e put in it).

One fix for now could be that the sampler has a mark_update method that falls back on update_priority for prioritized sampling. If sampling is uniform, mark_update is a no-op.
Like this we could remove the ReplayBuffer.update_priority method altogether (in a future PR).

Would that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense, yes.

I debated with myself for a while if we should attach samplers to storage or if we should attach RBs to storage, and decided for the latter because it is more extensible: It's possible that we need components other than the sampler to be aware of changes, and it's the RB's duty to coordinate the different moving parts.

The counterpoint to this would be YAGNI. It's simpler if we attach the samplers directly. I can definitely make that change.

By the way, another advantage of having the mark_update method instead of directly calling update_priority is that samplers can implement the update whenever they feel like it. We could, for example, simply store the modified indexes inside the sampler until the next call to sample, then update everything lazily.

Copy link
Contributor

@vmoens vmoens Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I debated with myself for a while if we should attach samplers to storage or if we should attach RBs to storage, and decided for the latter because it is more extensible: It's possible that we need components other than the sampler to be aware of changes, and it's the RB's duty to coordinate the different moving parts.

I fully agree. I still think we should attach the RB to the storage, not all other children (the parent is the minimal sufficient information).

My point is mostly about generality of the name update_priority in the replay buffer. I would rather name that mark_update:

class ReplayBuffer:
    def mark_update(self, ...):
        self.sampler.mark_update(...)

class PrioritizedSampler:
    def mark_update(self, ...):
        self.update_priority(...)
    def update_priority(self, ...):
        foo(...)

By the way, another advantage of having the mark_update method instead of directly calling update_priority is that samplers can implement the update whenever they feel like it. We could, for example, simply store the modified indexes inside the sampler until the next call to sample, then update everything lazily.

Yes that would be cool! That would definitely be an advantage (we could also update things in batch when needed, which I would expect to cut the compute time of those ops by a bit)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes. Just made that change.



class TensorDictReplayBuffer(ReplayBuffer):
"""TensorDict-specific wrapper around the ReplayBuffer class.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def extend(self, index: torch.Tensor) -> None:
pass

def update_priority(
self, index: Union[int, torch.Tensor], priority: Union[int, torch.Tensor]
self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx for that :)

) -> dict:
pass

Expand Down
22 changes: 21 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class Storage:

def __init__(self, max_size: int) -> None:
self.max_size = int(max_size)
# Prototype feature. RBs that use a given instance of Storage should add
# themselves to this set.
self.attached_entities = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want this to be public? Is any other object accessing that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. I can rename it.


@abc.abstractmethod
def set(self, cursor: int, data: Any):
Expand All @@ -38,11 +41,28 @@ def set(self, cursor: int, data: Any):
def get(self, index: int) -> Any:
raise NotImplementedError

def attach(self, buffer: Any) -> None:
"""This function attaches a buffer to this storage.

Replay Buffers that read from this storage must call this
method to attach themselves. This guarantees that when data
in the storage changes, all relevant pieces of the buffer are
aware of it even if the storage is shared with another buffer
(eg. Priority Samplers).

Args:
buffer: the object that reads from this storage.
"""
self.attached_entities.add(buffer)

def __getitem__(self, item):
return self.get(item)

def __setitem__(self, index, value):
return self.set(index, value)
ret = self.set(index, value)
for i in self.attached_entities:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually (at least in our code base) i points to an integer.
Can we rename this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.

i.mark_update(index)
return ret

def __iter__(self):
for i in range(len(self)):
Expand Down