Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,7 +1729,7 @@ def test_reset_heterogeneous_envs(
cls = ParallelEnv
else:
cls = SerialEnv
env = cls(2, [env1, env2], device=env_device)
env = cls(2, [env1, env2], device=env_device, share_individual_td=True)
collector = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
Expand Down
120 changes: 105 additions & 15 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@
UnboundedContinuousTensorSpec,
)
from torchrl.envs import (
CatFrames,
CatTensors,
DoubleToFloat,
EnvBase,
EnvCreator,
ParallelEnv,
SerialEnv,
)
from torchrl.envs.batched_envs import _stackable
from torchrl.envs.gym_like import default_info_dict_reader
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper
Expand Down Expand Up @@ -498,19 +500,6 @@ def env_make():
lambda task=task: DMControlEnv("humanoid", task) for task in tasks
]

if not share_individual_td and not single_task:
with pytest.raises(
ValueError, match="share_individual_td must be set to None"
):
SerialEnv(3, env_make, share_individual_td=share_individual_td)
with pytest.raises(
ValueError, match="share_individual_td must be set to None"
):
maybe_fork_ParallelEnv(
3, env_make, share_individual_td=share_individual_td
)
return

env_serial = SerialEnv(3, env_make, share_individual_td=share_individual_td)
env_serial.start()
assert env_serial._single_task is single_task
Expand Down Expand Up @@ -2617,7 +2606,8 @@ def test_auto_cast_to_device(break_when_any_done):


@pytest.mark.parametrize("device", get_default_devices())
def test_backprop(device, maybe_fork_ParallelEnv):
@pytest.mark.parametrize("share_individual_td", [True, False])
def test_backprop(device, maybe_fork_ParallelEnv, share_individual_td):
# Tests that backprop through a series of single envs and through a serial env are identical
# Also tests that no backprop can be achieved with parallel env.
class DifferentiableEnv(EnvBase):
Expand Down Expand Up @@ -2677,8 +2667,14 @@ def make_env(seed, device=device):
2,
[functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)],
device=device,
share_individual_td=share_individual_td,
)
r_serial = serial_env.rollout(10, policy)
if share_individual_td:
r_serial = serial_env.rollout(10, policy)
else:
with pytest.raises(RuntimeError, match="Cannot update a view of a tensordict"):
r_serial = serial_env.rollout(10, policy)
return

g_serial = torch.autograd.grad(
r_serial["next", "reward"].sum(), policy.parameters()
Expand Down Expand Up @@ -2735,6 +2731,100 @@ def test_parallel_another_ctx():
pass


@pytest.mark.skipif(not _has_gym, reason="gym not found")
def test_single_task_share_individual_td():
cartpole = CARTPOLE_VERSIONED()
env = SerialEnv(2, lambda: GymEnv(cartpole))
assert not env.share_individual_td
assert env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, TensorDict)

env = SerialEnv(2, lambda: GymEnv(cartpole), share_individual_td=True)
assert env.share_individual_td
assert env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict)

env = SerialEnv(2, [lambda: GymEnv(cartpole)] * 2)
assert not env.share_individual_td
assert env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, TensorDict)

env = SerialEnv(2, [lambda: GymEnv(cartpole)] * 2, share_individual_td=True)
assert env.share_individual_td
assert env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict)

env = SerialEnv(2, [EnvCreator(lambda: GymEnv(cartpole)) for _ in range(2)])
assert not env.share_individual_td
assert not env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, TensorDict)

env = SerialEnv(
2,
[EnvCreator(lambda: GymEnv(cartpole)) for _ in range(2)],
share_individual_td=True,
)
assert env.share_individual_td
assert not env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict)

# Change shape: makes results non-stackable
env = SerialEnv(
2,
[
EnvCreator(lambda: GymEnv(cartpole)),
EnvCreator(
lambda: TransformedEnv(
GymEnv(cartpole), CatFrames(N=4, dim=-1, in_keys=["observation"])
)
),
],
)
assert env.share_individual_td
assert not env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict)

with pytest.raises(ValueError, match="share_individual_td=False"):
SerialEnv(
2,
[
EnvCreator(lambda: GymEnv(cartpole)),
EnvCreator(
lambda: TransformedEnv(
GymEnv(cartpole),
CatFrames(N=4, dim=-1, in_keys=["observation"]),
)
),
],
share_individual_td=False,
)


def test_stackable():
# Tests the _stackable util
stack = [TensorDict({"a": 0}), TensorDict({"b": 1})]
assert not _stackable(*stack), torch.stack(stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": 1})]
assert not _stackable(*stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": [1]})]
assert _stackable(*stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": [1], "b": {}})]
assert _stackable(*stack)
stack = [TensorDict({"a": {"b": [0]}}), TensorDict({"a": {"b": [1]}})]
assert _stackable(*stack)
stack = [TensorDict({"a": {"b": [0]}}), TensorDict({"a": {"b": 1}})]
assert not _stackable(*stack)
stack = [TensorDict({"a": "a string"}), TensorDict({"a": "another string"})]
assert _stackable(*stack)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
9 changes: 8 additions & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,13 @@ def __init__(
# we we did not receive an env device, we use the device of the env
self.env_device = self.env.device

# If the storing device is not the same as the policy device, we have
# no guarantee that the "next" entry from the policy will be on the
# same device as the collector metadata.
self._cast_to_env_device = self._cast_to_policy_device or (
self.env.device != self.storing_device
)

self.max_frames_per_traj = (
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
)
Expand Down Expand Up @@ -923,7 +930,7 @@ def rollout(self) -> TensorDictBase:
policy_output, keys_to_update=self._policy_output_keys
)

if self._cast_to_policy_device:
if self._cast_to_env_device:
if self.env_device is not None:
env_input = self._shuttle.to(self.env_device, non_blocking=True)
elif self.env_device is None:
Expand Down
86 changes: 56 additions & 30 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,19 +295,11 @@ def __init__(
self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1)
if callable(create_env_fn):
create_env_fn = [create_env_fn for _ in range(num_workers)]
else:
if len(create_env_fn) != num_workers:
raise RuntimeError(
f"num_workers and len(create_env_fn) mismatch, "
f"got {len(create_env_fn)} and {num_workers}"
)
if (
share_individual_td is False and not self._single_task
): # then it has been explicitly set by the user
raise ValueError(
"share_individual_td must be set to None or True when using multi-task batched environments"
)
share_individual_td = True
elif len(create_env_fn) != num_workers:
raise RuntimeError(
f"num_workers and len(create_env_fn) mismatch, "
f"got {len(create_env_fn)} and {num_workers}"
)
create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs
if isinstance(create_env_kwargs, dict):
create_env_kwargs = [
Expand All @@ -322,7 +314,8 @@ def __init__(
if pin_memory:
raise ValueError("pin_memory for batched envs is deprecated")

self.share_individual_td = bool(share_individual_td)
# if share_individual_td is None, we will assess later if the output can be stacked
self.share_individual_td = share_individual_td
self._share_memory = shared_memory
self._memmap = memmap
self.allow_step_when_done = allow_step_when_done
Expand Down Expand Up @@ -365,13 +358,25 @@ def _get_metadata(
self.meta_data = meta_data.expand(
*(self.num_workers, *meta_data.batch_size)
)
if self.share_individual_td is None:
self.share_individual_td = False
else:
n_tasks = len(create_env_fn)
self.meta_data = []
for i in range(n_tasks):
self.meta_data.append(
get_env_metadata(create_env_fn[i], create_env_kwargs[i]).clone()
)
if self.share_individual_td is not True:
share_individual_td = not _stackable(
*[meta_data.tensordict for meta_data in self.meta_data]
)
if share_individual_td and self.share_individual_td is False:
raise ValueError(
"share_individual_td=False was provided but share_individual_td must "
"be True to accomodate non-stackable tensors."
)
self.share_individual_td = share_individual_td
self._set_properties()

def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None:
Expand Down Expand Up @@ -484,9 +489,14 @@ def map_device(key, value, device_map=device_map):
self.done_spec = output_spec["full_done_spec"]

self._dummy_env_str = str(meta_data[0])
self._env_tensordict = LazyStackedTensorDict.lazy_stack(
[meta_data.tensordict for meta_data in meta_data], 0
)
if self.share_individual_td:
self._env_tensordict = LazyStackedTensorDict.lazy_stack(
[meta_data.tensordict for meta_data in meta_data], 0
)
else:
self._env_tensordict = torch.stack(
[meta_data.tensordict for meta_data in meta_data], 0
)
self._batch_locked = meta_data[0].batch_locked
self.has_lazy_inputs = contains_lazy_spec(self.input_spec)

Expand All @@ -503,14 +513,11 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:

def _create_td(self) -> None:
"""Creates self.shared_tensordict_parent, a TensorDict used to store the most recent observations."""
if self._single_task:
shared_tensordict_parent = self._env_tensordict.clone()
if not self._env_tensordict.shape[0] == self.num_workers:
raise RuntimeError(
"batched environment base tensordict has the wrong shape"
)
else:
shared_tensordict_parent = self._env_tensordict.clone()
shared_tensordict_parent = self._env_tensordict.clone()
if self._env_tensordict.shape[0] != self.num_workers:
raise RuntimeError(
"batched environment base tensordict has the wrong shape"
)

if self._single_task:
self._env_input_keys = sorted(
Expand All @@ -525,6 +532,7 @@ def _create_td(self) -> None:
self._env_obs_keys.append(key)
self._env_output_keys += self.reward_keys + self.done_keys
else:
# this is only possible if _single_task=False
env_input_keys = set()
for meta_data in self.meta_data:
if meta_data.specs["input_spec", "full_state_spec"] is not None:
Expand Down Expand Up @@ -577,7 +585,7 @@ def _create_td(self) -> None:
# output keys after step
self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys}

if self._single_task:
if not self.share_individual_td:
shared_tensordict_parent = shared_tensordict_parent.select(
*self._selected_keys,
*(unravel_key(("next", key)) for key in self._env_output_keys),
Expand Down Expand Up @@ -807,10 +815,19 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
tensordict_ = None

_td = _env.reset(tensordict=tensordict_, **kwargs)
self.shared_tensordicts[i].update_(
_td,
keys_to_update=list(self._selected_reset_keys_filt),
)
try:
self.shared_tensordicts[i].update_(
_td,
keys_to_update=list(self._selected_reset_keys_filt),
)
except RuntimeError as err:
if "no_grad mode" in str(err):
raise RuntimeError(
"Cannot update a view of a tensordict when gradients are required. "
"To collect gradient across sub-environments, please set the "
"share_individual_td argument to True."
)
raise
selected_output_keys = self._selected_reset_keys_filt
device = self.device

Expand Down Expand Up @@ -1703,5 +1720,14 @@ def _filter_empty(tensordict):
return tensordict.select(*tensordict.keys(True, True))


def _stackable(*tensordicts):
try:
ls = LazyStackedTensorDict(*tensordicts, stack_dim=0)
ls.contiguous()
return not ls._has_exclusive_keys
except RuntimeError:
return False


# Create an alias for possible imports
_BatchedEnv = BatchedEnvBase
1 change: 0 additions & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,7 +2080,6 @@ def reset(
raise RuntimeError(
f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected."
)

return self._reset_proc_data(tensordict, tensordict_reset)

def _reset_proc_data(self, tensordict, tensordict_reset):
Expand Down