From e2deaf78e779d1927919e0582f39261a3f512ecf Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Thu, 1 Dec 2022 10:34:35 +0000 Subject: [PATCH 01/12] [Brax] init brax environment --- test/_utils_internal.py | 3 +- test/test_libs.py | 39 +++++++ torchrl/envs/libs/brax.py | 217 +++++++++++++++++++++++++++++++++++ torchrl/envs/libs/jumanji.py | 34 ++++-- 4 files changed, 280 insertions(+), 13 deletions(-) create mode 100644 torchrl/envs/libs/brax.py diff --git a/test/_utils_internal.py b/test/_utils_internal.py index a11bbf4ab1d..ad9f5d732c4 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -88,7 +88,8 @@ def _test_fake_tensordict(env: EnvBase): assert fake_tensordict[key].shape == real_tensordict[key].shape # test dtypes - for key, value in real_tensordict.unflatten_keys(".").items(): + real_tensordict = env.rollout(3) # Empty structures will be missing in flattened keys. + for key, value in real_tensordict.items(): _check_dtype(key, value, env.observation_spec, env.input_spec) diff --git a/test/test_libs.py b/test/test_libs.py index 06e09a0521b..282a53fd6cc 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -25,6 +25,7 @@ from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv +from torchrl.envs.libs.brax import _has_brax, BraxEnv if _has_gym: import gym @@ -425,6 +426,44 @@ def test_jumanji_consistency(self, envname, batch_size): ) +@pytest.mark.skipif(not _has_brax, reason="brax not installed") +@pytest.mark.parametrize("envname", ["ant"]) +class TestBrax: + def test_brax_seeding(self, envname): + final_seed = [] + tdreset = [] + tdrollout = [] + for _ in range(2): + env = BraxEnv(envname) + torch.manual_seed(0) + np.random.seed(0) + final_seed.append(env.set_seed(0)) + tdreset.append(env.reset()) + tdrollout.append(env.rollout(max_steps=50)) + env.close() + del env + assert final_seed[0] == final_seed[1] + assert_allclose_td(*tdreset) + assert_allclose_td(*tdrollout) + + @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + def test_brax_batch_size(self, envname, batch_size): + env = BraxEnv(envname, batch_size=batch_size) + env.set_seed(0) + tdreset = env.reset() + tdrollout = env.rollout(max_steps=50) + env.close() + del env + assert tdreset.batch_size == batch_size + assert tdrollout.batch_size[:-1] == batch_size + + @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + def test_brax_spec_rollout(self, envname, batch_size): + env = BraxEnv(envname, batch_size=batch_size) + env.set_seed(0) + _test_fake_tensordict(env) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py new file mode 100644 index 00000000000..eccec3fd6e6 --- /dev/null +++ b/torchrl/envs/libs/brax.py @@ -0,0 +1,217 @@ +from typing import Dict, Optional, Union + +import torch +from tensordict.tensordict import TensorDict, TensorDictBase +from torchrl.data import CompositeSpec, NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec +from torchrl.envs.common import _EnvWrapper +from torchrl.envs.libs.jumanji import _object_to_tensordict, _tensordict_to_object, _torchrl_data_to_spec_transform + +try: + import brax + import brax.envs + import brax.io.torch + from brax import jumpy as jp + import jax + import numpy as np + + _has_brax = True +except ImportError as err: + _has_brax = False + IMPORT_ERR = str(err) + + +class BraxWrapper(_EnvWrapper): + """Google Brax environment wrapper. + + Examples: + >>> env = brax.envs.get_environment("ant") + >>> env = BraxWrapper(env) + >>> td = env.rand_step() + >>> print(td) + >>> print(env.available_envs) + + """ + + git_url = "https://github.com/google/brax" + + @property + def lib(self): + return brax + + def __init__(self, env=None, categorical_action_encoding=False, **kwargs): + if env is not None: + kwargs["env"] = env + self._seed_calls_reset = None + self._categorical_action_encoding = categorical_action_encoding + super().__init__(**kwargs) + + def _check_kwargs(self, kwargs: Dict): + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance(env, brax.envs.env.Env): + raise TypeError("env is not of type 'brax.envs.env.Env'.") + + def _build_env( + self, + env, + _seed: Optional[int] = None, + from_pixels: bool = False, + render_kwargs: Optional[dict] = None, + pixels_only: bool = False, + camera_id: Union[int, str] = 0, + **kwargs, + ): + self.from_pixels = from_pixels + self.pixels_only = pixels_only + + if from_pixels: + raise NotImplementedError("TODO") + return env + + def _make_state_spec(self, env: "brax.envs.env.Env"): + key = jax.random.PRNGKey(0) + state = env.reset(key) + state_dict = _object_to_tensordict(state, self.device, batch_size=()) + state_spec = _torchrl_data_to_spec_transform(state_dict) + return state_spec + + def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 + self._input_spec = CompositeSpec( + action=NdBoundedTensorSpec( + minimum=-1, + maximum=1, + shape=(env.action_size,), + device=self.device + ) + ) + self._reward_spec = NdUnboundedContinuousTensorSpec( + shape=(), + device=self.device + ) + self._observation_spec = CompositeSpec( + observation=NdUnboundedContinuousTensorSpec( + shape=(env.observation_size,), + device=self.device + ) + ) + # extract state spec from instance + self._state_spec = self._make_state_spec(env) + self._input_spec["state"] = self._state_spec + + def _make_state_example(self): + key = jax.random.PRNGKey(0) + keys = jax.random.split(key, self.batch_size.numel()) + state = self._vmap_jit_env_reset(jax.numpy.stack(keys)) + state = self._reshape(state) + return state + + def _init_env(self) -> Optional[int]: + self._key = None + self._vmap_jit_env_reset = jp.vmap(jax.jit(self._env.reset)) + self._vmap_jit_env_step = jp.vmap(jax.jit(self._env.step)) + self._state_example = self._make_state_example() + + def _set_seed(self, seed: int): + if seed is None: + raise Exception("Brax requires an integer seed.") + self._key = jax.random.PRNGKey(seed) + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + + self._key, *keys = jax.random.split(self._key, 1 + self.numel()) + state = self._vmap_jit_env_reset(jax.numpy.stack(keys)) + state = self._reshape(state) + state = _object_to_tensordict(state, self.device, self.batch_size) + + tensordict_out = TensorDict( + source={ + "observation": state.get("obs"), + "reward": state.get("reward"), + "done": state.get("done").bool(), + "state": state, + }, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict_out + + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + + state = _tensordict_to_object(tensordict.get("state"), self._state_example) + action = tensordict.get("action").numpy() + + state = self._flatten(state) + action = self._flatten(action) + state = self._vmap_jit_env_step(state, action) + state = self._reshape(state) + state = _object_to_tensordict(state, self.device, self.batch_size) + + tensordict_out = TensorDict( + source={ + "observation": state.get("obs"), + "reward": state.get("reward"), + "done": state.get("done").bool(), + "state": state, + }, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict_out + + def _reshape(self, x): + shape, n = self.batch_size, 1 + return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) + + def _flatten(self, x): + shape, n = (self.batch_size.numel(),), len(self.batch_size) + return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) + + +class BraxEnv(BraxWrapper): + """Google Brax environment wrapper. + + Examples: + >>> env = BraxEnv(env_name="ant") + >>> td = env.rand_step() + >>> print(td) + >>> print(env.available_envs) + + """ + + def __init__(self, env_name, **kwargs): + kwargs["env_name"] = env_name + super().__init__(**kwargs) + + def _build_env( + self, + env_name: str, + **kwargs, + ) -> "brax.envs.env.Env": + if not _has_brax: + raise RuntimeError( + f"brax not found, unable to create {env_name}. " + f"Consider downloading and installing brax from" + f" {self.git_url}" + ) + from_pixels = kwargs.pop("from_pixels", False) + pixels_only = kwargs.pop("pixels_only", True) + assert not kwargs + self.wrapper_frame_skip = 1 + env = self.lib.envs.get_environment(env_name, **kwargs) + return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + + @property + def env_name(self): + return self._constructor_kwargs["env_name"] + + def _check_kwargs(self, kwargs: Dict): + if "env_name" not in kwargs: + raise TypeError("Expected 'env_name' to be part of kwargs") + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})" diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index ef4bf02b8d4..c22339ee9a3 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -44,18 +44,28 @@ def _ndarray_to_tensor(value: Union["jnp.ndarray", np.ndarray], device) -> torch return torch.tensor(value).to(device) -def _object_to_tensordict(obj: Union, device, batch_size) -> TensorDictBase: - """Converts a namedtuple or a dataclass to a TensorDict.""" - t = {} +def _object_to_dict(obj) -> dict: if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple - _iter = obj._fields + return dict(zip(obj._fields, obj)) elif dataclasses.is_dataclass(obj): - _iter = (field.name for field in dataclasses.fields(obj)) + return { + field.name: getattr(obj, field.name) + for field in dataclasses.fields(obj) + } + elif isinstance(obj, dict): + return obj else: raise NotImplementedError(f"unsupported data type {type(obj)}") - for name in _iter: - value = getattr(obj, name) - if isinstance(value, (jnp.ndarray, np.ndarray)): + + +def _object_to_tensordict(obj: Union, device, batch_size) -> TensorDictBase: + """Converts a namedtuple or a dataclass to a TensorDict.""" + t = {} + _dict = _object_to_dict(obj) + for name, value in _dict.items(): + if isinstance(value, (np.number, int, float)): + t[name] = _ndarray_to_tensor(np.asarray([value]), device=device) + elif isinstance(value, (jnp.ndarray, np.ndarray)): t[name] = _ndarray_to_tensor(value, device=device) else: t[name] = _object_to_tensordict(value, device, batch_size) @@ -64,18 +74,18 @@ def _object_to_tensordict(obj: Union, device, batch_size) -> TensorDictBase: def _tensordict_to_object(tensordict: TensorDictBase, object_example): """Converts a TensorDict to a namedtuple or a dataclass.""" - object_type = type(object_example) t = {} + _dict = _object_to_dict(object_example) for name in tensordict.keys(): + example = _dict[name] value = tensordict[name] if isinstance(value, TensorDictBase): - t[name] = _tensordict_to_object(value, getattr(object_example, name)) + t[name] = _tensordict_to_object(value, example) else: - example = getattr(object_example, name) t[name] = ( value.detach().numpy().reshape(example.shape).astype(example.dtype) ) - return object_type(**t) + return type(object_example)(**t) def _jumanji_to_torchrl_spec_transform( From a0bae56b4d48d3eb305703b3bac9d2d7d61f391f Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Thu, 1 Dec 2022 11:04:22 +0000 Subject: [PATCH 02/12] [Brax] init brax environment --- test/_utils_internal.py | 2 +- test/test_libs.py | 2 +- torchrl/envs/libs/brax.py | 30 +++++++++++++++--------------- torchrl/envs/libs/jumanji.py | 3 +-- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index ad9f5d732c4..d157216f188 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -88,7 +88,7 @@ def _test_fake_tensordict(env: EnvBase): assert fake_tensordict[key].shape == real_tensordict[key].shape # test dtypes - real_tensordict = env.rollout(3) # Empty structures will be missing in flattened keys. + real_tensordict = env.rollout(3) for key, value in real_tensordict.items(): _check_dtype(key, value, env.observation_spec, env.input_spec) diff --git a/test/test_libs.py b/test/test_libs.py index 282a53fd6cc..995eec69221 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -21,11 +21,11 @@ from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import RandomPolicy from torchrl.envs import EnvCreator, ParallelEnv +from torchrl.envs.libs.brax import _has_brax, BraxEnv from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv -from torchrl.envs.libs.brax import _has_brax, BraxEnv if _has_gym: import gym diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index eccec3fd6e6..735f62de36c 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -1,18 +1,24 @@ from typing import Dict, Optional, Union -import torch from tensordict.tensordict import TensorDict, TensorDictBase -from torchrl.data import CompositeSpec, NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec +from torchrl.data import ( + CompositeSpec, + NdBoundedTensorSpec, + NdUnboundedContinuousTensorSpec, +) from torchrl.envs.common import _EnvWrapper -from torchrl.envs.libs.jumanji import _object_to_tensordict, _tensordict_to_object, _torchrl_data_to_spec_transform +from torchrl.envs.libs.jumanji import ( + _object_to_tensordict, + _tensordict_to_object, + _torchrl_data_to_spec_transform, +) try: import brax import brax.envs import brax.io.torch - from brax import jumpy as jp import jax - import numpy as np + from brax import jumpy as jp _has_brax = True except ImportError as err: @@ -33,7 +39,7 @@ class BraxWrapper(_EnvWrapper): """ git_url = "https://github.com/google/brax" - + @property def lib(self): return brax @@ -79,20 +85,15 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 self._input_spec = CompositeSpec( action=NdBoundedTensorSpec( - minimum=-1, - maximum=1, - shape=(env.action_size,), - device=self.device + minimum=-1, maximum=1, shape=(env.action_size,), device=self.device ) ) self._reward_spec = NdUnboundedContinuousTensorSpec( - shape=(), - device=self.device + shape=(), device=self.device ) self._observation_spec = CompositeSpec( observation=NdUnboundedContinuousTensorSpec( - shape=(env.observation_size,), - device=self.device + shape=(env.observation_size,), device=self.device ) ) # extract state spec from instance @@ -136,7 +137,6 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: ) return tensordict_out - def _step( self, tensordict: TensorDictBase, diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index c22339ee9a3..77782b675f1 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -49,8 +49,7 @@ def _object_to_dict(obj) -> dict: return dict(zip(obj._fields, obj)) elif dataclasses.is_dataclass(obj): return { - field.name: getattr(obj, field.name) - for field in dataclasses.fields(obj) + field.name: getattr(obj, field.name) for field in dataclasses.fields(obj) } elif isinstance(obj, dict): return obj From 464932bb0322612837c119444a461c20326d1476 Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Wed, 7 Dec 2022 12:24:13 +0000 Subject: [PATCH 03/12] use dlpack conversion to avoid moving data between devices --- test/_utils_internal.py | 2 +- test/test_libs.py | 3 +- torchrl/envs/libs/brax.py | 39 ++++++------ torchrl/envs/libs/jax_utils.py | 87 +++++++++++++++++++++++++++ torchrl/envs/libs/jumanji.py | 105 +++++++-------------------------- 5 files changed, 129 insertions(+), 107 deletions(-) create mode 100644 torchrl/envs/libs/jax_utils.py diff --git a/test/_utils_internal.py b/test/_utils_internal.py index d157216f188..00f69e6a40f 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -88,7 +88,7 @@ def _test_fake_tensordict(env: EnvBase): assert fake_tensordict[key].shape == real_tensordict[key].shape # test dtypes - real_tensordict = env.rollout(3) + real_tensordict = env.rollout(3) # keep empty structs. for key, value in real_tensordict.items(): _check_dtype(key, value, env.observation_spec, env.input_spec) diff --git a/test/test_libs.py b/test/test_libs.py index 995eec69221..80154257aa1 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -26,6 +26,7 @@ from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv +from torchrl.envs.libs.jax_utils import tree_flatten if _has_gym: import gym @@ -399,7 +400,7 @@ def test_jumanji_consistency(self, envname, batch_size): for i in range(rollout.shape[-1]): action = rollout[..., i]["action"] # state = env._flatten(state) - action = env._flatten(env.read_action(action)) + action = tree_flatten(env.read_action(action), env.batch_size) state, timestep = jax.vmap(base_env.step)(state, action) # state = env._reshape(state) # timesteps.append(timestep) diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 735f62de36c..54218747c7d 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -7,10 +7,13 @@ NdUnboundedContinuousTensorSpec, ) from torchrl.envs.common import _EnvWrapper -from torchrl.envs.libs.jumanji import ( - _object_to_tensordict, - _tensordict_to_object, - _torchrl_data_to_spec_transform, +from torchrl.envs.libs.jumanji import _torchrl_data_to_spec_transform +from torchrl.envs.libs.jax_utils import ( + tree_flatten, + tree_reshape, + tensor_to_ndarray, + object_to_tensordict, + tensordict_to_object, ) try: @@ -78,7 +81,7 @@ def _build_env( def _make_state_spec(self, env: "brax.envs.env.Env"): key = jax.random.PRNGKey(0) state = env.reset(key) - state_dict = _object_to_tensordict(state, self.device, batch_size=()) + state_dict = object_to_tensordict(state, self.device, batch_size=()) state_spec = _torchrl_data_to_spec_transform(state_dict) return state_spec @@ -104,7 +107,7 @@ def _make_state_example(self): key = jax.random.PRNGKey(0) keys = jax.random.split(key, self.batch_size.numel()) state = self._vmap_jit_env_reset(jax.numpy.stack(keys)) - state = self._reshape(state) + state = tree_reshape(state, self.batch_size) return state def _init_env(self) -> Optional[int]: @@ -122,8 +125,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self._key, *keys = jax.random.split(self._key, 1 + self.numel()) state = self._vmap_jit_env_reset(jax.numpy.stack(keys)) - state = self._reshape(state) - state = _object_to_tensordict(state, self.device, self.batch_size) + state = tree_reshape(state, self.batch_size) + state = object_to_tensordict(state, self.device, self.batch_size) tensordict_out = TensorDict( source={ @@ -142,14 +145,14 @@ def _step( tensordict: TensorDictBase, ) -> TensorDictBase: - state = _tensordict_to_object(tensordict.get("state"), self._state_example) - action = tensordict.get("action").numpy() + state = tensordict_to_object(tensordict.get("state"), self._state_example) + action = tensor_to_ndarray(tensordict.get("action")) - state = self._flatten(state) - action = self._flatten(action) + state = tree_flatten(state, self.batch_size) + action = tree_flatten(action, self.batch_size) state = self._vmap_jit_env_step(state, action) - state = self._reshape(state) - state = _object_to_tensordict(state, self.device, self.batch_size) + state = tree_reshape(state, self.batch_size) + state = object_to_tensordict(state, self.device, self.batch_size) tensordict_out = TensorDict( source={ @@ -163,14 +166,6 @@ def _step( ) return tensordict_out - def _reshape(self, x): - shape, n = self.batch_size, 1 - return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) - - def _flatten(self, x): - shape, n = (self.batch_size.numel(),), len(self.batch_size) - return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) - class BraxEnv(BraxWrapper): """Google Brax environment wrapper. diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py new file mode 100644 index 00000000000..a20a70bcc2a --- /dev/null +++ b/torchrl/envs/libs/jax_utils.py @@ -0,0 +1,87 @@ +from typing import Union +import dataclasses + +import numpy as np + +import jax +from jax import numpy as jnp +from jax import dlpack as jax_dlpack + +import torch +from torch.utils import dlpack as torch_dlpack +from tensordict.tensordict import make_tensordict, TensorDictBase + + +def tree_reshape(x, batch_size: torch.Size): + shape, n = batch_size, 1 + return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) + + +def tree_flatten(x, batch_size: torch.Size): + shape, n = (batch_size.numel(),), len(batch_size) + return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) + + +_dtype_conversion = { + np.dtype("uint16"): np.int16, + np.dtype("uint32"): np.int32, + np.dtype("uint64"): np.int64, +} + +def ndarray_to_tensor(value: Union[jnp.ndarray, np.ndarray]) -> torch.Tensor: + # JAX arrays generated by jax.vmap would have Numpy dtypes. + if value.dtype in _dtype_conversion: + value = value.view(_dtype_conversion[value.dtype]) + if isinstance(value, jnp.ndarray): + dlpack_tensor = jax_dlpack.to_dlpack(value) + elif isinstance(value, np.ndarray): + dlpack_tensor = value.__dlpack__() + else: + raise NotImplementedError(f"unsupported data type {type(value)}") + return torch_dlpack.from_dlpack(dlpack_tensor) + + +def tensor_to_ndarray(value: torch.Tensor) -> jnp.ndarray: + return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value)) + + +def object_to_dict(obj) -> dict: + if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple + return dict(zip(obj._fields, obj)) + elif dataclasses.is_dataclass(obj): + return { + field.name: getattr(obj, field.name) for field in dataclasses.fields(obj) + } + elif isinstance(obj, dict): + return obj + else: + raise NotImplementedError(f"unsupported data type {type(obj)}") + + +def object_to_tensordict(obj, device, batch_size) -> TensorDictBase: + """Converts a namedtuple or a dataclass to a TensorDict.""" + t = {} + _dict = object_to_dict(obj) + for name, value in _dict.items(): + if isinstance(value, (np.number, int, float)): + t[name] = ndarray_to_tensor(np.asarray([value])).to(device) + elif isinstance(value, (jnp.ndarray, np.ndarray)): + t[name] = ndarray_to_tensor(value).to(device) + else: + t[name] = object_to_tensordict(value, device, batch_size) + return make_tensordict(**t, device=device, batch_size=batch_size) + + +def tensordict_to_object(tensordict: TensorDictBase, object_example): + """Converts a TensorDict to a namedtuple or a dataclass.""" + t = {} + _dict = object_to_dict(object_example) + for name in tensordict.keys(): + example = _dict[name] + value = tensordict[name] + if isinstance(value, TensorDictBase): + t[name] = tensordict_to_object(value, example) + else: + value = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value)) + t[name] = value.reshape(example.shape).view(example.dtype) + return type(object_example)(**t) \ No newline at end of file diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 77782b675f1..233d1fd8f8d 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -1,9 +1,8 @@ -import dataclasses from typing import Dict, Optional, Union import numpy as np import torch -from tensordict.tensordict import make_tensordict, TensorDict, TensorDictBase +from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data import ( CompositeSpec, @@ -17,6 +16,13 @@ ) from torchrl.data.utils import numpy_to_torch_dtype_dict from torchrl.envs import GymLikeEnv +from torchrl.envs.libs.jax_utils import ( + tree_reshape, + tree_flatten, + ndarray_to_tensor, + object_to_tensordict, + tensordict_to_object, +) try: import jax @@ -29,64 +35,6 @@ IMPORT_ERR = str(err) -def _ndarray_to_tensor(value: Union["jnp.ndarray", np.ndarray], device) -> torch.Tensor: - # tensor doesn't support conversion from jnp.ndarray. - if isinstance(value, jnp.ndarray): - value = np.asarray(value) - # tensor doesn't support unsigned dtypes. - if value.dtype == np.uint16: - value = value.astype(np.int16) - elif value.dtype == np.uint32: - value = value.astype(np.int32) - elif value.dtype == np.uint64: - value = value.astype(np.int64) - # convert to tensor. - return torch.tensor(value).to(device) - - -def _object_to_dict(obj) -> dict: - if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple - return dict(zip(obj._fields, obj)) - elif dataclasses.is_dataclass(obj): - return { - field.name: getattr(obj, field.name) for field in dataclasses.fields(obj) - } - elif isinstance(obj, dict): - return obj - else: - raise NotImplementedError(f"unsupported data type {type(obj)}") - - -def _object_to_tensordict(obj: Union, device, batch_size) -> TensorDictBase: - """Converts a namedtuple or a dataclass to a TensorDict.""" - t = {} - _dict = _object_to_dict(obj) - for name, value in _dict.items(): - if isinstance(value, (np.number, int, float)): - t[name] = _ndarray_to_tensor(np.asarray([value]), device=device) - elif isinstance(value, (jnp.ndarray, np.ndarray)): - t[name] = _ndarray_to_tensor(value, device=device) - else: - t[name] = _object_to_tensordict(value, device, batch_size) - return make_tensordict(**t, device=device, batch_size=batch_size) - - -def _tensordict_to_object(tensordict: TensorDictBase, object_example): - """Converts a TensorDict to a namedtuple or a dataclass.""" - t = {} - _dict = _object_to_dict(object_example) - for name in tensordict.keys(): - example = _dict[name] - value = tensordict[name] - if isinstance(value, TensorDictBase): - t[name] = _tensordict_to_object(value, example) - else: - t[name] = ( - value.detach().numpy().reshape(example.shape).astype(example.dtype) - ) - return type(object_example)(**t) - - def _jumanji_to_torchrl_spec_transform( spec, dtype: Optional[torch.dtype] = None, @@ -205,13 +153,13 @@ def _make_state_example(self, env): key = jax.random.PRNGKey(0) keys = jax.random.split(key, self.batch_size.numel()) state, _ = jax.vmap(env.reset)(jnp.stack(keys)) - state = self._reshape(state) + state = tree_reshape(state, self.batch_size) return state def _make_state_spec(self, env) -> TensorSpec: key = jax.random.PRNGKey(0) state, _ = env.reset(key) - state_dict = _object_to_tensordict(state, self.device, batch_size=()) + state_dict = object_to_tensordict(state, self.device, batch_size=()) state_spec = _torchrl_data_to_spec_transform(state_dict) return state_spec @@ -265,41 +213,40 @@ def _set_seed(self, seed): self.key = jax.random.PRNGKey(seed) def read_state(self, state): - state_dict = _object_to_tensordict(state, self.device, self.batch_size) + state_dict = object_to_tensordict(state, self.device, self.batch_size) return self._state_spec.encode(state_dict) def read_obs(self, obs): if isinstance(obs, (list, jnp.ndarray, np.ndarray)): - obs_dict = _ndarray_to_tensor(obs, self.device) + obs_dict = ndarray_to_tensor(obs).to(self.device) else: - obs_dict = _object_to_tensordict(obs, self.device, self.batch_size) + obs_dict = object_to_tensordict(obs, self.device, self.batch_size) return super().read_obs(obs_dict) def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # prepare inputs - state = _tensordict_to_object(tensordict.get("state"), self._state_example) + state = tensordict_to_object(tensordict.get("state"), self._state_example) action = self.read_action(tensordict.get("action")) reward = self.reward_spec.zero(self.batch_size) # flatten batch size into vector - state = self._flatten(state) - action = self._flatten(action) + state = tree_flatten(state, self.batch_size) + action = tree_flatten(action, self.batch_size) # jax vectorizing map on env.step state, timestep = jax.vmap(self._env.step)(state, action) # reshape batch size from vector - state = self._reshape(state) - timestep = self._reshape(timestep) + state = tree_reshape(state, self.batch_size) + timestep = tree_reshape(timestep, self.batch_size) # collect outputs state_dict = self.read_state(state) obs_dict = self.read_obs(timestep.observation) reward = self.read_reward(reward, np.asarray(timestep.reward)) - done = torch.tensor( - np.asarray(timestep.step_type == self.lib.types.StepType.LAST) - ) + done = timestep.step_type == self.lib.types.StepType.LAST + done = ndarray_to_tensor(done).view(torch.bool).to(self.device) self._is_done = done @@ -326,8 +273,8 @@ def _reset( state, timestep = jax.vmap(self._env.reset)(jnp.stack(keys)) # reshape batch size from vector - state = self._reshape(state) - timestep = self._reshape(timestep) + state = tree_reshape(state, self.batch_size) + timestep = tree_reshape(timestep, self.batch_size) # collect outputs state_dict = self.read_state(state) @@ -347,14 +294,6 @@ def _reset( return tensordict_out - def _reshape(self, x): - shape, n = self.batch_size, 1 - return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) - - def _flatten(self, x): - shape, n = (self.batch_size.numel(),), len(self.batch_size) - return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) - class JumanjiEnv(JumanjiWrapper): """Jumanji environment wrapper. From 80f926c5a197a9edb3db8580d2d52897a06a0085 Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Thu, 8 Dec 2022 12:57:17 +0000 Subject: [PATCH 04/12] add autograd support for step method of Brax envs --- test/test_libs.py | 45 +++++++- torchrl/envs/libs/brax.py | 202 ++++++++++++++++++++++++++++----- torchrl/envs/libs/jax_utils.py | 46 ++++---- torchrl/envs/libs/jumanji.py | 36 +++--- 4 files changed, 254 insertions(+), 75 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 80154257aa1..8a910001939 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -26,7 +26,6 @@ from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv -from torchrl.envs.libs.jax_utils import tree_flatten if _has_gym: import gym @@ -384,6 +383,7 @@ def test_jumanji_consistency(self, envname, batch_size): import jax import jax.numpy as jnp import numpy as onp + from torchrl.envs.libs.jax_utils import _tree_flatten env = JumanjiEnv(envname, batch_size=batch_size) obs_keys = list(env.observation_spec.keys(True)) @@ -400,7 +400,7 @@ def test_jumanji_consistency(self, envname, batch_size): for i in range(rollout.shape[-1]): action = rollout[..., i]["action"] # state = env._flatten(state) - action = tree_flatten(env.read_action(action), env.batch_size) + action = _tree_flatten(env.read_action(action), env.batch_size) state, timestep = jax.vmap(base_env.step)(state, action) # state = env._reshape(state) # timesteps.append(timestep) @@ -428,7 +428,7 @@ def test_jumanji_consistency(self, envname, batch_size): @pytest.mark.skipif(not _has_brax, reason="brax not installed") -@pytest.mark.parametrize("envname", ["ant"]) +@pytest.mark.parametrize("envname", ["fast"]) class TestBrax: def test_brax_seeding(self, envname): final_seed = [] @@ -463,6 +463,45 @@ def test_brax_spec_rollout(self, envname, batch_size): env = BraxEnv(envname, batch_size=batch_size) env.set_seed(0) _test_fake_tensordict(env) + + @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + @pytest.mark.parametrize("requires_grad", [False, True]) + def test_brax_grad(self, envname, batch_size, requires_grad): + import jax + import jax.numpy as jnp + from torchrl.envs.libs.jax_utils import _tree_flatten, _tensor_to_ndarray, _ndarray_to_tensor + + env = BraxEnv(envname, batch_size=batch_size, requires_grad=requires_grad) + env.set_seed(1) + rollout = env.rollout(10) + + env.set_seed(1) + key = env._key + base_env = env._env + key, *keys = jax.random.split(key, np.prod(batch_size) + 1) + state = jax.vmap(base_env.reset)(jnp.stack(keys)) + for i in range(rollout.shape[-1]): + action = rollout[..., i]["action"] + action = _tensor_to_ndarray(action.clone()) + action = _tree_flatten(action, env.batch_size) + state = jax.vmap(base_env.step)(state, action) + t1 = rollout[..., i][("next", "observation")] + t2 = _ndarray_to_tensor(state.obs) + torch.testing.assert_close(t1, t2) + + @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + def test_brax_grad(self, envname, batch_size): + batch_size = (1,) + env = BraxEnv(envname, batch_size=batch_size, requires_grad=True) + env.set_seed(0) + td1 = env.reset() + action = torch.randn(batch_size + env.action_spec.shape) + action.requires_grad_(True) + td1["action"] = action + td2 = env.step(td1) + td2["reward"].mean().backward() + env.close() + del env if __name__ == "__main__": diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 54218747c7d..7f8f06e8c55 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -1,27 +1,28 @@ from typing import Dict, Optional, Union +import torch from tensordict.tensordict import TensorDict, TensorDictBase + from torchrl.data import ( CompositeSpec, NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, ) from torchrl.envs.common import _EnvWrapper -from torchrl.envs.libs.jumanji import _torchrl_data_to_spec_transform from torchrl.envs.libs.jax_utils import ( - tree_flatten, - tree_reshape, - tensor_to_ndarray, - object_to_tensordict, - tensordict_to_object, + _ndarray_to_tensor, + _object_to_tensordict, + _tensor_to_ndarray, + _tensordict_to_object, + _tree_flatten, + _tree_reshape, ) +from torchrl.envs.libs.jumanji import _torchrl_data_to_spec_transform try: import brax import brax.envs - import brax.io.torch import jax - from brax import jumpy as jp _has_brax = True except ImportError as err: @@ -68,11 +69,13 @@ def _build_env( from_pixels: bool = False, render_kwargs: Optional[dict] = None, pixels_only: bool = False, + requires_grad: bool = False, camera_id: Union[int, str] = 0, **kwargs, ): self.from_pixels = from_pixels self.pixels_only = pixels_only + self.requires_grad = requires_grad if from_pixels: raise NotImplementedError("TODO") @@ -81,7 +84,7 @@ def _build_env( def _make_state_spec(self, env: "brax.envs.env.Env"): key = jax.random.PRNGKey(0) state = env.reset(key) - state_dict = object_to_tensordict(state, self.device, batch_size=()) + state_dict = _object_to_tensordict(state, self.device, batch_size=()) state_spec = _torchrl_data_to_spec_transform(state_dict) return state_spec @@ -107,13 +110,13 @@ def _make_state_example(self): key = jax.random.PRNGKey(0) keys = jax.random.split(key, self.batch_size.numel()) state = self._vmap_jit_env_reset(jax.numpy.stack(keys)) - state = tree_reshape(state, self.batch_size) + state = _tree_reshape(state, self.batch_size) return state def _init_env(self) -> Optional[int]: self._key = None - self._vmap_jit_env_reset = jp.vmap(jax.jit(self._env.reset)) - self._vmap_jit_env_step = jp.vmap(jax.jit(self._env.step)) + self._vmap_jit_env_reset = jax.vmap(jax.jit(self._env.reset)) + self._vmap_jit_env_step = jax.vmap(jax.jit(self._env.step)) self._state_example = self._make_state_example() def _set_seed(self, seed: int): @@ -123,11 +126,17 @@ def _set_seed(self, seed: int): def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + # generate random keys self._key, *keys = jax.random.split(self._key, 1 + self.numel()) + + # call env reset with jit and vmap state = self._vmap_jit_env_reset(jax.numpy.stack(keys)) - state = tree_reshape(state, self.batch_size) - state = object_to_tensordict(state, self.device, self.batch_size) + # reshape batch size + state = _tree_reshape(state, self.batch_size) + state = _object_to_tensordict(state, self.device, self.batch_size) + + # build result tensordict_out = TensorDict( source={ "observation": state.get("obs"), @@ -140,32 +149,82 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: ) return tensordict_out - def _step( - self, - tensordict: TensorDictBase, - ) -> TensorDictBase: + def _step_without_grad(self, tensordict: TensorDictBase): + + # convert tensors to ndarrays + state = _tensordict_to_object(tensordict.get("state"), self._state_example) + action = _tensor_to_ndarray(tensordict.get("action")) - state = tensordict_to_object(tensordict.get("state"), self._state_example) - action = tensor_to_ndarray(tensordict.get("action")) + # flatten batch size + state = _tree_flatten(state, self.batch_size) + action = _tree_flatten(action, self.batch_size) - state = tree_flatten(state, self.batch_size) - action = tree_flatten(action, self.batch_size) - state = self._vmap_jit_env_step(state, action) - state = tree_reshape(state, self.batch_size) - state = object_to_tensordict(state, self.device, self.batch_size) + # call env step with jit and vmap + next_state = self._vmap_jit_env_step(state, action) + # reshape batch size and convert ndarrays to tensors + next_state = _tree_reshape(next_state, self.batch_size) + next_state = _object_to_tensordict(next_state, self.device, self.batch_size) + + # build result tensordict_out = TensorDict( source={ - "observation": state.get("obs"), - "reward": state.get("reward"), - "done": state.get("done").bool(), - "state": state, + "observation": next_state.get("obs"), + "reward": next_state.get("reward"), + "done": next_state.get("done").bool(), + "state": next_state, + }, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict_out + + def _step_with_grad(self, tensordict: TensorDictBase): + + # convert tensors to ndarrays + action = tensordict.get("action") + state = tensordict.get("state") + qp_keys = list(state.get("qp").keys()) + qp_values = list(state.get("qp").values()) + + # call env step with autograd function + next_state_nograd, next_obs, next_reward, *next_qp_values = _BraxEnvStep.apply( + self, state, action, *qp_values + ) + + # extract done values + next_done = next_state_nograd["done"].bool() + self._is_done = next_done + + # merge with tensors with grad function + next_state = next_state_nograd + next_state["obs"] = next_obs + next_state["reward"] = next_reward + next_state["qp"].update(dict(zip(qp_keys, next_qp_values))) + + # build result + tensordict_out = TensorDict( + source={ + "observation": next_obs, + "reward": next_reward, + "done": next_done, + "state": next_state, }, batch_size=self.batch_size, device=self.device, ) return tensordict_out + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + + if self.requires_grad: + return self._step_with_grad(tensordict) + else: + return self._step_without_grad(tensordict) + class BraxEnv(BraxWrapper): """Google Brax environment wrapper. @@ -195,10 +254,16 @@ def _build_env( ) from_pixels = kwargs.pop("from_pixels", False) pixels_only = kwargs.pop("pixels_only", True) + requires_grad = kwargs.pop("requires_grad", False) assert not kwargs self.wrapper_frame_skip = 1 env = self.lib.envs.get_environment(env_name, **kwargs) - return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + return super()._build_env( + env, + pixels_only=pixels_only, + from_pixels=from_pixels, + requires_grad=requires_grad, + ) @property def env_name(self): @@ -210,3 +275,80 @@ def _check_kwargs(self, kwargs: Dict): def __repr__(self) -> str: return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})" + + +class _BraxEnvStep(torch.autograd.Function): + @staticmethod + def forward(ctx, env: BraxWrapper, state, action, *qp_values): + + # convert tensors to ndarrays + state = _tensordict_to_object(state, env._state_example) + action = _tensor_to_ndarray(action) + + # flatten batch size + state = _tree_flatten(state, env.batch_size) + action = _tree_flatten(action, env.batch_size) + + # call vjp with jit and vmap + next_state, vjp_fn = jax.vjp(env._vmap_jit_env_step, state, action) + + # reshape batch size + next_state = _tree_reshape(next_state, env.batch_size) + + # convert ndarrays to tensors + next_state = _object_to_tensordict( + next_state, device=env.device, batch_size=env.batch_size + ) + + # save context + ctx.vjp_fn = vjp_fn + ctx.next_state = next_state + ctx.env = env + + return ( + next_state, # no gradient + next_state["obs"], + next_state["reward"], + *next_state["qp"].values(), + ) + + @staticmethod + def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values): + # build gradient tensordict with zeros in fields with no grad + grad_next_state = TensorDict( + source={ + "qp": dict(zip(ctx.next_state["qp"].keys(), grad_next_qp_values)), + "obs": grad_next_obs, + "reward": grad_next_reward, + "done": torch.zeros_like(ctx.next_state["done"]), + "metrics": { + k: torch.zeros_like(v) for k, v in ctx.next_state["metrics"].items() + }, + "info": { + k: torch.zeros_like(v) for k, v in ctx.next_state["info"].items() + }, + }, + device=ctx.env.device, + batch_size=ctx.env.batch_size, + ) + + # convert tensors to ndarrays + grad_next_state = _tensordict_to_object(grad_next_state, ctx.env._state_example) + + # flatten batch size + grad_next_state = _tree_flatten(grad_next_state, ctx.env.batch_size) + + # call vjp to get gradients + grad_state, grad_action = ctx.vjp_fn(grad_next_state) + + # reshape batch size + grad_state = _tree_reshape(grad_state, ctx.env.batch_size) + grad_action = _tree_reshape(grad_action, ctx.env.batch_size) + + # convert ndarrays to tensors + grad_state_qp = _object_to_tensordict( + grad_state.qp, device=ctx.env.device, batch_size=ctx.env.batch_size + ) + grad_action = _ndarray_to_tensor(grad_action) + + return None, None, grad_action, *grad_state_qp.values() diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index a20a70bcc2a..f3b2d1a0fe8 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -1,23 +1,20 @@ -from typing import Union import dataclasses - -import numpy as np +from typing import Union import jax -from jax import numpy as jnp -from jax import dlpack as jax_dlpack - +import numpy as np import torch -from torch.utils import dlpack as torch_dlpack +from jax import dlpack as jax_dlpack, numpy as jnp from tensordict.tensordict import make_tensordict, TensorDictBase +from torch.utils import dlpack as torch_dlpack -def tree_reshape(x, batch_size: torch.Size): +def _tree_reshape(x, batch_size: torch.Size): shape, n = batch_size, 1 return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) -def tree_flatten(x, batch_size: torch.Size): +def _tree_flatten(x, batch_size: torch.Size): shape, n = (batch_size.numel(),), len(batch_size) return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) @@ -28,7 +25,8 @@ def tree_flatten(x, batch_size: torch.Size): np.dtype("uint64"): np.int64, } -def ndarray_to_tensor(value: Union[jnp.ndarray, np.ndarray]) -> torch.Tensor: + +def _ndarray_to_tensor(value: Union[jnp.ndarray, np.ndarray]) -> torch.Tensor: # JAX arrays generated by jax.vmap would have Numpy dtypes. if value.dtype in _dtype_conversion: value = value.view(_dtype_conversion[value.dtype]) @@ -41,11 +39,11 @@ def ndarray_to_tensor(value: Union[jnp.ndarray, np.ndarray]) -> torch.Tensor: return torch_dlpack.from_dlpack(dlpack_tensor) -def tensor_to_ndarray(value: torch.Tensor) -> jnp.ndarray: +def _tensor_to_ndarray(value: torch.Tensor) -> jnp.ndarray: return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value)) - -def object_to_dict(obj) -> dict: + +def _get_object_fields(obj) -> dict: if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple return dict(zip(obj._fields, obj)) elif dataclasses.is_dataclass(obj): @@ -58,30 +56,30 @@ def object_to_dict(obj) -> dict: raise NotImplementedError(f"unsupported data type {type(obj)}") -def object_to_tensordict(obj, device, batch_size) -> TensorDictBase: +def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase: """Converts a namedtuple or a dataclass to a TensorDict.""" t = {} - _dict = object_to_dict(obj) - for name, value in _dict.items(): + _fields = _get_object_fields(obj) + for name, value in _fields.items(): if isinstance(value, (np.number, int, float)): - t[name] = ndarray_to_tensor(np.asarray([value])).to(device) + t[name] = _ndarray_to_tensor(np.asarray([value])).to(device) elif isinstance(value, (jnp.ndarray, np.ndarray)): - t[name] = ndarray_to_tensor(value).to(device) + t[name] = _ndarray_to_tensor(value).to(device) else: - t[name] = object_to_tensordict(value, device, batch_size) + t[name] = _object_to_tensordict(value, device, batch_size) return make_tensordict(**t, device=device, batch_size=batch_size) -def tensordict_to_object(tensordict: TensorDictBase, object_example): +def _tensordict_to_object(tensordict: TensorDictBase, object_example): """Converts a TensorDict to a namedtuple or a dataclass.""" t = {} - _dict = object_to_dict(object_example) + _fields = _get_object_fields(object_example) for name in tensordict.keys(): - example = _dict[name] + example = _fields[name] value = tensordict[name] if isinstance(value, TensorDictBase): - t[name] = tensordict_to_object(value, example) + t[name] = _tensordict_to_object(value, example) else: value = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value)) t[name] = value.reshape(example.shape).view(example.dtype) - return type(object_example)(**t) \ No newline at end of file + return type(object_example)(**t) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 233d1fd8f8d..aeed31e689a 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -17,11 +17,11 @@ from torchrl.data.utils import numpy_to_torch_dtype_dict from torchrl.envs import GymLikeEnv from torchrl.envs.libs.jax_utils import ( - tree_reshape, - tree_flatten, - ndarray_to_tensor, - object_to_tensordict, - tensordict_to_object, + _ndarray_to_tensor, + _object_to_tensordict, + _tensordict_to_object, + _tree_flatten, + _tree_reshape, ) try: @@ -153,13 +153,13 @@ def _make_state_example(self, env): key = jax.random.PRNGKey(0) keys = jax.random.split(key, self.batch_size.numel()) state, _ = jax.vmap(env.reset)(jnp.stack(keys)) - state = tree_reshape(state, self.batch_size) + state = _tree_reshape(state, self.batch_size) return state def _make_state_spec(self, env) -> TensorSpec: key = jax.random.PRNGKey(0) state, _ = env.reset(key) - state_dict = object_to_tensordict(state, self.device, batch_size=()) + state_dict = _object_to_tensordict(state, self.device, batch_size=()) state_spec = _torchrl_data_to_spec_transform(state_dict) return state_spec @@ -213,40 +213,40 @@ def _set_seed(self, seed): self.key = jax.random.PRNGKey(seed) def read_state(self, state): - state_dict = object_to_tensordict(state, self.device, self.batch_size) + state_dict = _object_to_tensordict(state, self.device, self.batch_size) return self._state_spec.encode(state_dict) def read_obs(self, obs): if isinstance(obs, (list, jnp.ndarray, np.ndarray)): - obs_dict = ndarray_to_tensor(obs).to(self.device) + obs_dict = _ndarray_to_tensor(obs).to(self.device) else: - obs_dict = object_to_tensordict(obs, self.device, self.batch_size) + obs_dict = _object_to_tensordict(obs, self.device, self.batch_size) return super().read_obs(obs_dict) def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # prepare inputs - state = tensordict_to_object(tensordict.get("state"), self._state_example) + state = _tensordict_to_object(tensordict.get("state"), self._state_example) action = self.read_action(tensordict.get("action")) reward = self.reward_spec.zero(self.batch_size) # flatten batch size into vector - state = tree_flatten(state, self.batch_size) - action = tree_flatten(action, self.batch_size) + state = _tree_flatten(state, self.batch_size) + action = _tree_flatten(action, self.batch_size) # jax vectorizing map on env.step state, timestep = jax.vmap(self._env.step)(state, action) # reshape batch size from vector - state = tree_reshape(state, self.batch_size) - timestep = tree_reshape(timestep, self.batch_size) + state = _tree_reshape(state, self.batch_size) + timestep = _tree_reshape(timestep, self.batch_size) # collect outputs state_dict = self.read_state(state) obs_dict = self.read_obs(timestep.observation) reward = self.read_reward(reward, np.asarray(timestep.reward)) done = timestep.step_type == self.lib.types.StepType.LAST - done = ndarray_to_tensor(done).view(torch.bool).to(self.device) + done = _ndarray_to_tensor(done).view(torch.bool).to(self.device) self._is_done = done @@ -273,8 +273,8 @@ def _reset( state, timestep = jax.vmap(self._env.reset)(jnp.stack(keys)) # reshape batch size from vector - state = tree_reshape(state, self.batch_size) - timestep = tree_reshape(timestep, self.batch_size) + state = _tree_reshape(state, self.batch_size) + timestep = _tree_reshape(timestep, self.batch_size) # collect outputs state_dict = self.read_state(state) From 62e7b11034023fe032efcedc64ed6e0ed8e07837 Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Thu, 8 Dec 2022 13:00:18 +0000 Subject: [PATCH 05/12] add autograd support for step method of Brax envs --- test/test_libs.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 8a910001939..4857d60960e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -463,13 +463,17 @@ def test_brax_spec_rollout(self, envname, batch_size): env = BraxEnv(envname, batch_size=batch_size) env.set_seed(0) _test_fake_tensordict(env) - + @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) @pytest.mark.parametrize("requires_grad", [False, True]) - def test_brax_grad(self, envname, batch_size, requires_grad): + def test_brax_consistency(self, envname, batch_size, requires_grad): import jax import jax.numpy as jnp - from torchrl.envs.libs.jax_utils import _tree_flatten, _tensor_to_ndarray, _ndarray_to_tensor + from torchrl.envs.libs.jax_utils import ( + _ndarray_to_tensor, + _tensor_to_ndarray, + _tree_flatten, + ) env = BraxEnv(envname, batch_size=batch_size, requires_grad=requires_grad) env.set_seed(1) From 27a12e33f8cd425844058d3ace519c7c64fe0f3e Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Thu, 8 Dec 2022 13:23:25 +0000 Subject: [PATCH 06/12] add autograd support for step method of Brax envs --- torchrl/envs/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ac86cb6d2da..324259e32f5 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -182,7 +182,7 @@ def check_env_specs(env): assert fake_tensordict[key].shape == real_tensordict[key].shape # test dtypes - real_tensordict = env.rollout(3) # keep empty structures, i.e. dict() + real_tensordict = env.rollout(3) # keep empty structures, for example dict() for key, value in real_tensordict.items(): _check_dtype(key, value, env.observation_spec, env.input_spec) From 6efa28e519a725b17139b40f0fb4ce62108bb749 Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Thu, 8 Dec 2022 13:26:31 +0000 Subject: [PATCH 07/12] fix linter bugs --- test/test_libs.py | 1 - torchrl/envs/utils.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 5539e53d054..c16774de4c2 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -473,7 +473,6 @@ def test_brax_consistency(self, envname, batch_size, requires_grad): _ndarray_to_tensor, _tensor_to_ndarray, _tree_flatten, - _tree_reshape, ) env = BraxEnv(envname, batch_size=batch_size, requires_grad=requires_grad) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 324259e32f5..ad5f4dadabc 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -182,7 +182,7 @@ def check_env_specs(env): assert fake_tensordict[key].shape == real_tensordict[key].shape # test dtypes - real_tensordict = env.rollout(3) # keep empty structures, for example dict() + real_tensordict = env.rollout(3) # keep empty structures, for example dict() for key, value in real_tensordict.items(): _check_dtype(key, value, env.observation_spec, env.input_spec) From c55e9f3d28bb5cdb9f2a6952c97cd5680ace0f23 Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Thu, 8 Dec 2022 13:35:16 +0000 Subject: [PATCH 08/12] import jax only when brax is installed --- torchrl/envs/libs/brax.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 7f8f06e8c55..35db4e8a246 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -9,20 +9,20 @@ NdUnboundedContinuousTensorSpec, ) from torchrl.envs.common import _EnvWrapper -from torchrl.envs.libs.jax_utils import ( - _ndarray_to_tensor, - _object_to_tensordict, - _tensor_to_ndarray, - _tensordict_to_object, - _tree_flatten, - _tree_reshape, -) from torchrl.envs.libs.jumanji import _torchrl_data_to_spec_transform try: import brax import brax.envs import jax + from torchrl.envs.libs.jax_utils import ( + _ndarray_to_tensor, + _object_to_tensordict, + _tensor_to_ndarray, + _tensordict_to_object, + _tree_flatten, + _tree_reshape, + ) _has_brax = True except ImportError as err: From 121a74c9b416589c95921db64f18282eed8d2bc5 Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Thu, 8 Dec 2022 13:43:36 +0000 Subject: [PATCH 09/12] import jax only when brax is installed --- torchrl/envs/libs/jumanji.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index aeed31e689a..3fadabc877f 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -16,18 +16,18 @@ ) from torchrl.data.utils import numpy_to_torch_dtype_dict from torchrl.envs import GymLikeEnv -from torchrl.envs.libs.jax_utils import ( - _ndarray_to_tensor, - _object_to_tensordict, - _tensordict_to_object, - _tree_flatten, - _tree_reshape, -) try: import jax import jumanji from jax import numpy as jnp + from torchrl.envs.libs.jax_utils import ( + _ndarray_to_tensor, + _object_to_tensordict, + _tensordict_to_object, + _tree_flatten, + _tree_reshape, + ) _has_jumanji = True except ImportError as err: From d425db47619a92905309c512210efbfd87d02785 Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Thu, 8 Dec 2022 14:01:40 +0000 Subject: [PATCH 10/12] compatible with 3.7 --- torchrl/envs/libs/brax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 35db4e8a246..cffdc5124e1 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -351,4 +351,4 @@ def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values): ) grad_action = _ndarray_to_tensor(grad_action) - return None, None, grad_action, *grad_state_qp.values() + return (None, None, grad_action, *grad_state_qp.values()) From 2cb9ddbb0da9d01972ce67ddae0d40dbad4aeb4a Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Fri, 9 Dec 2022 10:40:54 +0000 Subject: [PATCH 11/12] add test pipeline --- .circleci/config.yml | 49 +++ .../linux_libs/scripts_brax/environment.yml | 18 + .../linux_libs/scripts_brax/install.sh | 48 +++ .../linux_libs/scripts_brax/post_process.sh | 6 + .../scripts_brax/run-clang-format.py | 356 ++++++++++++++++++ .../linux_libs/scripts_brax/run_test.sh | 32 ++ .../linux_libs/scripts_brax/setup_env.sh | 62 +++ 7 files changed, 571 insertions(+) create mode 100644 .circleci/unittest/linux_libs/scripts_brax/environment.yml create mode 100755 .circleci/unittest/linux_libs/scripts_brax/install.sh create mode 100755 .circleci/unittest/linux_libs/scripts_brax/post_process.sh create mode 100755 .circleci/unittest/linux_libs/scripts_brax/run-clang-format.py create mode 100755 .circleci/unittest/linux_libs/scripts_brax/run_test.sh create mode 100755 .circleci/unittest/linux_libs/scripts_brax/setup_env.sh diff --git a/.circleci/config.yml b/.circleci/config.yml index ed602992ee4..ee24520dc2b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -498,6 +498,51 @@ jobs: - store_test_results: path: test-results + unittest_linux_brax_gpu: + <<: *binary_common + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.medium + environment: + image_name: "pytorch/manylinux-cuda113" + TAR_OPTIONS: --no-same-owner + PYTHON_VERSION: << parameters.python_version >> + CU_VERSION: << parameters.cu_version >> + + steps: + - checkout + - designate_upload_channel + - run: + name: Generate cache key + # This will refresh cache on Sundays, nightly build should generate new cache. + command: echo "$(date +"%Y-%U")" > .circleci-weekly + - restore_cache: + keys: + - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_brax/environment.yml" }}-{{ checksum ".circleci-weekly" }} + - run: + name: Setup + command: .circleci/unittest/linux_libs/scripts_brax/setup_env.sh + - save_cache: + key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_brax/environment.yml" }}-{{ checksum ".circleci-weekly" }} + paths: + - conda + - env + - run: + name: Install torchrl + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux_libs/scripts_brax/install.sh + - run: + name: Run tests + command: bash .circleci/unittest/linux_libs/scripts_brax/run_test.sh + - run: + name: Codecov upload + command: | + bash <(curl -s https://codecov.io/bash) -Z -F linux-brax + - run: + name: Post Process + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_libs/scripts_brax/post_process.sh + - store_test_results: + path: test-results + unittest_linux_gym_gpu: <<: *binary_common machine: @@ -928,6 +973,10 @@ workflows: cu_version: cu113 name: unittest_linux_jumanji_gpu_py3.8 python_version: '3.8' + - unittest_linux_brax_gpu: + cu_version: cu113 + name: unittest_linux_brax_gpu_py3.8 + python_version: '3.8' - unittest_linux_gym_gpu: cu_version: cu113 name: unittest_linux_gym_gpu_py3.8 diff --git a/.circleci/unittest/linux_libs/scripts_brax/environment.yml b/.circleci/unittest/linux_libs/scripts_brax/environment.yml new file mode 100644 index 00000000000..1c213df7227 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_brax/environment.yml @@ -0,0 +1,18 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - expecttest + - pyyaml + - scipy + - hydra-core + - brax diff --git a/.circleci/unittest/linux_libs/scripts_brax/install.sh b/.circleci/unittest/linux_libs/scripts_brax/install.sh new file mode 100755 index 00000000000..767070f2b25 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_brax/install.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch-labs/tensordict + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +pip3 install -e . + +# smoke test +python -c "import torchrl" diff --git a/.circleci/unittest/linux_libs/scripts_brax/post_process.sh b/.circleci/unittest/linux_libs/scripts_brax/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_brax/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.circleci/unittest/linux_libs/scripts_brax/run-clang-format.py b/.circleci/unittest/linux_libs/scripts_brax/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_brax/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.circleci/unittest/linux_libs/scripts_brax/run_test.sh b/.circleci/unittest/linux_libs/scripts_brax/run_test.sh new file mode 100755 index 00000000000..02a3279f96f --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_brax/run_test.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env +apt-get update && apt-get install -y git wget + + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +export MKL_THREADING_LAYER=GNU +# more logging +export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON + +#wget https://github.com/openai/mujoco-py/blob/master/vendor/10_nvidia.json +#mv 10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json + +# this workflow only tests the libs +python -c "import brax" + +coverage run -m pytest test/test_libs.py --instafail -v --durations 20 --capture no -k TestBrax +coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_brax/setup_env.sh b/.circleci/unittest/linux_libs/scripts_brax/setup_env.sh new file mode 100755 index 00000000000..705bd9a3814 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_brax/setup_env.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +## 3. Install mujoco +#printf "* Installing mujoco and related\n" +#mkdir $root_dir/.mujoco +#cd $root_dir/.mujoco/ +#wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz +#tar -xf mujoco-2.1.1-linux-x86_64.tar.gz +#wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz +#tar -xf mujoco210-linux-x86_64.tar.gz +#cd $this_dir + +# 4. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune + +#yum makecache +#yum -y install glfw-devel +#yum -y install libGLEW +#yum -y install gcc-c++ From d04cc382acbdb910ee445434873ee9637129cbfd Mon Sep 17 00:00:00 2001 From: Ying-Chen Lin Date: Fri, 9 Dec 2022 16:36:01 +0000 Subject: [PATCH 12/12] update according to pr feedbacks --- torchrl/envs/libs/brax.py | 41 ++++++++++++++++++++--- torchrl/envs/libs/jax_utils.py | 24 ++++++++++++++ torchrl/envs/libs/jumanji.py | 59 +++++++++++++++++++--------------- 3 files changed, 94 insertions(+), 30 deletions(-) diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index cffdc5124e1..e0e271b5ede 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -9,13 +9,13 @@ NdUnboundedContinuousTensorSpec, ) from torchrl.envs.common import _EnvWrapper -from torchrl.envs.libs.jumanji import _torchrl_data_to_spec_transform try: import brax import brax.envs import jax from torchrl.envs.libs.jax_utils import ( + _extract_spec, _ndarray_to_tensor, _object_to_tensordict, _tensor_to_ndarray, @@ -25,24 +25,52 @@ ) _has_brax = True + IMPORT_ERR = "" except ImportError as err: _has_brax = False IMPORT_ERR = str(err) +def _get_envs(): + if not _has_brax: + return [] + return list(brax.envs._envs.keys()) + + class BraxWrapper(_EnvWrapper): """Google Brax environment wrapper. Examples: >>> env = brax.envs.get_environment("ant") >>> env = BraxWrapper(env) - >>> td = env.rand_step() + >>> env.set_seed(0) + >>> td = env.reset() + >>> td["action"] = env.action_spec.rand() + >>> td = env.step(td) >>> print(td) + TensorDict( + fields={ + action: Tensor(torch.Size([8]), dtype=torch.float32), + done: Tensor(torch.Size([1]), dtype=torch.bool), + next: TensorDict( + fields={ + observation: Tensor(torch.Size([87]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False), + observation: Tensor(torch.Size([87]), dtype=torch.float32), + reward: Tensor(torch.Size([1]), dtype=torch.float32), + state: TensorDict(...)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) >>> print(env.available_envs) - + ['acrobot', 'ant', 'fast', 'fetch', ...] """ git_url = "https://github.com/google/brax" + available_envs = _get_envs() + libname = "brax" @property def lib(self): @@ -85,7 +113,7 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): key = jax.random.PRNGKey(0) state = env.reset(key) state_dict = _object_to_tensordict(state, self.device, batch_size=()) - state_spec = _torchrl_data_to_spec_transform(state_dict) + state_spec = _extract_spec(state_dict) return state_spec def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 @@ -146,6 +174,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: }, batch_size=self.batch_size, device=self.device, + _run_checks=False, ) return tensordict_out @@ -176,6 +205,7 @@ def _step_without_grad(self, tensordict: TensorDictBase): }, batch_size=self.batch_size, device=self.device, + _run_checks=False, ) return tensordict_out @@ -212,6 +242,7 @@ def _step_with_grad(self, tensordict: TensorDictBase): }, batch_size=self.batch_size, device=self.device, + _run_checks=False, ) return tensordict_out @@ -314,6 +345,7 @@ def forward(ctx, env: BraxWrapper, state, action, *qp_values): @staticmethod def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values): + # build gradient tensordict with zeros in fields with no grad grad_next_state = TensorDict( source={ @@ -330,6 +362,7 @@ def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values): }, device=ctx.env.device, batch_size=ctx.env.batch_size, + _run_checks=False, ) # convert tensors to ndarrays diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index f3b2d1a0fe8..93bd9325300 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -7,6 +7,12 @@ from jax import dlpack as jax_dlpack, numpy as jnp from tensordict.tensordict import make_tensordict, TensorDictBase from torch.utils import dlpack as torch_dlpack +from torchrl.data import ( + CompositeSpec, + NdUnboundedContinuousTensorSpec, + NdUnboundedDiscreteTensorSpec, + TensorSpec, +) def _tree_reshape(x, batch_size: torch.Size): @@ -83,3 +89,21 @@ def _tensordict_to_object(tensordict: TensorDictBase, object_example): value = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value)) t[name] = value.reshape(example.shape).view(example.dtype) return type(object_example)(**t) + + +def _extract_spec(data: Union[torch.Tensor, TensorDictBase]) -> TensorSpec: + if isinstance(data, torch.Tensor): + if data.dtype in (torch.float, torch.double, torch.half): + return NdUnboundedContinuousTensorSpec( + shape=data.shape, dtype=data.dtype, device=data.device + ) + else: + return NdUnboundedDiscreteTensorSpec( + shape=data.shape, dtype=data.dtype, device=data.device + ) + elif isinstance(data, TensorDictBase): + return CompositeSpec( + **{key: _extract_spec(value) for key, value in data.items()} + ) + else: + raise TypeError(f"Unsupported data type {type(data)}") diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 3fadabc877f..e430acba5bc 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -22,6 +22,7 @@ import jumanji from jax import numpy as jnp from torchrl.envs.libs.jax_utils import ( + _extract_spec, _ndarray_to_tensor, _object_to_tensordict, _tensordict_to_object, @@ -30,11 +31,18 @@ ) _has_jumanji = True + IMPORT_ERR = "" except ImportError as err: _has_jumanji = False IMPORT_ERR = str(err) +def _get_envs(): + if not _has_jumanji: + return [] + return jumanji.registered_environments() + + def _jumanji_to_torchrl_spec_transform( spec, dtype: Optional[torch.dtype] = None, @@ -87,41 +95,40 @@ def _jumanji_to_torchrl_spec_transform( raise TypeError(f"Unsupported spec type {type(spec)}") -def _torchrl_data_to_spec_transform(data) -> TensorSpec: - if isinstance(data, torch.Tensor): - if data.dtype in (torch.float, torch.double, torch.half): - return NdUnboundedContinuousTensorSpec( - shape=data.shape, dtype=data.dtype, device=data.device - ) - else: - return NdUnboundedDiscreteTensorSpec( - shape=data.shape, dtype=data.dtype, device=data.device - ) - elif isinstance(data, TensorDict): - return CompositeSpec( - **{ - key: _torchrl_data_to_spec_transform(value) - for key, value in data.items() - } - ) - else: - raise TypeError(f"Unsupported data type {type(data)}") - - class JumanjiWrapper(GymLikeEnv): """Jumanji environment wrapper. Examples: - >>> env = jumanju.make("Snake-6x6-v0") + >>> env = jumanji.make("Snake-6x6-v0") >>> env = JumanjiWrapper(env) - >>> td0 = env.reset() - >>> print(td0) - >>> td1 = env.rand_step(td0) + >>> env.set_seed(0) + >>> td = env.reset() + >>> td["action"] = env.action_spec.rand() + >>> td = env.step(td) >>> print(td1) + TensorDict( + fields={ + action: Tensor(torch.Size([1]), dtype=torch.int32), + done: Tensor(torch.Size([1]), dtype=torch.bool), + next: TensorDict( + fields={ + observation: Tensor(torch.Size([6, 6, 5]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False), + observation: Tensor(torch.Size([6, 6, 5]), dtype=torch.float32), + reward: Tensor(torch.Size([1]), dtype=torch.float32), + state: TensorDict(...)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) >>> print(env.available_envs) + ['Snake-6x6-v0', 'Snake-12x12-v0', 'TSP50-v0', 'TSP100-v0', ...] """ git_url = "https://github.com/instadeepai/jumanji" + available_envs = _get_envs() + libname = "jumanji" @property def lib(self): @@ -160,7 +167,7 @@ def _make_state_spec(self, env) -> TensorSpec: key = jax.random.PRNGKey(0) state, _ = env.reset(key) state_dict = _object_to_tensordict(state, self.device, batch_size=()) - state_spec = _torchrl_data_to_spec_transform(state_dict) + state_spec = _extract_spec(state_dict) return state_spec def _make_input_spec(self, env) -> TensorSpec: