diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 14a8e9e1a02..e25ed08ebed 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6676,7 +6676,7 @@ class ActionMask(Transform): ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec(3)) ... self.reward_spec = UnboundedContinuousTensorSpec(1) ... - ... def _reset(self, data): + ... def _reset(self, tensordict=None): ... td = self.observation_spec.rand() ... td.update(torch.ones_like(self.state_spec.rand())) ... return td