Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn as nn
from tensordict.tensordict import TensorDict, TensorDictBase
from torchrl._utils import seed_generator
from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
BoundedTensorSpec,
Expand Down Expand Up @@ -85,12 +84,9 @@ def __init__(
def maxstep(self):
return 100

def set_seed(self, seed: int, static_seed=False) -> int:
def _set_seed(self, seed: Optional[int]):
self.seed = seed
self.counter = seed % 17 # make counter a small number
if static_seed:
return seed
return seed_generator(seed)

def custom_fun(self):
return 0
Expand Down Expand Up @@ -136,14 +132,11 @@ def __init__(self, device):
super(MockSerialEnv, self).__init__(device=device)
self.is_closed = False

def set_seed(self, seed: int, static_seed: bool = False) -> int:
def _set_seed(self, seed: Optional[int]):
assert seed >= 1
self.seed = seed
self.counter = seed % 17 # make counter a small number
self.max_val = max(self.counter + 100, self.counter * 2)
if static_seed:
return seed
return seed_generator(seed)

def _step(self, tensordict):
self.counter += 1
Expand Down Expand Up @@ -207,9 +200,14 @@ def __init__(self, device, batch_size=None):
super(MockBatchedLockedEnv, self).__init__(device=device, batch_size=batch_size)
self.counter = 0

set_seed = MockSerialEnv.set_seed
rand_step = MockSerialEnv.rand_step

def _set_seed(self, seed: Optional[int]):
assert seed >= 1
self.seed = seed
self.counter = seed % 17 # make counter a small number
self.max_val = max(self.counter + 100, self.counter * 2)

def _step(self, tensordict):
self.counter += 1
# We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv
Expand Down
5 changes: 1 addition & 4 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def set_seed(
seed = new_seed
return seed

@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError

Expand Down Expand Up @@ -824,10 +825,6 @@ def close(self) -> None:
except AttributeError:
pass

@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError


def make_tensordict(
env: _EnvWrapper,
Expand Down
9 changes: 7 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from torchrl.envs.transforms.utils import check_finite
from torchrl.envs.utils import step_mdp


try:
from torchvision.transforms.functional import center_crop
from torchvision.transforms.functional_tensor import (
Expand Down Expand Up @@ -453,10 +452,16 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:

return tensordict_out

def set_seed(self, seed: int, static_seed: bool = False) -> int:
def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
) -> Optional[int]:
"""Set the seeds of the environment."""
return self.base_env.set_seed(seed, static_seed=static_seed)

def _set_seed(self, seed: Optional[int]):
"""This method is not used in transformed envs."""
pass

def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
if tensordict is not None:
tensordict = tensordict.clone(recurse=False)
Expand Down
12 changes: 10 additions & 2 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,10 @@ def close(self) -> None:
def _shutdown_workers(self) -> None:
raise NotImplementedError

def _set_seed(self, seed: Optional[int]):
"""This method is not used in batched envs."""
pass

def start(self) -> None:
if not self.is_closed:
raise RuntimeError("trying to start a environment that is not closed.")
Expand Down Expand Up @@ -606,7 +610,9 @@ def _shutdown_workers(self) -> None:
del self._envs

@_check_start
def set_seed(self, seed: int, static_seed: bool = False) -> int:
def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
) -> Optional[int]:
for env in self._envs:
new_seed = env.set_seed(seed, static_seed=static_seed)
seed = new_seed
Expand Down Expand Up @@ -816,7 +822,9 @@ def _shutdown_workers(self) -> None:
del self.parent_channels

@_check_start
def set_seed(self, seed: int, static_seed: bool = False) -> int:
def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
) -> Optional[int]:
self._seeds = []
for channel in self.parent_channels:
channel.send(("seed", (seed, static_seed)))
Expand Down