diff --git a/test/test_collector.py b/test/test_collector.py index 58dcfd8bcbd..3720f0f8349 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1729,7 +1729,7 @@ def test_reset_heterogeneous_envs( cls = ParallelEnv else: cls = SerialEnv - env = cls(2, [env1, env2], device=env_device) + env = cls(2, [env1, env2], device=env_device, share_individual_td=True) collector = SyncDataCollector( env, RandomPolicy(env.action_spec), diff --git a/test/test_env.py b/test/test_env.py index 9ac9427beed..048f7812903 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -67,6 +67,7 @@ UnboundedContinuousTensorSpec, ) from torchrl.envs import ( + CatFrames, CatTensors, DoubleToFloat, EnvBase, @@ -74,6 +75,7 @@ ParallelEnv, SerialEnv, ) +from torchrl.envs.batched_envs import _stackable from torchrl.envs.gym_like import default_info_dict_reader from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper @@ -498,19 +500,6 @@ def env_make(): lambda task=task: DMControlEnv("humanoid", task) for task in tasks ] - if not share_individual_td and not single_task: - with pytest.raises( - ValueError, match="share_individual_td must be set to None" - ): - SerialEnv(3, env_make, share_individual_td=share_individual_td) - with pytest.raises( - ValueError, match="share_individual_td must be set to None" - ): - maybe_fork_ParallelEnv( - 3, env_make, share_individual_td=share_individual_td - ) - return - env_serial = SerialEnv(3, env_make, share_individual_td=share_individual_td) env_serial.start() assert env_serial._single_task is single_task @@ -2617,7 +2606,8 @@ def test_auto_cast_to_device(break_when_any_done): @pytest.mark.parametrize("device", get_default_devices()) -def test_backprop(device, maybe_fork_ParallelEnv): +@pytest.mark.parametrize("share_individual_td", [True, False]) +def test_backprop(device, maybe_fork_ParallelEnv, share_individual_td): # Tests that backprop through a series of single envs and through a serial env are identical # Also tests that no backprop can be achieved with parallel env. class DifferentiableEnv(EnvBase): @@ -2677,8 +2667,14 @@ def make_env(seed, device=device): 2, [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)], device=device, + share_individual_td=share_individual_td, ) - r_serial = serial_env.rollout(10, policy) + if share_individual_td: + r_serial = serial_env.rollout(10, policy) + else: + with pytest.raises(RuntimeError, match="Cannot update a view of a tensordict"): + r_serial = serial_env.rollout(10, policy) + return g_serial = torch.autograd.grad( r_serial["next", "reward"].sum(), policy.parameters() @@ -2735,6 +2731,100 @@ def test_parallel_another_ctx(): pass +@pytest.mark.skipif(not _has_gym, reason="gym not found") +def test_single_task_share_individual_td(): + cartpole = CARTPOLE_VERSIONED() + env = SerialEnv(2, lambda: GymEnv(cartpole)) + assert not env.share_individual_td + assert env._single_task + env.rollout(2) + assert isinstance(env.shared_tensordict_parent, TensorDict) + + env = SerialEnv(2, lambda: GymEnv(cartpole), share_individual_td=True) + assert env.share_individual_td + assert env._single_task + env.rollout(2) + assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict) + + env = SerialEnv(2, [lambda: GymEnv(cartpole)] * 2) + assert not env.share_individual_td + assert env._single_task + env.rollout(2) + assert isinstance(env.shared_tensordict_parent, TensorDict) + + env = SerialEnv(2, [lambda: GymEnv(cartpole)] * 2, share_individual_td=True) + assert env.share_individual_td + assert env._single_task + env.rollout(2) + assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict) + + env = SerialEnv(2, [EnvCreator(lambda: GymEnv(cartpole)) for _ in range(2)]) + assert not env.share_individual_td + assert not env._single_task + env.rollout(2) + assert isinstance(env.shared_tensordict_parent, TensorDict) + + env = SerialEnv( + 2, + [EnvCreator(lambda: GymEnv(cartpole)) for _ in range(2)], + share_individual_td=True, + ) + assert env.share_individual_td + assert not env._single_task + env.rollout(2) + assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict) + + # Change shape: makes results non-stackable + env = SerialEnv( + 2, + [ + EnvCreator(lambda: GymEnv(cartpole)), + EnvCreator( + lambda: TransformedEnv( + GymEnv(cartpole), CatFrames(N=4, dim=-1, in_keys=["observation"]) + ) + ), + ], + ) + assert env.share_individual_td + assert not env._single_task + env.rollout(2) + assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict) + + with pytest.raises(ValueError, match="share_individual_td=False"): + SerialEnv( + 2, + [ + EnvCreator(lambda: GymEnv(cartpole)), + EnvCreator( + lambda: TransformedEnv( + GymEnv(cartpole), + CatFrames(N=4, dim=-1, in_keys=["observation"]), + ) + ), + ], + share_individual_td=False, + ) + + +def test_stackable(): + # Tests the _stackable util + stack = [TensorDict({"a": 0}), TensorDict({"b": 1})] + assert not _stackable(*stack), torch.stack(stack) + stack = [TensorDict({"a": [0]}), TensorDict({"a": 1})] + assert not _stackable(*stack) + stack = [TensorDict({"a": [0]}), TensorDict({"a": [1]})] + assert _stackable(*stack) + stack = [TensorDict({"a": [0]}), TensorDict({"a": [1], "b": {}})] + assert _stackable(*stack) + stack = [TensorDict({"a": {"b": [0]}}), TensorDict({"a": {"b": [1]}})] + assert _stackable(*stack) + stack = [TensorDict({"a": {"b": [0]}}), TensorDict({"a": {"b": 1}})] + assert not _stackable(*stack) + stack = [TensorDict({"a": "a string"}), TensorDict({"a": "another string"})] + assert _stackable(*stack) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index ff74beb8afd..2acf0b92e59 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -506,6 +506,13 @@ def __init__( # we we did not receive an env device, we use the device of the env self.env_device = self.env.device + # If the storing device is not the same as the policy device, we have + # no guarantee that the "next" entry from the policy will be on the + # same device as the collector metadata. + self._cast_to_env_device = self._cast_to_policy_device or ( + self.env.device != self.storing_device + ) + self.max_frames_per_traj = ( int(max_frames_per_traj) if max_frames_per_traj is not None else 0 ) @@ -923,7 +930,7 @@ def rollout(self) -> TensorDictBase: policy_output, keys_to_update=self._policy_output_keys ) - if self._cast_to_policy_device: + if self._cast_to_env_device: if self.env_device is not None: env_input = self._shuttle.to(self.env_device, non_blocking=True) elif self.env_device is None: diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index d67e6472d05..48085d21093 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -295,19 +295,11 @@ def __init__( self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1) if callable(create_env_fn): create_env_fn = [create_env_fn for _ in range(num_workers)] - else: - if len(create_env_fn) != num_workers: - raise RuntimeError( - f"num_workers and len(create_env_fn) mismatch, " - f"got {len(create_env_fn)} and {num_workers}" - ) - if ( - share_individual_td is False and not self._single_task - ): # then it has been explicitly set by the user - raise ValueError( - "share_individual_td must be set to None or True when using multi-task batched environments" - ) - share_individual_td = True + elif len(create_env_fn) != num_workers: + raise RuntimeError( + f"num_workers and len(create_env_fn) mismatch, " + f"got {len(create_env_fn)} and {num_workers}" + ) create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs if isinstance(create_env_kwargs, dict): create_env_kwargs = [ @@ -322,7 +314,8 @@ def __init__( if pin_memory: raise ValueError("pin_memory for batched envs is deprecated") - self.share_individual_td = bool(share_individual_td) + # if share_individual_td is None, we will assess later if the output can be stacked + self.share_individual_td = share_individual_td self._share_memory = shared_memory self._memmap = memmap self.allow_step_when_done = allow_step_when_done @@ -365,6 +358,8 @@ def _get_metadata( self.meta_data = meta_data.expand( *(self.num_workers, *meta_data.batch_size) ) + if self.share_individual_td is None: + self.share_individual_td = False else: n_tasks = len(create_env_fn) self.meta_data = [] @@ -372,6 +367,16 @@ def _get_metadata( self.meta_data.append( get_env_metadata(create_env_fn[i], create_env_kwargs[i]).clone() ) + if self.share_individual_td is not True: + share_individual_td = not _stackable( + *[meta_data.tensordict for meta_data in self.meta_data] + ) + if share_individual_td and self.share_individual_td is False: + raise ValueError( + "share_individual_td=False was provided but share_individual_td must " + "be True to accomodate non-stackable tensors." + ) + self.share_individual_td = share_individual_td self._set_properties() def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None: @@ -484,9 +489,14 @@ def map_device(key, value, device_map=device_map): self.done_spec = output_spec["full_done_spec"] self._dummy_env_str = str(meta_data[0]) - self._env_tensordict = LazyStackedTensorDict.lazy_stack( - [meta_data.tensordict for meta_data in meta_data], 0 - ) + if self.share_individual_td: + self._env_tensordict = LazyStackedTensorDict.lazy_stack( + [meta_data.tensordict for meta_data in meta_data], 0 + ) + else: + self._env_tensordict = torch.stack( + [meta_data.tensordict for meta_data in meta_data], 0 + ) self._batch_locked = meta_data[0].batch_locked self.has_lazy_inputs = contains_lazy_spec(self.input_spec) @@ -503,14 +513,11 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: def _create_td(self) -> None: """Creates self.shared_tensordict_parent, a TensorDict used to store the most recent observations.""" - if self._single_task: - shared_tensordict_parent = self._env_tensordict.clone() - if not self._env_tensordict.shape[0] == self.num_workers: - raise RuntimeError( - "batched environment base tensordict has the wrong shape" - ) - else: - shared_tensordict_parent = self._env_tensordict.clone() + shared_tensordict_parent = self._env_tensordict.clone() + if self._env_tensordict.shape[0] != self.num_workers: + raise RuntimeError( + "batched environment base tensordict has the wrong shape" + ) if self._single_task: self._env_input_keys = sorted( @@ -525,6 +532,7 @@ def _create_td(self) -> None: self._env_obs_keys.append(key) self._env_output_keys += self.reward_keys + self.done_keys else: + # this is only possible if _single_task=False env_input_keys = set() for meta_data in self.meta_data: if meta_data.specs["input_spec", "full_state_spec"] is not None: @@ -577,7 +585,7 @@ def _create_td(self) -> None: # output keys after step self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys} - if self._single_task: + if not self.share_individual_td: shared_tensordict_parent = shared_tensordict_parent.select( *self._selected_keys, *(unravel_key(("next", key)) for key in self._env_output_keys), @@ -807,10 +815,19 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_ = None _td = _env.reset(tensordict=tensordict_, **kwargs) - self.shared_tensordicts[i].update_( - _td, - keys_to_update=list(self._selected_reset_keys_filt), - ) + try: + self.shared_tensordicts[i].update_( + _td, + keys_to_update=list(self._selected_reset_keys_filt), + ) + except RuntimeError as err: + if "no_grad mode" in str(err): + raise RuntimeError( + "Cannot update a view of a tensordict when gradients are required. " + "To collect gradient across sub-environments, please set the " + "share_individual_td argument to True." + ) + raise selected_output_keys = self._selected_reset_keys_filt device = self.device @@ -1703,5 +1720,14 @@ def _filter_empty(tensordict): return tensordict.select(*tensordict.keys(True, True)) +def _stackable(*tensordicts): + try: + ls = LazyStackedTensorDict(*tensordicts, stack_dim=0) + ls.contiguous() + return not ls._has_exclusive_keys + except RuntimeError: + return False + + # Create an alias for possible imports _BatchedEnv = BatchedEnvBase diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 59262c98eff..093ead56164 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2080,7 +2080,6 @@ def reset( raise RuntimeError( f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected." ) - return self._reset_proc_data(tensordict, tensordict_reset) def _reset_proc_data(self, tensordict, tensordict_reset):