From ec1aa7ecf6d2af1c987e24abaefeb905aa9c1794 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 23 Nov 2022 18:32:57 +0100 Subject: [PATCH 1/6] bugfix --- torchrl/data/replay_buffers/storages.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index b37cd2fb589..3c81e5e29b6 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -241,7 +241,10 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: for key, tensor in data.items(): if isinstance(tensor, TensorDictBase): out[key] = ( - tensor.expand(self.max_size).clone().to(self.device).zero_() + tensor.expand( + self.max_size, + *tensor.shape, + ).clone().to(self.device).zero_() ) else: out[key] = torch.empty( From 6c97358ebd9114cef0b11af9de2e0ffc45537154 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 23 Nov 2022 18:48:47 +0100 Subject: [PATCH 2/6] formatting --- torchrl/data/replay_buffers/storages.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 3c81e5e29b6..799fc868493 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -7,11 +7,11 @@ import os from collections import OrderedDict from copy import copy -from typing import Any, Sequence, Union, Dict +from typing import Any, Dict, Sequence, Union import torch from tensordict.memmap import MemmapTensor -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torchrl._utils import _CKPT_BACKEND from torchrl.data.replay_buffers.utils import INT_CLASSES @@ -244,7 +244,10 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: tensor.expand( self.max_size, *tensor.shape, - ).clone().to(self.device).zero_() + ) + .clone() + .to(self.device) + .zero_() ) else: out[key] = torch.empty( From 761acb5e796ec8c8aa78f5c76c29883d6fe71e90 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 23 Nov 2022 20:37:13 +0100 Subject: [PATCH 3/6] reset --- torchrl/data/replay_buffers/storages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 799fc868493..46e4a143221 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -240,7 +240,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: print("The storage is being created: ") for key, tensor in data.items(): if isinstance(tensor, TensorDictBase): - out[key] = ( + out[key] = ( tensor.expand( self.max_size, *tensor.shape, From 1bb40432391b49c0a54b5e1ae2f66353dad5d4eb Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 23 Nov 2022 20:59:11 +0100 Subject: [PATCH 4/6] LazyTensorStorage fix init --- torchrl/data/replay_buffers/storages.py | 31 +++++++------------------ 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 46e4a143221..1c3819a11f5 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -7,11 +7,11 @@ import os from collections import OrderedDict from copy import copy -from typing import Any, Dict, Sequence, Union +from typing import Any, Sequence, Union, Dict import torch from tensordict.memmap import MemmapTensor -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict.tensordict import TensorDictBase, TensorDict from torchrl._utils import _CKPT_BACKEND from torchrl.data.replay_buffers.utils import INT_CLASSES @@ -236,26 +236,13 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: dtype=data.dtype, ) else: - out = TensorDict({}, [self.max_size, *data.shape]) - print("The storage is being created: ") - for key, tensor in data.items(): - if isinstance(tensor, TensorDictBase): - out[key] = ( - tensor.expand( - self.max_size, - *tensor.shape, - ) - .clone() - .to(self.device) - .zero_() - ) - else: - out[key] = torch.empty( - self.max_size, - *tensor.shape, - device=self.device, - dtype=tensor.dtype, - ) + out = ( + data.expand(self.max_size, *data.shape) + .to_tensordict() + .zero_() + .clone() + .to(self.device) + ) self._storage = out self.initialized = True From a75f138124fcb21350b7823b10a645ccade0cffb Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 23 Nov 2022 21:05:56 +0100 Subject: [PATCH 5/6] format --- torchrl/data/replay_buffers/storages.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 1c3819a11f5..670076bc9f7 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -7,11 +7,11 @@ import os from collections import OrderedDict from copy import copy -from typing import Any, Sequence, Union, Dict +from typing import Any, Dict, Sequence, Union import torch from tensordict.memmap import MemmapTensor -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torchrl._utils import _CKPT_BACKEND from torchrl.data.replay_buffers.utils import INT_CLASSES @@ -238,10 +238,10 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: else: out = ( data.expand(self.max_size, *data.shape) - .to_tensordict() - .zero_() - .clone() - .to(self.device) + .to_tensordict() + .zero_() + .clone() + .to(self.device) ) self._storage = out From 74108579a2c7712d8f21d0a2d93d837ae4f906ab Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 24 Nov 2022 08:09:07 +0100 Subject: [PATCH 6/6] add testing --- test/test_rb.py | 62 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 0ac6228fcc2..c95bcfdf936 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -12,12 +12,8 @@ import pytest import torch from _utils_internal import get_available_devices -from tensordict.tensordict import assert_allclose_td, TensorDictBase, TensorDict -from torchrl.data import ( - PrioritizedReplayBuffer, - ReplayBuffer, - TensorDictReplayBuffer, -) +from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase +from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer, TensorDictReplayBuffer from torchrl.data.replay_buffers import ( rb_prototype, samplers, @@ -32,25 +28,25 @@ ) from torchrl.data.replay_buffers.writers import RoundRobinWriter from torchrl.envs.transforms.transforms import ( - CatTensors, - FlattenObservation, - SqueezeTransform, - ToTensorImage, - RewardClipping, BinarizeReward, - Resize, - CenterCrop, - UnsqueezeTransform, - GrayScale, - ObservationNorm, CatFrames, - RewardScaling, - DoubleToFloat, - VecNorm, + CatTensors, + CenterCrop, DiscreteActionProjection, + DoubleToFloat, FiniteTensorDictCheck, + FlattenObservation, + GrayScale, gSDENoise, + ObservationNorm, PinMemoryTransform, + Resize, + RewardClipping, + RewardScaling, + SqueezeTransform, + ToTensorImage, + UnsqueezeTransform, + VecNorm, ) _has_tv = importlib.util.find_spec("torchvision") is not None @@ -198,6 +194,34 @@ def test_index(self, rb_type, sampler, writer, storage, size): assert b +@pytest.mark.parametrize("max_size", [1000]) +@pytest.mark.parametrize("shape", [[3, 4]]) +@pytest.mark.parametrize("storage", [LazyTensorStorage, LazyMemmapStorage]) +class TestStorages: + def _get_nested_td(self, shape): + nested_td = TensorDict( + { + "key1": torch.ones(*shape), + "key2": torch.ones(*shape), + "next": TensorDict( + { + "key1": torch.ones(*shape), + "key2": torch.ones(*shape), + }, + shape, + ), + }, + shape, + ) + return nested_td + + def test_init(self, max_size, shape, storage): + td = self._get_nested_td(shape) + mystorage = storage(max_size=max_size) + mystorage._init(td) + assert mystorage._storage.shape == (max_size, *shape) + + @pytest.mark.parametrize("priority_key", ["pk", "td_error"]) @pytest.mark.parametrize("contiguous", [True, False]) @pytest.mark.parametrize("device", get_available_devices())