@@ -470,7 +470,7 @@ def reset(
470470 def numel (self ) -> int :
471471 return prod (self .batch_size )
472472
473- def set_seed (self , seed : Optional [ int ] = None , static_seed : bool = False ) -> Optional [ int ] :
473+ def set_seed (self , seed : int , static_seed : bool = False ) -> int :
474474 """Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
475475
476476 Args:
@@ -491,7 +491,6 @@ def set_seed(self, seed: Optional[int] = None, static_seed: bool = False) -> Opt
491491 seed = new_seed
492492 return seed
493493
494- @abc .abstractmethod
495494 def _set_seed (self , seed : Optional [int ]):
496495 raise NotImplementedError
497496
@@ -788,7 +787,6 @@ def __getattr__(self, attr: str) -> Any:
788787 f"env not set in { self .__class__ .__name__ } , cannot access { attr } "
789788 )
790789
791- @abc .abstractmethod
792790 def _init_env (self ) -> Optional [int ]:
793791 """Runs all the necessary steps such that the environment is ready to use.
794792
@@ -823,7 +821,22 @@ def close(self) -> None:
823821 except AttributeError :
824822 pass
825823
826-
824+ def set_seed (
825+ self , seed : Optional [int ] = None , static_seed : bool = False
826+ ) -> Optional [int ]:
827+ if seed is not None :
828+ torch .manual_seed (seed )
829+ self ._set_seed (seed )
830+ if seed is not None and not static_seed :
831+ new_seed = seed_generator (seed )
832+ seed = new_seed
833+ return seed
834+
835+ @abc .abstractmethod
836+ def _set_seed (self , seed : Optional [int ]):
837+ raise NotImplementedError
838+
839+
827840def make_tensordict (
828841 env : _EnvWrapper ,
829842 policy : Optional [Callable [[TensorDictBase , ...], TensorDictBase ]] = None ,
0 commit comments