From 0945d1e7b091eef191a081c21cd2ff583ce2c638 Mon Sep 17 00:00:00 2001 From: Kamil Piechowiak Date: Wed, 4 Jan 2023 15:06:06 +0000 Subject: [PATCH 1/5] Graduate Replay Buffer prototype --- README.md | 7 +- .../benchmark_sample_latency_over_rpc.py | 2 +- docs/source/reference/data.rst | 4 +- .../distributed/distributed_replay_buffer.py | 4 +- examples/dreamer/dreamer.py | 2 +- test/test_rb.py | 102 +-- test/test_rb_distributed.py | 2 +- test/test_trainer.py | 28 +- torchrl/data/__init__.py | 1 + torchrl/data/replay_buffers/__init__.py | 1 + torchrl/data/replay_buffers/rb_prototype.py | 308 -------- torchrl/data/replay_buffers/replay_buffers.py | 675 +++++++----------- torchrl/data/replay_buffers/samplers.py | 29 +- torchrl/data/replay_buffers/writers.py | 14 +- torchrl/trainers/helpers/replay_buffer.py | 39 +- torchrl/trainers/helpers/trainers.py | 2 +- torchrl/trainers/trainers.py | 9 +- tutorials/sphinx-tutorials/coding_ddpg.py | 45 +- tutorials/sphinx-tutorials/coding_dqn.py | 8 +- tutorials/sphinx-tutorials/torchrl_demo.py | 18 +- 20 files changed, 393 insertions(+), 907 deletions(-) delete mode 100644 torchrl/data/replay_buffers/rb_prototype.py diff --git a/README.md b/README.md index 050c1956796..80ede3fccf4 100644 --- a/README.md +++ b/README.md @@ -68,9 +68,9 @@ Here's another example of an off-policy training loop in TorchRL (assuming that - replay_buffer.add((obs, next_obs, action, log_prob, reward, done)) + replay_buffer.add(tensordict) for j in range(num_optim_steps): - - obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size) + - obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)[0] - loss = loss_fn(obs, next_obs, action, hidden_state, reward, done) - + tensordict = replay_buffer.sample(batch_size) + + tensordict = replay_buffer.sample(batch_size)[0] + loss = loss_fn(tensordict) loss.backward() optim.step() @@ -203,7 +203,6 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) scratch_dir="/tmp/" ) buffer = TensorDictPrioritizedReplayBuffer( - buffer_size=10000, alpha=0.7, beta=0.5, collate_fn=lambda x: x, @@ -327,7 +326,7 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) ```python from torchrl.objectives import DQNLoss loss_module = DQNLoss(value_network=value_network, gamma=0.99) - tensordict = replay_buffer.sample(batch_size) + tensordict = replay_buffer.sample(batch_size)[0] loss = loss_module(tensordict) ``` diff --git a/benchmarks/storage/benchmark_sample_latency_over_rpc.py b/benchmarks/storage/benchmark_sample_latency_over_rpc.py index d922095de5f..669a3795831 100644 --- a/benchmarks/storage/benchmark_sample_latency_over_rpc.py +++ b/benchmarks/storage/benchmark_sample_latency_over_rpc.py @@ -19,7 +19,7 @@ import torch import torch.distributed.rpc as rpc from tensordict import TensorDict -from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer +from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.data.replay_buffers.storages import ( LazyMemmapStorage, diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index b3341904ae4..297c4b856b8 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -19,10 +19,10 @@ widely used replay buffers: TensorDictReplayBuffer TensorDictPrioritizedReplayBuffer -Composable Replay Buffers (Prototype) +Composable Replay Buffers ------------------------------------- -We also provide a prototyped composable replay buffer. +We also give users the ability to compose a replay buffer using the following components: .. autosummary:: :toctree: generated/ diff --git a/examples/distributed/distributed_replay_buffer.py b/examples/distributed/distributed_replay_buffer.py index b36e11a625e..18b89921cb1 100644 --- a/examples/distributed/distributed_replay_buffer.py +++ b/examples/distributed/distributed_replay_buffer.py @@ -16,7 +16,7 @@ import torch import torch.distributed.rpc as rpc from tensordict import TensorDict -from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer +from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.data.replay_buffers.utils import accept_remote_rref_invocation @@ -86,7 +86,7 @@ def train(self, iterations: int) -> None: for iteration in range(iterations): print(f"[{self.id}] Training Iteration: {iteration}") time.sleep(3) - batch = rpc.rpc_sync( + batch, _ = rpc.rpc_sync( self.replay_buffer.owner(), ReplayBufferNode.sample, args=(self.replay_buffer, 16), diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index 83609ff6c96..be300d0c7d7 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -249,7 +249,7 @@ def main(cfg: "DictConfig"): # noqa: F821 for j in range(cfg.optim_steps_per_batch): cmpt += 1 # sample from replay buffer - sampled_tensordict = replay_buffer.sample(cfg.batch_size).to( + sampled_tensordict = replay_buffer.sample(cfg.batch_size)[0].to( device, non_blocking=True ) if reward_normalizer is not None: diff --git a/test/test_rb.py b/test/test_rb.py index 74bce8d770d..0fdf9144a43 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -13,13 +13,14 @@ import torch from _utils_internal import get_available_devices from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase -from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer, TensorDictReplayBuffer -from torchrl.data.replay_buffers import ( - rb_prototype, - samplers, +from torchrl.data import ( + PrioritizedReplayBuffer, + RemoteTensorDictReplayBuffer, + ReplayBuffer, TensorDictPrioritizedReplayBuffer, - writers, + TensorDictReplayBuffer, ) +from torchrl.data.replay_buffers import samplers, writers from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler from torchrl.data.replay_buffers.storages import ( LazyMemmapStorage, @@ -55,9 +56,9 @@ @pytest.mark.parametrize( "rb_type", [ - rb_prototype.ReplayBuffer, - rb_prototype.TensorDictReplayBuffer, - rb_prototype.RemoteTensorDictReplayBuffer, + ReplayBuffer, + TensorDictReplayBuffer, + RemoteTensorDictReplayBuffer, ], ) @pytest.mark.parametrize( @@ -82,11 +83,10 @@ def _get_rb(self, rb_type, size, sampler, writer, storage): return rb def _get_datum(self, rb_type): - if rb_type is rb_prototype.ReplayBuffer: + if rb_type is ReplayBuffer: data = torch.randint(100, (1,)) elif ( - rb_type is rb_prototype.TensorDictReplayBuffer - or rb_type is rb_prototype.RemoteTensorDictReplayBuffer + rb_type is TensorDictReplayBuffer or rb_type is RemoteTensorDictReplayBuffer ): data = TensorDict({"a": torch.randint(100, (1,))}, []) else: @@ -94,11 +94,10 @@ def _get_datum(self, rb_type): return data def _get_data(self, rb_type, size): - if rb_type is rb_prototype.ReplayBuffer: + if rb_type is ReplayBuffer: data = torch.randint(100, (size, 1)) elif ( - rb_type is rb_prototype.TensorDictReplayBuffer - or rb_type is rb_prototype.RemoteTensorDictReplayBuffer + rb_type is TensorDictReplayBuffer or rb_type is RemoteTensorDictReplayBuffer ): data = TensorDict( { @@ -248,7 +247,7 @@ def test_init(self, max_size, shape, storage): def test_prototype_prb(priority_key, contiguous, device): torch.manual_seed(0) np.random.seed(0) - rb = rb_prototype.TensorDictReplayBuffer( + rb = TensorDictReplayBuffer( sampler=samplers.PrioritizedSampler(5, alpha=0.7, beta=0.9), priority_key=priority_key, ) @@ -314,7 +313,7 @@ def test_prototype_prb(priority_key, contiguous, device): @pytest.mark.parametrize("stack", [False, True]) -def test_rb_prototype_trajectories(stack): +def test_replay_buffer_trajectories(stack): traj_td = TensorDict( {"obs": torch.randn(3, 4, 5), "actions": torch.randn(3, 4, 2)}, batch_size=[3, 4], @@ -322,7 +321,7 @@ def test_rb_prototype_trajectories(stack): if stack: traj_td = torch.stack([td.to_tensordict() for td in traj_td], 0) - rb = rb_prototype.TensorDictReplayBuffer( + rb = TensorDictReplayBuffer( sampler=samplers.PrioritizedSampler( 5, alpha=0.7, @@ -383,7 +382,7 @@ def _get_rb(self, rbtype, size, storage, prefetch): params = self._default_params_td_prb else: raise NotImplementedError(rbtype) - rb = rbtype(size=size, storage=storage, prefetch=prefetch, **params) + rb = rbtype(storage=storage, prefetch=prefetch, **params) return rb def _get_datum(self, rbtype): @@ -431,17 +430,17 @@ def test_cursor_position2(self, rbtype, storage, size, prefetch): rb.extend(batch1) # Added less data than storage max size - if size > 5: - assert rb._cursor == 5 + if size > 5 or storage is None: + assert rb._writer._cursor == 5 # Added more data than storage max size elif size < 5: - assert rb._cursor == 5 - size + assert rb._writer._cursor == 5 - size # Added as data as storage max size else: - assert rb._cursor == 0 + assert rb._writer._cursor == 0 batch2 = self._get_data(rbtype, size=size - 1) rb.extend(batch2) - assert rb._cursor == size - 1 + assert rb._writer._cursor == size - 1 def test_add(self, rbtype, storage, size, prefetch): torch.manual_seed(0) @@ -525,10 +524,10 @@ def test_prb(priority_key, contiguous, device): torch.manual_seed(0) np.random.seed(0) rb = TensorDictPrioritizedReplayBuffer( - 5, alpha=0.7, beta=0.9, priority_key=priority_key, + storage=ListStorage(5), ) td1 = TensorDict( source={ @@ -539,7 +538,7 @@ def test_prb(priority_key, contiguous, device): batch_size=[3], ).to(device) rb.extend(td1) - s = rb.sample(2) + s, _ = rb.sample(2) assert s.batch_size == torch.Size( [ 2, @@ -558,7 +557,7 @@ def test_prb(priority_key, contiguous, device): batch_size=[5], ).to(device) rb.extend(td2) - s = rb.sample(5) + s, _ = rb.sample(5) assert s.batch_size == torch.Size([5]) assert (td2[s.get("_idx").squeeze()].get("a") == s.get("a")).all() assert_allclose_td(td2[s.get("_idx").squeeze()].select("a"), s.select("a")) @@ -580,14 +579,14 @@ def test_prb(priority_key, contiguous, device): val = s.get("a")[0] idx0 = s.get("_idx")[0] - rb.update_priority(s) - s = rb.sample(5) + rb.update_tensordict_priority(s) + s, _ = rb.sample(5) assert (val == s.get("a")).sum() >= 1 torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) # test updating values of original td td2.set_("a", torch.ones_like(td2.get("a"))) - s = rb.sample(5) + s, _ = rb.sample(5) torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) @@ -601,16 +600,16 @@ def test_rb_trajectories(stack): traj_td = torch.stack([td.to_tensordict() for td in traj_td], 0) rb = TensorDictPrioritizedReplayBuffer( - 5, alpha=0.7, beta=0.9, priority_key="td_error", + storage=ListStorage(5), ) rb.extend(traj_td) - sampled_td = rb.sample(3) + sampled_td, _ = rb.sample(3) sampled_td.set("td_error", torch.rand(3)) - rb.update_priority(sampled_td) - sampled_td = rb.sample(3, return_weight=True) + rb.update_tensordict_priority(sampled_td) + sampled_td, _ = rb.sample(3, include_info=True) assert (sampled_td.get("_weight") > 0).all() assert sampled_td.batch_size == torch.Size([3]) @@ -630,12 +629,12 @@ def test_shared_storage_prioritized_sampler(): sampler0 = RandomSampler() sampler1 = PrioritizedSampler(max_capacity=n, alpha=0.7, beta=1.1) - rb0 = rb_prototype.ReplayBuffer( + rb0 = ReplayBuffer( storage=storage, writer=writer, sampler=sampler0, ) - rb1 = rb_prototype.ReplayBuffer( + rb1 = ReplayBuffer( storage=storage, writer=writer, sampler=sampler1, @@ -658,25 +657,8 @@ def test_shared_storage_prioritized_sampler(): assert rb1._sampler._sum_tree.query(0, 70) == 50 -def test_legacy_rb_does_not_attach(): - n = 10 - storage = LazyMemmapStorage(n) - writer = RoundRobinWriter() - sampler = RandomSampler() - rb = ReplayBuffer(storage=storage, size=n, prefetch=0) - prb = rb_prototype.ReplayBuffer( - storage=storage, - writer=writer, - sampler=sampler, - ) - - assert len(storage._attached_entities) == 1 - assert prb in storage._attached_entities - assert rb not in storage._attached_entities - - def test_append_transform(): - rb = rb_prototype.ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0)) + rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0)) td = TensorDict( { "observation": torch.randn(2, 4, 3, 16), @@ -700,9 +682,7 @@ def test_init_transform(): -2, -1, in_keys=["observation"], out_keys=["flattened"] ) - rb = rb_prototype.ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=flatten - ) + rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), transform=flatten) td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, []) rb.add(td) @@ -714,9 +694,7 @@ def test_insert_transform(): flatten = FlattenObservation( -2, -1, in_keys=["observation"], out_keys=["flattened"] ) - rb = rb_prototype.ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=flatten - ) + rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), transform=flatten) td = TensorDict({"observation": torch.randn(2, 4, 3, 16, 1)}, []) rb.add(td) @@ -760,7 +738,7 @@ def test_insert_transform(): @pytest.mark.parametrize("transform", transforms) def test_smoke_replay_buffer_transform(transform): - rb = rb_prototype.ReplayBuffer( + rb = ReplayBuffer( transform=transform(in_keys="observation"), ) @@ -783,9 +761,7 @@ def test_smoke_replay_buffer_transform(transform): @pytest.mark.parametrize("transform", transforms) def test_smoke_replay_buffer_transform_no_inkeys(transform): - rb = rb_prototype.ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=transform() - ) + rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), transform=transform()) td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, []) rb.add(td) diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py index 6b3b8482705..3c24a135d2f 100644 --- a/test/test_rb_distributed.py +++ b/test/test_rb_distributed.py @@ -6,7 +6,7 @@ import torch.distributed.rpc as rpc import torch.multiprocessing as mp from tensordict.tensordict import TensorDict -from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer +from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.data.replay_buffers.writers import RoundRobinWriter diff --git a/test/test_trainer.py b/test/test_trainer.py index bd0c8a8ea59..c6e4a8a75d2 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -200,10 +200,11 @@ def test_rb_trainer(self, prioritized): torch.manual_seed(0) trainer = mocking_trainer() S = 100 + storage = ListStorage(S) if prioritized: - replay_buffer = TensorDictPrioritizedReplayBuffer(S, 1.1, 0.9) + replay_buffer = TensorDictPrioritizedReplayBuffer(1.1, 0.9, storage=storage) else: - replay_buffer = TensorDictReplayBuffer(S) + replay_buffer = TensorDictReplayBuffer(storage=storage) N = 9 rb_trainer = ReplayBufferTrainer(replay_buffer=replay_buffer, batch_size=N) @@ -234,11 +235,9 @@ def test_rb_trainer(self, prioritized): if prioritized: for idx in range(min(S, batch)): if idx in td_out.get("index"): - assert replay_buffer._sum_tree[idx] != 1.0 + assert replay_buffer._sampler._sum_tree[idx] != 1.0 else: - assert replay_buffer._sum_tree[idx] == 1.0 - else: - assert "index" not in td_out.keys() + assert replay_buffer._sampler._sum_tree[idx] == 1.0 @pytest.mark.parametrize( "storage_type", @@ -260,14 +259,12 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type): if prioritized: replay_buffer = TensorDictPrioritizedReplayBuffer( - S, 1.1, 0.9, storage=storage, ) else: replay_buffer = TensorDictReplayBuffer( - S, storage=storage, ) @@ -295,18 +292,21 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type): trainer2 = mocking_trainer() if prioritized: replay_buffer2 = TensorDictPrioritizedReplayBuffer( - S, 1.1, 0.9, storage=storage + 1.1, 0.9, storage=storage ) else: - replay_buffer2 = TensorDictReplayBuffer(S, storage=storage) + replay_buffer2 = TensorDictReplayBuffer(storage=storage) N = 9 rb_trainer2 = ReplayBufferTrainer(replay_buffer=replay_buffer2, batch_size=N) rb_trainer2.register(trainer2) sd = trainer.state_dict() trainer2.load_state_dict(sd) - assert rb_trainer2.replay_buffer.cursor > 0 - assert rb_trainer2.replay_buffer.cursor == rb_trainer.replay_buffer.cursor + assert rb_trainer2.replay_buffer._writer._cursor > 0 + assert ( + rb_trainer2.replay_buffer._writer._cursor + == rb_trainer.replay_buffer._writer._cursor + ) if storage_type == "list": assert len(rb_trainer2.replay_buffer._storage._storage) > 0 @@ -397,14 +397,12 @@ def make_storage(): storage = make_storage() if prioritized: replay_buffer = TensorDictPrioritizedReplayBuffer( - S, 1.1, 0.9, storage=storage, ) else: replay_buffer = TensorDictReplayBuffer( - S, storage=storage, ) @@ -431,14 +429,12 @@ def make_storage(): storage2 = make_storage() if prioritized: replay_buffer2 = TensorDictPrioritizedReplayBuffer( - S, 1.1, 0.9, storage=storage2, ) else: replay_buffer2 = TensorDictReplayBuffer( - S, storage=storage2, ) N = 9 diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 44822cbfa7e..5d519234c3f 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -9,6 +9,7 @@ LazyTensorStorage, ListStorage, PrioritizedReplayBuffer, + RemoteTensorDictReplayBuffer, ReplayBuffer, Storage, TensorDictPrioritizedReplayBuffer, diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 53e363855ef..6a9911afa74 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -5,6 +5,7 @@ from .replay_buffers import ( PrioritizedReplayBuffer, + RemoteTensorDictReplayBuffer, ReplayBuffer, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, diff --git a/torchrl/data/replay_buffers/rb_prototype.py b/torchrl/data/replay_buffers/rb_prototype.py deleted file mode 100644 index 8534bba46b1..00000000000 --- a/torchrl/data/replay_buffers/rb_prototype.py +++ /dev/null @@ -1,308 +0,0 @@ -import collections -import threading -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union - -import torch -from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase - -from torchrl.envs.transforms.transforms import Compose, Transform - -from .replay_buffers import pin_memory_output -from .samplers import RandomSampler, Sampler -from .storages import _get_default_collate, ListStorage, Storage -from .utils import _to_numpy, accept_remote_rref_udf_invocation, INT_CLASSES -from .writers import RoundRobinWriter, Writer - - -class ReplayBuffer: - """A generic, composable replay buffer class. - - Args: - storage (Storage, optional): the storage to be used. If none is provided - a default ListStorage with max_size of 1_000 will be created. - sampler (Sampler, optional): the sampler to be used. If none is provided - a default RandomSampler() will be used. - writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. - collate_fn (callable, optional): merges a list of samples to form a - mini-batch of Tensor(s)/outputs. Used when using batched - loading from a map-style dataset. - pin_memory (bool): whether pin_memory() should be called on the rb - samples. - prefetch (int, optional): number of next batches to be prefetched - using multithreading. - transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. - """ - - def __init__( - self, - storage: Optional[Storage] = None, - sampler: Optional[Sampler] = None, - writer: Optional[Writer] = None, - collate_fn: Optional[Callable] = None, - pin_memory: bool = False, - prefetch: Optional[int] = None, - transform: Optional[Transform] = None, - ) -> None: - self._storage = storage if storage is not None else ListStorage(max_size=1_000) - self._storage.attach(self) - self._sampler = sampler if sampler is not None else RandomSampler() - self._writer = writer if writer is not None else RoundRobinWriter() - self._writer.register_storage(self._storage) - - self._collate_fn = ( - collate_fn - if collate_fn is not None - else _get_default_collate(self._storage) - ) - self._pin_memory = pin_memory - - self._prefetch = bool(prefetch) - self._prefetch_cap = prefetch or 0 - self._prefetch_queue = collections.deque() - if self._prefetch_cap: - self._prefetch_executor = ThreadPoolExecutor(max_workers=self._prefetch_cap) - - self._replay_lock = threading.RLock() - self._futures_lock = threading.RLock() - if transform is None: - transform = Compose() - elif not isinstance(transform, Compose): - transform = Compose(transform) - transform.eval() - self._transform = transform - - def __len__(self) -> int: - with self._replay_lock: - return len(self._storage) - - def __repr__(self) -> str: - return ( - f"{type(self).__name__}(" - f"storage={self._storage}, " - f"sampler={self._sampler}, " - f"writer={self._writer}" - ")" - ) - - @pin_memory_output - def __getitem__(self, index: Union[int, torch.Tensor]) -> Any: - index = _to_numpy(index) - with self._replay_lock: - data = self._storage[index] - - if not isinstance(index, INT_CLASSES): - data = self._collate_fn(data) - - return data - - def add(self, data: Any) -> int: - """Add a single element to the replay buffer. - - Args: - data (Any): data to be added to the replay buffer - - Returns: - index where the data lives in the replay buffer. - """ - with self._replay_lock: - index = self._writer.add(data) - self._sampler.add(index) - return index - - def extend(self, data: Sequence) -> torch.Tensor: - """Extends the replay buffer with one or more elements contained in an iterable. - - Args: - data (iterable): collection of data to be added to the replay - buffer. - - Returns: - Indices of the data aded to the replay buffer. - """ - with self._replay_lock: - index = self._writer.extend(data) - self._sampler.extend(index) - return index - - def update_priority( - self, - index: Union[int, torch.Tensor], - priority: Union[int, torch.Tensor], - ) -> None: - with self._replay_lock: - self._sampler.update_priority(index, priority) - - @pin_memory_output - def _sample(self, batch_size: int) -> Tuple[Any, dict]: - with self._replay_lock: - index, info = self._sampler.sample(self._storage, batch_size) - data = self._storage[index] - if not isinstance(index, INT_CLASSES): - data = self._collate_fn(data) - data = self._transform(data) - return data, info - - def sample(self, batch_size: int) -> Tuple[Any, dict]: - """Samples a batch of data from the replay buffer. - - Uses Sampler to sample indices, and retrieves them from Storage. - - Args: - batch_size (int): size of data to be collected. - - Returns: - A batch of data selected in the replay buffer. - """ - if not self._prefetch: - return self._sample(batch_size) - - if len(self._prefetch_queue) == 0: - ret = self._sample(batch_size) - else: - with self._futures_lock: - ret = self._prefetch_queue.popleft().result() - - with self._futures_lock: - while len(self._prefetch_queue) < self._prefetch_cap: - fut = self._prefetch_executor.submit(self._sample, batch_size) - self._prefetch_queue.append(fut) - - return ret - - def mark_update(self, index: Union[int, torch.Tensor]) -> None: - self._sampler.mark_update(index) - - def append_transform(self, transform: Transform) -> None: - """Appends transform at the end. - - Transforms are applied in order when `sample` is called. - - Args: - transform (Transform): The transform to be appended - """ - transform.eval() - self._transform.append(transform) - - def insert_transform(self, index: int, transform: Transform) -> None: - """Inserts transform. - - Transforms are executed in order when `sample` is called. - - Args: - index (int): Position to insert the transform. - transform (Transform): The transform to be appended - """ - transform.eval() - self._transform.insert(index, transform) - - -class TensorDictReplayBuffer(ReplayBuffer): - """TensorDict-specific wrapper around the ReplayBuffer class. - - Args: - priority_key (str): the key at which priority is assumed to be stored - within TensorDicts added to this ReplayBuffer. - """ - - def __init__(self, priority_key: str = "td_error", **kw) -> None: - super().__init__(**kw) - self.priority_key = priority_key - - def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]: - if self.priority_key not in tensordict.keys(): - return self._sampler.default_priority - if tensordict.batch_dims: - tensordict = tensordict.clone(recurse=False) - tensordict.batch_size = [] - try: - priority = tensordict.get(self.priority_key).item() - except ValueError: - raise ValueError( - f"Found a priority key of size" - f" {tensordict.get(self.priority_key).shape} but expected " - f"scalar value" - ) - return priority - - def add(self, data: TensorDictBase) -> int: - index = super().add(data) - data.set("index", index) - - priority = self._get_priority(data) - if priority: - self.update_priority(index, priority) - return index - - def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: - if isinstance(tensordicts, TensorDictBase): - if tensordicts.batch_dims > 1: - # we want the tensordict to have one dimension only. The batch size - # of the sampled tensordicts can be changed thereafter - if not isinstance(tensordicts, LazyStackedTensorDict): - tensordicts = tensordicts.clone(recurse=False) - else: - tensordicts = tensordicts.contiguous() - tensordicts.batch_size = tensordicts.batch_size[:1] - tensordicts.set( - "index", - torch.zeros( - tensordicts.shape, device=tensordicts.device, dtype=torch.int - ), - ) - - if not isinstance(tensordicts, TensorDictBase): - stacked_td = torch.stack(tensordicts, 0) - else: - stacked_td = tensordicts - - index = super().extend(stacked_td) - stacked_td.set( - "index", - torch.tensor(index, dtype=torch.int, device=stacked_td.device), - inplace=True, - ) - self.update_tensordict_priority(stacked_td) - return index - - def update_tensordict_priority(self, data: TensorDictBase) -> None: - priority = torch.tensor( - [self._get_priority(td) for td in data], - dtype=torch.float, - device=data.device, - ) - self.update_priority(data.get("index"), priority) - - def sample(self, batch_size: int, include_info: bool = False) -> TensorDictBase: - data, info = super().sample(batch_size) - if include_info: - for k, v in info.items(): - data.set(k, torch.tensor(v, device=data.device), inplace=True) - return data, info - - -@accept_remote_rref_udf_invocation -class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer): - """A remote invocation friendly ReplayBuffer class. Public methods can be invoked by remote agents using `torch.rpc` or called locally as normal.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def sample(self, batch_size: int, include_info: bool = False) -> TensorDictBase: - return super().sample(batch_size, include_info) - - def add(self, data: TensorDictBase) -> int: - return super().add(data) - - def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: - return super().extend(tensordicts) - - def update_priority( - self, index: Union[int, torch.Tensor], priority: Union[int, torch.Tensor] - ) -> None: - return super().update_priority(index, priority) - - def update_tensordict_priority(self, data: TensorDictBase) -> None: - return super().update_tensordict_priority(data) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 053384e6bb8..fcc28e08c58 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -4,30 +4,24 @@ # LICENSE file in the root directory of this source tree. import collections -import concurrent.futures import threading -from copy import deepcopy +from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -import numpy as np import torch from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase -from torch import Tensor - -from torchrl._torchrl import ( - MinSegmentTreeFp32, - MinSegmentTreeFp64, - SumSegmentTreeFp32, - SumSegmentTreeFp64, -) -from torchrl.data.replay_buffers.storages import ( - _get_default_collate, - ListStorage, - Storage, -) -from torchrl.data.replay_buffers.utils import _to_numpy, _to_torch, INT_CLASSES + from torchrl.data.utils import DEVICE_TYPING +# from torchrl.envs.transforms.transforms import Compose, Transform # FIXME(kamilpi) +Compose = "Compose" +Transform = "Transform" + +from .samplers import PrioritizedSampler, RandomSampler, Sampler +from .storages import _get_default_collate, ListStorage, Storage +from .utils import _to_numpy, accept_remote_rref_udf_invocation, INT_CLASSES +from .writers import RoundRobinWriter, Writer + def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]: """Zips a list of iterables containing tensor-like objects and stacks the resulting lists of tensors together. @@ -86,10 +80,15 @@ def decorated_fun(self, *args, **kwargs): class ReplayBuffer: - """Circular replay buffer. + """A generic, composable replay buffer class. Args: - size (int): integer indicating the maximum size of the replay buffer. + storage (Storage, optional): the storage to be used. If none is provided + a default ListStorage with max_size of 1_000 will be created. + sampler (Sampler, optional): the sampler to be used. If none is provided + a default RandomSampler() will be used. + writer (Writer, optional): the writer to be used. If none is provided + a default RoundRobinWriter() will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -97,23 +96,26 @@ class ReplayBuffer: samples. prefetch (int, optional): number of next batches to be prefetched using multithreading. - storage (Storage, optional): the storage to be used. If none is provided, - a ListStorage will be instantiated. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. """ def __init__( self, - size: int, + storage: Optional[Storage] = None, + sampler: Optional[Sampler] = None, + writer: Optional[Writer] = None, collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, - storage: Optional[Storage] = None, - ): - if storage is None: - storage = ListStorage(size) - self._storage = storage - self._capacity = size - self._cursor = 0 + transform: Optional[Transform] = None, + ) -> None: + self._storage = storage if storage is not None else ListStorage(max_size=1_000) + self._storage.attach(self) + self._sampler = sampler if sampler is not None else RandomSampler() + self._writer = writer if writer is not None else RoundRobinWriter() + self._writer.register_storage(self._storage) + self._collate_fn = ( collate_fn if collate_fn is not None @@ -121,50 +123,58 @@ def __init__( ) self._pin_memory = pin_memory - self._prefetch = prefetch is not None and prefetch > 0 - self._prefetch_cap = prefetch if prefetch is not None else 0 - self._prefetch_fut = collections.deque() - if self._prefetch_cap > 0: - self._prefetch_executor = concurrent.futures.ThreadPoolExecutor( - max_workers=self._prefetch_cap - ) + self._prefetch = bool(prefetch) + self._prefetch_cap = prefetch or 0 + self._prefetch_queue = collections.deque() + if self._prefetch_cap: + self._prefetch_executor = ThreadPoolExecutor(max_workers=self._prefetch_cap) self._replay_lock = threading.RLock() - self._future_lock = threading.RLock() + self._futures_lock = threading.RLock() + from torchrl.envs.transforms.transforms import Compose # FIXME(kamilpi) + + if transform is None: + transform = Compose() + elif not isinstance(transform, Compose): + transform = Compose(transform) + transform.eval() + self._transform = transform def __len__(self) -> int: with self._replay_lock: return len(self._storage) + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"storage={self._storage}, " + f"sampler={self._sampler}, " + f"writer={self._writer}" + ")" + ) + @pin_memory_output - def __getitem__(self, index: Union[int, Tensor]) -> Any: + def __getitem__(self, index: Union[int, torch.Tensor]) -> Any: index = _to_numpy(index) - with self._replay_lock: data = self._storage[index] if not isinstance(index, INT_CLASSES): data = self._collate_fn(data) + return data def state_dict(self) -> Dict[str, Any]: return { "_storage": self._storage.state_dict(), - "_cursor": self._cursor, + "_sampler": self._sampler.state_dict(), + "_writer": self._writer.state_dict(), } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._storage.load_state_dict(state_dict["_storage"]) - self._cursor = state_dict["_cursor"] - - @property - def capacity(self) -> int: - return self._capacity - - @property - def cursor(self) -> int: - with self._replay_lock: - return self._cursor + self._sampler.load_state_dict(state_dict["_sampler"]) + self._writer.load_state_dict(state_dict["_writer"]) def add(self, data: Any) -> int: """Add a single element to the replay buffer. @@ -176,12 +186,11 @@ def add(self, data: Any) -> int: index where the data lives in the replay buffer. """ with self._replay_lock: - ret = self._cursor - self._storage[self._cursor] = data - self._cursor = (self._cursor + 1) % self._capacity - return ret + index = self._writer.add(data) + self._sampler.add(index) + return index - def extend(self, data: Sequence[Any]): + def extend(self, data: Sequence) -> torch.Tensor: """Extends the replay buffer with one or more elements contained in an iterable. Args: @@ -190,66 +199,82 @@ def extend(self, data: Sequence[Any]): Returns: Indices of the data aded to the replay buffer. - """ - if not len(data): - raise Exception("extending with empty data is not supported") with self._replay_lock: - batch_size = len(data) - if self._cursor + batch_size <= self._capacity: - index = np.arange(self._cursor, self._cursor + batch_size) - self._cursor = (self._cursor + batch_size) % self._capacity - else: - d = self._capacity - self._cursor - index = np.empty(batch_size, dtype=np.int64) - index[:d] = np.arange(self._cursor, self._capacity) - index[d:] = np.arange(batch_size - d) - self._cursor = batch_size - d - # storage must convert the data to the appropriate format if needed - self._storage[index] = data - return index + index = self._writer.extend(data) + self._sampler.extend(index) + return index - @pin_memory_output - def _sample(self, batch_size: int) -> Any: - index = torch.randint(0, len(self._storage), (batch_size,)) + def update_priority( + self, + index: Union[int, torch.Tensor], + priority: Union[int, torch.Tensor], + ) -> None: + with self._replay_lock: + self._sampler.update_priority(index, priority) + @pin_memory_output + def _sample(self, batch_size: int) -> Tuple[Any, dict]: with self._replay_lock: + index, info = self._sampler.sample(self._storage, batch_size) data = self._storage[index] + if not isinstance(index, INT_CLASSES): + data = self._collate_fn(data) + data = self._transform(data) + return data, info - data = self._collate_fn(data) - return data - - def sample(self, batch_size: int) -> Any: + def sample(self, batch_size: int) -> Tuple[Any, dict]: """Samples a batch of data from the replay buffer. + Uses Sampler to sample indices, and retrieves them from Storage. + Args: - batch_size (int): float of data to be collected. + batch_size (int): size of data to be collected. Returns: - A batch of data randomly selected in the replay buffer. - + A batch of data selected in the replay buffer. """ if not self._prefetch: return self._sample(batch_size) - with self._future_lock: - if len(self._prefetch_fut) == 0: - ret = self._sample(batch_size) - else: - ret = self._prefetch_fut.popleft().result() + if len(self._prefetch_queue) == 0: + ret = self._sample(batch_size) + else: + with self._futures_lock: + ret = self._prefetch_queue.popleft().result() - while len(self._prefetch_fut) < self._prefetch_cap: + with self._futures_lock: + while len(self._prefetch_queue) < self._prefetch_cap: fut = self._prefetch_executor.submit(self._sample, batch_size) - self._prefetch_fut.append(fut) + self._prefetch_queue.append(fut) - return ret + return ret - def __repr__(self) -> str: - string = ( - f"{type(self).__name__}(size={len(self)}, " - f"pin_memory={self._pin_memory})" - ) - return string + def mark_update(self, index: Union[int, torch.Tensor]) -> None: + self._sampler.mark_update(index) + + def append_transform(self, transform: Transform) -> None: + """Appends transform at the end. + + Transforms are applied in order when `sample` is called. + + Args: + transform (Transform): The transform to be appended + """ + transform.eval() + self._transform.append(transform) + + def insert_transform(self, index: int, transform: Transform) -> None: + """Inserts transform. + + Transforms are executed in order when `sample` is called. + + Args: + index (int): Position to insert the transform. + transform (Transform): The transform to be appended + """ + transform.eval() + self._transform.insert(index, transform) class PrioritizedReplayBuffer(ReplayBuffer): @@ -261,12 +286,14 @@ class PrioritizedReplayBuffer(ReplayBuffer): (https://arxiv.org/abs/1511.05952) Args: - size (int): integer indicating the maximum size of the replay buffer. alpha (float): exponent α determines how much prioritization is used, with α = 0 corresponding to the uniform case. beta (float): importance sampling negative exponent. eps (float): delta added to the priorities to ensure that the buffer does not contain null priorities. + dtype (torch.dtype): type of the data. Can be torch.float or torch.double. + storage (Storage, optional): the storage to be used. If none is provided + a default ListStorage with max_size of 1_000 will be created. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -274,286 +301,138 @@ class PrioritizedReplayBuffer(ReplayBuffer): samples. prefetch (int, optional): number of next batches to be prefetched using multithreading. - storage (Storage, optional): the storage to be used. If none is provided, - a ListStorage will be instantiated. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. """ def __init__( self, - size: int, alpha: float, beta: float, eps: float = 1e-8, dtype: torch.dtype = torch.float, - collate_fn=None, + storage: Optional[Storage] = None, + collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, - storage: Optional[Storage] = None, + transform: Optional[Transform] = None, ) -> None: + if storage is None: + storage = ListStorage(max_size=1_000) + sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype) super(PrioritizedReplayBuffer, self).__init__( - size, - collate_fn, - pin_memory, - prefetch, storage=storage, + sampler=sampler, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + transform=transform, ) - if alpha <= 0: - raise ValueError( - f"alpha must be strictly greater than 0, got alpha={alpha}" - ) - if beta < 0: - raise ValueError(f"beta must be greater or equal to 0, got beta={beta}") - - self._alpha = alpha - self._beta = beta - self._eps = eps - if dtype in (torch.float, torch.FloatType, torch.float32): - self._sum_tree = SumSegmentTreeFp32(size) - self._min_tree = MinSegmentTreeFp32(size) - elif dtype in (torch.double, torch.DoubleTensor, torch.float64): - self._sum_tree = SumSegmentTreeFp64(size) - self._min_tree = MinSegmentTreeFp64(size) - else: - raise NotImplementedError( - f"dtype {dtype} not supported by PrioritizedReplayBuffer" - ) - self._max_priority = 1.0 - def state_dict(self) -> Dict[str, Any]: - state_dict = super().state_dict() - state_dict["_sum_tree"] = deepcopy(self._sum_tree) - state_dict["_min_tree"] = deepcopy(self._min_tree) - state_dict["_max_priority"] = self._max_priority - return state_dict - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self._sum_tree = state_dict.pop("_sum_tree") - self._min_tree = state_dict.pop("_min_tree") - self._max_priority = state_dict.pop("_max_priority") - super().load_state_dict(state_dict) - - @pin_memory_output - def __getitem__(self, index: Union[int, Tensor]) -> Any: - index = _to_numpy(index) - - with self._replay_lock: - p_min = self._min_tree.query(0, self._capacity) - if p_min <= 0: - raise ValueError(f"p_min must be greater than 0, got p_min={p_min}") - data = self._storage[index] - if isinstance(index, INT_CLASSES): - weight = np.array(self._sum_tree[index]) - else: - weight = self._sum_tree[index] - - if not isinstance(index, INT_CLASSES): - data = self._collate_fn(data) - # weight = np.power(weight / (p_min + self._eps), -self._beta) - weight = np.power(weight / p_min, -self._beta) - # x = first_field(data) - # if isinstance(x, torch.Tensor): - device = data.device if hasattr(data, "device") else torch.device("cpu") - weight = _to_torch(weight, device, self._pin_memory) - return data, weight - - @property - def alpha(self) -> float: - return self._alpha - - @property - def beta(self) -> float: - return self._beta - - @property - def eps(self) -> float: - return self._eps - - @property - def max_priority(self) -> float: - with self._replay_lock: - return self._max_priority +class TensorDictReplayBuffer(ReplayBuffer): + """TensorDict-specific wrapper around the ReplayBuffer class. - @property - def _default_priority(self) -> float: - return (self._max_priority + self._eps) ** self._alpha + Args: + priority_key (str): the key at which priority is assumed to be stored + within TensorDicts added to this ReplayBuffer. + """ - def _add_or_extend( - self, - data: Any, - priority: Optional[torch.Tensor] = None, - do_add: bool = True, - ) -> torch.Tensor: - if priority is not None: - priority = _to_numpy(priority) - max_priority = np.max(priority) - with self._replay_lock: - self._max_priority = max(self._max_priority, max_priority) - priority = np.power(priority + self._eps, self._alpha) - else: - with self._replay_lock: - priority = self._default_priority + def __init__(self, priority_key: str = "td_error", **kw) -> None: + super().__init__(**kw) + self.priority_key = priority_key - if do_add: - index = super(PrioritizedReplayBuffer, self).add(data) - else: - index = super(PrioritizedReplayBuffer, self).extend(data) - - if not ( - isinstance(priority, float) - or len(priority) == 1 - or len(priority) == len(index) - ): - raise RuntimeError( - "priority should be a scalar or an iterable of the same " - "length as index" + def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]: + if self.priority_key not in tensordict.keys(): + return self._sampler.default_priority + if tensordict.batch_dims: + tensordict = tensordict.clone(recurse=False) + tensordict.batch_size = [] + try: + priority = tensordict.get(self.priority_key).item() + except ValueError: + raise ValueError( + f"Found a priority key of size" + f" {tensordict.get(self.priority_key).shape} but expected " + f"scalar value" ) + return priority - with self._replay_lock: - self._sum_tree[index] = priority - self._min_tree[index] = priority + def add(self, data: TensorDictBase) -> int: + index = super().add(data) + data.set("index", index) + priority = self._get_priority(data) + if priority: + self.update_priority(index, priority) return index - def add(self, data: Any, priority: Optional[torch.Tensor] = None) -> torch.Tensor: - return self._add_or_extend(data, priority, True) - - def extend( - self, data: Sequence, priority: Optional[torch.Tensor] = None - ) -> torch.Tensor: - return self._add_or_extend(data, priority, False) - - @pin_memory_output - def _sample(self, batch_size: int) -> Tuple[Any, torch.Tensor, torch.Tensor]: - with self._replay_lock: - p_sum = self._sum_tree.query(0, self._capacity) - p_min = self._min_tree.query(0, self._capacity) - if p_sum <= 0: - raise RuntimeError("negative p_sum") - if p_min <= 0: - raise RuntimeError("negative p_min") - mass = np.random.uniform(0.0, p_sum, size=batch_size) - index = self._sum_tree.scan_lower_bound(mass) - if not isinstance(index, torch.Tensor): - index = torch.tensor(index) - if not index.ndimension(): - index = index.reshape((1,)) - index.clamp_max_(len(self._storage) - 1) - data = self._storage[index] - weight = self._sum_tree[index] - - data = self._collate_fn(data) - - # Importance sampling weight formula: - # w_i = (p_i / sum(p) * N) ^ (-beta) - # weight_i = w_i / max(w) - # weight_i = (p_i / sum(p) * N) ^ (-beta) / - # ((min(p) / sum(p) * N) ^ (-beta)) - # weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta) - # weight_i = (p_i / min(p)) ^ (-beta) - # weight = np.power(weight / (p_min + self._eps), -self._beta) - weight = np.power(weight / p_min, -self._beta) - - # x = first_field(data) # avoid calling tree.flatten - # if isinstance(x, torch.Tensor): - device = data.device if hasattr(data, "device") else torch.device("cpu") - weight = _to_torch(weight, device, self._pin_memory) - return data, weight, index - - def sample(self, batch_size: int) -> Tuple[Any, np.ndarray, torch.Tensor]: - """Gathers a batch of data according to the non-uniform multinomial distribution with weights computed with the provided priorities of each input. - - Args: - batch_size (int): float of data to be collected. - - Returns: a random sample from the replay buffer. - - """ - if not self._prefetch: - return self._sample(batch_size) - - with self._future_lock: - if len(self._prefetch_fut) == 0: - ret = self._sample(batch_size) - else: - ret = self._prefetch_fut.popleft().result() - - while len(self._prefetch_fut) < self._prefetch_cap: - fut = self._prefetch_executor.submit(self._sample, batch_size) - self._prefetch_fut.append(fut) - - return ret - - def update_priority( - self, index: Union[int, Tensor], priority: Union[float, Tensor] - ) -> None: - """Updates the priority of the data pointed by the index. - - Args: - index (int or torch.Tensor): indexes of the priorities to be - updated. - priority (Number or torch.Tensor): new priorities of the - indexed elements - + def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: + if isinstance(tensordicts, TensorDictBase): + if tensordicts.batch_dims > 1: + # we want the tensordict to have one dimension only. The batch size + # of the sampled tensordicts can be changed thereafter + if not isinstance(tensordicts, LazyStackedTensorDict): + tensordicts = tensordicts.clone(recurse=False) + else: + tensordicts = tensordicts.contiguous() + tensordicts.batch_size = tensordicts.batch_size[:1] + tensordicts.set( + "index", + torch.zeros( + tensordicts.shape, device=tensordicts.device, dtype=torch.int + ), + ) - """ - if isinstance(index, INT_CLASSES): - if not isinstance(priority, float): - if len(priority) != 1: - raise RuntimeError( - f"priority length should be 1, got {len(priority)}" - ) - priority = priority.item() + if not isinstance(tensordicts, TensorDictBase): + stacked_td = torch.stack(tensordicts, 0) else: - if not ( - isinstance(priority, float) - or len(priority) == 1 - or len(index) == len(priority) - ): - raise RuntimeError( - "priority should be a number or an iterable of the same " - "length as index" - ) - index = _to_numpy(index) - priority = _to_numpy(priority) - - with self._replay_lock: - self._max_priority = max(self._max_priority, np.max(priority)) - priority = np.power(priority + self._eps, self._alpha) - self._sum_tree[index] = priority - self._min_tree[index] = priority + stacked_td = tensordicts + index = super().extend(stacked_td) + stacked_td.set( + "index", + torch.tensor(index, dtype=torch.int, device=stacked_td.device), + inplace=True, + ) + self.update_tensordict_priority(stacked_td) + return index -class TensorDictReplayBuffer(ReplayBuffer): - """TensorDict-specific wrapper around the ReplayBuffer class.""" + def update_tensordict_priority(self, data: TensorDictBase) -> None: + priority = torch.tensor( + [self._get_priority(td) for td in data], + dtype=torch.float, + device=data.device, + ) + self.update_priority(data.get("index"), priority) - def __init__( - self, - size: int, - collate_fn: Optional[Callable] = None, - pin_memory: bool = False, - prefetch: Optional[int] = None, - storage: Optional[Storage] = None, - ): - super().__init__(size, collate_fn, pin_memory, prefetch, storage=storage) + def sample(self, batch_size: int, include_info: bool = False) -> Tuple[TensorDictBase, dict]: + data, info = super().sample(batch_size) + if include_info: + for k, v in info.items(): + data.set(k, torch.tensor(v, device=data.device), inplace=True) + return data, info -class TensorDictPrioritizedReplayBuffer(PrioritizedReplayBuffer): +class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): """TensorDict-specific wrapper around the PrioritizedReplayBuffer class. This class returns tensordicts with a new key "index" that represents - the index of each element in the replay buffer. It also facilitates the - call to the 'update_priority' method, as it only requires for the + the index of each element in the replay buffer. It also provides the + 'update_tensordict_priority' method that only requires for the tensordict to be passed to it with its new priority value. Args: - size (int): integer indicating the maximum size of the replay buffer. - alpha (flaot): exponent α determines how much prioritization is + alpha (float): exponent α determines how much prioritization is used, with α = 0 corresponding to the uniform case. beta (float): importance sampling negative exponent. priority_key (str, optional): key where the priority value can be found in the stored tensordicts. Default is :obj:`"td_error"` eps (float, optional): delta added to the priorities to ensure that the buffer does not contain null priorities. + dtype (torch.dtype): type of the data. Can be torch.float or torch.double. + storage (Storage, optional): the storage to be used. If none is provided + a default ListStorage with max_size of 1_000 will be created. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. @@ -561,131 +440,59 @@ class TensorDictPrioritizedReplayBuffer(PrioritizedReplayBuffer): the rb samples. Default is :obj:`False`. prefetch (int, optional): number of next batches to be prefetched using multithreading. - storage (Storage, optional): the storage to be used. If none is provided, - a ListStorage will be instantiated. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. """ def __init__( self, - size: int, alpha: float, beta: float, priority_key: str = "td_error", eps: float = 1e-8, - collate_fn=None, + storage: Optional[Storage] = None, + collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, - storage: Optional[Storage] = None, + transform: Optional[Transform] = None, ) -> None: + if storage is None: + storage = ListStorage(max_size=1_000) + sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps) super(TensorDictPrioritizedReplayBuffer, self).__init__( - size=size, - alpha=alpha, - beta=beta, - eps=eps, + priority_key=priority_key, + storage=storage, + sampler=sampler, collate_fn=collate_fn, pin_memory=pin_memory, prefetch=prefetch, - storage=storage, - ) - self.priority_key = priority_key - - def _get_priority(self, tensordict: TensorDictBase) -> torch.Tensor: - if self.priority_key in tensordict.keys(): - if tensordict.batch_dims: - tensordict = tensordict.clone(recurse=False) - tensordict.batch_size = [] - try: - priority = tensordict.get(self.priority_key).item() - except ValueError: - raise ValueError( - f"Found a priority key of size" - f" {tensordict.get(self.priority_key).shape} but expected " - f"scalar value" - ) - else: - priority = self._default_priority - return priority - - def add(self, tensordict: TensorDictBase) -> torch.Tensor: - priority = self._get_priority(tensordict) - index = super().add(tensordict, priority) - tensordict.set("index", index) - return index - - def extend( - self, tensordicts: Union[TensorDictBase, List[TensorDictBase]] - ) -> torch.Tensor: - if isinstance(tensordicts, TensorDictBase): - if self.priority_key in tensordicts.keys(): - priorities = tensordicts.get(self.priority_key) - else: - priorities = None - - if tensordicts.batch_dims > 1: - # we want the tensordict to have one dimension only. The batch size - # of the sampled tensordicts can be changed thereafter - if not isinstance(tensordicts, LazyStackedTensorDict): - tensordicts = tensordicts.clone(recurse=False) - else: - tensordicts = tensordicts.contiguous() - tensordicts.batch_size = tensordicts.batch_size[:1] - tensordicts.set( - "index", - torch.zeros( - tensordicts.shape, - device=tensordicts.device, - dtype=torch.int, - ), - ) - else: - priorities = [self._get_priority(td) for td in tensordicts] - - if not isinstance(tensordicts, TensorDictBase): - stacked_td = torch.stack(tensordicts, 0) - else: - stacked_td = tensordicts - idx = super().extend(tensordicts, priorities) - stacked_td.set( - "index", - torch.tensor(idx, dtype=torch.int, device=stacked_td.device), - inplace=True, + transform=transform, ) - return idx - def update_priority(self, tensordict: TensorDictBase) -> None: - """Updates the priorities of the tensordicts stored in the replay buffer. - Args: - tensordict: tensordict with key-value pairs 'self.priority_key' - and 'index'. +@accept_remote_rref_udf_invocation +class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer): + """A remote invocation friendly ReplayBuffer class. Public methods can be invoked by remote agents using `torch.rpc` or called locally as normal.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - """ - priority = tensordict.get(self.priority_key) - if (priority < 0).any(): - raise RuntimeError( - f"Priority must be a positive value, got " - f"{(priority < 0).sum()} negative priority values." - ) - return super().update_priority(tensordict.get("index"), priority=priority) + def sample(self, batch_size: int, include_info: bool = False) -> TensorDictBase: + return super().sample(batch_size, include_info) - def sample(self, size: int, return_weight: bool = False) -> TensorDictBase: - """Gather a batch of tensordicts according to the non-uniform multinomial distribution with weights computed with the priority_key of each input tensordict. + def add(self, data: TensorDictBase) -> int: + return super().add(data) - Args: - size (int): size of the batch to be returned - return_weight (bool, optional): if True, a '_weight' key will be - written in the output tensordict that indicates the weight - of the selected items + def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: + return super().extend(tensordicts) - Returns: - Stack of tensordicts + def update_priority( + self, index: Union[int, torch.Tensor], priority: Union[int, torch.Tensor] + ) -> None: + return super().update_priority(index, priority) - """ - td, weight, _ = super(TensorDictPrioritizedReplayBuffer, self).sample(size) - if return_weight: - td.set("_weight", weight) - return td + def update_tensordict_priority(self, data: TensorDictBase) -> None: + return super().update_tensordict_priority(data) class InPlaceSampler: diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 2bb159d0b8d..a4402c1cc04 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Any, Tuple, Union +from copy import deepcopy +from typing import Any, Dict, Tuple, Union import numpy as np import torch @@ -40,6 +41,12 @@ def mark_update(self, index: Union[int, torch.Tensor]) -> None: def default_priority(self) -> float: return 1.0 + def state_dict(self) -> Dict[str, Any]: + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + return + class RandomSampler(Sampler): """A uniformly random sampler for composable replay buffers.""" @@ -92,7 +99,7 @@ def __init__( self._min_tree = MinSegmentTreeFp64(self._max_capacity) else: raise NotImplementedError( - f"dtype {dtype} not supported by PrioritizedReplayBuffer" + f"dtype {dtype} not supported by PrioritizedSampler" ) self._max_priority = 1.0 @@ -189,3 +196,21 @@ def update_priority( def mark_update(self, index: Union[int, torch.Tensor]) -> None: self.update_priority(index, self.default_priority) + + def state_dict(self) -> Dict[str, Any]: + return { + "_alpha": self._alpha, + "_beta": self._beta, + "_eps": self._eps, + "_max_priority": self._max_priority, + "_sum_tree": deepcopy(self._sum_tree), + "_min_tree": deepcopy(self._min_tree), + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._alpha = state_dict["_alpha"] + self._beta = state_dict["_beta"] + self._eps = state_dict["_eps"] + self._max_priority = state_dict["_max_priority"] + self._sum_tree = state_dict.pop("_sum_tree") + self._min_tree = state_dict.pop("_min_tree") diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index f058dd32f2d..6c9c62a6117 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Sequence +from typing import Any, Dict, Sequence import numpy as np import torch @@ -26,6 +26,12 @@ def extend(self, data: Sequence) -> torch.Tensor: """Inserts a series of data points at appropriate indices, and returns a tensor containing the indices.""" raise NotImplementedError + def state_dict(self) -> Dict[str, Any]: + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + return + class RoundRobinWriter(Writer): """A RoundRobin Writer class for composable replay buffers.""" @@ -55,3 +61,9 @@ def extend(self, data: Sequence) -> torch.Tensor: # storage must convert the data to the appropriate format if needed self._storage[index] = data return index + + def state_dict(self) -> Dict[str, Any]: + return {"_cursor": self._cursor} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._cursor = state_dict["_cursor"] diff --git a/torchrl/trainers/helpers/replay_buffer.py b/torchrl/trainers/helpers/replay_buffer.py index c7bbee12d82..4f9c48bf4b9 100644 --- a/torchrl/trainers/helpers/replay_buffer.py +++ b/torchrl/trainers/helpers/replay_buffer.py @@ -7,11 +7,8 @@ import torch -from torchrl.data import ( - ReplayBuffer, - TensorDictPrioritizedReplayBuffer, - TensorDictReplayBuffer, -) +from torchrl.data import ReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.data.utils import DEVICE_TYPING @@ -22,29 +19,23 @@ def make_replay_buffer( """Builds a replay buffer using the config built from ReplayArgsConfig.""" device = torch.device(device) if not cfg.prb: - buffer = TensorDictReplayBuffer( - cfg.buffer_size, - pin_memory=device != torch.device("cpu"), - prefetch=cfg.buffer_prefetch, - storage=LazyMemmapStorage( - cfg.buffer_size, - scratch_dir=cfg.buffer_scratch_dir, - # device=device, # when using prefetch, this can overload the GPU memory - ), - ) + sampler = RandomSampler() else: - buffer = TensorDictPrioritizedReplayBuffer( - cfg.buffer_size, + sampler = PrioritizedSampler( + max_capacity=cfg.buffer_size, alpha=0.7, beta=0.5, - pin_memory=device != torch.device("cpu"), - prefetch=cfg.buffer_prefetch, - storage=LazyMemmapStorage( - cfg.buffer_size, - scratch_dir=cfg.buffer_scratch_dir, - # device=device, # when using prefetch, this can overload the GPU memory - ), ) + buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage( + cfg.buffer_size, + scratch_dir=cfg.buffer_scratch_dir, + # device=device, # when using prefetch, this can overload the GPU memory + ), + sampler=sampler, + pin_memory=device != torch.device("cpu"), + prefetch=cfg.buffer_prefetch, + ) return buffer diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 73691357c8d..5ffbb435410 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -129,7 +129,7 @@ def make_trainer( >>> recorder = env_proof >>> target_net_updater = None >>> policy_exploration = EGreedyWrapper(policy) - >>> replay_buffer = TensorDictReplayBuffer(1000) + >>> replay_buffer = TensorDictReplayBuffer() >>> dir = tempfile.gettempdir() >>> logger = TensorboardLogger(exp_name=dir) >>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration, diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 48de6a77719..4017898559f 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -600,7 +600,7 @@ class ReplayBufferTrainer(TrainerHookBase): """Replay buffer hook provider. Args: - replay_buffer (ReplayBuffer): replay buffer to be used. + replay_buffer (TensorDictReplayBuffer): replay buffer to be used. batch_size (int): batch size when sampling data from the latest collection or from the replay buffer. memmap (bool, optional): if True, a memmap tensordict is created. @@ -629,7 +629,7 @@ class ReplayBufferTrainer(TrainerHookBase): def __init__( self, - replay_buffer: ReplayBuffer, + replay_buffer: TensorDictReplayBuffer, batch_size: int, memmap: bool = False, device: DEVICE_TYPING = "cpu", @@ -669,12 +669,11 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase: self.replay_buffer.extend(batch) def sample(self, batch: TensorDictBase) -> TensorDictBase: - sample = self.replay_buffer.sample(self.batch_size) + sample, _ = self.replay_buffer.sample(self.batch_size) return sample.to(self.device, non_blocking=True) def update_priority(self, batch: TensorDictBase) -> None: - if isinstance(self.replay_buffer, TensorDictPrioritizedReplayBuffer): - self.replay_buffer.update_priority(batch) + self.replay_buffer.update_tensordict_priority(batch) def state_dict(self) -> Dict[str, Any]: return { diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 67cdaab8e4b..fa0c3a384fb 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -49,12 +49,9 @@ from tensordict.nn import TensorDictModule from torch import nn, optim from torchrl.collectors import MultiaSyncDataCollector -from torchrl.data import ( - CompositeSpec, - TensorDictPrioritizedReplayBuffer, - TensorDictReplayBuffer, -) +from torchrl.data import CompositeSpec, TensorDictReplayBuffer from torchrl.data.postprocs import MultiStep +from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.envs import ( CatTensors, @@ -383,29 +380,23 @@ def make_recorder(actor_model_explore, stats): def make_replay_buffer(make_replay_buffer=3): if prb: - replay_buffer = TensorDictPrioritizedReplayBuffer( - buffer_size, + sampler = PrioritizedSampler( + max_capacity=buffer_size, alpha=0.7, beta=0.5, - pin_memory=False, - prefetch=make_replay_buffer, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), ) else: - replay_buffer = TensorDictReplayBuffer( + sampler = RandomSampler() + replay_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage( buffer_size, - pin_memory=False, - prefetch=make_replay_buffer, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), - ) + scratch_dir=buffer_scratch_dir, + device=device, + ), + sampler=sampler, + pin_memory=False, + prefetch=make_replay_buffer, + ) return replay_buffer @@ -654,7 +645,7 @@ def make_replay_buffer(make_replay_buffer=3): if collected_frames >= init_random_frames: for _ in range(optim_steps_per_batch): # sample from replay buffer - sampled_tensordict = replay_buffer.sample(batch_size).clone() + sampled_tensordict = replay_buffer.sample(batch_size)[0].clone() # compute loss for qnet and backprop with hold_out_net(actor): @@ -696,7 +687,7 @@ def make_replay_buffer(make_replay_buffer=3): # update priority if prb: - replay_buffer.update_priority(sampled_tensordict) + replay_buffer.update_tensordict_priority(sampled_tensordict) rewards.append( (i, tensordict["reward"].mean().item() / norm_factor_training / frame_skip) @@ -898,7 +889,7 @@ def make_replay_buffer(make_replay_buffer=3): if collected_frames >= init_random_frames: for _ in range(optim_steps_per_batch): # sample from replay buffer - sampled_tensordict = replay_buffer.sample(batch_size_traj) + sampled_tensordict, _ = replay_buffer.sample(batch_size_traj) # reset the batch size temporarily, and exclude index whose shape is incompatible with the new size index = sampled_tensordict.get("index") sampled_tensordict.exclude("index", inplace=True) @@ -958,7 +949,7 @@ def make_replay_buffer(make_replay_buffer=3): sampled_tensordict["td_error"] = advantage.detach().pow(2).mean(1) sampled_tensordict["index"] = index if prb: - replay_buffer.update_priority(sampled_tensordict) + replay_buffer.update_tensordict_priority(sampled_tensordict) rewards.append( (i, tensordict["reward"].mean().item() / norm_factor_training / frame_skip) diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 78003cbd9b2..d20976f1ecb 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -312,7 +312,6 @@ def make_model(): # shape. This storage will be instantiated later. replay_buffer = TensorDictReplayBuffer( - buffer_size, storage=LazyMemmapStorage(buffer_size), prefetch=n_optim, ) @@ -385,7 +384,7 @@ def make_model(): if sum(frames) > init_random_frames: for _ in range(n_optim): # sample from the RB and send to device - sampled_data = replay_buffer.sample(batch_size) + sampled_data, _ = replay_buffer.sample(batch_size) sampled_data = sampled_data.to(device, non_blocking=True) # collect data from RB @@ -556,8 +555,7 @@ def make_model(): max_size = frames_per_batch // n_workers replay_buffer = TensorDictReplayBuffer( - -(-buffer_size // max_size), - storage=LazyMemmapStorage(buffer_size), + storage=LazyMemmapStorage(-(-buffer_size // max_size)), prefetch=n_optim, ) @@ -618,7 +616,7 @@ def make_model(): if sum(frames) > init_random_frames: for _ in range(n_optim): - sampled_data = replay_buffer.sample(batch_size // max_size) + sampled_data, _ = replay_buffer.sample(batch_size // max_size) sampled_data = sampled_data.clone().to(device, non_blocking=True) reward = sampled_data["reward"] diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index f7ac96f35e1..9f6c67959ea 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -224,7 +224,7 @@ ############################################################################### -rb = ReplayBuffer(100, collate_fn=lambda x: x) +rb = ReplayBuffer(collate_fn=lambda x: x) rb.add(1) rb.sample(1) @@ -235,7 +235,7 @@ ############################################################################### -rb = PrioritizedReplayBuffer(100, alpha=0.7, beta=1.1, collate_fn=lambda x: x) +rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x) rb.add(1) rb.sample(1) rb.update_priority(1, 0.5) @@ -244,7 +244,7 @@ # Here are examples of using a replaybuffer with tensordicts. collate_fn = torch.stack -rb = ReplayBuffer(100, collate_fn=collate_fn) +rb = ReplayBuffer(collate_fn=collate_fn) rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[])) len(rb) @@ -253,18 +253,16 @@ rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2])) print(len(rb)) print(rb.sample(10)) -print(rb.sample(2).contiguous()) +print(rb.sample(2)[0].contiguous()) ############################################################################### torch.manual_seed(0) from torchrl.data import TensorDictPrioritizedReplayBuffer -rb = TensorDictPrioritizedReplayBuffer( - 100, alpha=0.7, beta=1.1, priority_key="td_error" -) +rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error") rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2])) -tensordict_sample = rb.sample(2).contiguous() +tensordict_sample = rb.sample(2)[0].contiguous() tensordict_sample ############################################################################### @@ -274,9 +272,9 @@ ############################################################################### tensordict_sample["td_error"] = torch.rand(2) -rb.update_priority(tensordict_sample) +rb.update_tensordict_priority(tensordict_sample) -for i, val in enumerate(rb._sum_tree): +for i, val in enumerate(rb._sampler._sum_tree): print(i, val) if i == len(rb): break From 206ac97dae3856009126da798456c8915a38144e Mon Sep 17 00:00:00 2001 From: Kamil Piechowiak Date: Thu, 5 Jan 2023 17:04:02 +0000 Subject: [PATCH 2/5] change returned values of ReplayBuffer sample method --- README.md | 6 +-- .../distributed/distributed_replay_buffer.py | 2 +- examples/dreamer/dreamer.py | 2 +- test/test_rb.py | 30 +++++++------- torchrl/data/replay_buffers/replay_buffers.py | 40 ++++++++++--------- torchrl/trainers/trainers.py | 2 +- tutorials/sphinx-tutorials/coding_ddpg.py | 4 +- tutorials/sphinx-tutorials/coding_dqn.py | 4 +- tutorials/sphinx-tutorials/torchrl_demo.py | 4 +- 9 files changed, 49 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 4436803c910..4ec18685919 100644 --- a/README.md +++ b/README.md @@ -77,9 +77,9 @@ Here's another example of an off-policy training loop in TorchRL (assuming that - replay_buffer.add((obs, next_obs, action, log_prob, reward, done)) + replay_buffer.add(tensordict) for j in range(num_optim_steps): - - obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)[0] + - obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size) - loss = loss_fn(obs, next_obs, action, hidden_state, reward, done) - + tensordict = replay_buffer.sample(batch_size)[0] + + tensordict = replay_buffer.sample(batch_size) + loss = loss_fn(tensordict) loss.backward() optim.step() @@ -334,7 +334,7 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) ```python from torchrl.objectives import DQNLoss loss_module = DQNLoss(value_network=value_network, gamma=0.99) - tensordict = replay_buffer.sample(batch_size)[0] + tensordict = replay_buffer.sample(batch_size) loss = loss_module(tensordict) ``` diff --git a/examples/distributed/distributed_replay_buffer.py b/examples/distributed/distributed_replay_buffer.py index 18b89921cb1..c228e416a7f 100644 --- a/examples/distributed/distributed_replay_buffer.py +++ b/examples/distributed/distributed_replay_buffer.py @@ -86,7 +86,7 @@ def train(self, iterations: int) -> None: for iteration in range(iterations): print(f"[{self.id}] Training Iteration: {iteration}") time.sleep(3) - batch, _ = rpc.rpc_sync( + batch = rpc.rpc_sync( self.replay_buffer.owner(), ReplayBufferNode.sample, args=(self.replay_buffer, 16), diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index be300d0c7d7..83609ff6c96 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -249,7 +249,7 @@ def main(cfg: "DictConfig"): # noqa: F821 for j in range(cfg.optim_steps_per_batch): cmpt += 1 # sample from replay buffer - sampled_tensordict = replay_buffer.sample(cfg.batch_size)[0].to( + sampled_tensordict = replay_buffer.sample(cfg.batch_size).to( device, non_blocking=True ) if reward_normalizer is not None: diff --git a/test/test_rb.py b/test/test_rb.py index a2a442875fb..1d2967e2b6d 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -311,7 +311,7 @@ def test_prototype_prb(priority_key, contiguous, device): batch_size=[3], ).to(device) rb.extend(td1) - s, _ = rb.sample(2) + s = rb.sample(2) assert s.batch_size == torch.Size( [ 2, @@ -330,7 +330,7 @@ def test_prototype_prb(priority_key, contiguous, device): batch_size=[5], ).to(device) rb.extend(td2) - s, _ = rb.sample(5) + s = rb.sample(5) assert s.batch_size == torch.Size([5]) assert (td2[s.get("_idx").squeeze()].get("a") == s.get("a")).all() assert_allclose_td(td2[s.get("_idx").squeeze()].select("a"), s.select("a")) @@ -353,13 +353,13 @@ def test_prototype_prb(priority_key, contiguous, device): idx0 = s.get("_idx")[0] rb.update_tensordict_priority(s) - s, _ = rb.sample(5) + s = rb.sample(5) assert (val == s.get("a")).sum() >= 1 torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) # test updating values of original td td2.set_("a", torch.ones_like(td2.get("a"))) - s, _ = rb.sample(5) + s = rb.sample(5) torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) @@ -381,10 +381,10 @@ def test_replay_buffer_trajectories(stack): priority_key="td_error", ) rb.extend(traj_td) - sampled_td, _ = rb.sample(3) + sampled_td = rb.sample(3) sampled_td.set("td_error", torch.rand(3)) rb.update_tensordict_priority(sampled_td) - sampled_td, _ = rb.sample(3, include_info=True) + sampled_td = rb.sample(3, include_info=True) assert (sampled_td.get("_weight") > 0).all() assert sampled_td.batch_size == torch.Size([3]) @@ -589,7 +589,7 @@ def test_prb(priority_key, contiguous, device): batch_size=[3], ).to(device) rb.extend(td1) - s, _ = rb.sample(2) + s = rb.sample(2) assert s.batch_size == torch.Size( [ 2, @@ -608,7 +608,7 @@ def test_prb(priority_key, contiguous, device): batch_size=[5], ).to(device) rb.extend(td2) - s, _ = rb.sample(5) + s = rb.sample(5) assert s.batch_size == torch.Size([5]) assert (td2[s.get("_idx").squeeze()].get("a") == s.get("a")).all() assert_allclose_td(td2[s.get("_idx").squeeze()].select("a"), s.select("a")) @@ -631,13 +631,13 @@ def test_prb(priority_key, contiguous, device): idx0 = s.get("_idx")[0] rb.update_tensordict_priority(s) - s, _ = rb.sample(5) + s = rb.sample(5) assert (val == s.get("a")).sum() >= 1 torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) # test updating values of original td td2.set_("a", torch.ones_like(td2.get("a"))) - s, _ = rb.sample(5) + s = rb.sample(5) torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) @@ -657,10 +657,10 @@ def test_rb_trajectories(stack): storage=ListStorage(5), ) rb.extend(traj_td) - sampled_td, _ = rb.sample(3) + sampled_td = rb.sample(3) sampled_td.set("td_error", torch.rand(3)) rb.update_tensordict_priority(sampled_td) - sampled_td, _ = rb.sample(3, include_info=True) + sampled_td = rb.sample(3, include_info=True) assert (sampled_td.get("_weight") > 0).all() assert sampled_td.batch_size == torch.Size([3]) @@ -724,7 +724,7 @@ def test_append_transform(): rb.append_transform(flatten) - sampled, _ = rb.sample(1) + sampled = rb.sample(1) assert sampled.get("observation_cat").shape[-1] == 32 @@ -737,7 +737,7 @@ def test_init_transform(): td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, []) rb.add(td) - sampled, _ = rb.sample(1) + sampled = rb.sample(1) assert sampled.get("flattened").shape[-1] == 48 @@ -751,7 +751,7 @@ def test_insert_transform(): rb.insert_transform(0, SqueezeTransform(-1, in_keys=["observation"])) - sampled, _ = rb.sample(1) + sampled = rb.sample(1) assert sampled.get("flattened").shape[-1] == 48 with pytest.raises(ValueError): diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 07ab8e393c0..eaaa305294f 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -219,7 +219,7 @@ def _sample(self, batch_size: int) -> Tuple[Any, dict]: data = self._transform(data) return data, info - def sample(self, batch_size: int) -> Tuple[Any, dict]: + def sample(self, batch_size: int, return_info: bool = False) -> Any: """Samples a batch of data from the replay buffer. Uses Sampler to sample indices, and retrieves them from Storage. @@ -231,20 +231,22 @@ def sample(self, batch_size: int) -> Tuple[Any, dict]: A batch of data selected in the replay buffer. """ if not self._prefetch: - return self._sample(batch_size) - - if len(self._prefetch_queue) == 0: ret = self._sample(batch_size) else: - with self._futures_lock: - ret = self._prefetch_queue.popleft().result() - - with self._futures_lock: - while len(self._prefetch_queue) < self._prefetch_cap: - fut = self._prefetch_executor.submit(self._sample, batch_size) - self._prefetch_queue.append(fut) + if len(self._prefetch_queue) == 0: + ret = self._sample(batch_size) + else: + with self._futures_lock: + ret = self._prefetch_queue.popleft().result() - return ret + with self._futures_lock: + while len(self._prefetch_queue) < self._prefetch_cap: + fut = self._prefetch_executor.submit(self._sample, batch_size) + self._prefetch_queue.append(fut) + + if return_info: + return ret + return ret[0] def mark_update(self, index: Union[int, torch.Tensor]) -> None: self._sampler.mark_update(index) @@ -403,13 +405,15 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: self.update_priority(data.get("index"), priority) def sample( - self, batch_size: int, include_info: bool = False - ) -> Tuple[TensorDictBase, dict]: - data, info = super().sample(batch_size) + self, batch_size: int, include_info: bool = False, return_info: bool = False + ) -> TensorDictBase: + data, info = super().sample(batch_size, return_info=True) if include_info: for k, v in info.items(): data.set(k, torch.tensor(v, device=data.device), inplace=True) - return data, info + if return_info: + return data, info + return data class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): @@ -475,8 +479,8 @@ class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def sample(self, batch_size: int, include_info: bool = False) -> TensorDictBase: - return super().sample(batch_size, include_info) + def sample(self, batch_size: int, include_info: bool = False, return_info: bool = False) -> TensorDictBase: + return super().sample(batch_size, include_info, return_info) def add(self, data: TensorDictBase) -> int: return super().add(data) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index cd56b189742..fe701bc676c 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -669,7 +669,7 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase: self.replay_buffer.extend(batch) def sample(self, batch: TensorDictBase) -> TensorDictBase: - sample, _ = self.replay_buffer.sample(self.batch_size) + sample = self.replay_buffer.sample(self.batch_size) return sample.to(self.device, non_blocking=True) def update_priority(self, batch: TensorDictBase) -> None: diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 3f029d87f97..37f8aaa520b 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -645,7 +645,7 @@ def make_replay_buffer(make_replay_buffer=3): if collected_frames >= init_random_frames: for _ in range(optim_steps_per_batch): # sample from replay buffer - sampled_tensordict = replay_buffer.sample(batch_size)[0].clone() + sampled_tensordict = replay_buffer.sample(batch_size).clone() # compute loss for qnet and backprop with hold_out_net(actor): @@ -889,7 +889,7 @@ def make_replay_buffer(make_replay_buffer=3): if collected_frames >= init_random_frames: for _ in range(optim_steps_per_batch): # sample from replay buffer - sampled_tensordict, _ = replay_buffer.sample(batch_size_traj) + sampled_tensordict = replay_buffer.sample(batch_size_traj) # reset the batch size temporarily, and exclude index whose shape is incompatible with the new size index = sampled_tensordict.get("index") sampled_tensordict.exclude("index", inplace=True) diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 49645f9c286..5998ec21127 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -384,7 +384,7 @@ def make_model(): if sum(frames) > init_random_frames: for _ in range(n_optim): # sample from the RB and send to device - sampled_data, _ = replay_buffer.sample(batch_size) + sampled_data = replay_buffer.sample(batch_size) sampled_data = sampled_data.to(device, non_blocking=True) # collect data from RB @@ -616,7 +616,7 @@ def make_model(): if sum(frames) > init_random_frames: for _ in range(n_optim): - sampled_data, _ = replay_buffer.sample(batch_size // max_size) + sampled_data = replay_buffer.sample(batch_size // max_size) sampled_data = sampled_data.clone().to(device, non_blocking=True) reward = sampled_data["reward"] diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 4ad7799bf4a..f977b97a7ef 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -253,7 +253,7 @@ rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2])) print(len(rb)) print(rb.sample(10)) -print(rb.sample(2)[0].contiguous()) +print(rb.sample(2).contiguous()) ############################################################################### @@ -262,7 +262,7 @@ rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error") rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2])) -tensordict_sample = rb.sample(2)[0].contiguous() +tensordict_sample = rb.sample(2).contiguous() tensordict_sample ############################################################################### From f721ff6848951f6f1e9a3d4951c34ad82a633793 Mon Sep 17 00:00:00 2001 From: Kamil Piechowiak Date: Fri, 6 Jan 2023 13:04:44 +0000 Subject: [PATCH 3/5] Fix sample return values indexing in rpc --- benchmarks/storage/benchmark_sample_latency_over_rpc.py | 6 +++--- test/test_rb_distributed.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/storage/benchmark_sample_latency_over_rpc.py b/benchmarks/storage/benchmark_sample_latency_over_rpc.py index 669a3795831..4dbbafcfcc1 100644 --- a/benchmarks/storage/benchmark_sample_latency_over_rpc.py +++ b/benchmarks/storage/benchmark_sample_latency_over_rpc.py @@ -87,10 +87,10 @@ def train(self, batch_size: int) -> None: if self._ret is None: self._ret = ret else: - self._ret[0].update_(ret[0]) + self._ret.update_(ret) # make sure the content is read - self._ret[0]["observation"] + 1 - self._ret[0]["next_observation"] + 1 + self._ret["observation"] + 1 + self._ret["next_observation"] + 1 return timeit.default_timer() - start_time def _create_replay_buffer(self) -> rpc.RRef: diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py index 3c24a135d2f..30bbb9418d9 100644 --- a/test/test_rb_distributed.py +++ b/test/test_rb_distributed.py @@ -45,7 +45,7 @@ def sample_from_buffer_remotely_returns_correct_tensordict_test(rank, name, worl if name == "TRAINER": buffer = _construct_buffer("BUFFER") _, inserted = _add_random_tensor_dict_to_buffer(buffer) - sampled, _ = _sample_from_buffer(buffer, 1) + sampled = _sample_from_buffer(buffer, 1) assert type(sampled) is type(inserted) is TensorDict assert (sampled == inserted)["a"].item() From de0a1e61056e1e6b3a5759bfa53b6cb31ace9b6b Mon Sep 17 00:00:00 2001 From: Kamil Piechowiak Date: Fri, 6 Jan 2023 14:21:40 +0000 Subject: [PATCH 4/5] add docstrings to sample method in replay buffers --- torchrl/data/replay_buffers/replay_buffers.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index eaaa305294f..9ba9560e63d 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -226,9 +226,12 @@ def sample(self, batch_size: int, return_info: bool = False) -> Any: Args: batch_size (int): size of data to be collected. + return_info (bool): whether to return info. If True, the result + is a tuple (data, info). If False, the result is the data. Returns: A batch of data selected in the replay buffer. + A tuple containing this batch and info if return_info flag is set to True. """ if not self._prefetch: ret = self._sample(batch_size) @@ -407,6 +410,20 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: def sample( self, batch_size: int, include_info: bool = False, return_info: bool = False ) -> TensorDictBase: + """Samples a batch of data from the replay buffer. + + Uses Sampler to sample indices, and retrieves them from Storage. + + Args: + batch_size (int): size of data to be collected. + include_info (bool): whether to add info to the returned tensordict. + return_info (bool): whether to return info. If True, the result + is a tuple (data, info). If False, the result is the data. + + Returns: + A tensordict containing a batch of data selected in the replay buffer. + A tuple containing this tensordict and info if return_info flag is set to True. + """ data, info = super().sample(batch_size, return_info=True) if include_info: for k, v in info.items(): From d55306371d300ad9fa353388456c626b3e269de1 Mon Sep 17 00:00:00 2001 From: Kamil Piechowiak Date: Fri, 6 Jan 2023 14:47:53 +0000 Subject: [PATCH 5/5] fix linter errors in replay_buffer.py and trainers.py --- docs/source/reference/data.rst | 3 --- torchrl/data/replay_buffers/replay_buffers.py | 18 ++++++++++-------- torchrl/trainers/__init__.py | 1 - torchrl/trainers/trainers.py | 6 +----- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index aa1b704ebc0..695ef5398c8 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -30,9 +30,6 @@ We also give users the ability to compose a replay buffer using the following co .. currentmodule:: torchrl.data.replay_buffers - torchrl.data.replay_buffers.rb_prototype.ReplayBuffer - torchrl.data.replay_buffers.rb_prototype.TensorDictReplayBuffer - torchrl.data.replay_buffers.rb_prototype.RemoteTensorDictReplayBuffer torchrl.data.replay_buffers.samplers.Sampler torchrl.data.replay_buffers.samplers.RandomSampler torchrl.data.replay_buffers.samplers.PrioritizedSampler diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 9ba9560e63d..47f878e11af 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -104,7 +104,7 @@ def __init__( collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, - transform: Optional["Transform"] = None, + transform: Optional["Transform"] = None, # noqa-F821 ) -> None: self._storage = storage if storage is not None else ListStorage(max_size=1_000) self._storage.attach(self) @@ -246,7 +246,7 @@ def sample(self, batch_size: int, return_info: bool = False) -> Any: while len(self._prefetch_queue) < self._prefetch_cap: fut = self._prefetch_executor.submit(self._sample, batch_size) self._prefetch_queue.append(fut) - + if return_info: return ret return ret[0] @@ -254,7 +254,7 @@ def sample(self, batch_size: int, return_info: bool = False) -> Any: def mark_update(self, index: Union[int, torch.Tensor]) -> None: self._sampler.mark_update(index) - def append_transform(self, transform: "Transform") -> None: + def append_transform(self, transform: "Transform") -> None: # noqa-F821 """Appends transform at the end. Transforms are applied in order when `sample` is called. @@ -265,7 +265,7 @@ def append_transform(self, transform: "Transform") -> None: transform.eval() self._transform.append(transform) - def insert_transform(self, index: int, transform: "Transform") -> None: + def insert_transform(self, index: int, transform: "Transform") -> None: # noqa-F821 """Inserts transform. Transforms are executed in order when `sample` is called. @@ -316,7 +316,7 @@ def __init__( collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, - transform: Optional["Transform"] = None, + transform: Optional["Transform"] = None, # noqa-F821 ) -> None: if storage is None: storage = ListStorage(max_size=1_000) @@ -408,7 +408,7 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: self.update_priority(data.get("index"), priority) def sample( - self, batch_size: int, include_info: bool = False, return_info: bool = False + self, batch_size: int, include_info: bool = False, return_info: bool = False ) -> TensorDictBase: """Samples a batch of data from the replay buffer. @@ -473,7 +473,7 @@ def __init__( collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, - transform: Optional["Transform"] = None, + transform: Optional["Transform"] = None, # noqa-F821 ) -> None: if storage is None: storage = ListStorage(max_size=1_000) @@ -496,7 +496,9 @@ class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def sample(self, batch_size: int, include_info: bool = False, return_info: bool = False) -> TensorDictBase: + def sample( + self, batch_size: int, include_info: bool = False, return_info: bool = False + ) -> TensorDictBase: return super().sample(batch_size, include_info, return_info) def add(self, data: TensorDictBase) -> int: diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 035cdc60b27..5621a7f3993 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -11,7 +11,6 @@ mask_batch, OptimizerHook, Recorder, - ReplayBuffer, ReplayBufferTrainer, RewardNormalizer, SelectKeys, diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index fe701bc676c..a69edc2aac2 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -21,11 +21,7 @@ from torchrl._utils import _CKPT_BACKEND, KeyDependentDefaultDict from torchrl.collectors.collectors import _DataCollector -from torchrl.data import ( - ReplayBuffer, - TensorDictPrioritizedReplayBuffer, - TensorDictReplayBuffer, -) +from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.utils import set_exploration_mode