Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 4 additions & 17 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def reset(
def numel(self) -> int:
return prod(self.batch_size)

def set_seed(self, seed: int, static_seed: bool = False) -> int:
def set_seed(self, seed: Optional[int] = None, static_seed: bool = False) -> Optional[int]:
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).

Args:
Expand All @@ -491,6 +491,7 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int:
seed = new_seed
return seed

@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError

Expand Down Expand Up @@ -787,6 +788,7 @@ def __getattr__(self, attr: str) -> Any:
f"env not set in {self.__class__.__name__}, cannot access {attr}"
)

@abc.abstractmethod
def _init_env(self) -> Optional[int]:
"""Runs all the necessary steps such that the environment is ready to use.

Expand Down Expand Up @@ -821,22 +823,7 @@ def close(self) -> None:
except AttributeError:
pass

def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
) -> Optional[int]:
if seed is not None:
torch.manual_seed(seed)
self._set_seed(seed)
if seed is not None and not static_seed:
new_seed = seed_generator(seed)
seed = new_seed
return seed

@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError



def make_tensordict(
env: _EnvWrapper,
policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,
Expand Down