From ea04e4f6e0197bd6ad5ec881eae16880907be806 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 31 Dec 2022 12:54:20 +0100 Subject: [PATCH 1/5] PR fixes --- torchrl/envs/vec_env.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index f119e9e2151..66d007ba1da 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -209,9 +209,10 @@ def _get_metadata( ): if self._single_task: # if EnvCreator, the metadata are already there - self.meta_data = get_env_metadata( + meta_data = get_env_metadata( create_env_fn[0], create_env_kwargs[0] - ).expand(self.num_workers) + ) + self.meta_data = meta_data.expand(*(self.num_workers, *meta_data.batch_size)) else: n_tasks = len(create_env_fn) self.meta_data = [] From 01398d85c4dd5faba022584ffa5d1b8b755b50d7 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 31 Dec 2022 12:56:49 +0100 Subject: [PATCH 2/5] Linting --- torchrl/envs/vec_env.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 66d007ba1da..a74a70c3d86 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -209,10 +209,10 @@ def _get_metadata( ): if self._single_task: # if EnvCreator, the metadata are already there - meta_data = get_env_metadata( - create_env_fn[0], create_env_kwargs[0] + meta_data = get_env_metadata(create_env_fn[0], create_env_kwargs[0]) + self.meta_data = meta_data.expand( + *(self.num_workers, *meta_data.batch_size) ) - self.meta_data = meta_data.expand(*(self.num_workers, *meta_data.batch_size)) else: n_tasks = len(create_env_fn) self.meta_data = [] From 8e4f70b847309eca134074c8f845c4f6a4f4c759 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 31 Dec 2022 15:27:47 +0100 Subject: [PATCH 3/5] Testing --- test/mocking_classes.py | 14 ++++++++++---- test/test_env.py | 16 +++++++++++++--- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 5f92e0da52c..757f12a98c7 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase + from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -212,13 +213,16 @@ def _step(self, tensordict): self.counter += 1 # We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv n = ( - torch.full(tensordict.batch_size, self.counter) + torch.full( + (*tensordict.batch_size, *self.observation_spec["observation"].shape), + self.counter, + ) .to(self.device) .to(torch.get_default_dtype()) ) done = self.counter >= self.max_val done = torch.full( - tensordict.batch_size, done, dtype=torch.bool, device=self.device + (*tensordict.batch_size, 1), done, dtype=torch.bool, device=self.device ) return TensorDict( @@ -235,12 +239,14 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: batch_size = tensordict.batch_size n = ( - torch.full(batch_size, self.counter) + torch.full( + (*batch_size, *self.observation_spec["observation"].shape), self.counter + ) .to(self.device) .to(torch.get_default_dtype()) ) done = self.counter >= self.max_val - done = torch.full(batch_size, done, dtype=torch.bool, device=self.device) + done = torch.full((*batch_size, 1), done, dtype=torch.bool, device=self.device) return TensorDict( {"reward": n, "done": done, "observation": n}, diff --git a/test/test_env.py b/test/test_env.py index a1e6a725e34..c77938c9ce7 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -11,6 +11,10 @@ import pytest import torch import yaml +from packaging import version +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn + from _utils_internal import ( CARTPOLE_VERSIONED, get_available_devices, @@ -27,9 +31,6 @@ MockBatchedUnLockedEnv, MockSerialEnv, ) -from packaging import version -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn from torchrl.data.tensor_specs import ( OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, @@ -367,6 +368,15 @@ def test_mb_env_batch_lock(self, device, seed=0): class TestParallel: + @pytest.mark.parametrize("num_parallel_env", [1, 10]) + @pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)]) + def test_env_with_batch_size(self, num_parallel_env, env_batch_size): + env = MockBatchedLockedEnv(device="cpu", batch_size=torch.Size(env_batch_size)) + env.set_seed(1) + parallel_env = ParallelEnv(num_parallel_env, lambda: env) + parallel_env.start() + assert parallel_env.batch_size == (num_parallel_env, *env_batch_size) + @pytest.mark.skipif(not _has_dmc, reason="no dm_control") @pytest.mark.parametrize("env_task", ["stand,stand,stand", "stand,walk,stand"]) @pytest.mark.parametrize("share_individual_td", [True, False]) From 26b974b0d99c8ed81c3327e0239a8c977a0b8ef6 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 31 Dec 2022 15:35:11 +0100 Subject: [PATCH 4/5] Testing --- test/test_env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_env.py b/test/test_env.py index c77938c9ce7..8c41077d322 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -376,6 +376,7 @@ def test_env_with_batch_size(self, num_parallel_env, env_batch_size): parallel_env = ParallelEnv(num_parallel_env, lambda: env) parallel_env.start() assert parallel_env.batch_size == (num_parallel_env, *env_batch_size) + parallel_env.close() @pytest.mark.skipif(not _has_dmc, reason="no dm_control") @pytest.mark.parametrize("env_task", ["stand,stand,stand", "stand,walk,stand"]) From aaf21703ad11abd526c445e4fe4b2b64502fdeb3 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 31 Dec 2022 18:28:19 +0100 Subject: [PATCH 5/5] Linting --- test/test_env.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 8c41077d322..a36c88b6282 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -11,10 +11,6 @@ import pytest import torch import yaml -from packaging import version -from tensordict.tensordict import assert_allclose_td, TensorDict -from torch import nn - from _utils_internal import ( CARTPOLE_VERSIONED, get_available_devices, @@ -31,6 +27,9 @@ MockBatchedUnLockedEnv, MockSerialEnv, ) +from packaging import version +from tensordict.tensordict import assert_allclose_td, TensorDict +from torch import nn from torchrl.data.tensor_specs import ( OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec,