Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,20 @@ It is also possible to reset some but not all of the environments:
is_shared=True)


A note on performance: launching a :obj:`ParallelEnv` can take quite some time
*A note on performance*: launching a :obj:`ParallelEnv` can take quite some time
as it requires to launch as many python instances as there are processes. Due to
the time that it takes to run :obj:`import torch` (and other imports), starting the
parallel env can be a bottleneck. This is why, for instance, TorchRL tests are so slow.
Once the environment is launched, a great speedup should be observed.

Another thing to take in consideration is that :obj:`ParallelEnv`s (as well as data collectors)
will create data buffers based on the environment specs to pass data from one process
to another. This means that a misspecified spec (input, observation or reward) will
cause a breakage at runtime as the data can't be written on the preallocated buffer.
In general, an environment should be tested using the :obj:`check_env_specs`
test function before being used in a :obj:`ParallelEnv`. This function will raise
an assertion error whenever the preallocated buffer and the collected data mismatch.

We also offer the :obj:`SerialEnv` class that enjoys the exact same API but is executed
serially. This is mostly useful for testing purposes, when one wants to assess the
behaviour of a :obj:`ParallelEnv` without launching the subprocesses.
Expand Down Expand Up @@ -210,6 +218,7 @@ in the environment. The keys to be included in this inverse transform are passed
TensorDictPrimer
R3MTransform
VIPTransform
VIPRewardTransform

Helpers
-------
Expand All @@ -223,6 +232,7 @@ Helpers
get_available_libraries
set_exploration_mode
exploration_mode
check_env_specs

Domain-specific
---------------
Expand Down
40 changes: 0 additions & 40 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pytest
import torch.cuda
from torchrl._utils import implement_for, seed_generator
from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import _has_gym

# Specified for test_utils.py
Expand Down Expand Up @@ -70,45 +69,6 @@ def generate_seeds(seed, repeat):
return seeds


def _test_fake_tensordict(env: EnvBase):
fake_tensordict = env.fake_tensordict().flatten_keys(".")
real_tensordict = env.rollout(3).flatten_keys(".")

keys1 = set(fake_tensordict.keys())
keys2 = set(real_tensordict.keys())
assert keys1 == keys2
fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1)
fake_tensordict = fake_tensordict.expand(*real_tensordict.shape)
fake_tensordict = fake_tensordict.to_tensordict()
assert (
fake_tensordict.apply(lambda x: torch.zeros_like(x))
== real_tensordict.apply(lambda x: torch.zeros_like(x))
).all()
for key in keys2:
assert fake_tensordict[key].shape == real_tensordict[key].shape

# test dtypes
for key, value in real_tensordict.unflatten_keys(".").items():
_check_dtype(key, value, env.observation_spec, env.input_spec)


def _check_dtype(key, value, obs_spec, input_spec):
if key in {"reward", "done"}:
return
elif key == "next":
for _key, _value in value.items():
_check_dtype(_key, _value, obs_spec, input_spec)
return
elif key in input_spec.keys(yield_nesting_keys=True):
assert input_spec[key].is_in(value), (input_spec[key], value)
return
elif key in obs_spec.keys(yield_nesting_keys=True):
assert obs_spec[key].is_in(value), (input_spec[key], value)
return
else:
raise KeyError(key)


# Decorator to retry upon certain Exceptions.
def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):
def deco_retry(f):
Expand Down
10 changes: 5 additions & 5 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest
import torch
from _utils_internal import (
_test_fake_tensordict,
get_available_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
Expand All @@ -25,6 +24,7 @@
from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
from torchrl.envs.utils import check_env_specs

if _has_gym:
import gym
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
from_pixels=from_pixels,
pixels_only=pixels_only,
)
_test_fake_tensordict(env)
check_env_specs(env)


@implement_for("gym", None, "0.26")
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only):
from_pixels=from_pixels,
pixels_only=pixels_only,
)
_test_fake_tensordict(env)
check_env_specs(env)


@pytest.mark.skipif(
Expand Down Expand Up @@ -337,7 +337,7 @@ class TestHabitat:
def test_habitat(self, envname):
env = HabitatEnv(envname)
rollout = env.rollout(3)
_test_fake_tensordict(env)
check_env_specs(env)


@pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed")
Expand Down Expand Up @@ -375,7 +375,7 @@ def test_jumanji_batch_size(self, envname, batch_size):
def test_jumanji_spec_rollout(self, envname, batch_size):
env = JumanjiEnv(envname, batch_size=batch_size)
env.set_seed(0)
_test_fake_tensordict(env)
check_env_specs(env)

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_jumanji_consistency(self, envname, batch_size):
Expand Down
50 changes: 50 additions & 0 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import pkg_resources
import torch
from tensordict.nn.probabilistic import ( # noqa
interaction_mode as exploration_mode,
set_interaction_mode as set_exploration_mode,
Expand Down Expand Up @@ -151,3 +152,52 @@ def _check_dmlab():
# "screeps": None, # https://github.com/screeps/screeps
# "ml-agents": None,
}


def check_env_specs(env):
"""Tests an environment specs against the results of short rollout.

This test function should be used as a sanity check for an env wrapped with
torchrl's EnvBase subclasses: any discrepency between the expected data and
the data collected should raise an assertion error.

A broken environment spec will likely make it impossible to use parallel
environments.

"""
fake_tensordict = env.fake_tensordict().flatten_keys(".")
real_tensordict = env.rollout(3).flatten_keys(".")

keys1 = set(fake_tensordict.keys())
keys2 = set(real_tensordict.keys())
assert keys1 == keys2
fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1)
fake_tensordict = fake_tensordict.expand(*real_tensordict.shape)
fake_tensordict = fake_tensordict.to_tensordict()
assert (
fake_tensordict.apply(lambda x: torch.zeros_like(x))
== real_tensordict.apply(lambda x: torch.zeros_like(x))
).all()
for key in keys2:
assert fake_tensordict[key].shape == real_tensordict[key].shape

# test dtypes
for key, value in real_tensordict.unflatten_keys(".").items():
_check_dtype(key, value, env.observation_spec, env.input_spec)


def _check_dtype(key, value, obs_spec, input_spec):
if key in {"reward", "done"}:
return
elif key == "next":
for _key, _value in value.items():
_check_dtype(_key, _value, obs_spec, input_spec)
return
elif key in input_spec.keys(yield_nesting_keys=True):
assert input_spec[key].is_in(value), (input_spec[key], value)
return
elif key in obs_spec.keys(yield_nesting_keys=True):
assert obs_spec[key].is_in(value), (input_spec[key], value)
return
else:
raise KeyError(key)