From 004b7ff6ab018fbb56e78f0ca7f7e8ca47a3b9cd Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 27 Dec 2022 22:43:00 -0500 Subject: [PATCH 1/4] Making abstract --- test/mocking_classes.py | 18 ++++++++---------- torchrl/envs/common.py | 5 +---- torchrl/envs/transforms/transforms.py | 10 ++++++++-- torchrl/envs/vec_env.py | 12 ++++++++++-- 4 files changed, 27 insertions(+), 18 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 7629d2874e7..5f92e0da52c 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase -from torchrl._utils import seed_generator from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -85,12 +84,9 @@ def __init__( def maxstep(self): return 100 - def set_seed(self, seed: int, static_seed=False) -> int: + def _set_seed(self, seed: Optional[int]): self.seed = seed self.counter = seed % 17 # make counter a small number - if static_seed: - return seed - return seed_generator(seed) def custom_fun(self): return 0 @@ -136,14 +132,11 @@ def __init__(self, device): super(MockSerialEnv, self).__init__(device=device) self.is_closed = False - def set_seed(self, seed: int, static_seed: bool = False) -> int: + def _set_seed(self, seed: Optional[int]): assert seed >= 1 self.seed = seed self.counter = seed % 17 # make counter a small number self.max_val = max(self.counter + 100, self.counter * 2) - if static_seed: - return seed - return seed_generator(seed) def _step(self, tensordict): self.counter += 1 @@ -207,9 +200,14 @@ def __init__(self, device, batch_size=None): super(MockBatchedLockedEnv, self).__init__(device=device, batch_size=batch_size) self.counter = 0 - set_seed = MockSerialEnv.set_seed rand_step = MockSerialEnv.rand_step + def _set_seed(self, seed: Optional[int]): + assert seed >= 1 + self.seed = seed + self.counter = seed % 17 # make counter a small number + self.max_val = max(self.counter + 100, self.counter * 2) + def _step(self, tensordict): self.counter += 1 # We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 9565959ec4a..48a65618880 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -491,6 +491,7 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int: seed = new_seed return seed + @abc.abstractmethod def _set_seed(self, seed: Optional[int]): raise NotImplementedError @@ -832,10 +833,6 @@ def set_seed( seed = new_seed return seed - @abc.abstractmethod - def _set_seed(self, seed: Optional[int]): - raise NotImplementedError - def make_tensordict( env: _EnvWrapper, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 647760ba4cd..59b86267e22 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -453,10 +453,16 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict_out - def set_seed(self, seed: int, static_seed: bool = False) -> int: - """Set the seeds of the environment.""" + def set_seed( + self, seed: Optional[int] = None, static_seed: bool = False + ) -> Optional[int]: + # This method is not used in transformed environments return self.base_env.set_seed(seed, static_seed=static_seed) + def _set_seed(self, seed: Optional[int]): + # This method is not used in transformed envs + return + def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): if tensordict is not None: tensordict = tensordict.clone(recurse=False) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 06f9821eb8a..89e942c7e01 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -514,6 +514,10 @@ def close(self) -> None: def _shutdown_workers(self) -> None: raise NotImplementedError + def _set_seed(self, seed: Optional[int]): + # This method is not used in batched envs + return + def start(self) -> None: if not self.is_closed: raise RuntimeError("trying to start a environment that is not closed.") @@ -606,7 +610,9 @@ def _shutdown_workers(self) -> None: del self._envs @_check_start - def set_seed(self, seed: int, static_seed: bool = False) -> int: + def set_seed( + self, seed: Optional[int] = None, static_seed: bool = False + ) -> Optional[int]: for env in self._envs: new_seed = env.set_seed(seed, static_seed=static_seed) seed = new_seed @@ -816,7 +822,9 @@ def _shutdown_workers(self) -> None: del self.parent_channels @_check_start - def set_seed(self, seed: int, static_seed: bool = False) -> int: + def set_seed( + self, seed: Optional[int] = None, static_seed: bool = False + ) -> Optional[int]: self._seeds = [] for channel in self.parent_channels: channel.send(("seed", (seed, static_seed))) From b3738c2a3cc883a202e9b37001e7ef6b2e66df4d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 28 Dec 2022 16:38:19 -0500 Subject: [PATCH 2/4] PR fixes --- torchrl/envs/transforms/transforms.py | 9 ++++----- torchrl/envs/vec_env.py | 6 +++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 59b86267e22..9a8a49124fd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -30,7 +30,6 @@ from torchrl.envs.transforms.utils import check_finite from torchrl.envs.utils import step_mdp - try: from torchvision.transforms.functional import center_crop from torchvision.transforms.functional_tensor import ( @@ -456,12 +455,12 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: def set_seed( self, seed: Optional[int] = None, static_seed: bool = False ) -> Optional[int]: - # This method is not used in transformed environments + """Set the seeds of the environment.""" return self.base_env.set_seed(seed, static_seed=static_seed) def _set_seed(self, seed: Optional[int]): - # This method is not used in transformed envs - return + """This method is not used in transformed envs""" + pass def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): if tensordict is not None: @@ -2476,4 +2475,4 @@ def __repr__(self) -> str: return ( f"{self.__class__.__name__}(decay={self.decay:4.4f}," f"eps={self.eps:4.4f}, keys={self.in_keys})" - ) + ) \ No newline at end of file diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 89e942c7e01..d81197f63cb 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -515,8 +515,8 @@ def _shutdown_workers(self) -> None: raise NotImplementedError def _set_seed(self, seed: Optional[int]): - # This method is not used in batched envs - return + """ This method is not used in batched envs """ + pass def start(self) -> None: if not self.is_closed: @@ -1081,4 +1081,4 @@ def _run_worker_pipe_shared_mem( child_pipe.send(("_".join([cmd, "done"]), result)) else: # don't send env through pipe - child_pipe.send(("_".join([cmd, "done"]), None)) + child_pipe.send(("_".join([cmd, "done"]), None)) \ No newline at end of file From 2c59bd23d11c7da163fdd918f931935ec4af8ab9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 28 Dec 2022 16:40:05 -0500 Subject: [PATCH 3/4] Linting --- torchrl/envs/transforms/transforms.py | 4 ++-- torchrl/envs/vec_env.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 9a8a49124fd..ac7858fe119 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -459,7 +459,7 @@ def set_seed( return self.base_env.set_seed(seed, static_seed=static_seed) def _set_seed(self, seed: Optional[int]): - """This method is not used in transformed envs""" + """This method is not used in transformed envs.""" pass def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): @@ -2475,4 +2475,4 @@ def __repr__(self) -> str: return ( f"{self.__class__.__name__}(decay={self.decay:4.4f}," f"eps={self.eps:4.4f}, keys={self.in_keys})" - ) \ No newline at end of file + ) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index d81197f63cb..f119e9e2151 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -515,7 +515,7 @@ def _shutdown_workers(self) -> None: raise NotImplementedError def _set_seed(self, seed: Optional[int]): - """ This method is not used in batched envs """ + """This method is not used in batched envs.""" pass def start(self) -> None: @@ -1081,4 +1081,4 @@ def _run_worker_pipe_shared_mem( child_pipe.send(("_".join([cmd, "done"]), result)) else: # don't send env through pipe - child_pipe.send(("_".join([cmd, "done"]), None)) \ No newline at end of file + child_pipe.send(("_".join([cmd, "done"]), None)) From c403278061075b96182a3db4546c7d7e1acb2b79 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 28 Dec 2022 16:44:11 -0500 Subject: [PATCH 4/4] Linting --- torchrl/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 6186a935488..406013d2480 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -844,4 +844,4 @@ def make_tensordict( else: tensordict.set("action", env.action_spec.rand(), inplace=False) tensordict = env.step(tensordict) - return tensordict.zero_() \ No newline at end of file + return tensordict.zero_()