diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 36dbdc365fe..2ab51221cd2 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -89,7 +89,7 @@ def make_env_transforms( env.append_transform(Resize(cfg.image_size, cfg.image_size)) if cfg.grayscale: env.append_transform(GrayScale()) - env.append_transform(FlattenObservation(0)) + env.append_transform(FlattenObservation(0, -3)) env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"])) if stats is None: obs_stats = { diff --git a/test/test_helpers.py b/test/test_helpers.py index 73d8d082965..63f96d57dd5 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -1027,7 +1027,7 @@ def test_initialize_stats_from_non_obs_transform(device): env.set_seed(1) t_env = TransformedEnv(env) - t_env.transform = FlattenObservation(first_dim=0) + t_env.transform = FlattenObservation(first_dim=0, last_dim=-3) pre_init_state_dict = t_env.transform.state_dict() initialize_observation_norm_transforms(proof_environment=t_env, num_iter=100) post_init_state_dict = t_env.transform.state_dict() diff --git a/test/test_transforms.py b/test/test_transforms.py index 8d3e40f7550..32478315095 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2284,7 +2284,7 @@ def test_batch_unlocked_with_batch_size_transformed(device): id="CenterCrop", marks=pytest.mark.skipif(not _has_tv, reason="needs torchvision dependency"), ), - pytest.param(partial(FlattenObservation, first_dim=-3), id="FlattenObservation"), + pytest.param(partial(FlattenObservation, first_dim=-3, last_dim=-3), id="FlattenObservation"), pytest.param( partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform" ), diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 3dc630b7f31..888b5e41b60 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1027,10 +1027,8 @@ class FlattenObservation(ObservationTransform): """Flatten adjacent dimensions of a tensor. Args: - first_dim (int, optional): first dimension of the dimensions to flatten. - Default is 0. - last_dim (int, optional): last dimension of the dimensions to flatten. - Default is -3. + first_dim (int): first dimension of the dimensions to flatten. + last_dim (int): last dimension of the dimensions to flatten. """ inplace = False @@ -1038,7 +1036,7 @@ class FlattenObservation(ObservationTransform): def __init__( self, first_dim: int, - last_dim: int = -3, + last_dim: int, in_keys: Optional[Sequence[str]] = None, out_keys: Optional[Sequence[str]] = None, ): diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index 4ba2a3213af..ef03e1b849b 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -123,7 +123,7 @@ def make_env_transforms( env.append_transform(Resize(cfg.image_size, cfg.image_size)) if cfg.grayscale: env.append_transform(GrayScale()) - env.append_transform(FlattenObservation(0)) + env.append_transform(FlattenObservation(0, -3)) env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"])) if stats is None: obs_stats = {