From 0346923fdbd5d9a7dacb6e1fa658bb0585486ce9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 3 Jan 2023 15:27:54 +0100 Subject: [PATCH 1/6] Fixed nested key sorting --- torchrl/envs/vec_env.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index a74a70c3d86..8ad132b0a2c 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -18,7 +18,6 @@ from tensordict import TensorDict from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp - from torchrl._utils import _check_for_faulty_process from torchrl.data import CompositeSpec, TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING @@ -39,6 +38,12 @@ def decorated_fun(self: _BatchedEnv, *args, **kwargs): return decorated_fun +def _sort_keys(element): + if isinstance(element, tuple): + return "".join(element) + return element + + class _dispatch_caller_parallel: def __init__(self, attr, parallel_env): self.attr = attr @@ -421,17 +426,17 @@ def _create_td(self) -> None: ) else: if self._single_task: - self.env_input_keys = sorted(self.input_spec.keys()) + self.env_input_keys = sorted(self.input_spec.keys(), key=_sort_keys) else: env_input_keys = set() for meta_data in self.meta_data: env_input_keys = env_input_keys.union( meta_data.specs["input_spec"].keys() ) - self.env_input_keys = sorted(env_input_keys) + self.env_input_keys = sorted(env_input_keys, key=_sort_keys) if not len(self.env_input_keys): raise RuntimeError( - f"found 0 action keys in {sorted(self.selected_keys)}" + f"found 0 action keys in {sorted(self.selected_keys,key=_sort_keys)}" ) if self._single_task: shared_tensordict_parent = shared_tensordict_parent.select( From 50ecbc16781211d376f6d462404abab8462dd8d4 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 3 Jan 2023 22:41:53 +0100 Subject: [PATCH 2/6] Better separator --- torchrl/envs/vec_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 8ad132b0a2c..5b0caad6381 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -40,7 +40,7 @@ def decorated_fun(self: _BatchedEnv, *args, **kwargs): def _sort_keys(element): if isinstance(element, tuple): - return "".join(element) + return "_-|-_".join(element) return element From e59b4bf37950e3d4220491c8d6ce972bf6275390 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 4 Jan 2023 11:16:58 +0000 Subject: [PATCH 3/6] brax fix --- test/test_libs.py | 10 ++++++++++ torchrl/envs/libs/brax.py | 21 ++++++++++++++------- torchrl/envs/vec_env.py | 2 ++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index de3e8c665c4..a30d1dd364a 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -513,6 +513,16 @@ def test_brax_grad(self, envname, batch_size): env.close() del env + @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + def test_brax_parallel(self, envname, batch_size): + def make_brax(): + env = BraxEnv(envname, batch_size=batch_size, requires_grad=False) + env.set_seed(1) + return env + env = ParallelEnv(1, make_brax) + tensordict = env.rollout(3) + print(tensordict) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 7d768f34cd3..daee4a9b4aa 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -151,7 +151,7 @@ def _set_seed(self, seed: int): raise Exception("Brax requires an integer seed.") self._key = jax.random.PRNGKey(seed) - def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + def _reset(self, tensordict: TensorDictBase=None, **kwargs) -> TensorDictBase: # generate random keys self._key, *keys = jax.random.split(self._key, 1 + self.numel()) @@ -164,11 +164,15 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: state = _object_to_tensordict(state, self.device, self.batch_size) # build result + reward = state.pop("reward").view(*self.batch_size, *self.reward_spec.shape) + done = state.pop("done").view(*self.batch_size, *self.reward_spec.shape) + print("reward shape:", reward.shape) + print("done shape:", done.shape) tensordict_out = TensorDict( source={ "observation": state.get("obs"), - "reward": state.get("reward"), - "done": state.get("done").bool(), + "reward": reward, + "done": done, "state": state, }, batch_size=self.batch_size, @@ -195,17 +199,20 @@ def _step_without_grad(self, tensordict: TensorDictBase): next_state = _object_to_tensordict(next_state, self.device, self.batch_size) # build result + reward = next_state.pop("reward").view(*self.batch_size, *self.reward_spec.shape) + done = next_state.pop("done").view(*self.batch_size, *self.reward_spec.shape) tensordict_out = TensorDict( source={ "observation": next_state.get("obs"), - "reward": next_state.get("reward"), - "done": next_state.get("done").bool(), + "reward": reward, + "done": done, "state": next_state, }, batch_size=self.batch_size, device=self.device, _run_checks=False, ) + print("tensordict_out", tensordict_out) return tensordict_out def _step_with_grad(self, tensordict: TensorDictBase): @@ -222,12 +229,12 @@ def _step_with_grad(self, tensordict: TensorDictBase): ) # extract done values - next_done = next_state_nograd["done"].bool() + next_done = next_state_nograd["done"].view(*self.batch_size, *self.reward_spec.shape) # merge with tensors with grad function next_state = next_state_nograd next_state["obs"] = next_obs - next_state["reward"] = next_reward + next_state["reward"] = next_reward.view(*self.batch_size, *self.reward_spec.shape) next_state["qp"].update(dict(zip(qp_keys, next_qp_values))) # build result diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 5b0caad6381..70050bb1314 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -1033,6 +1033,8 @@ def _run_worker_pipe_shared_mem( ) if pin_memory: _td.pin_memory() + print("tensordict", tensordict) + print("_td", _td) tensordict.update_(_td.select(*step_keys, strict=False)) if _td.get("done"): msg = "done" From 76f41282e9fd30e8f9251a4338af10a387eda61b Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 4 Jan 2023 11:36:41 +0000 Subject: [PATCH 4/6] brax fix --- test/test_libs.py | 5 +++-- torchrl/envs/libs/brax.py | 23 +++++++++++++---------- torchrl/envs/vec_env.py | 2 -- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index a30d1dd364a..6f1eb81cba0 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -519,9 +519,10 @@ def make_brax(): env = BraxEnv(envname, batch_size=batch_size, requires_grad=False) env.set_seed(1) return env - env = ParallelEnv(1, make_brax) + + env = ParallelEnv(2, make_brax) tensordict = env.rollout(3) - print(tensordict) + assert tensordict.shape == torch.Size([2, *batch_size, 3]) if __name__ == "__main__": diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index daee4a9b4aa..da049f73bfb 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -151,7 +151,7 @@ def _set_seed(self, seed: int): raise Exception("Brax requires an integer seed.") self._key = jax.random.PRNGKey(seed) - def _reset(self, tensordict: TensorDictBase=None, **kwargs) -> TensorDictBase: + def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: # generate random keys self._key, *keys = jax.random.split(self._key, 1 + self.numel()) @@ -164,10 +164,8 @@ def _reset(self, tensordict: TensorDictBase=None, **kwargs) -> TensorDictBase: state = _object_to_tensordict(state, self.device, self.batch_size) # build result - reward = state.pop("reward").view(*self.batch_size, *self.reward_spec.shape) - done = state.pop("done").view(*self.batch_size, *self.reward_spec.shape) - print("reward shape:", reward.shape) - print("done shape:", done.shape) + reward = state.get("reward").view(*self.batch_size, *self.reward_spec.shape) + done = state.get("done").bool().view(*self.batch_size, *self.reward_spec.shape) tensordict_out = TensorDict( source={ "observation": state.get("obs"), @@ -199,8 +197,10 @@ def _step_without_grad(self, tensordict: TensorDictBase): next_state = _object_to_tensordict(next_state, self.device, self.batch_size) # build result - reward = next_state.pop("reward").view(*self.batch_size, *self.reward_spec.shape) - done = next_state.pop("done").view(*self.batch_size, *self.reward_spec.shape) + reward = next_state.get("reward").view( + *self.batch_size, *self.reward_spec.shape + ) + done = next_state.get("done").bool().view(*self.batch_size, *self.reward_spec.shape) tensordict_out = TensorDict( source={ "observation": next_state.get("obs"), @@ -212,7 +212,6 @@ def _step_without_grad(self, tensordict: TensorDictBase): device=self.device, _run_checks=False, ) - print("tensordict_out", tensordict_out) return tensordict_out def _step_with_grad(self, tensordict: TensorDictBase): @@ -229,12 +228,16 @@ def _step_with_grad(self, tensordict: TensorDictBase): ) # extract done values - next_done = next_state_nograd["done"].view(*self.batch_size, *self.reward_spec.shape) + next_done = next_state_nograd.get("done").bool().view( + *self.batch_size, *self.reward_spec.shape + ) # merge with tensors with grad function next_state = next_state_nograd next_state["obs"] = next_obs - next_state["reward"] = next_reward.view(*self.batch_size, *self.reward_spec.shape) + next_state["reward"] = next_reward.view( + *self.batch_size, *self.reward_spec.shape + ) next_state["qp"].update(dict(zip(qp_keys, next_qp_values))) # build result diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 006b3092ec7..733de7d1c45 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -1037,8 +1037,6 @@ def _run_worker_pipe_shared_mem( ) if pin_memory: _td.pin_memory() - print("tensordict", tensordict) - print("_td", _td) tensordict.update_(_td.select(*step_keys, strict=False)) msg = "step_result" data = (msg, step_keys) From 276f9473caba8bb046ae56c21265e68157e63413 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 4 Jan 2023 11:41:23 +0000 Subject: [PATCH 5/6] brax fix --- test/test_libs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 6f1eb81cba0..51f3a5745cb 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -514,15 +514,15 @@ def test_brax_grad(self, envname, batch_size): del env @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) - def test_brax_parallel(self, envname, batch_size): + def test_brax_parallel(self, envname, batch_size, n=1): def make_brax(): env = BraxEnv(envname, batch_size=batch_size, requires_grad=False) env.set_seed(1) return env - env = ParallelEnv(2, make_brax) + env = ParallelEnv(n, make_brax) tensordict = env.rollout(3) - assert tensordict.shape == torch.Size([2, *batch_size, 3]) + assert tensordict.shape == torch.Size([n, *batch_size, 3]) if __name__ == "__main__": From e714dd8cbe6df1a7723519e1c5bde394290f81b2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 4 Jan 2023 11:47:43 +0000 Subject: [PATCH 6/6] lint --- torchrl/envs/libs/brax.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index da049f73bfb..a1fbca2079b 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -200,7 +200,11 @@ def _step_without_grad(self, tensordict: TensorDictBase): reward = next_state.get("reward").view( *self.batch_size, *self.reward_spec.shape ) - done = next_state.get("done").bool().view(*self.batch_size, *self.reward_spec.shape) + done = ( + next_state.get("done") + .bool() + .view(*self.batch_size, *self.reward_spec.shape) + ) tensordict_out = TensorDict( source={ "observation": next_state.get("obs"), @@ -228,8 +232,10 @@ def _step_with_grad(self, tensordict: TensorDictBase): ) # extract done values - next_done = next_state_nograd.get("done").bool().view( - *self.batch_size, *self.reward_spec.shape + next_done = ( + next_state_nograd.get("done") + .bool() + .view(*self.batch_size, *self.reward_spec.shape) ) # merge with tensors with grad function