Skip to content

Commit d68d129

Browse files
authored
Revert "Minor cleaning in BaseEnv classes (#767)"
This reverts commit 7c28895.
1 parent 7c28895 commit d68d129

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

torchrl/envs/common.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def reset(
470470
def numel(self) -> int:
471471
return prod(self.batch_size)
472472

473-
def set_seed(self, seed: Optional[int] = None, static_seed: bool = False) -> Optional[int]:
473+
def set_seed(self, seed: int, static_seed: bool = False) -> int:
474474
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
475475
476476
Args:
@@ -491,7 +491,6 @@ def set_seed(self, seed: Optional[int] = None, static_seed: bool = False) -> Opt
491491
seed = new_seed
492492
return seed
493493

494-
@abc.abstractmethod
495494
def _set_seed(self, seed: Optional[int]):
496495
raise NotImplementedError
497496

@@ -788,7 +787,6 @@ def __getattr__(self, attr: str) -> Any:
788787
f"env not set in {self.__class__.__name__}, cannot access {attr}"
789788
)
790789

791-
@abc.abstractmethod
792790
def _init_env(self) -> Optional[int]:
793791
"""Runs all the necessary steps such that the environment is ready to use.
794792
@@ -823,7 +821,22 @@ def close(self) -> None:
823821
except AttributeError:
824822
pass
825823

826-
824+
def set_seed(
825+
self, seed: Optional[int] = None, static_seed: bool = False
826+
) -> Optional[int]:
827+
if seed is not None:
828+
torch.manual_seed(seed)
829+
self._set_seed(seed)
830+
if seed is not None and not static_seed:
831+
new_seed = seed_generator(seed)
832+
seed = new_seed
833+
return seed
834+
835+
@abc.abstractmethod
836+
def _set_seed(self, seed: Optional[int]):
837+
raise NotImplementedError
838+
839+
827840
def make_tensordict(
828841
env: _EnvWrapper,
829842
policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,

0 commit comments

Comments
 (0)