From a1010a07b44101136dd37473e6f35c479fc49a9e Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Wed, 21 Dec 2022 03:26:38 +0100 Subject: [PATCH 1/5] [Feature] Add Step Counter transform --- test/test_transforms.py | 22 +++++++++++++ torchrl/envs/__init__.py | 1 + torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 47 +++++++++++++++++++++++++++ 4 files changed, 71 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 8748a9cb1a6..1fd5859cf8f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -43,6 +43,7 @@ RewardClipping, RewardScaling, SerialEnv, + StepCounter, ToTensorImage, VIPTransform, ) @@ -1652,6 +1653,27 @@ def test_insert(self): assert env._observation_spec is not None assert env._reward_spec is not None + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) + @pytest.mark.parametrize("max_steps", [None, 0, 5]) + def test_step_counter(self, max_steps, device, batch): + torch.manual_seed(0) + step_counter = StepCounter(max_steps) + td = TensorDict( + {"done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch, device=device + ) + step_counter.reset(td) + assert not torch.all(td.get("step_count")) + i = 0 + while not td.get("done").all(): + step_counter._step(td) + i += 1 + assert torch.all(td.get("step_count") == i) + if max_steps is None or i == max_steps: + break + if max_steps is not None: + assert torch.all(td.get("step_count") == max_steps) + @pytest.mark.skipif(not _has_tv, reason="torchvision not installed") @pytest.mark.parametrize("device", get_available_devices()) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 04be8e7320a..1bc51074197 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -26,6 +26,7 @@ Resize, RewardClipping, RewardScaling, + StepCounter, TensorDictPrimer, ToTensorImage, Transform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 6c682c6210b..61ee5c20d7d 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -24,6 +24,7 @@ RewardClipping, RewardScaling, SqueezeTransform, + StepCounter, TensorDictPrimer, ToTensorImage, Transform, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 92185d64768..5000c6dd74b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2476,3 +2476,50 @@ def __repr__(self) -> str: f"{self.__class__.__name__}(decay={self.decay:4.4f}," f"eps={self.eps:4.4f}, keys={self.in_keys})" ) + + +class StepCounter(Transform): + """Counts the steps from a reset and sets the done state to True after a certain number of steps. + + Args: + max_steps (:obj:`int`, optional): the maximum number of steps to take before setting the done state to True. + """ + + invertible = False + inplace = True + + def __init__(self, max_steps: Optional[int]): + self.max_steps = max_steps + super().__init__([]) + + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + tensordict.set( + "step_count", torch.zeros(*tensordict.batch_size, 1, dtype=torch.int64) + ) + if self.max_steps is not None and self.max_steps <= 0: + tensordict.fill_("done", True) + return tensordict + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + next_step_count = tensordict.get("step_count") + 1 + tensordict.set("step_count", next_step_count) + if self.max_steps is not None: + tensordict.set( + "done", + torch.where( + next_step_count < self.max_steps, tensordict.get("done"), True + ), + ) + return tensordict + + def _transform_spec(self, spec: TensorSpec) -> None: + if isinstance(spec, CompositeSpec): + for key in spec: + self._transform_spec(spec[key]) + else: + spec.dtype = torch.int64 + + @_apply_to_composite + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + self._transform_spec(observation_spec) + return observation_spec From a2e80a08a672888e4f25268d2ae7f629b52770e8 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Thu, 22 Dec 2022 23:11:20 +0100 Subject: [PATCH 2/5] Adapt StepCounter to reset_workers --- test/test_transforms.py | 20 +++++++++++--- torchrl/envs/transforms/transforms.py | 38 ++++++++++++++++----------- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 1fd5859cf8f..2279395a77b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1656,23 +1656,37 @@ def test_insert(self): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) @pytest.mark.parametrize("max_steps", [None, 0, 5]) - def test_step_counter(self, max_steps, device, batch): + @pytest.mark.parametrize("reset_workers", [True, False]) + def test_step_counter(self, max_steps, device, batch, reset_workers): torch.manual_seed(0) step_counter = StepCounter(max_steps) td = TensorDict( {"done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch, device=device ) + if reset_workers: + td.set("reset_workers", torch.randn(*batch, 1) < 0) step_counter.reset(td) assert not torch.all(td.get("step_count")) i = 0 - while not td.get("done").all(): + while max_steps is None or i < max_steps: step_counter._step(td) i += 1 assert torch.all(td.get("step_count") == i) - if max_steps is None or i == max_steps: + if max_steps is None: break if max_steps is not None: assert torch.all(td.get("step_count") == max_steps) + assert torch.all(td.get("done")) + step_counter.reset(td) + if reset_workers: + assert torch.all( + torch.masked_select(td.get("step_count"), td.get("reset_workers")) == 0 + ) + assert torch.all( + torch.masked_select(td.get("step_count"), ~td.get("reset_workers")) == i + ) + else: + assert torch.all(td.get("step_count") == 0) @pytest.mark.skipif(not _has_tv, reason="torchvision not installed") diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 5000c6dd74b..49ba104ef59 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -23,6 +23,7 @@ DEVICE_TYPING, TensorSpec, UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, ) from torchrl.envs.common import EnvBase, make_tensordict from torchrl.envs.transforms import functional as F @@ -2482,44 +2483,51 @@ class StepCounter(Transform): """Counts the steps from a reset and sets the done state to True after a certain number of steps. Args: - max_steps (:obj:`int`, optional): the maximum number of steps to take before setting the done state to True. + max_steps (:obj:`int`, optional): the maximum number of steps to take before setting the done state to + True. If set to None (the default value), the environment will run indefinitely until the done state is manually + set by the user or by the environment itself. However, the step count will still be incremented on each call to + step(). """ invertible = False inplace = True - def __init__(self, max_steps: Optional[int]): + def __init__(self, max_steps: Optional[int] = None): self.max_steps = max_steps super().__init__([]) def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + workers = tensordict.get( + "reset_workers", + default=torch.ones(*tensordict.batch_size, 1, dtype=torch.bool), + ) tensordict.set( - "step_count", torch.zeros(*tensordict.batch_size, 1, dtype=torch.int64) + "step_count", + (~workers) + * tensordict.get( + "step_count", torch.zeros(*tensordict.batch_size, 1, dtype=torch.int64) + ), ) if self.max_steps is not None and self.max_steps <= 0: tensordict.fill_("done", True) return tensordict def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - next_step_count = tensordict.get("step_count") + 1 + next_step_count = ( + tensordict.get( + "step_count", torch.zeros(*tensordict.batch_size, 1, dtype=torch.int64) + ) + + 1 + ) tensordict.set("step_count", next_step_count) if self.max_steps is not None: tensordict.set( "done", - torch.where( - next_step_count < self.max_steps, tensordict.get("done"), True - ), + tensordict.get("done") | next_step_count >= self.max_steps, ) return tensordict - def _transform_spec(self, spec: TensorSpec) -> None: - if isinstance(spec, CompositeSpec): - for key in spec: - self._transform_spec(spec[key]) - else: - spec.dtype = torch.int64 - @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - self._transform_spec(observation_spec) + observation_spec["step_count"] = UnboundedDiscreteTensorSpec(dtype=torch.int64) return observation_spec From cb03df7766a9bbb8481b471fc27fc31ff2b6bc6d Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Fri, 23 Dec 2022 13:27:34 +0100 Subject: [PATCH 3/5] Create tensor on the right device --- test/test_transforms.py | 6 +++++ torchrl/envs/transforms/transforms.py | 34 ++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 2279395a77b..32bcf4e0f59 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -63,6 +63,7 @@ UnsqueezeTransform, ) from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform +from torchrl.envs.utils import check_env_specs TIMEOUT = 10.0 @@ -1688,6 +1689,11 @@ def test_step_counter(self, max_steps, device, batch, reset_workers): else: assert torch.all(td.get("step_count") == 0) + def test_step_counter_observation_spec(self): + env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(50)) + check_env_specs(GymEnv("Pendulum-v1")) + check_env_specs(env) + @pytest.mark.skipif(not _has_tv, reason="torchvision not installed") @pytest.mark.parametrize("device", get_available_devices()) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 49ba104ef59..8b91d9d9494 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2499,13 +2499,21 @@ def __init__(self, max_steps: Optional[int] = None): def reset(self, tensordict: TensorDictBase) -> TensorDictBase: workers = tensordict.get( "reset_workers", - default=torch.ones(*tensordict.batch_size, 1, dtype=torch.bool), + default=torch.ones( + *tensordict.batch_size, 1, dtype=torch.bool, device=tensordict.device + ), ) tensordict.set( "step_count", (~workers) * tensordict.get( - "step_count", torch.zeros(*tensordict.batch_size, 1, dtype=torch.int64) + "step_count", + torch.zeros( + *tensordict.batch_size, + 1, + dtype=torch.int64, + device=tensordict.device, + ), ), ) if self.max_steps is not None and self.max_steps <= 0: @@ -2515,7 +2523,13 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: def _step(self, tensordict: TensorDictBase) -> TensorDictBase: next_step_count = ( tensordict.get( - "step_count", torch.zeros(*tensordict.batch_size, 1, dtype=torch.int64) + "step_count", + torch.zeros( + *tensordict.batch_size, + 1, + dtype=torch.int64, + device=tensordict.device, + ), ) + 1 ) @@ -2527,7 +2541,15 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: ) return tensordict - @_apply_to_composite - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - observation_spec["step_count"] = UnboundedDiscreteTensorSpec(dtype=torch.int64) + def transform_observation_spec( + self, observation_spec: CompositeSpec + ) -> CompositeSpec: + if not isinstance(observation_spec, CompositeSpec): + raise ValueError( + f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." + ) + observation_spec["step_count"] = UnboundedDiscreteTensorSpec( + dtype=torch.int64, device=observation_spec.device + ) + observation_spec["step_count"].space.minimum = 0 return observation_spec From d06db953995cf3f26d43797687aa8ef32d417878 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Fri, 23 Dec 2022 13:45:25 +0100 Subject: [PATCH 4/5] Constraint max_steps to be >= 1 --- test/test_transforms.py | 2 +- torchrl/envs/transforms/transforms.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 32bcf4e0f59..cd52e215431 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1656,7 +1656,7 @@ def test_insert(self): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) - @pytest.mark.parametrize("max_steps", [None, 0, 5]) + @pytest.mark.parametrize("max_steps", [None, 1, 5, 50]) @pytest.mark.parametrize("reset_workers", [True, False]) def test_step_counter(self, max_steps, device, batch, reset_workers): torch.manual_seed(0) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 8b91d9d9494..10f496d9bbd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2483,16 +2483,18 @@ class StepCounter(Transform): """Counts the steps from a reset and sets the done state to True after a certain number of steps. Args: - max_steps (:obj:`int`, optional): the maximum number of steps to take before setting the done state to - True. If set to None (the default value), the environment will run indefinitely until the done state is manually - set by the user or by the environment itself. However, the step count will still be incremented on each call to - step(). + max_steps (:obj:`int`, optional): a positive integer that indicates the maximum number of steps to take before + setting the done state to True. If set to None (the default value), the environment will run indefinitely until + the done state is manually set by the user or by the environment itself. However, the step count will still be + incremented on each call to step(). """ invertible = False inplace = True def __init__(self, max_steps: Optional[int] = None): + if max_steps is not None and max_steps < 1: + raise ValueError("max_steps should have a value greater or equal to one.") self.max_steps = max_steps super().__init__([]) @@ -2516,8 +2518,6 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: ), ), ) - if self.max_steps is not None and self.max_steps <= 0: - tensordict.fill_("done", True) return tensordict def _step(self, tensordict: TensorDictBase) -> TensorDictBase: From 922d61a3f8c4dbaf313f079351738b9c55d0815b Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Fri, 30 Dec 2022 00:30:05 +0100 Subject: [PATCH 5/5] Change UnboundedDiscreteTensorSpec to NdUnboundedDiscreteTensorSpec for step_count attr in observation_spec --- test/test_transforms.py | 6 +++--- torchrl/envs/transforms/transforms.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index cd52e215431..090d0837585 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1690,9 +1690,9 @@ def test_step_counter(self, max_steps, device, batch, reset_workers): assert torch.all(td.get("step_count") == 0) def test_step_counter_observation_spec(self): - env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(50)) - check_env_specs(GymEnv("Pendulum-v1")) - check_env_specs(env) + transformed_env = TransformedEnv(ContinuousActionVecMockEnv(), StepCounter(50)) + check_env_specs(transformed_env) + transformed_env.close() @pytest.mark.skipif(not _has_tv, reason="torchvision not installed") diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 10f496d9bbd..11587053295 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2486,7 +2486,7 @@ class StepCounter(Transform): max_steps (:obj:`int`, optional): a positive integer that indicates the maximum number of steps to take before setting the done state to True. If set to None (the default value), the environment will run indefinitely until the done state is manually set by the user or by the environment itself. However, the step count will still be - incremented on each call to step(). + incremented on each call to step() into the `step_count` attribute. """ invertible = False @@ -2549,7 +2549,7 @@ def transform_observation_spec( f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." ) observation_spec["step_count"] = UnboundedDiscreteTensorSpec( - dtype=torch.int64, device=observation_spec.device + shape=torch.Size([1]), dtype=torch.int64, device=observation_spec.device ) observation_spec["step_count"].space.minimum = 0 return observation_spec