From 93fd6b4573d637ed1de62ff2c9472a7f41c60d76 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 11 Jan 2023 17:27:46 +0000 Subject: [PATCH] init --- torchrl/collectors/collectors.py | 2 +- torchrl/envs/libs/brax.py | 2 +- torchrl/envs/libs/jumanji.py | 8 ++++++-- torchrl/envs/libs/vmas.py | 4 ++-- torchrl/envs/vec_env.py | 8 ++------ torchrl/modules/tensordict_module/common.py | 8 ++++---- torchrl/objectives/deprecated.py | 4 +--- torchrl/objectives/redq.py | 4 +--- torchrl/objectives/sac.py | 6 ++---- 9 files changed, 20 insertions(+), 26 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index e7742b70c98..193d60dba76 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1038,7 +1038,7 @@ def _run_processes(self) -> None: `env = ParallelEnv(N, EnvCreator(my_lambda_function))`. This will not only ensure that your lambda function is cloud-pickled once, but also that the state dict is synchronised across processes if needed.""" - ) + ) from err pipe_child.close() self.procs.append(proc) self.pipes.append(pipe_parent) diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 529742a328a..4d1852a704c 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -301,7 +301,7 @@ def _build_env( f"brax not found, unable to create {env_name}. " f"Consider downloading and installing brax from" f" {self.git_url}" - ) + ) from IMPORT_ERR from_pixels = kwargs.pop("from_pixels", False) pixels_only = kwargs.pop("pixels_only", True) requires_grad = kwargs.pop("requires_grad", False) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 99e99b76326..a4ae6200e94 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -140,6 +140,10 @@ def lib(self): return jumanji def __init__(self, env: "jumanji.env.Environment" = None, **kwargs): + if not _has_jumanji: + raise ImportError( + "jumanji is not installed or importing it failed. Consider checking your installation." + ) from IMPORT_ERR if env is not None: kwargs["env"] = env super().__init__(**kwargs) @@ -331,8 +335,8 @@ def _build_env( raise RuntimeError( f"jumanji not found, unable to create {env_name}. " f"Consider installing jumanji. More info:" - f" {self.git_url}. (Original error message during import: {IMPORT_ERR})." - ) + f" {self.git_url}." + ) from IMPORT_ERR from_pixels = kwargs.pop("from_pixels", False) pixels_only = kwargs.pop("pixels_only", True) assert not kwargs diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 42eba338509..e3880504660 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -380,8 +380,8 @@ def __init__( if not _has_vmas: raise ImportError( f"vmas python package was not found. Please install this dependency. " - f"More info: {self.git_url} (ImportError: {IMPORT_ERR})" - ) + f"More info: {self.git_url}." + ) from IMPORT_ERR kwargs["scenario_name"] = scenario_name kwargs["num_envs"] = num_envs kwargs["continuous_actions"] = continuous_actions diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index b5e26035453..c21733f484f 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -982,9 +982,7 @@ def _run_worker_pipe_shared_mem( try: cmd, data = child_pipe.recv() except EOFError as err: - raise EOFError( - f"proc {pid} failed, last command: {cmd}. " f"\nErr={str(err)}" - ) + raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err if cmd == "seed": if not initialized: raise RuntimeError("call 'init' before closing") @@ -1087,9 +1085,7 @@ def _run_worker_pipe_shared_mem( else: result = attr except Exception as err: - raise RuntimeError( - f"querying {err_msg} resulted in the following error: " f"{err}" - ) + raise RuntimeError(f"querying {err_msg} resulted in an error.") from err if cmd not in ("to"): child_pipe.send(("_".join([cmd, "done"]), result)) else: diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 2d95e539fda..641301e9159 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -50,8 +50,8 @@ def _check_all_str(list_of_str, first_level=True): return [_check_all_str(item, False) for item in list_of_str] except Exception as err: raise TypeError( - f"Expected a list of strings but got: {list_of_str} that raised the following error: {err}." - ) + f"Expected a list of strings but got: {list_of_str}." + ) from err def _forward_hook_safe_action(module, tensordict_in, tensordict_out): @@ -89,8 +89,8 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): ): # "_is_stateless" in module.__dict__ and module._is_stateless: raise RuntimeError( - f"vmap cannot be used with safe=True, consider turning the safe mode off. (original error message: {err})" - ) + "vmap cannot be used with safe=True, consider turning the safe mode off." + ) from err else: raise err diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 9f493525b8b..4ebb559c550 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -81,9 +81,7 @@ def __init__( gSDE: bool = False, ): if not _has_functorch: - raise ImportError( - f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" - ) + raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR super().__init__() self.convert_to_functional( actor_network, diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 30ddfa50953..fc4cae4f6c9 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -85,9 +85,7 @@ def __init__( gSDE: bool = False, ): if not _has_functorch: - raise ImportError( - f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" - ) + raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR super().__init__() self.convert_to_functional( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index c5d31e19f5a..0d26dfab80e 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -27,7 +27,7 @@ err = "" except ImportError as err: _has_functorch = False - FUNCTORCH_ERROR = str(err) + FUNCTORCH_ERROR = err class SACLoss(LossModule): @@ -94,9 +94,7 @@ def __init__( delay_value: bool = False, ) -> None: if not _has_functorch: - raise ImportError( - f"Failed to import functorch with error message:\n{FUNCTORCH_ERROR}" - ) + raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR super().__init__() # Actor