diff --git a/test/test_libs.py b/test/test_libs.py index 53fc10c41fb..e6891148c2b 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -636,7 +636,7 @@ def test_vmas_spec_rollout( ) for e in [env, wrapped]: e.set_seed(0) - check_env_specs(e, check_dtype=False) + check_env_specs(e) del e @pytest.mark.parametrize("num_envs", [1, 20]) diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index e3880504660..0ec49794756 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -2,6 +2,7 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase + from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform @@ -178,6 +179,7 @@ def _make_specs( value, batch_size=torch.Size((self.num_envs,)) ).shape[1:], device=self.device, + dtype=torch.float32, ) for key, value in self.scenario.info(agent0).items() }, @@ -285,7 +287,7 @@ def read_info(self, infos: Dict[str, torch.Tensor]) -> torch.Tensor: infos = TensorDict( source={ key: _selective_unsqueeze( - value, batch_size=torch.Size((self.num_envs,)) + value.to(torch.float32), batch_size=torch.Size((self.num_envs,)) ) for key, value in infos.items() },