diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 9565959ec4a..c2cca0a5c23 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -470,7 +470,7 @@ def reset( def numel(self) -> int: return prod(self.batch_size) - def set_seed(self, seed: int, static_seed: bool = False) -> int: + def set_seed(self, seed: Optional[int] = None, static_seed: bool = False) -> Optional[int]: """Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present). Args: @@ -491,6 +491,7 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int: seed = new_seed return seed + @abc.abstractmethod def _set_seed(self, seed: Optional[int]): raise NotImplementedError @@ -787,6 +788,7 @@ def __getattr__(self, attr: str) -> Any: f"env not set in {self.__class__.__name__}, cannot access {attr}" ) + @abc.abstractmethod def _init_env(self) -> Optional[int]: """Runs all the necessary steps such that the environment is ready to use. @@ -821,22 +823,7 @@ def close(self) -> None: except AttributeError: pass - def set_seed( - self, seed: Optional[int] = None, static_seed: bool = False - ) -> Optional[int]: - if seed is not None: - torch.manual_seed(seed) - self._set_seed(seed) - if seed is not None and not static_seed: - new_seed = seed_generator(seed) - seed = new_seed - return seed - - @abc.abstractmethod - def _set_seed(self, seed: Optional[int]): - raise NotImplementedError - - + def make_tensordict( env: _EnvWrapper, policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,