diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 5f92e0da52c..757f12a98c7 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase + from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -212,13 +213,16 @@ def _step(self, tensordict): self.counter += 1 # We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv n = ( - torch.full(tensordict.batch_size, self.counter) + torch.full( + (*tensordict.batch_size, *self.observation_spec["observation"].shape), + self.counter, + ) .to(self.device) .to(torch.get_default_dtype()) ) done = self.counter >= self.max_val done = torch.full( - tensordict.batch_size, done, dtype=torch.bool, device=self.device + (*tensordict.batch_size, 1), done, dtype=torch.bool, device=self.device ) return TensorDict( @@ -235,12 +239,14 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: batch_size = tensordict.batch_size n = ( - torch.full(batch_size, self.counter) + torch.full( + (*batch_size, *self.observation_spec["observation"].shape), self.counter + ) .to(self.device) .to(torch.get_default_dtype()) ) done = self.counter >= self.max_val - done = torch.full(batch_size, done, dtype=torch.bool, device=self.device) + done = torch.full((*batch_size, 1), done, dtype=torch.bool, device=self.device) return TensorDict( {"reward": n, "done": done, "observation": n}, diff --git a/test/test_env.py b/test/test_env.py index a1e6a725e34..a36c88b6282 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -367,6 +367,16 @@ def test_mb_env_batch_lock(self, device, seed=0): class TestParallel: + @pytest.mark.parametrize("num_parallel_env", [1, 10]) + @pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)]) + def test_env_with_batch_size(self, num_parallel_env, env_batch_size): + env = MockBatchedLockedEnv(device="cpu", batch_size=torch.Size(env_batch_size)) + env.set_seed(1) + parallel_env = ParallelEnv(num_parallel_env, lambda: env) + parallel_env.start() + assert parallel_env.batch_size == (num_parallel_env, *env_batch_size) + parallel_env.close() + @pytest.mark.skipif(not _has_dmc, reason="no dm_control") @pytest.mark.parametrize("env_task", ["stand,stand,stand", "stand,walk,stand"]) @pytest.mark.parametrize("share_individual_td", [True, False]) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index f119e9e2151..a74a70c3d86 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -209,9 +209,10 @@ def _get_metadata( ): if self._single_task: # if EnvCreator, the metadata are already there - self.meta_data = get_env_metadata( - create_env_fn[0], create_env_kwargs[0] - ).expand(self.num_workers) + meta_data = get_env_metadata(create_env_fn[0], create_env_kwargs[0]) + self.meta_data = meta_data.expand( + *(self.num_workers, *meta_data.batch_size) + ) else: n_tasks = len(create_env_fn) self.meta_data = []