From a2b745ad864632136990c05d316f0fc6043c883a Mon Sep 17 00:00:00 2001 From: Alan Schelten Date: Mon, 21 Nov 2022 13:40:45 +0100 Subject: [PATCH] Prototype transformed replay buffers using extra class. --- test/test_rb.py | 98 +++++++++++++++++++++++++++ torchrl/envs/transforms/transforms.py | 58 ++++++++++++---- 2 files changed, 143 insertions(+), 13 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 36915407450..f44937d3bd2 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -4,12 +4,14 @@ # LICENSE file in the root directory of this source tree. import argparse +from functools import partial import numpy as np import pytest import torch from _utils_internal import get_available_devices from tensordict.tensordict import assert_allclose_td, TensorDictBase, TensorDict +from torchrl._utils import prod from torchrl.data import ( PrioritizedReplayBuffer, ReplayBuffer, @@ -28,6 +30,7 @@ ListStorage, ) from torchrl.data.replay_buffers.writers import RoundRobinWriter +from torchrl.envs.transforms.transforms import * collate_fn_dict = { @@ -183,6 +186,101 @@ def test_index(self, rb_type, sampler, writer, storage, size): b = b.all() assert b + def test_transformed_replay_buffer(self, rb_type, sampler, writer, storage, size): + torch.manual_seed(0) + rb = self._get_rb( + rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size + ) + td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, []) + rb.add(td) + sampled, _ = rb.sample(1) + # assert sampled.get("observation").shape[0] == 4 + flatten = FlattenObservation( + -2, -1, in_keys=["observation"], out_keys=["flattened"] + ) + transformed_rb = TransformedReplayBuffer(rb, flatten) + + sampled, _ = transformed_rb.sample(1) + assert sampled.get("flattened").shape[-1] == 48 + + def test_transformed_replay_buffer_append_transform( + self, rb_type, sampler, writer, storage, size + ): + rb = self._get_rb( + rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size + ) + td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, []) + rb.add(td) + transformed_rb = TransformedReplayBuffer( + rb, + FlattenObservation(-2, -1, in_keys=["observation"], out_keys=["flattened"]), + ) + transformed_rb.append_transform( + FlattenObservation(-2, -1, in_keys=["flattened"]) + ) + sampled, _ = transformed_rb.sample(1) + assert sampled.get("flattened").shape[-1] == 192 + + def test_transformed_replay_buffer_insert_transform( + self, rb_type, sampler, writer, storage, size + ): + rb = self._get_rb( + rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size + ) + td = TensorDict({"observation": torch.randn(2, 4, 1, 3, 16)}, []) + rb.add(td) + transformed_rb = TransformedReplayBuffer( + rb, + FlattenObservation( + -2, -1, in_keys=["observation"], out_keys=["transformed"] + ), + ) + transformed_rb.append_transform( + FlattenObservation(-2, -1, in_keys=["transformed"]) + ) + sampled, _ = transformed_rb.sample(1) + assert sampled.get("transformed").shape[-1] == 48 + transformed_rb.insert_transform( + 0, SqueezeTransform(-3, in_keys=["observation"]) + ) + sampled, _ = transformed_rb.sample(1) + assert sampled.get("transformed").shape[-1] == 192 + + +transforms = [ + ToTensorImage, + partial(RewardClipping, clamp_min=0.1, clamp_max=0.9), + BinarizeReward, + partial(Resize, w=2, h=2), + partial(CenterCrop, w=1), + FlattenObservation, + partial(UnsqueezeTransform, unsqueeze_dim=0), # FIXME: Assumes existence of parent, cannot call Compose + partial(SqueezeTransform, squeeze_dim=0), # FIXME: Assumes existence of parent, cannot call Compose + GrayScale, + ObservationNorm, + CatFrames, + partial(RewardScaling, loc=1, scale=2), + FiniteTensorDictCheck, + DoubleToFloat, + CatTensors, + partial(DiscreteActionProjection, max_n=1, m=1), + NoopResetEnv, + TensorDictPrimer, + PinMemoryTransform, + gSDENoise, + VecNorm, + ] +@pytest.mark.parametrize( + "transform", transforms +) +def test_transformed_replay_buffer_smoke(transform): + rb = ReplayBuffer() + td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, []) + rb.add(td) + transformed_rb = TransformedReplayBuffer(rb, transform()) + transformed_rb.append_transform(transform()) + transformed_rb.insert_transform(0, transform()) + @pytest.mark.parametrize("priority_key", ["pk", "td_error"]) @pytest.mark.parametrize("contiguous", [True, False]) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b3c92d15932..5a9570eeebd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -14,6 +14,7 @@ import torch from tensordict.tensordict import TensorDictBase, TensorDict from torch import nn, Tensor +from torchrl.data.replay_buffers.rb_prototype import ReplayBuffer from torchrl.data.tensor_specs import ( BoundedTensorSpec, @@ -213,7 +214,7 @@ def dump(self, **kwargs) -> None: def __repr__(self) -> str: return f"{self.__class__.__name__}(keys={self.in_keys})" - def set_parent(self, parent: Union[Transform, EnvBase]) -> None: + def set_parent(self, parent) -> None: if self.__dict__["_parent"] is not None: raise AttributeError("parent of transform already set") self.__dict__["_parent"] = parent @@ -228,7 +229,7 @@ def parent(self) -> EnvBase: parent = self._parent if parent is None: return parent - if not isinstance(parent, EnvBase): + if not isinstance(parent, EnvBase) and not isinstance(parent, TransformedReplayBuffer): # if it's not an env, it should be a Compose transform if not isinstance(parent, Compose): raise ValueError( @@ -256,6 +257,8 @@ def parent(self) -> EnvBase: transform = copy(orig_trans) transform.reset_parent() out.append_transform(transform) + elif isinstance(parent, TransformedReplayBuffer): + return parent elif isinstance(parent, TransformedEnv): out = TransformedEnv(parent.base_env) else: @@ -267,6 +270,32 @@ def empty_cache(self): self.parent.empty_cache() +class TransformedReplayBuffer: + + def __init__( + self, + rb: ReplayBuffer, + transform: Optional[Transform] + ): + self.rb = rb + if isinstance(transform, Compose): + self.transform = transform + else: + self.transform = Compose(transform) + transform._parent = self + + def sample(self, batch_size: int) -> Tuple[Any, dict]: + data, info = self.rb.sample(batch_size) + self.transform(data) + return data, info + + def append_transform(self, transform: Transform): + self.transform.append(transform) + + def insert_transform(self, index: int, transform: Transform): + self.transform.insert(index, transform) + + class TransformedEnv(EnvBase): """A transformed_in environment. @@ -1029,10 +1058,11 @@ def __init__( first_dim: int = 0, last_dim: int = -3, in_keys: Optional[Sequence[str]] = None, + out_keys: Optional[Sequence[str]] = None, ): if in_keys is None: in_keys = IMAGE_KEYS # default - super().__init__(in_keys=in_keys) + super().__init__(in_keys=in_keys, out_keys=out_keys) self.first_dim = first_dim self.last_dim = last_dim @@ -1042,15 +1072,17 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: def set_parent(self, parent: Union[Transform, EnvBase]) -> None: out = super().set_parent(parent) - observation_spec = self.parent.observation_spec - for key in self.in_keys: - if key in observation_spec: - observation_spec = observation_spec[key] - if self.first_dim >= 0: - self.first_dim = self.first_dim - len(observation_spec.shape) - if self.last_dim >= 0: - self.last_dim = self.last_dim - len(observation_spec.shape) - break + # FIXME: Cannot run Compose(FlattenObservation) + if isinstance(parent, EnvBase): + observation_spec = self.parent.observation_spec + for key in self.in_keys: + if key in observation_spec: + observation_spec = observation_spec[key] + if self.first_dim >= 0: + self.first_dim = self.first_dim - len(observation_spec.shape) + if self.last_dim >= 0: + self.last_dim = self.last_dim - len(observation_spec.shape) + break return out @_apply_to_composite @@ -1111,7 +1143,7 @@ def set_parent(self, parent: Union[Transform, EnvBase]) -> None: if self._unsqueeze_dim_orig < 0: self._unsqueeze_dim = self._unsqueeze_dim_orig else: - parent = self.parent + parent = self.parent # FIXME: Assumes parent exists and has batch_size batch_size = parent.batch_size self._unsqueeze_dim = self._unsqueeze_dim_orig + len(batch_size) return super().set_parent(parent)