diff --git a/test/test_transforms.py b/test/test_transforms.py index 8748a9cb1a6..090d0837585 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -43,6 +43,7 @@ RewardClipping, RewardScaling, SerialEnv, + StepCounter, ToTensorImage, VIPTransform, ) @@ -62,6 +63,7 @@ UnsqueezeTransform, ) from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform +from torchrl.envs.utils import check_env_specs TIMEOUT = 10.0 @@ -1652,6 +1654,46 @@ def test_insert(self): assert env._observation_spec is not None assert env._reward_spec is not None + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) + @pytest.mark.parametrize("max_steps", [None, 1, 5, 50]) + @pytest.mark.parametrize("reset_workers", [True, False]) + def test_step_counter(self, max_steps, device, batch, reset_workers): + torch.manual_seed(0) + step_counter = StepCounter(max_steps) + td = TensorDict( + {"done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch, device=device + ) + if reset_workers: + td.set("reset_workers", torch.randn(*batch, 1) < 0) + step_counter.reset(td) + assert not torch.all(td.get("step_count")) + i = 0 + while max_steps is None or i < max_steps: + step_counter._step(td) + i += 1 + assert torch.all(td.get("step_count") == i) + if max_steps is None: + break + if max_steps is not None: + assert torch.all(td.get("step_count") == max_steps) + assert torch.all(td.get("done")) + step_counter.reset(td) + if reset_workers: + assert torch.all( + torch.masked_select(td.get("step_count"), td.get("reset_workers")) == 0 + ) + assert torch.all( + torch.masked_select(td.get("step_count"), ~td.get("reset_workers")) == i + ) + else: + assert torch.all(td.get("step_count") == 0) + + def test_step_counter_observation_spec(self): + transformed_env = TransformedEnv(ContinuousActionVecMockEnv(), StepCounter(50)) + check_env_specs(transformed_env) + transformed_env.close() + @pytest.mark.skipif(not _has_tv, reason="torchvision not installed") @pytest.mark.parametrize("device", get_available_devices()) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 04be8e7320a..1bc51074197 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -26,6 +26,7 @@ Resize, RewardClipping, RewardScaling, + StepCounter, TensorDictPrimer, ToTensorImage, Transform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 6c682c6210b..61ee5c20d7d 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -24,6 +24,7 @@ RewardClipping, RewardScaling, SqueezeTransform, + StepCounter, TensorDictPrimer, ToTensorImage, Transform, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 92185d64768..11587053295 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -23,6 +23,7 @@ DEVICE_TYPING, TensorSpec, UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, ) from torchrl.envs.common import EnvBase, make_tensordict from torchrl.envs.transforms import functional as F @@ -2476,3 +2477,79 @@ def __repr__(self) -> str: f"{self.__class__.__name__}(decay={self.decay:4.4f}," f"eps={self.eps:4.4f}, keys={self.in_keys})" ) + + +class StepCounter(Transform): + """Counts the steps from a reset and sets the done state to True after a certain number of steps. + + Args: + max_steps (:obj:`int`, optional): a positive integer that indicates the maximum number of steps to take before + setting the done state to True. If set to None (the default value), the environment will run indefinitely until + the done state is manually set by the user or by the environment itself. However, the step count will still be + incremented on each call to step() into the `step_count` attribute. + """ + + invertible = False + inplace = True + + def __init__(self, max_steps: Optional[int] = None): + if max_steps is not None and max_steps < 1: + raise ValueError("max_steps should have a value greater or equal to one.") + self.max_steps = max_steps + super().__init__([]) + + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + workers = tensordict.get( + "reset_workers", + default=torch.ones( + *tensordict.batch_size, 1, dtype=torch.bool, device=tensordict.device + ), + ) + tensordict.set( + "step_count", + (~workers) + * tensordict.get( + "step_count", + torch.zeros( + *tensordict.batch_size, + 1, + dtype=torch.int64, + device=tensordict.device, + ), + ), + ) + return tensordict + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + next_step_count = ( + tensordict.get( + "step_count", + torch.zeros( + *tensordict.batch_size, + 1, + dtype=torch.int64, + device=tensordict.device, + ), + ) + + 1 + ) + tensordict.set("step_count", next_step_count) + if self.max_steps is not None: + tensordict.set( + "done", + tensordict.get("done") | next_step_count >= self.max_steps, + ) + return tensordict + + def transform_observation_spec( + self, observation_spec: CompositeSpec + ) -> CompositeSpec: + if not isinstance(observation_spec, CompositeSpec): + raise ValueError( + f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." + ) + observation_spec["step_count"] = UnboundedDiscreteTensorSpec( + shape=torch.Size([1]), dtype=torch.int64, device=observation_spec.device + ) + observation_spec["step_count"].space.minimum = 0 + return observation_spec