diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 9a702c8d290..cdb4d35f4ce 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -28,5 +28,6 @@ VecNorm, gSDENoise, TensorDictPrimer, + SqueezeTransform, ) -from .vip import VIPTransform +from .vip import VIPTransform, VIPRewardTransform diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 91ea7049a1f..6e0f612f511 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1049,15 +1049,26 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: def set_parent(self, parent: Union[Transform, EnvBase]) -> None: out = super().set_parent(parent) - observation_spec = self.parent.observation_spec - for key in self.in_keys: - if key in observation_spec: - observation_spec = observation_spec[key] - if self.first_dim >= 0: - self.first_dim = self.first_dim - len(observation_spec.shape) - if self.last_dim >= 0: - self.last_dim = self.last_dim - len(observation_spec.shape) - break + try: + observation_spec = self.parent.observation_spec + for key in self.in_keys: + if key in observation_spec: + observation_spec = observation_spec[key] + if self.first_dim >= 0: + self.first_dim = self.first_dim - len(observation_spec.shape) + if self.last_dim >= 0: + self.last_dim = self.last_dim - len(observation_spec.shape) + break + except ValueError: + if self.first_dim >= 0 or self.last_dim >= 0: + raise ValueError( + f"FlattenObservation got first and last dim {self.first_dim} amd {self.last_dim}. " + f"Those values assume that the observation spec is known, which requires the " + f"parent environment to be set. " + f"Consider setting the parent environment beforehand (ie passing the transform " + f"to `TransformedEnv.append_transform()`) or setting strictly negative " + f"flatten dimensions to the transform." + ) return out @_apply_to_composite @@ -1119,7 +1130,15 @@ def set_parent(self, parent: Union[Transform, EnvBase]) -> None: self._unsqueeze_dim = self._unsqueeze_dim_orig else: parent = self.parent - batch_size = parent.batch_size + try: + batch_size = parent.batch_size + except AttributeError: + raise ValueError( + f"Got the unsqueeze dimension {self._unsqueeze_dim_orig} which is greater or equal to zero. " + f"However this requires to know what the parent environment is, but it has not been provided. " + f"Consider providing a negative dimension or setting the transform using the " + f"`TransformedEnv.append_transform()` method." + ) self._unsqueeze_dim = self._unsqueeze_dim_orig + len(batch_size) return super().set_parent(parent)