diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index fe76e0b5d3c..8a4b3e517dd 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -612,7 +612,7 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An def mark_update(self, index: Union[int, torch.Tensor]) -> None: self._sampler.mark_update(index) - def append_transform(self, transform: "Transform") -> None: # noqa-F821 + def append_transform(self, transform: "Transform") -> ReplayBuffer: # noqa-F821 """Appends transform at the end. Transforms are applied in order when `sample` is called. @@ -626,8 +626,11 @@ def append_transform(self, transform: "Transform") -> None: # noqa-F821 transform = _CallableTransform(transform) transform.eval() self._transform.append(transform) + return self - def insert_transform(self, index: int, transform: "Transform") -> None: # noqa-F821 + def insert_transform( + self, index: int, transform: "Transform" # noqa-F821 + ) -> ReplayBuffer: """Inserts transform. Transforms are executed in order when `sample` is called. @@ -638,6 +641,7 @@ def insert_transform(self, index: int, transform: "Transform") -> None: # noqa- """ transform.eval() self._transform.insert(index, transform) + return self def __iter__(self): if self._sampler.ran_out: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 6aeea0529ce..80998be36b2 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -877,7 +877,7 @@ def empty_cache(self): def append_transform( self, transform: Transform | Callable[[TensorDictBase], TensorDictBase] - ) -> None: + ) -> TransformedEnv: """Appends a transform to the env. :class:`~torchrl.envs.transforms.Transform` or callable are accepted. @@ -899,8 +899,9 @@ def append_transform( self.transform.append(prev_transform) self.transform.append(transform) + return self - def insert_transform(self, index: int, transform: Transform) -> None: + def insert_transform(self, index: int, transform: Transform) -> TransformedEnv: """Inserts a transform to the env at the desired index. :class:`~torchrl.envs.transforms.Transform` or callable are accepted. @@ -920,6 +921,7 @@ def insert_transform(self, index: int, transform: Transform) -> None: self.transform = compose # parent set automatically self.transform.insert(index, transform) + return self def __getattr__(self, attr: str) -> Any: try: