Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,17 @@ def test_brax_grad(self, envname, batch_size):
env.close()
del env

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_brax_parallel(self, envname, batch_size, n=1):
def make_brax():
env = BraxEnv(envname, batch_size=batch_size, requires_grad=False)
env.set_seed(1)
return env

env = ParallelEnv(n, make_brax)
tensordict = env.rollout(3)
assert tensordict.shape == torch.Size([n, *batch_size, 3])


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
30 changes: 23 additions & 7 deletions torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _set_seed(self, seed: int):
raise Exception("Brax requires an integer seed.")
self._key = jax.random.PRNGKey(seed)

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:

# generate random keys
self._key, *keys = jax.random.split(self._key, 1 + self.numel())
Expand All @@ -164,11 +164,13 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
state = _object_to_tensordict(state, self.device, self.batch_size)

# build result
reward = state.get("reward").view(*self.batch_size, *self.reward_spec.shape)
done = state.get("done").bool().view(*self.batch_size, *self.reward_spec.shape)
tensordict_out = TensorDict(
source={
"observation": state.get("obs"),
"reward": state.get("reward"),
"done": state.get("done").bool(),
"reward": reward,
"done": done,
"state": state,
},
batch_size=self.batch_size,
Expand All @@ -195,11 +197,19 @@ def _step_without_grad(self, tensordict: TensorDictBase):
next_state = _object_to_tensordict(next_state, self.device, self.batch_size)

# build result
reward = next_state.get("reward").view(
*self.batch_size, *self.reward_spec.shape
)
done = (
next_state.get("done")
.bool()
.view(*self.batch_size, *self.reward_spec.shape)
)
tensordict_out = TensorDict(
source={
"observation": next_state.get("obs"),
"reward": next_state.get("reward"),
"done": next_state.get("done").bool(),
"reward": reward,
"done": done,
"state": next_state,
},
batch_size=self.batch_size,
Expand All @@ -222,12 +232,18 @@ def _step_with_grad(self, tensordict: TensorDictBase):
)

# extract done values
next_done = next_state_nograd["done"].bool()
next_done = (
next_state_nograd.get("done")
.bool()
.view(*self.batch_size, *self.reward_spec.shape)
)

# merge with tensors with grad function
next_state = next_state_nograd
next_state["obs"] = next_obs
next_state["reward"] = next_reward
next_state["reward"] = next_reward.view(
*self.batch_size, *self.reward_spec.shape
)
next_state["qp"].update(dict(zip(qp_keys, next_qp_values)))

# build result
Expand Down
12 changes: 9 additions & 3 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def decorated_fun(self: _BatchedEnv, *args, **kwargs):
return decorated_fun


def _sort_keys(element):
if isinstance(element, tuple):
return "_-|-_".join(element)
return element


class _dispatch_caller_parallel:
def __init__(self, attr, parallel_env):
self.attr = attr
Expand Down Expand Up @@ -420,17 +426,17 @@ def _create_td(self) -> None:
)
else:
if self._single_task:
self.env_input_keys = sorted(self.input_spec.keys())
self.env_input_keys = sorted(self.input_spec.keys(), key=_sort_keys)
else:
env_input_keys = set()
for meta_data in self.meta_data:
env_input_keys = env_input_keys.union(
meta_data.specs["input_spec"].keys()
)
self.env_input_keys = sorted(env_input_keys)
self.env_input_keys = sorted(env_input_keys, key=_sort_keys)
if not len(self.env_input_keys):
raise RuntimeError(
f"found 0 action keys in {sorted(self.selected_keys)}"
f"found 0 action keys in {sorted(self.selected_keys,key=_sort_keys)}"
)
if self._single_task:
shared_tensordict_parent = shared_tensordict_parent.select(
Expand Down