From 658dfa6951ce593ef3bcbc78d219aaa0923e711e Mon Sep 17 00:00:00 2001 From: Alexander Lobov Date: Mon, 31 Oct 2022 17:06:26 +0100 Subject: [PATCH 1/2] #602 Unfold transforms for folded TransformedEnv --- test/test_transforms.py | 14 ++++++++++++++ torchrl/envs/transforms/transforms.py | 22 ++++++++++++++-------- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 18bed4ebabf..305d94a0884 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -284,6 +284,20 @@ def test_added_transforms_are_in_eval_mode(): assert t.transform[1].training +def test_nested_transformed_env(): + base_env = ContinuousActionVecMockEnv() + t1 = RewardScaling(0, 1) + t2 = RewardScaling(0, 2) + env = TransformedEnv(TransformedEnv(base_env, t1), t2) + + assert(env.base_env is base_env) + assert(isinstance(env.transform, Compose)) + children = list(env.transform.transforms.children()) + assert len(children) == 2 + assert(children[0] == t1) + assert(children[1] == t2) + + class TestTransforms: @pytest.mark.skipif(not _has_tv, reason="no torchvision") @pytest.mark.parametrize("interpolation", ["bilinear", "bicubic"]) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 350a5bc175b..8bd51cb2eb8 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -294,18 +294,24 @@ def __init__( device = kwargs.pop("device", env.device) env = env.to(device) super().__init__(device=None, **kwargs) - self._set_env(env, device) - if transform is None: - transform = Compose() - transform.set_parent(self) + + if isinstance(env, TransformedEnv): + self._set_env(env.base_env, device) + self.transform = env.transform + self.transform.set_parent(self) + self.append_transform(transform) else: - transform = transform.to(device) - transform.eval() - self.transform = transform + self._set_env(env, device) + if transform is None: + transform = Compose() + transform.set_parent(self) + else: + transform = transform.to(device) + transform.eval() + self.transform = transform self._last_obs = None self.cache_specs = cache_specs - self._reward_spec = None self._observation_spec = None self.batch_size = self.base_env.batch_size From de7f2b1e0bc1b9adfee71713c76f72dd58726823 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 1 Nov 2022 14:13:49 +0000 Subject: [PATCH 2/2] solve bugs --- test/test_transforms.py | 16 +++++++--- torchrl/envs/transforms/transforms.py | 45 ++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 305d94a0884..d7739ccbd6a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -290,12 +290,12 @@ def test_nested_transformed_env(): t2 = RewardScaling(0, 2) env = TransformedEnv(TransformedEnv(base_env, t1), t2) - assert(env.base_env is base_env) - assert(isinstance(env.transform, Compose)) + assert env.base_env is base_env + assert isinstance(env.transform, Compose) children = list(env.transform.transforms.children()) assert len(children) == 2 - assert(children[0] == t1) - assert(children[1] == t2) + assert children[0] == t1 + assert children[1] == t2 class TestTransforms: @@ -1334,7 +1334,13 @@ def test_r3m_instantiation(self, model, tensor_pixels_key, device): transformed_env.close() @pytest.mark.parametrize("stack_images", [True, False]) - @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize( + "parallel", + [ + False, + True, + ], + ) def test_r3m_mult_images(self, model, device, stack_images, parallel): keys_in = ["next_pixels", "next_pixels2"] keys_out = ["next_vec"] if stack_images else ["next_vec", "next_vec2"] diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 8bd51cb2eb8..bb03ac954d4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -244,10 +244,20 @@ def parent(self) -> EnvBase: raise ValueError( "A transform parent must be either another Compose transform or an environment object." ) + compose = parent + # the parent of the compose must be a TransformedEnv + compose_parent = compose.parent + if not isinstance(compose_parent, TransformedEnv): + raise ValueError( + f"Compose parent was of type {type(compose_parent)} but expected TransformedEnv." + ) out = TransformedEnv( - parent.parent.base_env, + compose_parent.base_env, + transform=compose_parent.transform + if compose_parent.transform is not compose + else None, ) - for transform in parent.transforms: + for transform in compose.transforms: if transform is self: break out.append_transform(transform) @@ -291,24 +301,28 @@ def __init__( cache_specs: bool = True, **kwargs, ): + self._transform = None device = kwargs.pop("device", env.device) env = env.to(device) super().__init__(device=None, **kwargs) if isinstance(env, TransformedEnv): self._set_env(env.base_env, device) - self.transform = env.transform - self.transform.set_parent(self) - self.append_transform(transform) + if type(transform) is not Compose: + # we don't use isinstance as some transforms may be subclassed from + # Compose but with other features that we don't want to loose. + transform = [transform] + env_transform = env.transform + if type(env_transform) is not Compose: + env_transform = [env_transform] + transform = Compose(*env_transform, *transform).to(device) else: self._set_env(env, device) if transform is None: transform = Compose() - transform.set_parent(self) else: transform = transform.to(device) - transform.eval() - self.transform = transform + self.transform = transform self._last_obs = None self.cache_specs = cache_specs @@ -321,6 +335,21 @@ def _set_env(self, env: EnvBase, device) -> None: # updates need not be inplace, as transforms may modify values out-place self.base_env._inplace_update = False + @property + def transform(self) -> Transform: + return self._transform + + @transform.setter + def transform(self, transform: Transform): + if not isinstance(transform, Transform): + raise ValueError( + f"""Expected a transform of type torchrl.envs.transforms.Transform, +but got an object of type {type(transform)}.""" + ) + transform.set_parent(self) + transform.eval() + self._transform = transform + @property def device(self) -> bool: return self.base_env.device