diff --git a/README.md b/README.md index e12592a9838..4ec18685919 100644 --- a/README.md +++ b/README.md @@ -212,7 +212,6 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) scratch_dir="/tmp/" ) buffer = TensorDictPrioritizedReplayBuffer( - buffer_size=10000, alpha=0.7, beta=0.5, collate_fn=lambda x: x, diff --git a/benchmarks/storage/benchmark_sample_latency_over_rpc.py b/benchmarks/storage/benchmark_sample_latency_over_rpc.py index 6bd1619419c..be1055e8b8a 100644 --- a/benchmarks/storage/benchmark_sample_latency_over_rpc.py +++ b/benchmarks/storage/benchmark_sample_latency_over_rpc.py @@ -24,7 +24,7 @@ import torch import torch.distributed.rpc as rpc from tensordict import TensorDict -from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer +from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.data.replay_buffers.storages import ( LazyMemmapStorage, @@ -92,10 +92,10 @@ def train(self, batch_size: int) -> None: if self._ret is None: self._ret = ret else: - self._ret[0].update_(ret[0]) + self._ret.update_(ret) # make sure the content is read - self._ret[0]["observation"] + 1 - self._ret[0]["next_observation"] + 1 + self._ret["observation"] + 1 + self._ret["next_observation"] + 1 return timeit.default_timer() - start_time def _create_replay_buffer(self) -> rpc.RRef: diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index d709abf6cf3..6ed97a74466 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -19,10 +19,10 @@ widely used replay buffers: TensorDictReplayBuffer TensorDictPrioritizedReplayBuffer -Composable Replay Buffers (Prototype) +Composable Replay Buffers ------------------------------------- -We also provide a prototyped composable replay buffer. +We also give users the ability to compose a replay buffer using the following components: .. autosummary:: :toctree: generated/ @@ -30,9 +30,6 @@ We also provide a prototyped composable replay buffer. .. currentmodule:: torchrl.data.replay_buffers - torchrl.data.replay_buffers.rb_prototype.ReplayBuffer - torchrl.data.replay_buffers.rb_prototype.TensorDictReplayBuffer - torchrl.data.replay_buffers.rb_prototype.RemoteTensorDictReplayBuffer torchrl.data.replay_buffers.samplers.Sampler torchrl.data.replay_buffers.samplers.RandomSampler torchrl.data.replay_buffers.samplers.PrioritizedSampler diff --git a/examples/distributed/distributed_replay_buffer.py b/examples/distributed/distributed_replay_buffer.py index b36e11a625e..c228e416a7f 100644 --- a/examples/distributed/distributed_replay_buffer.py +++ b/examples/distributed/distributed_replay_buffer.py @@ -16,7 +16,7 @@ import torch import torch.distributed.rpc as rpc from tensordict import TensorDict -from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer +from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.data.replay_buffers.utils import accept_remote_rref_invocation diff --git a/test/test_rb.py b/test/test_rb.py index f6a8316d72a..1d2967e2b6d 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -14,18 +14,20 @@ from _utils_internal import get_available_devices from tensordict.prototype import is_tensorclass, tensorclass from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase -from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer, TensorDictReplayBuffer -from torchrl.data.replay_buffers import ( - rb_prototype, - samplers, +from torchrl.data import ( + PrioritizedReplayBuffer, + RemoteTensorDictReplayBuffer, + ReplayBuffer, TensorDictPrioritizedReplayBuffer, - writers, + TensorDictReplayBuffer, ) +from torchrl.data.replay_buffers import samplers, writers from torchrl.data.replay_buffers.samplers import ( PrioritizedSampler, RandomSampler, SamplerWithoutReplacement, ) + from torchrl.data.replay_buffers.storages import ( LazyMemmapStorage, LazyTensorStorage, @@ -60,9 +62,9 @@ @pytest.mark.parametrize( "rb_type", [ - rb_prototype.ReplayBuffer, - rb_prototype.TensorDictReplayBuffer, - rb_prototype.RemoteTensorDictReplayBuffer, + ReplayBuffer, + TensorDictReplayBuffer, + RemoteTensorDictReplayBuffer, ], ) @pytest.mark.parametrize( @@ -87,11 +89,10 @@ def _get_rb(self, rb_type, size, sampler, writer, storage): return rb def _get_datum(self, rb_type): - if rb_type is rb_prototype.ReplayBuffer: + if rb_type is ReplayBuffer: data = torch.randint(100, (1,)) elif ( - rb_type is rb_prototype.TensorDictReplayBuffer - or rb_type is rb_prototype.RemoteTensorDictReplayBuffer + rb_type is TensorDictReplayBuffer or rb_type is RemoteTensorDictReplayBuffer ): data = TensorDict({"a": torch.randint(100, (1,))}, []) else: @@ -99,11 +100,10 @@ def _get_datum(self, rb_type): return data def _get_data(self, rb_type, size): - if rb_type is rb_prototype.ReplayBuffer: + if rb_type is ReplayBuffer: data = torch.randint(100, (size, 1)) elif ( - rb_type is rb_prototype.TensorDictReplayBuffer - or rb_type is rb_prototype.RemoteTensorDictReplayBuffer + rb_type is TensorDictReplayBuffer or rb_type is RemoteTensorDictReplayBuffer ): data = TensorDict( { @@ -298,7 +298,7 @@ def test_set_tensorclass(self, max_size, shape, storage): def test_prototype_prb(priority_key, contiguous, device): torch.manual_seed(0) np.random.seed(0) - rb = rb_prototype.TensorDictReplayBuffer( + rb = TensorDictReplayBuffer( sampler=samplers.PrioritizedSampler(5, alpha=0.7, beta=0.9), priority_key=priority_key, ) @@ -311,7 +311,7 @@ def test_prototype_prb(priority_key, contiguous, device): batch_size=[3], ).to(device) rb.extend(td1) - s, _ = rb.sample(2) + s = rb.sample(2) assert s.batch_size == torch.Size( [ 2, @@ -330,7 +330,7 @@ def test_prototype_prb(priority_key, contiguous, device): batch_size=[5], ).to(device) rb.extend(td2) - s, _ = rb.sample(5) + s = rb.sample(5) assert s.batch_size == torch.Size([5]) assert (td2[s.get("_idx").squeeze()].get("a") == s.get("a")).all() assert_allclose_td(td2[s.get("_idx").squeeze()].select("a"), s.select("a")) @@ -353,18 +353,18 @@ def test_prototype_prb(priority_key, contiguous, device): idx0 = s.get("_idx")[0] rb.update_tensordict_priority(s) - s, _ = rb.sample(5) + s = rb.sample(5) assert (val == s.get("a")).sum() >= 1 torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) # test updating values of original td td2.set_("a", torch.ones_like(td2.get("a"))) - s, _ = rb.sample(5) + s = rb.sample(5) torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) @pytest.mark.parametrize("stack", [False, True]) -def test_rb_prototype_trajectories(stack): +def test_replay_buffer_trajectories(stack): traj_td = TensorDict( {"obs": torch.randn(3, 4, 5), "actions": torch.randn(3, 4, 2)}, batch_size=[3, 4], @@ -372,7 +372,7 @@ def test_rb_prototype_trajectories(stack): if stack: traj_td = torch.stack([td.to_tensordict() for td in traj_td], 0) - rb = rb_prototype.TensorDictReplayBuffer( + rb = TensorDictReplayBuffer( sampler=samplers.PrioritizedSampler( 5, alpha=0.7, @@ -381,10 +381,10 @@ def test_rb_prototype_trajectories(stack): priority_key="td_error", ) rb.extend(traj_td) - sampled_td, _ = rb.sample(3) + sampled_td = rb.sample(3) sampled_td.set("td_error", torch.rand(3)) rb.update_tensordict_priority(sampled_td) - sampled_td, _ = rb.sample(3, include_info=True) + sampled_td = rb.sample(3, include_info=True) assert (sampled_td.get("_weight") > 0).all() assert sampled_td.batch_size == torch.Size([3]) @@ -433,7 +433,7 @@ def _get_rb(self, rbtype, size, storage, prefetch): params = self._default_params_td_prb else: raise NotImplementedError(rbtype) - rb = rbtype(size=size, storage=storage, prefetch=prefetch, **params) + rb = rbtype(storage=storage, prefetch=prefetch, **params) return rb def _get_datum(self, rbtype): @@ -481,17 +481,17 @@ def test_cursor_position2(self, rbtype, storage, size, prefetch): rb.extend(batch1) # Added less data than storage max size - if size > 5: - assert rb._cursor == 5 + if size > 5 or storage is None: + assert rb._writer._cursor == 5 # Added more data than storage max size elif size < 5: - assert rb._cursor == 5 - size + assert rb._writer._cursor == 5 - size # Added as data as storage max size else: - assert rb._cursor == 0 + assert rb._writer._cursor == 0 batch2 = self._get_data(rbtype, size=size - 1) rb.extend(batch2) - assert rb._cursor == size - 1 + assert rb._writer._cursor == size - 1 def test_add(self, rbtype, storage, size, prefetch): torch.manual_seed(0) @@ -575,10 +575,10 @@ def test_prb(priority_key, contiguous, device): torch.manual_seed(0) np.random.seed(0) rb = TensorDictPrioritizedReplayBuffer( - 5, alpha=0.7, beta=0.9, priority_key=priority_key, + storage=ListStorage(5), ) td1 = TensorDict( source={ @@ -630,7 +630,7 @@ def test_prb(priority_key, contiguous, device): val = s.get("a")[0] idx0 = s.get("_idx")[0] - rb.update_priority(s) + rb.update_tensordict_priority(s) s = rb.sample(5) assert (val == s.get("a")).sum() >= 1 torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) @@ -651,16 +651,16 @@ def test_rb_trajectories(stack): traj_td = torch.stack([td.to_tensordict() for td in traj_td], 0) rb = TensorDictPrioritizedReplayBuffer( - 5, alpha=0.7, beta=0.9, priority_key="td_error", + storage=ListStorage(5), ) rb.extend(traj_td) sampled_td = rb.sample(3) sampled_td.set("td_error", torch.rand(3)) - rb.update_priority(sampled_td) - sampled_td = rb.sample(3, return_weight=True) + rb.update_tensordict_priority(sampled_td) + sampled_td = rb.sample(3, include_info=True) assert (sampled_td.get("_weight") > 0).all() assert sampled_td.batch_size == torch.Size([3]) @@ -680,12 +680,12 @@ def test_shared_storage_prioritized_sampler(): sampler0 = RandomSampler() sampler1 = PrioritizedSampler(max_capacity=n, alpha=0.7, beta=1.1) - rb0 = rb_prototype.ReplayBuffer( + rb0 = ReplayBuffer( storage=storage, writer=writer, sampler=sampler0, ) - rb1 = rb_prototype.ReplayBuffer( + rb1 = ReplayBuffer( storage=storage, writer=writer, sampler=sampler1, @@ -708,25 +708,8 @@ def test_shared_storage_prioritized_sampler(): assert rb1._sampler._sum_tree.query(0, 70) == 50 -def test_legacy_rb_does_not_attach(): - n = 10 - storage = LazyMemmapStorage(n) - writer = RoundRobinWriter() - sampler = RandomSampler() - rb = ReplayBuffer(storage=storage, size=n, prefetch=0) - prb = rb_prototype.ReplayBuffer( - storage=storage, - writer=writer, - sampler=sampler, - ) - - assert len(storage._attached_entities) == 1 - assert prb in storage._attached_entities - assert rb not in storage._attached_entities - - def test_append_transform(): - rb = rb_prototype.ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0)) + rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0)) td = TensorDict( { "observation": torch.randn(2, 4, 3, 16), @@ -741,7 +724,7 @@ def test_append_transform(): rb.append_transform(flatten) - sampled, _ = rb.sample(1) + sampled = rb.sample(1) assert sampled.get("observation_cat").shape[-1] == 32 @@ -750,13 +733,11 @@ def test_init_transform(): -2, -1, in_keys=["observation"], out_keys=["flattened"] ) - rb = rb_prototype.ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=flatten - ) + rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), transform=flatten) td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, []) rb.add(td) - sampled, _ = rb.sample(1) + sampled = rb.sample(1) assert sampled.get("flattened").shape[-1] == 48 @@ -764,15 +745,13 @@ def test_insert_transform(): flatten = FlattenObservation( -2, -1, in_keys=["observation"], out_keys=["flattened"] ) - rb = rb_prototype.ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=flatten - ) + rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), transform=flatten) td = TensorDict({"observation": torch.randn(2, 4, 3, 16, 1)}, []) rb.add(td) rb.insert_transform(0, SqueezeTransform(-1, in_keys=["observation"])) - sampled, _ = rb.sample(1) + sampled = rb.sample(1) assert sampled.get("flattened").shape[-1] == 48 with pytest.raises(ValueError): @@ -810,7 +789,7 @@ def test_insert_transform(): @pytest.mark.parametrize("transform", transforms) def test_smoke_replay_buffer_transform(transform): - rb = rb_prototype.ReplayBuffer( + rb = ReplayBuffer( transform=transform(in_keys="observation"), ) @@ -833,9 +812,7 @@ def test_smoke_replay_buffer_transform(transform): @pytest.mark.parametrize("transform", transforms) def test_smoke_replay_buffer_transform_no_inkeys(transform): - rb = rb_prototype.ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=transform() - ) + rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), transform=transform()) td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, []) rb.add(td) diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py index 15e594fa78c..9f59dcb353f 100644 --- a/test/test_rb_distributed.py +++ b/test/test_rb_distributed.py @@ -11,7 +11,7 @@ import torch.distributed.rpc as rpc import torch.multiprocessing as mp from tensordict.tensordict import TensorDict -from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer +from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.data.replay_buffers.writers import RoundRobinWriter @@ -50,7 +50,7 @@ def sample_from_buffer_remotely_returns_correct_tensordict_test(rank, name, worl if name == "TRAINER": buffer = _construct_buffer("BUFFER") _, inserted = _add_random_tensor_dict_to_buffer(buffer) - sampled, _ = _sample_from_buffer(buffer, 1) + sampled = _sample_from_buffer(buffer, 1) assert type(sampled) is type(inserted) is TensorDict assert (sampled == inserted)["a"].item() diff --git a/test/test_trainer.py b/test/test_trainer.py index 87919df592d..7f15911a179 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -200,10 +200,11 @@ def test_rb_trainer(self, prioritized): torch.manual_seed(0) trainer = mocking_trainer() S = 100 + storage = ListStorage(S) if prioritized: - replay_buffer = TensorDictPrioritizedReplayBuffer(S, 1.1, 0.9) + replay_buffer = TensorDictPrioritizedReplayBuffer(1.1, 0.9, storage=storage) else: - replay_buffer = TensorDictReplayBuffer(S) + replay_buffer = TensorDictReplayBuffer(storage=storage) N = 9 rb_trainer = ReplayBufferTrainer(replay_buffer=replay_buffer, batch_size=N) @@ -234,11 +235,9 @@ def test_rb_trainer(self, prioritized): if prioritized: for idx in range(min(S, batch)): if idx in td_out.get("index"): - assert replay_buffer._sum_tree[idx] != 1.0 + assert replay_buffer._sampler._sum_tree[idx] != 1.0 else: - assert replay_buffer._sum_tree[idx] == 1.0 - else: - assert "index" not in td_out.keys() + assert replay_buffer._sampler._sum_tree[idx] == 1.0 @pytest.mark.parametrize( "storage_type", @@ -260,14 +259,12 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type): if prioritized: replay_buffer = TensorDictPrioritizedReplayBuffer( - S, 1.1, 0.9, storage=storage, ) else: replay_buffer = TensorDictReplayBuffer( - S, storage=storage, ) @@ -295,18 +292,21 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type): trainer2 = mocking_trainer() if prioritized: replay_buffer2 = TensorDictPrioritizedReplayBuffer( - S, 1.1, 0.9, storage=storage + 1.1, 0.9, storage=storage ) else: - replay_buffer2 = TensorDictReplayBuffer(S, storage=storage) + replay_buffer2 = TensorDictReplayBuffer(storage=storage) N = 9 rb_trainer2 = ReplayBufferTrainer(replay_buffer=replay_buffer2, batch_size=N) rb_trainer2.register(trainer2) sd = trainer.state_dict() trainer2.load_state_dict(sd) - assert rb_trainer2.replay_buffer.cursor > 0 - assert rb_trainer2.replay_buffer.cursor == rb_trainer.replay_buffer.cursor + assert rb_trainer2.replay_buffer._writer._cursor > 0 + assert ( + rb_trainer2.replay_buffer._writer._cursor + == rb_trainer.replay_buffer._writer._cursor + ) if storage_type == "list": assert len(rb_trainer2.replay_buffer._storage._storage) > 0 @@ -397,14 +397,12 @@ def make_storage(): storage = make_storage() if prioritized: replay_buffer = TensorDictPrioritizedReplayBuffer( - S, 1.1, 0.9, storage=storage, ) else: replay_buffer = TensorDictReplayBuffer( - S, storage=storage, ) @@ -431,14 +429,12 @@ def make_storage(): storage2 = make_storage() if prioritized: replay_buffer2 = TensorDictPrioritizedReplayBuffer( - S, 1.1, 0.9, storage=storage2, ) else: replay_buffer2 = TensorDictReplayBuffer( - S, storage=storage2, ) N = 9 diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 4f0fe172d04..9510cca0309 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -9,6 +9,7 @@ LazyTensorStorage, ListStorage, PrioritizedReplayBuffer, + RemoteTensorDictReplayBuffer, ReplayBuffer, Storage, TensorDictPrioritizedReplayBuffer, diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 53e363855ef..6a9911afa74 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -5,6 +5,7 @@ from .replay_buffers import ( PrioritizedReplayBuffer, + RemoteTensorDictReplayBuffer, ReplayBuffer, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, diff --git a/torchrl/data/replay_buffers/rb_prototype.py b/torchrl/data/replay_buffers/rb_prototype.py deleted file mode 100644 index 1817ae7d67a..00000000000 --- a/torchrl/data/replay_buffers/rb_prototype.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import collections -import threading -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union - -import torch -from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase - -from torchrl.envs.transforms.transforms import Compose, Transform - -from .replay_buffers import pin_memory_output -from .samplers import RandomSampler, Sampler -from .storages import _get_default_collate, ListStorage, Storage -from .utils import _to_numpy, accept_remote_rref_udf_invocation, INT_CLASSES -from .writers import RoundRobinWriter, Writer - - -class ReplayBuffer: - """A generic, composable replay buffer class. - - Args: - storage (Storage, optional): the storage to be used. If none is provided - a default ListStorage with max_size of 1_000 will be created. - sampler (Sampler, optional): the sampler to be used. If none is provided - a default RandomSampler() will be used. - writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. - collate_fn (callable, optional): merges a list of samples to form a - mini-batch of Tensor(s)/outputs. Used when using batched - loading from a map-style dataset. - pin_memory (bool): whether pin_memory() should be called on the rb - samples. - prefetch (int, optional): number of next batches to be prefetched - using multithreading. - transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. - """ - - def __init__( - self, - storage: Optional[Storage] = None, - sampler: Optional[Sampler] = None, - writer: Optional[Writer] = None, - collate_fn: Optional[Callable] = None, - pin_memory: bool = False, - prefetch: Optional[int] = None, - transform: Optional[Transform] = None, - ) -> None: - self._storage = storage if storage is not None else ListStorage(max_size=1_000) - self._storage.attach(self) - self._sampler = sampler if sampler is not None else RandomSampler() - self._writer = writer if writer is not None else RoundRobinWriter() - self._writer.register_storage(self._storage) - - self._collate_fn = ( - collate_fn - if collate_fn is not None - else _get_default_collate(self._storage) - ) - self._pin_memory = pin_memory - - self._prefetch = bool(prefetch) - self._prefetch_cap = prefetch or 0 - self._prefetch_queue = collections.deque() - if self._prefetch_cap: - self._prefetch_executor = ThreadPoolExecutor(max_workers=self._prefetch_cap) - - self._replay_lock = threading.RLock() - self._futures_lock = threading.RLock() - if transform is None: - transform = Compose() - elif not isinstance(transform, Compose): - transform = Compose(transform) - transform.eval() - self._transform = transform - - def __len__(self) -> int: - with self._replay_lock: - return len(self._storage) - - def __repr__(self) -> str: - return ( - f"{type(self).__name__}(" - f"storage={self._storage}, " - f"sampler={self._sampler}, " - f"writer={self._writer}" - ")" - ) - - @pin_memory_output - def __getitem__(self, index: Union[int, torch.Tensor]) -> Any: - index = _to_numpy(index) - with self._replay_lock: - data = self._storage[index] - - if not isinstance(index, INT_CLASSES): - data = self._collate_fn(data) - - return data - - def add(self, data: Any) -> int: - """Add a single element to the replay buffer. - - Args: - data (Any): data to be added to the replay buffer - - Returns: - index where the data lives in the replay buffer. - """ - with self._replay_lock: - index = self._writer.add(data) - self._sampler.add(index) - return index - - def extend(self, data: Sequence) -> torch.Tensor: - """Extends the replay buffer with one or more elements contained in an iterable. - - Args: - data (iterable): collection of data to be added to the replay - buffer. - - Returns: - Indices of the data aded to the replay buffer. - """ - with self._replay_lock: - index = self._writer.extend(data) - self._sampler.extend(index) - return index - - def update_priority( - self, - index: Union[int, torch.Tensor], - priority: Union[int, torch.Tensor], - ) -> None: - with self._replay_lock: - self._sampler.update_priority(index, priority) - - @pin_memory_output - def _sample(self, batch_size: int) -> Tuple[Any, dict]: - with self._replay_lock: - index, info = self._sampler.sample(self._storage, batch_size) - data = self._storage[index] - if not isinstance(index, INT_CLASSES): - data = self._collate_fn(data) - data = self._transform(data) - return data, info - - def sample(self, batch_size: int) -> Tuple[Any, dict]: - """Samples a batch of data from the replay buffer. - - Uses Sampler to sample indices, and retrieves them from Storage. - - Args: - batch_size (int): size of data to be collected. - - Returns: - A batch of data selected in the replay buffer. - """ - if not self._prefetch: - return self._sample(batch_size) - - if len(self._prefetch_queue) == 0: - ret = self._sample(batch_size) - else: - with self._futures_lock: - ret = self._prefetch_queue.popleft().result() - - with self._futures_lock: - while len(self._prefetch_queue) < self._prefetch_cap: - fut = self._prefetch_executor.submit(self._sample, batch_size) - self._prefetch_queue.append(fut) - - return ret - - def mark_update(self, index: Union[int, torch.Tensor]) -> None: - self._sampler.mark_update(index) - - def append_transform(self, transform: Transform) -> None: - """Appends transform at the end. - - Transforms are applied in order when `sample` is called. - - Args: - transform (Transform): The transform to be appended - """ - transform.eval() - self._transform.append(transform) - - def insert_transform(self, index: int, transform: Transform) -> None: - """Inserts transform. - - Transforms are executed in order when `sample` is called. - - Args: - index (int): Position to insert the transform. - transform (Transform): The transform to be appended - """ - transform.eval() - self._transform.insert(index, transform) - - -class TensorDictReplayBuffer(ReplayBuffer): - """TensorDict-specific wrapper around the ReplayBuffer class. - - Args: - priority_key (str): the key at which priority is assumed to be stored - within TensorDicts added to this ReplayBuffer. - """ - - def __init__(self, priority_key: str = "td_error", **kw) -> None: - super().__init__(**kw) - self.priority_key = priority_key - - def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]: - if self.priority_key not in tensordict.keys(): - return self._sampler.default_priority - if tensordict.batch_dims: - tensordict = tensordict.clone(recurse=False) - tensordict.batch_size = [] - try: - priority = tensordict.get(self.priority_key).item() - except ValueError: - raise ValueError( - f"Found a priority key of size" - f" {tensordict.get(self.priority_key).shape} but expected " - f"scalar value" - ) - return priority - - def add(self, data: TensorDictBase) -> int: - index = super().add(data) - data.set("index", index) - - priority = self._get_priority(data) - if priority: - self.update_priority(index, priority) - return index - - def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: - if isinstance(tensordicts, TensorDictBase): - if tensordicts.batch_dims > 1: - # we want the tensordict to have one dimension only. The batch size - # of the sampled tensordicts can be changed thereafter - if not isinstance(tensordicts, LazyStackedTensorDict): - tensordicts = tensordicts.clone(recurse=False) - else: - tensordicts = tensordicts.contiguous() - tensordicts.batch_size = tensordicts.batch_size[:1] - tensordicts.set( - "index", - torch.zeros( - tensordicts.shape, device=tensordicts.device, dtype=torch.int - ), - ) - - if not isinstance(tensordicts, TensorDictBase): - stacked_td = torch.stack(tensordicts, 0) - else: - stacked_td = tensordicts - - index = super().extend(stacked_td) - stacked_td.set( - "index", - torch.tensor(index, dtype=torch.int, device=stacked_td.device), - inplace=True, - ) - self.update_tensordict_priority(stacked_td) - return index - - def update_tensordict_priority(self, data: TensorDictBase) -> None: - priority = torch.tensor( - [self._get_priority(td) for td in data], - dtype=torch.float, - device=data.device, - ) - self.update_priority(data.get("index"), priority) - - def sample(self, batch_size: int, include_info: bool = False) -> TensorDictBase: - data, info = super().sample(batch_size) - if include_info: - for k, v in info.items(): - data.set(k, torch.tensor(v, device=data.device), inplace=True) - return data, info - - -@accept_remote_rref_udf_invocation -class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer): - """A remote invocation friendly ReplayBuffer class. Public methods can be invoked by remote agents using `torch.rpc` or called locally as normal.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def sample(self, batch_size: int, include_info: bool = False) -> TensorDictBase: - return super().sample(batch_size, include_info) - - def add(self, data: TensorDictBase) -> int: - return super().add(data) - - def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: - return super().extend(tensordicts) - - def update_priority( - self, index: Union[int, torch.Tensor], priority: Union[int, torch.Tensor] - ) -> None: - return super().update_priority(index, priority) - - def update_tensordict_priority(self, data: TensorDictBase) -> None: - return super().update_tensordict_priority(data) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 053384e6bb8..47f878e11af 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -4,30 +4,20 @@ # LICENSE file in the root directory of this source tree. import collections -import concurrent.futures import threading -from copy import deepcopy +from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -import numpy as np import torch from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase -from torch import Tensor - -from torchrl._torchrl import ( - MinSegmentTreeFp32, - MinSegmentTreeFp64, - SumSegmentTreeFp32, - SumSegmentTreeFp64, -) -from torchrl.data.replay_buffers.storages import ( - _get_default_collate, - ListStorage, - Storage, -) -from torchrl.data.replay_buffers.utils import _to_numpy, _to_torch, INT_CLASSES + from torchrl.data.utils import DEVICE_TYPING +from .samplers import PrioritizedSampler, RandomSampler, Sampler +from .storages import _get_default_collate, ListStorage, Storage +from .utils import _to_numpy, accept_remote_rref_udf_invocation, INT_CLASSES +from .writers import RoundRobinWriter, Writer + def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]: """Zips a list of iterables containing tensor-like objects and stacks the resulting lists of tensors together. @@ -86,10 +76,15 @@ def decorated_fun(self, *args, **kwargs): class ReplayBuffer: - """Circular replay buffer. + """A generic, composable replay buffer class. Args: - size (int): integer indicating the maximum size of the replay buffer. + storage (Storage, optional): the storage to be used. If none is provided + a default ListStorage with max_size of 1_000 will be created. + sampler (Sampler, optional): the sampler to be used. If none is provided + a default RandomSampler() will be used. + writer (Writer, optional): the writer to be used. If none is provided + a default RoundRobinWriter() will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -97,23 +92,26 @@ class ReplayBuffer: samples. prefetch (int, optional): number of next batches to be prefetched using multithreading. - storage (Storage, optional): the storage to be used. If none is provided, - a ListStorage will be instantiated. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. """ def __init__( self, - size: int, + storage: Optional[Storage] = None, + sampler: Optional[Sampler] = None, + writer: Optional[Writer] = None, collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, - storage: Optional[Storage] = None, - ): - if storage is None: - storage = ListStorage(size) - self._storage = storage - self._capacity = size - self._cursor = 0 + transform: Optional["Transform"] = None, # noqa-F821 + ) -> None: + self._storage = storage if storage is not None else ListStorage(max_size=1_000) + self._storage.attach(self) + self._sampler = sampler if sampler is not None else RandomSampler() + self._writer = writer if writer is not None else RoundRobinWriter() + self._writer.register_storage(self._storage) + self._collate_fn = ( collate_fn if collate_fn is not None @@ -121,50 +119,58 @@ def __init__( ) self._pin_memory = pin_memory - self._prefetch = prefetch is not None and prefetch > 0 - self._prefetch_cap = prefetch if prefetch is not None else 0 - self._prefetch_fut = collections.deque() - if self._prefetch_cap > 0: - self._prefetch_executor = concurrent.futures.ThreadPoolExecutor( - max_workers=self._prefetch_cap - ) + self._prefetch = bool(prefetch) + self._prefetch_cap = prefetch or 0 + self._prefetch_queue = collections.deque() + if self._prefetch_cap: + self._prefetch_executor = ThreadPoolExecutor(max_workers=self._prefetch_cap) self._replay_lock = threading.RLock() - self._future_lock = threading.RLock() + self._futures_lock = threading.RLock() + from torchrl.envs.transforms.transforms import Compose + + if transform is None: + transform = Compose() + elif not isinstance(transform, Compose): + transform = Compose(transform) + transform.eval() + self._transform = transform def __len__(self) -> int: with self._replay_lock: return len(self._storage) + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"storage={self._storage}, " + f"sampler={self._sampler}, " + f"writer={self._writer}" + ")" + ) + @pin_memory_output - def __getitem__(self, index: Union[int, Tensor]) -> Any: + def __getitem__(self, index: Union[int, torch.Tensor]) -> Any: index = _to_numpy(index) - with self._replay_lock: data = self._storage[index] if not isinstance(index, INT_CLASSES): data = self._collate_fn(data) + return data def state_dict(self) -> Dict[str, Any]: return { "_storage": self._storage.state_dict(), - "_cursor": self._cursor, + "_sampler": self._sampler.state_dict(), + "_writer": self._writer.state_dict(), } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._storage.load_state_dict(state_dict["_storage"]) - self._cursor = state_dict["_cursor"] - - @property - def capacity(self) -> int: - return self._capacity - - @property - def cursor(self) -> int: - with self._replay_lock: - return self._cursor + self._sampler.load_state_dict(state_dict["_sampler"]) + self._writer.load_state_dict(state_dict["_writer"]) def add(self, data: Any) -> int: """Add a single element to the replay buffer. @@ -176,12 +182,11 @@ def add(self, data: Any) -> int: index where the data lives in the replay buffer. """ with self._replay_lock: - ret = self._cursor - self._storage[self._cursor] = data - self._cursor = (self._cursor + 1) % self._capacity - return ret + index = self._writer.add(data) + self._sampler.add(index) + return index - def extend(self, data: Sequence[Any]): + def extend(self, data: Sequence) -> torch.Tensor: """Extends the replay buffer with one or more elements contained in an iterable. Args: @@ -190,66 +195,87 @@ def extend(self, data: Sequence[Any]): Returns: Indices of the data aded to the replay buffer. - """ - if not len(data): - raise Exception("extending with empty data is not supported") with self._replay_lock: - batch_size = len(data) - if self._cursor + batch_size <= self._capacity: - index = np.arange(self._cursor, self._cursor + batch_size) - self._cursor = (self._cursor + batch_size) % self._capacity - else: - d = self._capacity - self._cursor - index = np.empty(batch_size, dtype=np.int64) - index[:d] = np.arange(self._cursor, self._capacity) - index[d:] = np.arange(batch_size - d) - self._cursor = batch_size - d - # storage must convert the data to the appropriate format if needed - self._storage[index] = data - return index + index = self._writer.extend(data) + self._sampler.extend(index) + return index - @pin_memory_output - def _sample(self, batch_size: int) -> Any: - index = torch.randint(0, len(self._storage), (batch_size,)) + def update_priority( + self, + index: Union[int, torch.Tensor], + priority: Union[int, torch.Tensor], + ) -> None: + with self._replay_lock: + self._sampler.update_priority(index, priority) + @pin_memory_output + def _sample(self, batch_size: int) -> Tuple[Any, dict]: with self._replay_lock: + index, info = self._sampler.sample(self._storage, batch_size) data = self._storage[index] + if not isinstance(index, INT_CLASSES): + data = self._collate_fn(data) + data = self._transform(data) + return data, info - data = self._collate_fn(data) - return data - - def sample(self, batch_size: int) -> Any: + def sample(self, batch_size: int, return_info: bool = False) -> Any: """Samples a batch of data from the replay buffer. + Uses Sampler to sample indices, and retrieves them from Storage. + Args: - batch_size (int): float of data to be collected. + batch_size (int): size of data to be collected. + return_info (bool): whether to return info. If True, the result + is a tuple (data, info). If False, the result is the data. Returns: - A batch of data randomly selected in the replay buffer. - + A batch of data selected in the replay buffer. + A tuple containing this batch and info if return_info flag is set to True. """ if not self._prefetch: - return self._sample(batch_size) - - with self._future_lock: - if len(self._prefetch_fut) == 0: + ret = self._sample(batch_size) + else: + if len(self._prefetch_queue) == 0: ret = self._sample(batch_size) else: - ret = self._prefetch_fut.popleft().result() + with self._futures_lock: + ret = self._prefetch_queue.popleft().result() - while len(self._prefetch_fut) < self._prefetch_cap: - fut = self._prefetch_executor.submit(self._sample, batch_size) - self._prefetch_fut.append(fut) + with self._futures_lock: + while len(self._prefetch_queue) < self._prefetch_cap: + fut = self._prefetch_executor.submit(self._sample, batch_size) + self._prefetch_queue.append(fut) + if return_info: return ret + return ret[0] - def __repr__(self) -> str: - string = ( - f"{type(self).__name__}(size={len(self)}, " - f"pin_memory={self._pin_memory})" - ) - return string + def mark_update(self, index: Union[int, torch.Tensor]) -> None: + self._sampler.mark_update(index) + + def append_transform(self, transform: "Transform") -> None: # noqa-F821 + """Appends transform at the end. + + Transforms are applied in order when `sample` is called. + + Args: + transform (Transform): The transform to be appended + """ + transform.eval() + self._transform.append(transform) + + def insert_transform(self, index: int, transform: "Transform") -> None: # noqa-F821 + """Inserts transform. + + Transforms are executed in order when `sample` is called. + + Args: + index (int): Position to insert the transform. + transform (Transform): The transform to be appended + """ + transform.eval() + self._transform.insert(index, transform) class PrioritizedReplayBuffer(ReplayBuffer): @@ -261,12 +287,14 @@ class PrioritizedReplayBuffer(ReplayBuffer): (https://arxiv.org/abs/1511.05952) Args: - size (int): integer indicating the maximum size of the replay buffer. alpha (float): exponent α determines how much prioritization is used, with α = 0 corresponding to the uniform case. beta (float): importance sampling negative exponent. eps (float): delta added to the priorities to ensure that the buffer does not contain null priorities. + dtype (torch.dtype): type of the data. Can be torch.float or torch.double. + storage (Storage, optional): the storage to be used. If none is provided + a default ListStorage with max_size of 1_000 will be created. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -274,286 +302,156 @@ class PrioritizedReplayBuffer(ReplayBuffer): samples. prefetch (int, optional): number of next batches to be prefetched using multithreading. - storage (Storage, optional): the storage to be used. If none is provided, - a ListStorage will be instantiated. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. """ def __init__( self, - size: int, alpha: float, beta: float, eps: float = 1e-8, dtype: torch.dtype = torch.float, - collate_fn=None, + storage: Optional[Storage] = None, + collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, - storage: Optional[Storage] = None, + transform: Optional["Transform"] = None, # noqa-F821 ) -> None: + if storage is None: + storage = ListStorage(max_size=1_000) + sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype) super(PrioritizedReplayBuffer, self).__init__( - size, - collate_fn, - pin_memory, - prefetch, storage=storage, + sampler=sampler, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + transform=transform, ) - if alpha <= 0: - raise ValueError( - f"alpha must be strictly greater than 0, got alpha={alpha}" - ) - if beta < 0: - raise ValueError(f"beta must be greater or equal to 0, got beta={beta}") - - self._alpha = alpha - self._beta = beta - self._eps = eps - if dtype in (torch.float, torch.FloatType, torch.float32): - self._sum_tree = SumSegmentTreeFp32(size) - self._min_tree = MinSegmentTreeFp32(size) - elif dtype in (torch.double, torch.DoubleTensor, torch.float64): - self._sum_tree = SumSegmentTreeFp64(size) - self._min_tree = MinSegmentTreeFp64(size) - else: - raise NotImplementedError( - f"dtype {dtype} not supported by PrioritizedReplayBuffer" - ) - self._max_priority = 1.0 - - def state_dict(self) -> Dict[str, Any]: - state_dict = super().state_dict() - state_dict["_sum_tree"] = deepcopy(self._sum_tree) - state_dict["_min_tree"] = deepcopy(self._min_tree) - state_dict["_max_priority"] = self._max_priority - return state_dict - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self._sum_tree = state_dict.pop("_sum_tree") - self._min_tree = state_dict.pop("_min_tree") - self._max_priority = state_dict.pop("_max_priority") - super().load_state_dict(state_dict) - - @pin_memory_output - def __getitem__(self, index: Union[int, Tensor]) -> Any: - index = _to_numpy(index) - with self._replay_lock: - p_min = self._min_tree.query(0, self._capacity) - if p_min <= 0: - raise ValueError(f"p_min must be greater than 0, got p_min={p_min}") - data = self._storage[index] - if isinstance(index, INT_CLASSES): - weight = np.array(self._sum_tree[index]) - else: - weight = self._sum_tree[index] - if not isinstance(index, INT_CLASSES): - data = self._collate_fn(data) - # weight = np.power(weight / (p_min + self._eps), -self._beta) - weight = np.power(weight / p_min, -self._beta) - # x = first_field(data) - # if isinstance(x, torch.Tensor): - device = data.device if hasattr(data, "device") else torch.device("cpu") - weight = _to_torch(weight, device, self._pin_memory) - return data, weight - - @property - def alpha(self) -> float: - return self._alpha - - @property - def beta(self) -> float: - return self._beta - - @property - def eps(self) -> float: - return self._eps - - @property - def max_priority(self) -> float: - with self._replay_lock: - return self._max_priority +class TensorDictReplayBuffer(ReplayBuffer): + """TensorDict-specific wrapper around the ReplayBuffer class. - @property - def _default_priority(self) -> float: - return (self._max_priority + self._eps) ** self._alpha + Args: + priority_key (str): the key at which priority is assumed to be stored + within TensorDicts added to this ReplayBuffer. + """ - def _add_or_extend( - self, - data: Any, - priority: Optional[torch.Tensor] = None, - do_add: bool = True, - ) -> torch.Tensor: - if priority is not None: - priority = _to_numpy(priority) - max_priority = np.max(priority) - with self._replay_lock: - self._max_priority = max(self._max_priority, max_priority) - priority = np.power(priority + self._eps, self._alpha) - else: - with self._replay_lock: - priority = self._default_priority + def __init__(self, priority_key: str = "td_error", **kw) -> None: + super().__init__(**kw) + self.priority_key = priority_key - if do_add: - index = super(PrioritizedReplayBuffer, self).add(data) - else: - index = super(PrioritizedReplayBuffer, self).extend(data) - - if not ( - isinstance(priority, float) - or len(priority) == 1 - or len(priority) == len(index) - ): - raise RuntimeError( - "priority should be a scalar or an iterable of the same " - "length as index" + def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]: + if self.priority_key not in tensordict.keys(): + return self._sampler.default_priority + if tensordict.batch_dims: + tensordict = tensordict.clone(recurse=False) + tensordict.batch_size = [] + try: + priority = tensordict.get(self.priority_key).item() + except ValueError: + raise ValueError( + f"Found a priority key of size" + f" {tensordict.get(self.priority_key).shape} but expected " + f"scalar value" ) + return priority - with self._replay_lock: - self._sum_tree[index] = priority - self._min_tree[index] = priority + def add(self, data: TensorDictBase) -> int: + index = super().add(data) + data.set("index", index) + priority = self._get_priority(data) + if priority: + self.update_priority(index, priority) return index - def add(self, data: Any, priority: Optional[torch.Tensor] = None) -> torch.Tensor: - return self._add_or_extend(data, priority, True) - - def extend( - self, data: Sequence, priority: Optional[torch.Tensor] = None - ) -> torch.Tensor: - return self._add_or_extend(data, priority, False) - - @pin_memory_output - def _sample(self, batch_size: int) -> Tuple[Any, torch.Tensor, torch.Tensor]: - with self._replay_lock: - p_sum = self._sum_tree.query(0, self._capacity) - p_min = self._min_tree.query(0, self._capacity) - if p_sum <= 0: - raise RuntimeError("negative p_sum") - if p_min <= 0: - raise RuntimeError("negative p_min") - mass = np.random.uniform(0.0, p_sum, size=batch_size) - index = self._sum_tree.scan_lower_bound(mass) - if not isinstance(index, torch.Tensor): - index = torch.tensor(index) - if not index.ndimension(): - index = index.reshape((1,)) - index.clamp_max_(len(self._storage) - 1) - data = self._storage[index] - weight = self._sum_tree[index] - - data = self._collate_fn(data) - - # Importance sampling weight formula: - # w_i = (p_i / sum(p) * N) ^ (-beta) - # weight_i = w_i / max(w) - # weight_i = (p_i / sum(p) * N) ^ (-beta) / - # ((min(p) / sum(p) * N) ^ (-beta)) - # weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta) - # weight_i = (p_i / min(p)) ^ (-beta) - # weight = np.power(weight / (p_min + self._eps), -self._beta) - weight = np.power(weight / p_min, -self._beta) - - # x = first_field(data) # avoid calling tree.flatten - # if isinstance(x, torch.Tensor): - device = data.device if hasattr(data, "device") else torch.device("cpu") - weight = _to_torch(weight, device, self._pin_memory) - return data, weight, index - - def sample(self, batch_size: int) -> Tuple[Any, np.ndarray, torch.Tensor]: - """Gathers a batch of data according to the non-uniform multinomial distribution with weights computed with the provided priorities of each input. - - Args: - batch_size (int): float of data to be collected. - - Returns: a random sample from the replay buffer. + def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: + if isinstance(tensordicts, TensorDictBase): + if tensordicts.batch_dims > 1: + # we want the tensordict to have one dimension only. The batch size + # of the sampled tensordicts can be changed thereafter + if not isinstance(tensordicts, LazyStackedTensorDict): + tensordicts = tensordicts.clone(recurse=False) + else: + tensordicts = tensordicts.contiguous() + tensordicts.batch_size = tensordicts.batch_size[:1] + tensordicts.set( + "index", + torch.zeros( + tensordicts.shape, device=tensordicts.device, dtype=torch.int + ), + ) - """ - if not self._prefetch: - return self._sample(batch_size) + if not isinstance(tensordicts, TensorDictBase): + stacked_td = torch.stack(tensordicts, 0) + else: + stacked_td = tensordicts - with self._future_lock: - if len(self._prefetch_fut) == 0: - ret = self._sample(batch_size) - else: - ret = self._prefetch_fut.popleft().result() + index = super().extend(stacked_td) + stacked_td.set( + "index", + torch.tensor(index, dtype=torch.int, device=stacked_td.device), + inplace=True, + ) + self.update_tensordict_priority(stacked_td) + return index - while len(self._prefetch_fut) < self._prefetch_cap: - fut = self._prefetch_executor.submit(self._sample, batch_size) - self._prefetch_fut.append(fut) + def update_tensordict_priority(self, data: TensorDictBase) -> None: + priority = torch.tensor( + [self._get_priority(td) for td in data], + dtype=torch.float, + device=data.device, + ) + self.update_priority(data.get("index"), priority) - return ret + def sample( + self, batch_size: int, include_info: bool = False, return_info: bool = False + ) -> TensorDictBase: + """Samples a batch of data from the replay buffer. - def update_priority( - self, index: Union[int, Tensor], priority: Union[float, Tensor] - ) -> None: - """Updates the priority of the data pointed by the index. + Uses Sampler to sample indices, and retrieves them from Storage. Args: - index (int or torch.Tensor): indexes of the priorities to be - updated. - priority (Number or torch.Tensor): new priorities of the - indexed elements - + batch_size (int): size of data to be collected. + include_info (bool): whether to add info to the returned tensordict. + return_info (bool): whether to return info. If True, the result + is a tuple (data, info). If False, the result is the data. + Returns: + A tensordict containing a batch of data selected in the replay buffer. + A tuple containing this tensordict and info if return_info flag is set to True. """ - if isinstance(index, INT_CLASSES): - if not isinstance(priority, float): - if len(priority) != 1: - raise RuntimeError( - f"priority length should be 1, got {len(priority)}" - ) - priority = priority.item() - else: - if not ( - isinstance(priority, float) - or len(priority) == 1 - or len(index) == len(priority) - ): - raise RuntimeError( - "priority should be a number or an iterable of the same " - "length as index" - ) - index = _to_numpy(index) - priority = _to_numpy(priority) - - with self._replay_lock: - self._max_priority = max(self._max_priority, np.max(priority)) - priority = np.power(priority + self._eps, self._alpha) - self._sum_tree[index] = priority - self._min_tree[index] = priority - - -class TensorDictReplayBuffer(ReplayBuffer): - """TensorDict-specific wrapper around the ReplayBuffer class.""" - - def __init__( - self, - size: int, - collate_fn: Optional[Callable] = None, - pin_memory: bool = False, - prefetch: Optional[int] = None, - storage: Optional[Storage] = None, - ): - super().__init__(size, collate_fn, pin_memory, prefetch, storage=storage) + data, info = super().sample(batch_size, return_info=True) + if include_info: + for k, v in info.items(): + data.set(k, torch.tensor(v, device=data.device), inplace=True) + if return_info: + return data, info + return data -class TensorDictPrioritizedReplayBuffer(PrioritizedReplayBuffer): +class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): """TensorDict-specific wrapper around the PrioritizedReplayBuffer class. This class returns tensordicts with a new key "index" that represents - the index of each element in the replay buffer. It also facilitates the - call to the 'update_priority' method, as it only requires for the + the index of each element in the replay buffer. It also provides the + 'update_tensordict_priority' method that only requires for the tensordict to be passed to it with its new priority value. Args: - size (int): integer indicating the maximum size of the replay buffer. - alpha (flaot): exponent α determines how much prioritization is + alpha (float): exponent α determines how much prioritization is used, with α = 0 corresponding to the uniform case. beta (float): importance sampling negative exponent. priority_key (str, optional): key where the priority value can be found in the stored tensordicts. Default is :obj:`"td_error"` eps (float, optional): delta added to the priorities to ensure that the buffer does not contain null priorities. + dtype (torch.dtype): type of the data. Can be torch.float or torch.double. + storage (Storage, optional): the storage to be used. If none is provided + a default ListStorage with max_size of 1_000 will be created. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -561,131 +459,61 @@ class TensorDictPrioritizedReplayBuffer(PrioritizedReplayBuffer): the rb samples. Default is :obj:`False`. prefetch (int, optional): number of next batches to be prefetched using multithreading. - storage (Storage, optional): the storage to be used. If none is provided, - a ListStorage will be instantiated. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. """ def __init__( self, - size: int, alpha: float, beta: float, priority_key: str = "td_error", eps: float = 1e-8, - collate_fn=None, + storage: Optional[Storage] = None, + collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, - storage: Optional[Storage] = None, + transform: Optional["Transform"] = None, # noqa-F821 ) -> None: + if storage is None: + storage = ListStorage(max_size=1_000) + sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps) super(TensorDictPrioritizedReplayBuffer, self).__init__( - size=size, - alpha=alpha, - beta=beta, - eps=eps, + priority_key=priority_key, + storage=storage, + sampler=sampler, collate_fn=collate_fn, pin_memory=pin_memory, prefetch=prefetch, - storage=storage, + transform=transform, ) - self.priority_key = priority_key - def _get_priority(self, tensordict: TensorDictBase) -> torch.Tensor: - if self.priority_key in tensordict.keys(): - if tensordict.batch_dims: - tensordict = tensordict.clone(recurse=False) - tensordict.batch_size = [] - try: - priority = tensordict.get(self.priority_key).item() - except ValueError: - raise ValueError( - f"Found a priority key of size" - f" {tensordict.get(self.priority_key).shape} but expected " - f"scalar value" - ) - else: - priority = self._default_priority - return priority - def add(self, tensordict: TensorDictBase) -> torch.Tensor: - priority = self._get_priority(tensordict) - index = super().add(tensordict, priority) - tensordict.set("index", index) - return index +@accept_remote_rref_udf_invocation +class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer): + """A remote invocation friendly ReplayBuffer class. Public methods can be invoked by remote agents using `torch.rpc` or called locally as normal.""" - def extend( - self, tensordicts: Union[TensorDictBase, List[TensorDictBase]] - ) -> torch.Tensor: - if isinstance(tensordicts, TensorDictBase): - if self.priority_key in tensordicts.keys(): - priorities = tensordicts.get(self.priority_key) - else: - priorities = None + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - if tensordicts.batch_dims > 1: - # we want the tensordict to have one dimension only. The batch size - # of the sampled tensordicts can be changed thereafter - if not isinstance(tensordicts, LazyStackedTensorDict): - tensordicts = tensordicts.clone(recurse=False) - else: - tensordicts = tensordicts.contiguous() - tensordicts.batch_size = tensordicts.batch_size[:1] - tensordicts.set( - "index", - torch.zeros( - tensordicts.shape, - device=tensordicts.device, - dtype=torch.int, - ), - ) - else: - priorities = [self._get_priority(td) for td in tensordicts] - - if not isinstance(tensordicts, TensorDictBase): - stacked_td = torch.stack(tensordicts, 0) - else: - stacked_td = tensordicts - idx = super().extend(tensordicts, priorities) - stacked_td.set( - "index", - torch.tensor(idx, dtype=torch.int, device=stacked_td.device), - inplace=True, - ) - return idx + def sample( + self, batch_size: int, include_info: bool = False, return_info: bool = False + ) -> TensorDictBase: + return super().sample(batch_size, include_info, return_info) - def update_priority(self, tensordict: TensorDictBase) -> None: - """Updates the priorities of the tensordicts stored in the replay buffer. + def add(self, data: TensorDictBase) -> int: + return super().add(data) - Args: - tensordict: tensordict with key-value pairs 'self.priority_key' - and 'index'. + def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: + return super().extend(tensordicts) + def update_priority( + self, index: Union[int, torch.Tensor], priority: Union[int, torch.Tensor] + ) -> None: + return super().update_priority(index, priority) - """ - priority = tensordict.get(self.priority_key) - if (priority < 0).any(): - raise RuntimeError( - f"Priority must be a positive value, got " - f"{(priority < 0).sum()} negative priority values." - ) - return super().update_priority(tensordict.get("index"), priority=priority) - - def sample(self, size: int, return_weight: bool = False) -> TensorDictBase: - """Gather a batch of tensordicts according to the non-uniform multinomial distribution with weights computed with the priority_key of each input tensordict. - - Args: - size (int): size of the batch to be returned - return_weight (bool, optional): if True, a '_weight' key will be - written in the output tensordict that indicates the weight - of the selected items - - Returns: - Stack of tensordicts - - """ - td, weight, _ = super(TensorDictPrioritizedReplayBuffer, self).sample(size) - if return_weight: - td.set("_weight", weight) - return td + def update_tensordict_priority(self, data: TensorDictBase) -> None: + return super().update_tensordict_priority(data) class InPlaceSampler: diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 1dafaa31a98..78355d329c1 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Tuple, Union +from copy import deepcopy +from typing import Any, Dict, Tuple, Union import numpy as np import torch @@ -45,6 +46,12 @@ def mark_update(self, index: Union[int, torch.Tensor]) -> None: def default_priority(self) -> float: return 1.0 + def state_dict(self) -> Dict[str, Any]: + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + return + class RandomSampler(Sampler): """A uniformly random sampler for composable replay buffers.""" @@ -157,7 +164,7 @@ def __init__( self._min_tree = MinSegmentTreeFp64(self._max_capacity) else: raise NotImplementedError( - f"dtype {dtype} not supported by PrioritizedReplayBuffer" + f"dtype {dtype} not supported by PrioritizedSampler" ) self._max_priority = 1.0 @@ -254,3 +261,21 @@ def update_priority( def mark_update(self, index: Union[int, torch.Tensor]) -> None: self.update_priority(index, self.default_priority) + + def state_dict(self) -> Dict[str, Any]: + return { + "_alpha": self._alpha, + "_beta": self._beta, + "_eps": self._eps, + "_max_priority": self._max_priority, + "_sum_tree": deepcopy(self._sum_tree), + "_min_tree": deepcopy(self._min_tree), + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._alpha = state_dict["_alpha"] + self._beta = state_dict["_beta"] + self._eps = state_dict["_eps"] + self._max_priority = state_dict["_max_priority"] + self._sum_tree = state_dict.pop("_sum_tree") + self._min_tree = state_dict.pop("_min_tree") diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 75a6aa5d971..8aacda38213 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Sequence +from typing import Any, Dict, Sequence import numpy as np import torch @@ -31,6 +31,12 @@ def extend(self, data: Sequence) -> torch.Tensor: """Inserts a series of data points at appropriate indices, and returns a tensor containing the indices.""" raise NotImplementedError + def state_dict(self) -> Dict[str, Any]: + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + return + class RoundRobinWriter(Writer): """A RoundRobin Writer class for composable replay buffers.""" @@ -60,3 +66,9 @@ def extend(self, data: Sequence) -> torch.Tensor: # storage must convert the data to the appropriate format if needed self._storage[index] = data return index + + def state_dict(self) -> Dict[str, Any]: + return {"_cursor": self._cursor} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._cursor = state_dict["_cursor"] diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 507c914d1e7..2c799e17fb5 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -11,7 +11,6 @@ mask_batch, OptimizerHook, Recorder, - ReplayBuffer, ReplayBufferTrainer, RewardNormalizer, SelectKeys, diff --git a/torchrl/trainers/helpers/replay_buffer.py b/torchrl/trainers/helpers/replay_buffer.py index c7bbee12d82..4f9c48bf4b9 100644 --- a/torchrl/trainers/helpers/replay_buffer.py +++ b/torchrl/trainers/helpers/replay_buffer.py @@ -7,11 +7,8 @@ import torch -from torchrl.data import ( - ReplayBuffer, - TensorDictPrioritizedReplayBuffer, - TensorDictReplayBuffer, -) +from torchrl.data import ReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.data.utils import DEVICE_TYPING @@ -22,29 +19,23 @@ def make_replay_buffer( """Builds a replay buffer using the config built from ReplayArgsConfig.""" device = torch.device(device) if not cfg.prb: - buffer = TensorDictReplayBuffer( - cfg.buffer_size, - pin_memory=device != torch.device("cpu"), - prefetch=cfg.buffer_prefetch, - storage=LazyMemmapStorage( - cfg.buffer_size, - scratch_dir=cfg.buffer_scratch_dir, - # device=device, # when using prefetch, this can overload the GPU memory - ), - ) + sampler = RandomSampler() else: - buffer = TensorDictPrioritizedReplayBuffer( - cfg.buffer_size, + sampler = PrioritizedSampler( + max_capacity=cfg.buffer_size, alpha=0.7, beta=0.5, - pin_memory=device != torch.device("cpu"), - prefetch=cfg.buffer_prefetch, - storage=LazyMemmapStorage( - cfg.buffer_size, - scratch_dir=cfg.buffer_scratch_dir, - # device=device, # when using prefetch, this can overload the GPU memory - ), ) + buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage( + cfg.buffer_size, + scratch_dir=cfg.buffer_scratch_dir, + # device=device, # when using prefetch, this can overload the GPU memory + ), + sampler=sampler, + pin_memory=device != torch.device("cpu"), + prefetch=cfg.buffer_prefetch, + ) return buffer diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 73691357c8d..5ffbb435410 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -129,7 +129,7 @@ def make_trainer( >>> recorder = env_proof >>> target_net_updater = None >>> policy_exploration = EGreedyWrapper(policy) - >>> replay_buffer = TensorDictReplayBuffer(1000) + >>> replay_buffer = TensorDictReplayBuffer() >>> dir = tempfile.gettempdir() >>> logger = TensorboardLogger(exp_name=dir) >>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration, diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index da312aaf9f5..a69edc2aac2 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -21,11 +21,7 @@ from torchrl._utils import _CKPT_BACKEND, KeyDependentDefaultDict from torchrl.collectors.collectors import _DataCollector -from torchrl.data import ( - ReplayBuffer, - TensorDictPrioritizedReplayBuffer, - TensorDictReplayBuffer, -) +from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.utils import set_exploration_mode @@ -600,7 +596,7 @@ class ReplayBufferTrainer(TrainerHookBase): """Replay buffer hook provider. Args: - replay_buffer (ReplayBuffer): replay buffer to be used. + replay_buffer (TensorDictReplayBuffer): replay buffer to be used. batch_size (int): batch size when sampling data from the latest collection or from the replay buffer. memmap (bool, optional): if True, a memmap tensordict is created. @@ -629,7 +625,7 @@ class ReplayBufferTrainer(TrainerHookBase): def __init__( self, - replay_buffer: ReplayBuffer, + replay_buffer: TensorDictReplayBuffer, batch_size: int, memmap: bool = False, device: DEVICE_TYPING = "cpu", @@ -673,8 +669,7 @@ def sample(self, batch: TensorDictBase) -> TensorDictBase: return sample.to(self.device, non_blocking=True) def update_priority(self, batch: TensorDictBase) -> None: - if isinstance(self.replay_buffer, TensorDictPrioritizedReplayBuffer): - self.replay_buffer.update_priority(batch) + self.replay_buffer.update_tensordict_priority(batch) def state_dict(self) -> Dict[str, Any]: return { diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 91965a9f091..37f8aaa520b 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -49,12 +49,9 @@ from tensordict.nn import TensorDictModule from torch import nn, optim from torchrl.collectors import MultiaSyncDataCollector -from torchrl.data import ( - CompositeSpec, - TensorDictPrioritizedReplayBuffer, - TensorDictReplayBuffer, -) +from torchrl.data import CompositeSpec, TensorDictReplayBuffer from torchrl.data.postprocs import MultiStep +from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.envs import ( CatTensors, @@ -383,29 +380,23 @@ def make_recorder(actor_model_explore, stats): def make_replay_buffer(make_replay_buffer=3): if prb: - replay_buffer = TensorDictPrioritizedReplayBuffer( - buffer_size, + sampler = PrioritizedSampler( + max_capacity=buffer_size, alpha=0.7, beta=0.5, - pin_memory=False, - prefetch=make_replay_buffer, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), ) else: - replay_buffer = TensorDictReplayBuffer( + sampler = RandomSampler() + replay_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage( buffer_size, - pin_memory=False, - prefetch=make_replay_buffer, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), - ) + scratch_dir=buffer_scratch_dir, + device=device, + ), + sampler=sampler, + pin_memory=False, + prefetch=make_replay_buffer, + ) return replay_buffer @@ -696,7 +687,7 @@ def make_replay_buffer(make_replay_buffer=3): # update priority if prb: - replay_buffer.update_priority(sampled_tensordict) + replay_buffer.update_tensordict_priority(sampled_tensordict) rewards.append( (i, tensordict["reward"].mean().item() / norm_factor_training / frame_skip) @@ -958,7 +949,7 @@ def make_replay_buffer(make_replay_buffer=3): sampled_tensordict["td_error"] = advantage.detach().pow(2).mean(1) sampled_tensordict["index"] = index if prb: - replay_buffer.update_priority(sampled_tensordict) + replay_buffer.update_tensordict_priority(sampled_tensordict) rewards.append( (i, tensordict["reward"].mean().item() / norm_factor_training / frame_skip) diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index c5b728f440c..5998ec21127 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -312,7 +312,6 @@ def make_model(): # shape. This storage will be instantiated later. replay_buffer = TensorDictReplayBuffer( - buffer_size, storage=LazyMemmapStorage(buffer_size), prefetch=n_optim, ) @@ -556,8 +555,7 @@ def make_model(): max_size = frames_per_batch // n_workers replay_buffer = TensorDictReplayBuffer( - -(-buffer_size // max_size), - storage=LazyMemmapStorage(buffer_size), + storage=LazyMemmapStorage(-(-buffer_size // max_size)), prefetch=n_optim, ) diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 41f61665a96..f977b97a7ef 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -224,7 +224,7 @@ ############################################################################### -rb = ReplayBuffer(100, collate_fn=lambda x: x) +rb = ReplayBuffer(collate_fn=lambda x: x) rb.add(1) rb.sample(1) @@ -235,7 +235,7 @@ ############################################################################### -rb = PrioritizedReplayBuffer(100, alpha=0.7, beta=1.1, collate_fn=lambda x: x) +rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x) rb.add(1) rb.sample(1) rb.update_priority(1, 0.5) @@ -244,7 +244,7 @@ # Here are examples of using a replaybuffer with tensordicts. collate_fn = torch.stack -rb = ReplayBuffer(100, collate_fn=collate_fn) +rb = ReplayBuffer(collate_fn=collate_fn) rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[])) len(rb) @@ -260,9 +260,7 @@ torch.manual_seed(0) from torchrl.data import TensorDictPrioritizedReplayBuffer -rb = TensorDictPrioritizedReplayBuffer( - 100, alpha=0.7, beta=1.1, priority_key="td_error" -) +rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error") rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2])) tensordict_sample = rb.sample(2).contiguous() tensordict_sample @@ -274,9 +272,9 @@ ############################################################################### tensordict_sample["td_error"] = torch.rand(2) -rb.update_priority(tensordict_sample) +rb.update_tensordict_priority(tensordict_sample) -for i, val in enumerate(rb._sum_tree): +for i, val in enumerate(rb._sampler._sum_tree): print(i, val) if i == len(rb): break