diff --git a/test/test_transforms.py b/test/test_transforms.py index 18bed4ebabf..d7739ccbd6a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -284,6 +284,20 @@ def test_added_transforms_are_in_eval_mode(): assert t.transform[1].training +def test_nested_transformed_env(): + base_env = ContinuousActionVecMockEnv() + t1 = RewardScaling(0, 1) + t2 = RewardScaling(0, 2) + env = TransformedEnv(TransformedEnv(base_env, t1), t2) + + assert env.base_env is base_env + assert isinstance(env.transform, Compose) + children = list(env.transform.transforms.children()) + assert len(children) == 2 + assert children[0] == t1 + assert children[1] == t2 + + class TestTransforms: @pytest.mark.skipif(not _has_tv, reason="no torchvision") @pytest.mark.parametrize("interpolation", ["bilinear", "bicubic"]) @@ -1320,7 +1334,13 @@ def test_r3m_instantiation(self, model, tensor_pixels_key, device): transformed_env.close() @pytest.mark.parametrize("stack_images", [True, False]) - @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize( + "parallel", + [ + False, + True, + ], + ) def test_r3m_mult_images(self, model, device, stack_images, parallel): keys_in = ["next_pixels", "next_pixels2"] keys_out = ["next_vec"] if stack_images else ["next_vec", "next_vec2"] diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 350a5bc175b..bb03ac954d4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -244,10 +244,20 @@ def parent(self) -> EnvBase: raise ValueError( "A transform parent must be either another Compose transform or an environment object." ) + compose = parent + # the parent of the compose must be a TransformedEnv + compose_parent = compose.parent + if not isinstance(compose_parent, TransformedEnv): + raise ValueError( + f"Compose parent was of type {type(compose_parent)} but expected TransformedEnv." + ) out = TransformedEnv( - parent.parent.base_env, + compose_parent.base_env, + transform=compose_parent.transform + if compose_parent.transform is not compose + else None, ) - for transform in parent.transforms: + for transform in compose.transforms: if transform is self: break out.append_transform(transform) @@ -291,21 +301,31 @@ def __init__( cache_specs: bool = True, **kwargs, ): + self._transform = None device = kwargs.pop("device", env.device) env = env.to(device) super().__init__(device=None, **kwargs) - self._set_env(env, device) - if transform is None: - transform = Compose() - transform.set_parent(self) + + if isinstance(env, TransformedEnv): + self._set_env(env.base_env, device) + if type(transform) is not Compose: + # we don't use isinstance as some transforms may be subclassed from + # Compose but with other features that we don't want to loose. + transform = [transform] + env_transform = env.transform + if type(env_transform) is not Compose: + env_transform = [env_transform] + transform = Compose(*env_transform, *transform).to(device) else: - transform = transform.to(device) - transform.eval() + self._set_env(env, device) + if transform is None: + transform = Compose() + else: + transform = transform.to(device) self.transform = transform self._last_obs = None self.cache_specs = cache_specs - self._reward_spec = None self._observation_spec = None self.batch_size = self.base_env.batch_size @@ -315,6 +335,21 @@ def _set_env(self, env: EnvBase, device) -> None: # updates need not be inplace, as transforms may modify values out-place self.base_env._inplace_update = False + @property + def transform(self) -> Transform: + return self._transform + + @transform.setter + def transform(self, transform: Transform): + if not isinstance(transform, Transform): + raise ValueError( + f"""Expected a transform of type torchrl.envs.transforms.Transform, +but got an object of type {type(transform)}.""" + ) + transform.set_parent(self) + transform.eval() + self._transform = transform + @property def device(self) -> bool: return self.base_env.device