Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
RewardClipping,
RewardScaling,
SerialEnv,
StepCounter,
ToTensorImage,
VIPTransform,
)
Expand All @@ -62,6 +63,7 @@
UnsqueezeTransform,
)
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
from torchrl.envs.utils import check_env_specs

TIMEOUT = 10.0

Expand Down Expand Up @@ -1652,6 +1654,46 @@ def test_insert(self):
assert env._observation_spec is not None
assert env._reward_spec is not None

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("batch", [[], [4], [6, 4]])
@pytest.mark.parametrize("max_steps", [None, 1, 5, 50])
@pytest.mark.parametrize("reset_workers", [True, False])
def test_step_counter(self, max_steps, device, batch, reset_workers):
torch.manual_seed(0)
step_counter = StepCounter(max_steps)
td = TensorDict(
{"done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch, device=device
)
if reset_workers:
td.set("reset_workers", torch.randn(*batch, 1) < 0)
step_counter.reset(td)
assert not torch.all(td.get("step_count"))
i = 0
while max_steps is None or i < max_steps:
step_counter._step(td)
i += 1
assert torch.all(td.get("step_count") == i)
if max_steps is None:
break
if max_steps is not None:
assert torch.all(td.get("step_count") == max_steps)
assert torch.all(td.get("done"))
step_counter.reset(td)
if reset_workers:
assert torch.all(
torch.masked_select(td.get("step_count"), td.get("reset_workers")) == 0
)
assert torch.all(
torch.masked_select(td.get("step_count"), ~td.get("reset_workers")) == i
)
else:
assert torch.all(td.get("step_count") == 0)

def test_step_counter_observation_spec(self):
transformed_env = TransformedEnv(ContinuousActionVecMockEnv(), StepCounter(50))
check_env_specs(transformed_env)
transformed_env.close()


@pytest.mark.skipif(not _has_tv, reason="torchvision not installed")
@pytest.mark.parametrize("device", get_available_devices())
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Resize,
RewardClipping,
RewardScaling,
StepCounter,
TensorDictPrimer,
ToTensorImage,
Transform,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
RewardClipping,
RewardScaling,
SqueezeTransform,
StepCounter,
TensorDictPrimer,
ToTensorImage,
Transform,
Expand Down
77 changes: 77 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DEVICE_TYPING,
TensorSpec,
UnboundedContinuousTensorSpec,
UnboundedDiscreteTensorSpec,
)
from torchrl.envs.common import EnvBase, make_tensordict
from torchrl.envs.transforms import functional as F
Expand Down Expand Up @@ -2476,3 +2477,79 @@ def __repr__(self) -> str:
f"{self.__class__.__name__}(decay={self.decay:4.4f},"
f"eps={self.eps:4.4f}, keys={self.in_keys})"
)


class StepCounter(Transform):
"""Counts the steps from a reset and sets the done state to True after a certain number of steps.

Args:
max_steps (:obj:`int`, optional): a positive integer that indicates the maximum number of steps to take before
setting the done state to True. If set to None (the default value), the environment will run indefinitely until
the done state is manually set by the user or by the environment itself. However, the step count will still be
incremented on each call to step() into the `step_count` attribute.
"""

invertible = False
inplace = True

def __init__(self, max_steps: Optional[int] = None):
if max_steps is not None and max_steps < 1:
raise ValueError("max_steps should have a value greater or equal to one.")
self.max_steps = max_steps
super().__init__([])

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
workers = tensordict.get(
"reset_workers",
default=torch.ones(
*tensordict.batch_size, 1, dtype=torch.bool, device=tensordict.device
),
)
tensordict.set(
"step_count",
(~workers)
* tensordict.get(
"step_count",
torch.zeros(
*tensordict.batch_size,
1,
dtype=torch.int64,
device=tensordict.device,
),
),
)
return tensordict

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
next_step_count = (
tensordict.get(
"step_count",
torch.zeros(
*tensordict.batch_size,
1,
dtype=torch.int64,
device=tensordict.device,
),
)
+ 1
)
tensordict.set("step_count", next_step_count)
if self.max_steps is not None:
tensordict.set(
"done",
tensordict.get("done") | next_step_count >= self.max_steps,
)
return tensordict

def transform_observation_spec(
self, observation_spec: CompositeSpec
) -> CompositeSpec:
if not isinstance(observation_spec, CompositeSpec):
raise ValueError(
f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead."
)
observation_spec["step_count"] = UnboundedDiscreteTensorSpec(
shape=torch.Size([1]), dtype=torch.int64, device=observation_spec.device
)
observation_spec["step_count"].space.minimum = 0
return observation_spec