Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def _run_processes(self) -> None:
`env = ParallelEnv(N, EnvCreator(my_lambda_function))`.
This will not only ensure that your lambda function is cloud-pickled once, but
also that the state dict is synchronised across processes if needed."""
)
) from err
pipe_child.close()
self.procs.append(proc)
self.pipes.append(pipe_parent)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _build_env(
f"brax not found, unable to create {env_name}. "
f"Consider downloading and installing brax from"
f" {self.git_url}"
)
) from IMPORT_ERR
from_pixels = kwargs.pop("from_pixels", False)
pixels_only = kwargs.pop("pixels_only", True)
requires_grad = kwargs.pop("requires_grad", False)
Expand Down
8 changes: 6 additions & 2 deletions torchrl/envs/libs/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def lib(self):
return jumanji

def __init__(self, env: "jumanji.env.Environment" = None, **kwargs):
if not _has_jumanji:
raise ImportError(
"jumanji is not installed or importing it failed. Consider checking your installation."
) from IMPORT_ERR
if env is not None:
kwargs["env"] = env
super().__init__(**kwargs)
Expand Down Expand Up @@ -331,8 +335,8 @@ def _build_env(
raise RuntimeError(
f"jumanji not found, unable to create {env_name}. "
f"Consider installing jumanji. More info:"
f" {self.git_url}. (Original error message during import: {IMPORT_ERR})."
)
f" {self.git_url}."
) from IMPORT_ERR
from_pixels = kwargs.pop("from_pixels", False)
pixels_only = kwargs.pop("pixels_only", True)
assert not kwargs
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def __init__(
if not _has_vmas:
raise ImportError(
f"vmas python package was not found. Please install this dependency. "
f"More info: {self.git_url} (ImportError: {IMPORT_ERR})"
)
f"More info: {self.git_url}."
) from IMPORT_ERR
kwargs["scenario_name"] = scenario_name
kwargs["num_envs"] = num_envs
kwargs["continuous_actions"] = continuous_actions
Expand Down
8 changes: 2 additions & 6 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,9 +982,7 @@ def _run_worker_pipe_shared_mem(
try:
cmd, data = child_pipe.recv()
except EOFError as err:
raise EOFError(
f"proc {pid} failed, last command: {cmd}. " f"\nErr={str(err)}"
)
raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err
if cmd == "seed":
if not initialized:
raise RuntimeError("call 'init' before closing")
Expand Down Expand Up @@ -1087,9 +1085,7 @@ def _run_worker_pipe_shared_mem(
else:
result = attr
except Exception as err:
raise RuntimeError(
f"querying {err_msg} resulted in the following error: " f"{err}"
)
raise RuntimeError(f"querying {err_msg} resulted in an error.") from err
if cmd not in ("to"):
child_pipe.send(("_".join([cmd, "done"]), result))
else:
Expand Down
8 changes: 4 additions & 4 deletions torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def _check_all_str(list_of_str, first_level=True):
return [_check_all_str(item, False) for item in list_of_str]
except Exception as err:
raise TypeError(
f"Expected a list of strings but got: {list_of_str} that raised the following error: {err}."
)
f"Expected a list of strings but got: {list_of_str}."
) from err


def _forward_hook_safe_action(module, tensordict_in, tensordict_out):
Expand Down Expand Up @@ -89,8 +89,8 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out):
):
# "_is_stateless" in module.__dict__ and module._is_stateless:
raise RuntimeError(
f"vmap cannot be used with safe=True, consider turning the safe mode off. (original error message: {err})"
)
"vmap cannot be used with safe=True, consider turning the safe mode off."
) from err
else:
raise err

Expand Down
4 changes: 1 addition & 3 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def __init__(
gSDE: bool = False,
):
if not _has_functorch:
raise ImportError(
f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}"
)
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR
super().__init__()
self.convert_to_functional(
actor_network,
Expand Down
4 changes: 1 addition & 3 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def __init__(
gSDE: bool = False,
):
if not _has_functorch:
raise ImportError(
f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}"
)
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR

super().__init__()
self.convert_to_functional(
Expand Down
6 changes: 2 additions & 4 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
err = ""
except ImportError as err:
_has_functorch = False
FUNCTORCH_ERROR = str(err)
FUNCTORCH_ERROR = err


class SACLoss(LossModule):
Expand Down Expand Up @@ -94,9 +94,7 @@ def __init__(
delay_value: bool = False,
) -> None:
if not _has_functorch:
raise ImportError(
f"Failed to import functorch with error message:\n{FUNCTORCH_ERROR}"
)
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR
super().__init__()

# Actor
Expand Down