From f6471253b95687e81cc21ab7f5f6214b5632f6ee Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Tue, 27 Dec 2022 01:21:35 +0000 Subject: [PATCH 1/3] Added abstract annotations and del repeated funcs Added some abstract annotations and removed the reimplementation of `set_seed()' and '_set_seed()' in `_EnvWrapper` --- torchrl/envs/common.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 9565959ec4a..c2ca4d6346c 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -491,6 +491,7 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int: seed = new_seed return seed + @abc.abstractmethod def _set_seed(self, seed: Optional[int]): raise NotImplementedError @@ -787,6 +788,7 @@ def __getattr__(self, attr: str) -> Any: f"env not set in {self.__class__.__name__}, cannot access {attr}" ) + @abc.abstractmethod def _init_env(self) -> Optional[int]: """Runs all the necessary steps such that the environment is ready to use. @@ -821,22 +823,7 @@ def close(self) -> None: except AttributeError: pass - def set_seed( - self, seed: Optional[int] = None, static_seed: bool = False - ) -> Optional[int]: - if seed is not None: - torch.manual_seed(seed) - self._set_seed(seed) - if seed is not None and not static_seed: - new_seed = seed_generator(seed) - seed = new_seed - return seed - - @abc.abstractmethod - def _set_seed(self, seed: Optional[int]): - raise NotImplementedError - - + def make_tensordict( env: _EnvWrapper, policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None, From d495b4217b4cf6069b6c4ce9a2c37a412ea41105 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Tue, 27 Dec 2022 01:25:05 +0000 Subject: [PATCH 2/3] Added optianal type hint in `set_seed' Added optianal type hint in `set_seed' in `EnvBase` as it was in `_EnvWrapper` --- torchrl/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c2ca4d6346c..13dc69125ee 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -470,7 +470,7 @@ def reset( def numel(self) -> int: return prod(self.batch_size) - def set_seed(self, seed: int, static_seed: bool = False) -> int: + def set_seed(self, seed: Optional[int] = None, static_seed: bool = False) -> Optional[int]: """Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present). Args: From 0e6de632fecd1026305f50ca138baf5f6cc18d1d Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Tue, 27 Dec 2022 01:30:24 +0000 Subject: [PATCH 3/3] Fix indentation --- torchrl/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 13dc69125ee..c2cca0a5c23 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -470,7 +470,7 @@ def reset( def numel(self) -> int: return prod(self.batch_size) - def set_seed(self, seed: Optional[int] = None, static_seed: bool = False) -> Optional[int]: + def set_seed(self, seed: Optional[int] = None, static_seed: bool = False) -> Optional[int]: """Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present). Args: