diff --git a/test/test_transforms.py b/test/test_transforms.py index c63f6f69828..c07ee7cb8b3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -59,6 +59,7 @@ _has_tv, CenterCrop, DiscreteActionProjection, + FrameSkipTransform, gSDENoise, NoopResetEnv, PinMemoryTransform, @@ -375,17 +376,53 @@ def test_transform_parent(): t3 = RewardClipping(0.1, 0.5) env.append_transform(t3) - t1_parent_gt = t1._parent - t2_parent_gt = t2._parent - t3_parent_gt = t3._parent + t1_parent_gt = t1._container + t2_parent_gt = t2._container + t3_parent_gt = t3._container _ = t1.parent _ = t2.parent _ = t3.parent - assert t1_parent_gt == t1._parent - assert t2_parent_gt == t2._parent - assert t3_parent_gt == t3._parent + assert t1_parent_gt == t1._container + assert t2_parent_gt == t2._container + assert t3_parent_gt == t3._container + + +def test_transform_parent_cache(): + """Tests the caching and uncaching of the transformed envs.""" + env = TransformedEnv( + ContinuousActionVecMockEnv(), + FrameSkipTransform(3), + ) + + # print the parent + assert ( + type(env.transform.parent.transform) is Compose + and len(env.transform.parent.transform) == 0 + ) + transform = env.transform + parent1 = env.transform.parent + parent2 = env.transform.parent + assert parent1 is parent2 + + # change the env, re-print the parent + env.insert_transform(0, NoopResetEnv(3)) + parent3 = env.transform[-1].parent + assert parent1 is not parent3 + assert type(parent3.transform[0]) is NoopResetEnv + + # change the env, re-print the parent + env.insert_transform(0, CatTensors(["observation"])) + parent4 = env.transform[-1].parent + assert parent1 is not parent4 + assert parent3 is not parent4 + assert type(parent4.transform[0]) is CatTensors + assert type(parent4.transform[1]) is NoopResetEnv + + # check that we don't keep track of the wrong parent + env.transform = NoopResetEnv(3) + assert transform.parent is None class TestTransforms: @@ -510,6 +547,66 @@ def test_flatten(self, keys, size, nchannels, batch, device): for key in keys: assert observation_spec[key].shape[-3] == expected_size + @pytest.mark.skipif(not _has_gym, reason="gym not installed") + @pytest.mark.parametrize("skip", [-1, 1, 2, 3]) + def test_frame_skip_transform_builtin(self, skip): + torch.manual_seed(0) + if skip < 0: + with pytest.raises( + ValueError, + match="frame_skip should have a value greater or equal to one", + ): + FrameSkipTransform(skip) + return + else: + fs = FrameSkipTransform(skip) + base_env = GymEnv(PENDULUM_VERSIONED, frame_skip=skip) + tensordicts = TensorDict({"action": base_env.action_spec.rand((10,))}, [10]) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), fs) + base_env.set_seed(0) + env.base_env.set_seed(0) + td1 = base_env.reset() + td2 = env.reset() + for key in td1.keys(): + torch.testing.assert_close(td1[key], td2[key]) + for i in range(10): + td1 = base_env.step(tensordicts[i].clone()).flatten_keys() + td2 = env.step(tensordicts[i].clone()).flatten_keys() + for key in td1.keys(): + torch.testing.assert_close(td1[key], td2[key]) + + @pytest.mark.skipif(not _has_gym, reason="gym not installed") + @pytest.mark.parametrize("skip", [-1, 1, 2, 3]) + def test_frame_skip_transform_unroll(self, skip): + torch.manual_seed(0) + if skip < 0: + with pytest.raises( + ValueError, + match="frame_skip should have a value greater or equal to one", + ): + FrameSkipTransform(skip) + return + else: + fs = FrameSkipTransform(skip) + base_env = GymEnv(PENDULUM_VERSIONED) + tensordicts = TensorDict({"action": base_env.action_spec.rand((10,))}, [10]) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), fs) + base_env.set_seed(0) + env.base_env.set_seed(0) + td1 = base_env.reset() + td2 = env.reset() + for key in td1.keys(): + torch.testing.assert_close(td1[key], td2[key]) + for i in range(10): + r = 0.0 + for _ in range(skip): + td1 = base_env.step(tensordicts[i].clone()).flatten_keys() + r = td1.get("reward") + r + td1.set("reward", r) + td2 = env.step(tensordicts[i].clone()).flatten_keys() + for key in td1.keys(): + torch.testing.assert_close(td1[key], td2[key]) + @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b93ccdb48f3..9565959ec4a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -143,7 +143,7 @@ def build_tensordict( """ # build a tensordict from specs - td = TensorDict({}, batch_size=torch.Size([])) + td = TensorDict({}, batch_size=torch.Size([]), _run_checks=False) action_placeholder = torch.zeros( self["action_spec"].shape, dtype=self["action_spec"].dtype ) @@ -518,7 +518,9 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa """ if tensordict is None: - tensordict = TensorDict({}, device=self.device, batch_size=self.batch_size) + tensordict = TensorDict( + {}, device=self.device, batch_size=self.batch_size, _run_checks=False + ) action = self.action_spec.rand(self.batch_size) tensordict.set("action", action) return self.step(tensordict) @@ -702,6 +704,7 @@ def fake_tensordict(self) -> TensorDictBase: }, batch_size=self.batch_size, device=self.device, + _run_checks=False, ) return fake_td diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 098c7545812..6c682c6210b 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -13,6 +13,7 @@ DoubleToFloat, FiniteTensorDictCheck, FlattenObservation, + FrameSkipTransform, GrayScale, gSDENoise, NoopResetEnv, diff --git a/torchrl/envs/transforms/functional.py b/torchrl/envs/transforms/functional.py index e2135c3e814..6ef23b11fd5 100644 --- a/torchrl/envs/transforms/functional.py +++ b/torchrl/envs/transforms/functional.py @@ -22,8 +22,8 @@ def _assert_channels(img: Tensor, permitted: List[int]) -> None: c = _get_image_num_channels(img) if c not in permitted: raise TypeError( - "Input image tensor permitted channel values are {}, but found" - "{}".format(permitted, c) + f"Input image tensor permitted channel values are {permitted}, but found " + f"{c} (full shape: {img.shape})" ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 8ef11c6f691..647760ba4cd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -107,6 +107,7 @@ def __init__( if out_keys_inv is None: out_keys_inv = copy(self.in_keys_inv) self.out_keys_inv = out_keys_inv + self.__dict__["_container"] = None self.__dict__["_parent"] = None def reset(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -213,15 +214,16 @@ def dump(self, **kwargs) -> None: def __repr__(self) -> str: return f"{self.__class__.__name__}(keys={self.in_keys})" - def set_parent(self, parent: Union[Transform, EnvBase]) -> None: - if self.__dict__["_parent"] is not None: + def set_container(self, container: Union[Transform, EnvBase]) -> None: + if self.__dict__["_container"] is not None: raise AttributeError( - "parent of transform already set. " + f"parent of transform {type(self)} already set. " "Call `transform.clone()` to get a similar transform with no parent set." ) - self.__dict__["_parent"] = parent + self.__dict__["_container"] = container def reset_parent(self) -> None: + self.__dict__["_container"] = None self.__dict__["_parent"] = None def clone(self): @@ -231,45 +233,48 @@ def clone(self): @property def parent(self) -> Optional[EnvBase]: - if not hasattr(self, "_parent"): - raise AttributeError("transform parent uninitialized") - parent = self._parent - if parent is None: - return parent - out = None - if not isinstance(parent, EnvBase): - # if it's not an env, it should be a Compose transform - if not isinstance(parent, Compose): - raise ValueError( - "A transform parent must be either another Compose transform or an environment object." - ) - compose = parent - if compose.parent: - # the parent of the compose must be a TransformedEnv - compose_parent = compose.parent - if compose_parent.transform is not compose: - comp_parent_trans = compose_parent.transform.clone() - else: - comp_parent_trans = None - out = TransformedEnv( - compose_parent.base_env, - transform=comp_parent_trans, - ) - for orig_trans in compose.transforms: - if orig_trans is self: - break - transform = copy(orig_trans) - transform.reset_parent() - out.append_transform(transform) - elif isinstance(parent, TransformedEnv): - out = TransformedEnv(parent.base_env) - else: - raise ValueError(f"parent is of type {type(parent)}") - return out + if self.__dict__.get("_parent", None) is None: + if "_container" not in self.__dict__: + raise AttributeError("transform parent uninitialized") + container = self.__dict__["_container"] + if container is None: + return container + out = None + if not isinstance(container, EnvBase): + # if it's not an env, it should be a Compose transform + if not isinstance(container, Compose): + raise ValueError( + "A transform parent must be either another Compose transform or an environment object." + ) + compose = container + if compose.__dict__["_container"]: + # the parent of the compose must be a TransformedEnv + compose_parent = TransformedEnv( + compose.__dict__["_container"].base_env + ) + if compose_parent.transform is not compose: + comp_parent_trans = compose_parent.transform.clone() + else: + comp_parent_trans = None + out = TransformedEnv( + compose_parent.base_env, + transform=comp_parent_trans, + ) + for orig_trans in compose.transforms: + if orig_trans is self: + break + transform = orig_trans.clone() + transform.reset_parent() + out.append_transform(transform) + elif isinstance(container, TransformedEnv): + out = TransformedEnv(container.base_env) + else: + raise ValueError(f"container is of type {type(container)}") + self.__dict__["_parent"] = out + return self.__dict__["_parent"] def empty_cache(self): - if self.parent is not None: - self.parent.empty_cache() + self.__dict__["_parent"] = None class TransformedEnv(EnvBase): @@ -354,7 +359,11 @@ def transform(self, transform: Transform): f"""Expected a transform of type torchrl.envs.transforms.Transform, but got an object of type {type(transform)}.""" ) - transform.set_parent(self) + prev_transform = self.transform + if prev_transform is not None: + prev_transform.empty_cache() + prev_transform.__dict__["_container"] = None + transform.set_container(self) transform.eval() self._transform = transform @@ -433,13 +442,9 @@ def reward_spec(self) -> TensorSpec: return reward_spec def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - # selected_keys = [key for key in tensordict.keys() if "action" in key] - # tensordict_in = tensordict.select(*selected_keys).clone() - tensordict = tensordict.clone() + tensordict = tensordict.clone(False) tensordict_in = self.transform.inv(tensordict) tensordict_out = self.base_env._step(tensordict_in) - # tensordict should already have been processed by the transforms - # for logging purposes tensordict_out = tensordict_out.update( tensordict.exclude(*tensordict_out.keys()) ) @@ -519,8 +524,8 @@ def insert_transform(self, index: int, transform: Transform) -> None: ) transform = transform.to(self.device) if not isinstance(self.transform, Compose): - self.transform = Compose(self.transform) - self.transform.set_parent(self) + compose = Compose(self.transform.clone()) + self.transform = compose # parent set automatically self.transform.insert(index, transform) self._erase_metadata() @@ -621,7 +626,7 @@ def __init__(self, *transforms: Transform): super().__init__(in_keys=[]) self.transforms = nn.ModuleList(transforms) for t in self.transforms: - t.set_parent(self) + t.set_container(self) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: for t in self.transforms: @@ -658,7 +663,7 @@ def __getitem__(self, item: Union[int, slice, List]) -> Union: transform = transform[item] if not isinstance(transform, Transform): out = Compose(*self.transforms[item]) - out.set_parent(self.parent) + out.set_container(self.parent) return out return transform @@ -684,7 +689,7 @@ def append(self, transform): ) transform.eval() self.transforms.append(transform) - transform.set_parent(self) + transform.set_container(self) def insert(self, index: int, transform: Transform) -> None: if not isinstance(transform, Transform): @@ -698,12 +703,13 @@ def insert(self, index: int, transform: Transform) -> None: f"Index expected to be between [-{len(self.transforms)}, {len(self.transforms)}] got index={index}" ) + # empty cache of all transforms to reset parents and specs self.empty_cache() if index < 0: index = index + len(self.transforms) transform.eval() self.transforms.insert(index, transform) - transform.set_parent(self) + transform.set_container(self) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Compose: for t in self.transforms: @@ -722,6 +728,11 @@ def __repr__(self) -> str: ) return f"{self.__class__.__name__}(\n{indent(layers_str, 4 * ' ')})" + def empty_cache(self): + for t in self.transforms: + t.empty_cache() + super().empty_cache() + class ToTensorImage(ObservationTransform): """Transforms a numpy-like image (3 x W x H) to a pytorch image (3 x W x H). @@ -1035,8 +1046,8 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation = torch.flatten(observation, self.first_dim, self.last_dim) return observation - def set_parent(self, parent: Union[Transform, EnvBase]) -> None: - out = super().set_parent(parent) + def set_container(self, container: Union[Transform, EnvBase]) -> None: + out = super().set_container(container) try: observation_spec = self.parent.observation_spec for key in self.in_keys: @@ -1113,13 +1124,13 @@ def __init__( ) self._unsqueeze_dim_orig = unsqueeze_dim - def set_parent(self, parent: Union[Transform, EnvBase]) -> None: + def set_container(self, container: Union[Transform, EnvBase]) -> None: if self._unsqueeze_dim_orig < 0: self._unsqueeze_dim = self._unsqueeze_dim_orig else: - parent = self.parent + container = self.parent try: - batch_size = parent.batch_size + batch_size = container.batch_size except AttributeError: raise ValueError( f"Got the unsqueeze dimension {self._unsqueeze_dim_orig} which is greater or equal to zero. " @@ -1128,7 +1139,7 @@ def set_parent(self, parent: Union[Transform, EnvBase]) -> None: f"`TransformedEnv.append_transform()` method." ) self._unsqueeze_dim = self._unsqueeze_dim_orig + len(batch_size) - return super().set_parent(parent) + return super().set_container(container) @property def unsqueeze_dim(self): @@ -1916,6 +1927,36 @@ def __repr__(self) -> str: ) +class FrameSkipTransform(Transform): + """A frame-skip transform. + + This transform applies the same action repeatedly in the parent environment, + which improves stability on certain training algorithms. + + Args: + frame_skip (int, optional): a positive integer representing the number + of frames during which the same action must be applied. + + """ + + inplace = False + + def __init__(self, frame_skip: int = 1): + super().__init__([]) + if frame_skip < 1: + raise ValueError("frame_skip should have a value greater or equal to one.") + self.frame_skip = frame_skip + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + parent = self.parent + reward = tensordict.get("reward") + for _ in range(self.frame_skip - 1): + tensordict = parent._step(tensordict) + reward = reward + tensordict.get("reward") + tensordict.set("reward", reward) + return tensordict + + class NoopResetEnv(Transform): """Runs a series of random actions when an environment is reset. @@ -2080,8 +2121,8 @@ def transform_observation_spec( observation_spec[key] = spec.to(self.device) return observation_spec - def set_parent(self, parent: Union[Transform, EnvBase]) -> None: - super().set_parent(parent) + def set_container(self, container: Union[Transform, EnvBase]) -> None: + super().set_container(container) @property def _batch_size(self):