diff --git a/test/test_libs.py b/test/test_libs.py index de3e8c665c4..51f3a5745cb 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -513,6 +513,17 @@ def test_brax_grad(self, envname, batch_size): env.close() del env + @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + def test_brax_parallel(self, envname, batch_size, n=1): + def make_brax(): + env = BraxEnv(envname, batch_size=batch_size, requires_grad=False) + env.set_seed(1) + return env + + env = ParallelEnv(n, make_brax) + tensordict = env.rollout(3) + assert tensordict.shape == torch.Size([n, *batch_size, 3]) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 7d768f34cd3..a1fbca2079b 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -151,7 +151,7 @@ def _set_seed(self, seed: int): raise Exception("Brax requires an integer seed.") self._key = jax.random.PRNGKey(seed) - def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: # generate random keys self._key, *keys = jax.random.split(self._key, 1 + self.numel()) @@ -164,11 +164,13 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: state = _object_to_tensordict(state, self.device, self.batch_size) # build result + reward = state.get("reward").view(*self.batch_size, *self.reward_spec.shape) + done = state.get("done").bool().view(*self.batch_size, *self.reward_spec.shape) tensordict_out = TensorDict( source={ "observation": state.get("obs"), - "reward": state.get("reward"), - "done": state.get("done").bool(), + "reward": reward, + "done": done, "state": state, }, batch_size=self.batch_size, @@ -195,11 +197,19 @@ def _step_without_grad(self, tensordict: TensorDictBase): next_state = _object_to_tensordict(next_state, self.device, self.batch_size) # build result + reward = next_state.get("reward").view( + *self.batch_size, *self.reward_spec.shape + ) + done = ( + next_state.get("done") + .bool() + .view(*self.batch_size, *self.reward_spec.shape) + ) tensordict_out = TensorDict( source={ "observation": next_state.get("obs"), - "reward": next_state.get("reward"), - "done": next_state.get("done").bool(), + "reward": reward, + "done": done, "state": next_state, }, batch_size=self.batch_size, @@ -222,12 +232,18 @@ def _step_with_grad(self, tensordict: TensorDictBase): ) # extract done values - next_done = next_state_nograd["done"].bool() + next_done = ( + next_state_nograd.get("done") + .bool() + .view(*self.batch_size, *self.reward_spec.shape) + ) # merge with tensors with grad function next_state = next_state_nograd next_state["obs"] = next_obs - next_state["reward"] = next_reward + next_state["reward"] = next_reward.view( + *self.batch_size, *self.reward_spec.shape + ) next_state["qp"].update(dict(zip(qp_keys, next_qp_values))) # build result diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index af44cc533e5..733de7d1c45 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -38,6 +38,12 @@ def decorated_fun(self: _BatchedEnv, *args, **kwargs): return decorated_fun +def _sort_keys(element): + if isinstance(element, tuple): + return "_-|-_".join(element) + return element + + class _dispatch_caller_parallel: def __init__(self, attr, parallel_env): self.attr = attr @@ -420,17 +426,17 @@ def _create_td(self) -> None: ) else: if self._single_task: - self.env_input_keys = sorted(self.input_spec.keys()) + self.env_input_keys = sorted(self.input_spec.keys(), key=_sort_keys) else: env_input_keys = set() for meta_data in self.meta_data: env_input_keys = env_input_keys.union( meta_data.specs["input_spec"].keys() ) - self.env_input_keys = sorted(env_input_keys) + self.env_input_keys = sorted(env_input_keys, key=_sort_keys) if not len(self.env_input_keys): raise RuntimeError( - f"found 0 action keys in {sorted(self.selected_keys)}" + f"found 0 action keys in {sorted(self.selected_keys,key=_sort_keys)}" ) if self._single_task: shared_tensordict_parent = shared_tensordict_parent.select(