Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@
VecNorm,
gSDENoise,
TensorDictPrimer,
SqueezeTransform,
)
from .vip import VIPTransform
from .vip import VIPTransform, VIPRewardTransform
39 changes: 29 additions & 10 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,15 +1049,26 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:

def set_parent(self, parent: Union[Transform, EnvBase]) -> None:
out = super().set_parent(parent)
observation_spec = self.parent.observation_spec
for key in self.in_keys:
if key in observation_spec:
observation_spec = observation_spec[key]
if self.first_dim >= 0:
self.first_dim = self.first_dim - len(observation_spec.shape)
if self.last_dim >= 0:
self.last_dim = self.last_dim - len(observation_spec.shape)
break
try:
observation_spec = self.parent.observation_spec
for key in self.in_keys:
if key in observation_spec:
observation_spec = observation_spec[key]
if self.first_dim >= 0:
self.first_dim = self.first_dim - len(observation_spec.shape)
if self.last_dim >= 0:
self.last_dim = self.last_dim - len(observation_spec.shape)
break
except ValueError:
if self.first_dim >= 0 or self.last_dim >= 0:
raise ValueError(
f"FlattenObservation got first and last dim {self.first_dim} amd {self.last_dim}. "
f"Those values assume that the observation spec is known, which requires the "
f"parent environment to be set. "
f"Consider setting the parent environment beforehand (ie passing the transform "
f"to `TransformedEnv.append_transform()`) or setting strictly negative "
f"flatten dimensions to the transform."
)
return out

@_apply_to_composite
Expand Down Expand Up @@ -1119,7 +1130,15 @@ def set_parent(self, parent: Union[Transform, EnvBase]) -> None:
self._unsqueeze_dim = self._unsqueeze_dim_orig
else:
parent = self.parent
batch_size = parent.batch_size
try:
batch_size = parent.batch_size
except AttributeError:
raise ValueError(
f"Got the unsqueeze dimension {self._unsqueeze_dim_orig} which is greater or equal to zero. "
f"However this requires to know what the parent environment is, but it has not been provided. "
f"Consider providing a negative dimension or setting the transform using the "
f"`TransformedEnv.append_transform()` method."
)
self._unsqueeze_dim = self._unsqueeze_dim_orig + len(batch_size)
return super().set_parent(parent)

Expand Down