diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c2cca0a5c23..9565959ec4a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -470,7 +470,7 @@ def reset( def numel(self) -> int: return prod(self.batch_size) - def set_seed(self, seed: Optional[int] = None, static_seed: bool = False) -> Optional[int]: + def set_seed(self, seed: int, static_seed: bool = False) -> int: """Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present). Args: @@ -491,7 +491,6 @@ def set_seed(self, seed: Optional[int] = None, static_seed: bool = False) -> Opt seed = new_seed return seed - @abc.abstractmethod def _set_seed(self, seed: Optional[int]): raise NotImplementedError @@ -788,7 +787,6 @@ def __getattr__(self, attr: str) -> Any: f"env not set in {self.__class__.__name__}, cannot access {attr}" ) - @abc.abstractmethod def _init_env(self) -> Optional[int]: """Runs all the necessary steps such that the environment is ready to use. @@ -823,7 +821,22 @@ def close(self) -> None: except AttributeError: pass - + def set_seed( + self, seed: Optional[int] = None, static_seed: bool = False + ) -> Optional[int]: + if seed is not None: + torch.manual_seed(seed) + self._set_seed(seed) + if seed is not None and not static_seed: + new_seed = seed_generator(seed) + seed = new_seed + return seed + + @abc.abstractmethod + def _set_seed(self, seed: Optional[int]): + raise NotImplementedError + + def make_tensordict( env: _EnvWrapper, policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,