From 563fbef6c398e582ec13b8290377091bafeff197 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 8 Dec 2022 08:29:09 +0000 Subject: [PATCH 1/4] init --- docs/source/reference/envs.rst | 2 ++ test/_utils_internal.py | 40 -------------------------- test/test_libs.py | 10 +++---- torchrl/envs/utils.py | 51 ++++++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 45 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 748f65d6b68..2d3bc55d754 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -210,6 +210,7 @@ in the environment. The keys to be included in this inverse transform are passed TensorDictPrimer R3MTransform VIPTransform + VIPRewardTransform Helpers ------- @@ -223,6 +224,7 @@ Helpers get_available_libraries set_exploration_mode exploration_mode + test_fake_tensordict Domain-specific --------------- diff --git a/test/_utils_internal.py b/test/_utils_internal.py index a11bbf4ab1d..a9a3db889f7 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -12,7 +12,6 @@ import pytest import torch.cuda from torchrl._utils import implement_for, seed_generator -from torchrl.envs import EnvBase from torchrl.envs.libs.gym import _has_gym # Specified for test_utils.py @@ -70,45 +69,6 @@ def generate_seeds(seed, repeat): return seeds -def _test_fake_tensordict(env: EnvBase): - fake_tensordict = env.fake_tensordict().flatten_keys(".") - real_tensordict = env.rollout(3).flatten_keys(".") - - keys1 = set(fake_tensordict.keys()) - keys2 = set(real_tensordict.keys()) - assert keys1 == keys2 - fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) - fake_tensordict = fake_tensordict.expand(*real_tensordict.shape) - fake_tensordict = fake_tensordict.to_tensordict() - assert ( - fake_tensordict.apply(lambda x: torch.zeros_like(x)) - == real_tensordict.apply(lambda x: torch.zeros_like(x)) - ).all() - for key in keys2: - assert fake_tensordict[key].shape == real_tensordict[key].shape - - # test dtypes - for key, value in real_tensordict.unflatten_keys(".").items(): - _check_dtype(key, value, env.observation_spec, env.input_spec) - - -def _check_dtype(key, value, obs_spec, input_spec): - if key in {"reward", "done"}: - return - elif key == "next": - for _key, _value in value.items(): - _check_dtype(_key, _value, obs_spec, input_spec) - return - elif key in input_spec.keys(yield_nesting_keys=True): - assert input_spec[key].is_in(value), (input_spec[key], value) - return - elif key in obs_spec.keys(yield_nesting_keys=True): - assert obs_spec[key].is_in(value), (input_spec[key], value) - return - else: - raise KeyError(key) - - # Decorator to retry upon certain Exceptions. def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False): def deco_retry(f): diff --git a/test/test_libs.py b/test/test_libs.py index 4e4b4b811b2..3ca2a58596e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -9,7 +9,6 @@ import pytest import torch from _utils_internal import ( - _test_fake_tensordict, get_available_devices, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, @@ -25,6 +24,7 @@ from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv +from torchrl.envs.utils import test_fake_tensordict if _has_gym: import gym @@ -136,7 +136,7 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only): from_pixels=from_pixels, pixels_only=pixels_only, ) - _test_fake_tensordict(env) + test_fake_tensordict(env) @implement_for("gym", None, "0.26") @@ -243,7 +243,7 @@ def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only): from_pixels=from_pixels, pixels_only=pixels_only, ) - _test_fake_tensordict(env) + test_fake_tensordict(env) @pytest.mark.skipif( @@ -337,7 +337,7 @@ class TestHabitat: def test_habitat(self, envname): env = HabitatEnv(envname) rollout = env.rollout(3) - _test_fake_tensordict(env) + test_fake_tensordict(env) @pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed") @@ -375,7 +375,7 @@ def test_jumanji_batch_size(self, envname, batch_size): def test_jumanji_spec_rollout(self, envname, batch_size): env = JumanjiEnv(envname, batch_size=batch_size) env.set_seed(0) - _test_fake_tensordict(env) + test_fake_tensordict(env) @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) def test_jumanji_consistency(self, envname, batch_size): diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ef86b45b0c4..220c78abfd9 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -4,11 +4,13 @@ # LICENSE file in the root directory of this source tree. import pkg_resources +import torch from tensordict.nn.probabilistic import ( # noqa interaction_mode as exploration_mode, set_interaction_mode as set_exploration_mode, ) from tensordict.tensordict import TensorDictBase +from torchrl.envs import EnvBase AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set} @@ -151,3 +153,52 @@ def _check_dmlab(): # "screeps": None, # https://github.com/screeps/screeps # "ml-agents": None, } + + +def test_fake_tensordict(env: EnvBase): + """Tests an environment specs against the results of short rollout. + + This test function should be used as a sanity check for an env wrapped with + torchrl's EnvBase subclasses: any discrepency between the expected data and + the data collected should raise an assertion error. + + A broken environment spec will likely make it impossible to use parallel + environments. + + """ + fake_tensordict = env.fake_tensordict().flatten_keys(".") + real_tensordict = env.rollout(3).flatten_keys(".") + + keys1 = set(fake_tensordict.keys()) + keys2 = set(real_tensordict.keys()) + assert keys1 == keys2 + fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) + fake_tensordict = fake_tensordict.expand(*real_tensordict.shape) + fake_tensordict = fake_tensordict.to_tensordict() + assert ( + fake_tensordict.apply(lambda x: torch.zeros_like(x)) + == real_tensordict.apply(lambda x: torch.zeros_like(x)) + ).all() + for key in keys2: + assert fake_tensordict[key].shape == real_tensordict[key].shape + + # test dtypes + for key, value in real_tensordict.unflatten_keys(".").items(): + _check_dtype(key, value, env.observation_spec, env.input_spec) + + +def _check_dtype(key, value, obs_spec, input_spec): + if key in {"reward", "done"}: + return + elif key == "next": + for _key, _value in value.items(): + _check_dtype(_key, _value, obs_spec, input_spec) + return + elif key in input_spec.keys(yield_nesting_keys=True): + assert input_spec[key].is_in(value), (input_spec[key], value) + return + elif key in obs_spec.keys(yield_nesting_keys=True): + assert obs_spec[key].is_in(value), (input_spec[key], value) + return + else: + raise KeyError(key) From caa2782817a3f710d5b3c70225486105e7aa4b02 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 8 Dec 2022 08:39:15 +0000 Subject: [PATCH 2/4] amend --- docs/source/reference/envs.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 2d3bc55d754..a50b3e32fb8 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -111,12 +111,20 @@ It is also possible to reset some but not all of the environments: is_shared=True) -A note on performance: launching a :obj:`ParallelEnv` can take quite some time +*A note on performance*: launching a :obj:`ParallelEnv` can take quite some time as it requires to launch as many python instances as there are processes. Due to the time that it takes to run :obj:`import torch` (and other imports), starting the parallel env can be a bottleneck. This is why, for instance, TorchRL tests are so slow. Once the environment is launched, a great speedup should be observed. +Another thing to take in consideration is that :obj:`ParallelEnv`s (as well as data collectors) +will create data buffers based on the environment specs to pass data from one process +to another. This means that a misspecified spec (input, observation or reward) will +cause a breakage at runtime as the data can't be written on the preallocated buffer. +In general, an environment should be tested using the :obj:`test_fake_tensordict` +test function before being used in a :obj:`ParallelEnv`. This function will raise +an assertion error whenever the preallocated buffer and the collected data mismatch. + We also offer the :obj:`SerialEnv` class that enjoys the exact same API but is executed serially. This is mostly useful for testing purposes, when one wants to assess the behaviour of a :obj:`ParallelEnv` without launching the subprocesses. From fd224efc3f9df035d4075aac939fa9c1f825f5cb Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 8 Dec 2022 09:46:31 +0000 Subject: [PATCH 3/4] amend --- torchrl/envs/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 220c78abfd9..e70b88605ed 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -10,7 +10,6 @@ set_interaction_mode as set_exploration_mode, ) from tensordict.tensordict import TensorDictBase -from torchrl.envs import EnvBase AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set} @@ -155,7 +154,7 @@ def _check_dmlab(): } -def test_fake_tensordict(env: EnvBase): +def test_fake_tensordict(env): """Tests an environment specs against the results of short rollout. This test function should be used as a sanity check for an env wrapped with From 70f4c4dd4f25e90ac3f17773c2724c9b2c24641b Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 8 Dec 2022 10:17:48 +0000 Subject: [PATCH 4/4] renaming to make pytest happy --- docs/source/reference/envs.rst | 4 ++-- test/test_libs.py | 10 +++++----- torchrl/envs/utils.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index a50b3e32fb8..1b8bfdeb240 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -121,7 +121,7 @@ Another thing to take in consideration is that :obj:`ParallelEnv`s (as well as d will create data buffers based on the environment specs to pass data from one process to another. This means that a misspecified spec (input, observation or reward) will cause a breakage at runtime as the data can't be written on the preallocated buffer. -In general, an environment should be tested using the :obj:`test_fake_tensordict` +In general, an environment should be tested using the :obj:`check_env_specs` test function before being used in a :obj:`ParallelEnv`. This function will raise an assertion error whenever the preallocated buffer and the collected data mismatch. @@ -232,7 +232,7 @@ Helpers get_available_libraries set_exploration_mode exploration_mode - test_fake_tensordict + check_env_specs Domain-specific --------------- diff --git a/test/test_libs.py b/test/test_libs.py index 3ca2a58596e..a655c33e7d7 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -24,7 +24,7 @@ from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv -from torchrl.envs.utils import test_fake_tensordict +from torchrl.envs.utils import check_env_specs if _has_gym: import gym @@ -136,7 +136,7 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only): from_pixels=from_pixels, pixels_only=pixels_only, ) - test_fake_tensordict(env) + check_env_specs(env) @implement_for("gym", None, "0.26") @@ -243,7 +243,7 @@ def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only): from_pixels=from_pixels, pixels_only=pixels_only, ) - test_fake_tensordict(env) + check_env_specs(env) @pytest.mark.skipif( @@ -337,7 +337,7 @@ class TestHabitat: def test_habitat(self, envname): env = HabitatEnv(envname) rollout = env.rollout(3) - test_fake_tensordict(env) + check_env_specs(env) @pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed") @@ -375,7 +375,7 @@ def test_jumanji_batch_size(self, envname, batch_size): def test_jumanji_spec_rollout(self, envname, batch_size): env = JumanjiEnv(envname, batch_size=batch_size) env.set_seed(0) - test_fake_tensordict(env) + check_env_specs(env) @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) def test_jumanji_consistency(self, envname, batch_size): diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index e70b88605ed..f525ff43d05 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -154,7 +154,7 @@ def _check_dmlab(): } -def test_fake_tensordict(env): +def check_env_specs(env): """Tests an environment specs against the results of short rollout. This test function should be used as a sanity check for an env wrapped with