diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 748f65d6b68..1b8bfdeb240 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -111,12 +111,20 @@ It is also possible to reset some but not all of the environments: is_shared=True) -A note on performance: launching a :obj:`ParallelEnv` can take quite some time +*A note on performance*: launching a :obj:`ParallelEnv` can take quite some time as it requires to launch as many python instances as there are processes. Due to the time that it takes to run :obj:`import torch` (and other imports), starting the parallel env can be a bottleneck. This is why, for instance, TorchRL tests are so slow. Once the environment is launched, a great speedup should be observed. +Another thing to take in consideration is that :obj:`ParallelEnv`s (as well as data collectors) +will create data buffers based on the environment specs to pass data from one process +to another. This means that a misspecified spec (input, observation or reward) will +cause a breakage at runtime as the data can't be written on the preallocated buffer. +In general, an environment should be tested using the :obj:`check_env_specs` +test function before being used in a :obj:`ParallelEnv`. This function will raise +an assertion error whenever the preallocated buffer and the collected data mismatch. + We also offer the :obj:`SerialEnv` class that enjoys the exact same API but is executed serially. This is mostly useful for testing purposes, when one wants to assess the behaviour of a :obj:`ParallelEnv` without launching the subprocesses. @@ -210,6 +218,7 @@ in the environment. The keys to be included in this inverse transform are passed TensorDictPrimer R3MTransform VIPTransform + VIPRewardTransform Helpers ------- @@ -223,6 +232,7 @@ Helpers get_available_libraries set_exploration_mode exploration_mode + check_env_specs Domain-specific --------------- diff --git a/test/_utils_internal.py b/test/_utils_internal.py index a11bbf4ab1d..a9a3db889f7 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -12,7 +12,6 @@ import pytest import torch.cuda from torchrl._utils import implement_for, seed_generator -from torchrl.envs import EnvBase from torchrl.envs.libs.gym import _has_gym # Specified for test_utils.py @@ -70,45 +69,6 @@ def generate_seeds(seed, repeat): return seeds -def _test_fake_tensordict(env: EnvBase): - fake_tensordict = env.fake_tensordict().flatten_keys(".") - real_tensordict = env.rollout(3).flatten_keys(".") - - keys1 = set(fake_tensordict.keys()) - keys2 = set(real_tensordict.keys()) - assert keys1 == keys2 - fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) - fake_tensordict = fake_tensordict.expand(*real_tensordict.shape) - fake_tensordict = fake_tensordict.to_tensordict() - assert ( - fake_tensordict.apply(lambda x: torch.zeros_like(x)) - == real_tensordict.apply(lambda x: torch.zeros_like(x)) - ).all() - for key in keys2: - assert fake_tensordict[key].shape == real_tensordict[key].shape - - # test dtypes - for key, value in real_tensordict.unflatten_keys(".").items(): - _check_dtype(key, value, env.observation_spec, env.input_spec) - - -def _check_dtype(key, value, obs_spec, input_spec): - if key in {"reward", "done"}: - return - elif key == "next": - for _key, _value in value.items(): - _check_dtype(_key, _value, obs_spec, input_spec) - return - elif key in input_spec.keys(yield_nesting_keys=True): - assert input_spec[key].is_in(value), (input_spec[key], value) - return - elif key in obs_spec.keys(yield_nesting_keys=True): - assert obs_spec[key].is_in(value), (input_spec[key], value) - return - else: - raise KeyError(key) - - # Decorator to retry upon certain Exceptions. def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False): def deco_retry(f): diff --git a/test/test_libs.py b/test/test_libs.py index 4e4b4b811b2..a655c33e7d7 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -9,7 +9,6 @@ import pytest import torch from _utils_internal import ( - _test_fake_tensordict, get_available_devices, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, @@ -25,6 +24,7 @@ from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv +from torchrl.envs.utils import check_env_specs if _has_gym: import gym @@ -136,7 +136,7 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only): from_pixels=from_pixels, pixels_only=pixels_only, ) - _test_fake_tensordict(env) + check_env_specs(env) @implement_for("gym", None, "0.26") @@ -243,7 +243,7 @@ def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only): from_pixels=from_pixels, pixels_only=pixels_only, ) - _test_fake_tensordict(env) + check_env_specs(env) @pytest.mark.skipif( @@ -337,7 +337,7 @@ class TestHabitat: def test_habitat(self, envname): env = HabitatEnv(envname) rollout = env.rollout(3) - _test_fake_tensordict(env) + check_env_specs(env) @pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed") @@ -375,7 +375,7 @@ def test_jumanji_batch_size(self, envname, batch_size): def test_jumanji_spec_rollout(self, envname, batch_size): env = JumanjiEnv(envname, batch_size=batch_size) env.set_seed(0) - _test_fake_tensordict(env) + check_env_specs(env) @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) def test_jumanji_consistency(self, envname, batch_size): diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ef86b45b0c4..f525ff43d05 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import pkg_resources +import torch from tensordict.nn.probabilistic import ( # noqa interaction_mode as exploration_mode, set_interaction_mode as set_exploration_mode, @@ -151,3 +152,52 @@ def _check_dmlab(): # "screeps": None, # https://github.com/screeps/screeps # "ml-agents": None, } + + +def check_env_specs(env): + """Tests an environment specs against the results of short rollout. + + This test function should be used as a sanity check for an env wrapped with + torchrl's EnvBase subclasses: any discrepency between the expected data and + the data collected should raise an assertion error. + + A broken environment spec will likely make it impossible to use parallel + environments. + + """ + fake_tensordict = env.fake_tensordict().flatten_keys(".") + real_tensordict = env.rollout(3).flatten_keys(".") + + keys1 = set(fake_tensordict.keys()) + keys2 = set(real_tensordict.keys()) + assert keys1 == keys2 + fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) + fake_tensordict = fake_tensordict.expand(*real_tensordict.shape) + fake_tensordict = fake_tensordict.to_tensordict() + assert ( + fake_tensordict.apply(lambda x: torch.zeros_like(x)) + == real_tensordict.apply(lambda x: torch.zeros_like(x)) + ).all() + for key in keys2: + assert fake_tensordict[key].shape == real_tensordict[key].shape + + # test dtypes + for key, value in real_tensordict.unflatten_keys(".").items(): + _check_dtype(key, value, env.observation_spec, env.input_spec) + + +def _check_dtype(key, value, obs_spec, input_spec): + if key in {"reward", "done"}: + return + elif key == "next": + for _key, _value in value.items(): + _check_dtype(_key, _value, obs_spec, input_spec) + return + elif key in input_spec.keys(yield_nesting_keys=True): + assert input_spec[key].is_in(value), (input_spec[key], value) + return + elif key in obs_spec.keys(yield_nesting_keys=True): + assert obs_spec[key].is_in(value), (input_spec[key], value) + return + else: + raise KeyError(key)