-
Notifications
You must be signed in to change notification settings - Fork 420
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
The MaskedEnv example in the documentation for the ActionMask transform (Line 6666 in torchrl/envs/transforms/transforms.py) throws the error TypeError: MaskedEnv._reset() got an unexpected keyword argument 'tensordict'.
To Reproduce
Running the following example which was provided in the documentation:
import torch
from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec
from torchrl.envs.transforms import ActionMask, TransformedEnv
from torchrl.envs.common import EnvBase
class MaskedEnv(EnvBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action_spec=DiscreteTensorSpec(4)
self.state_spec = CompositeSpec(action_mask=BinaryDiscreteTensorSpec(4,dtype=torch.bool))
self.observation_spec=CompositeSpec(obs=UnboundedContinuousTensorSpec(3))
self.reward_spec=UnboundedContinuousTensorSpec(1)
def _reset(self, data):
td = self.observation_spec.rand()
td.update(torch.ones_like(self.state_spec.rand()))
return td
def _step(self, data):
td = self.observation_spec.rand()
mask = data.get("action_mask")
print("old mask:", mask)
action = data.get("action")
mask = mask.scatter(-1, action.unsqueeze(-1), 0)
print("new mask:", mask)
td.set("action_mask", mask)
td.set("reward", self.reward_spec.rand())
td.set("done", ~mask.any().view(1))
return td
def _set_seed(self, seed):
return seed
base_env = MaskedEnv()
env = TransformedEnv(base_env, ActionMask())
env.rollout(10)
env = TransformedEnv(base_env, ActionMask())
r = env.rollout(10)
r["action_mask"]results in:
Traceback (most recent call last):
File "<USER_PATH>\Projects\test.py", line 36, in <module>
env.rollout(10)
^^^^^^^^^^^^^^^^
File "<ENV_PATH>\Lib\site-packages\torchrl\envs\common.py", line 2410, in rollout
tensordict = self.reset()
^^^^^^^^^^^^
File "<ENV_PATH>\Lib\site-packages\torchrl\envs\common.py", line 2068, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<ENV_PATH>\Lib\site-packages\torchrl\envs\transforms\transforms.py", line 768, in _reset
tensordict_reset = self.base_env._reset(tensordict=tensordict, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: MaskedEnv._reset() got an unexpected keyword argument 'tensordict'Expected behavior
The documentation indicates that the following output is expected:
tensor([[ True, True, True, True],
[ True, True, False, True],
[ True, True, False, False],
[ True, False, False, False]])Screenshots
From https://pytorch.org/rl/reference/generated/torchrl.envs.transforms.ActionMask.html:

System info
- Installed with pip in a conda environment
- Python 3.11.8
Reason and Possible fixes
Replacing the data argument in the _reset method with tensordict=None solves the problem.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working