Skip to content

[BUG] Documentation Error: MaskedEnv Example Under ActionMask Transform Throws TypeError #2059

@Jonathanace

Description

@Jonathanace

Describe the bug

The MaskedEnv example in the documentation for the ActionMask transform (Line 6666 in torchrl/envs/transforms/transforms.py) throws the error TypeError: MaskedEnv._reset() got an unexpected keyword argument 'tensordict'.

To Reproduce

Running the following example which was provided in the documentation:

import torch
from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec
from torchrl.envs.transforms import ActionMask, TransformedEnv
from torchrl.envs.common import EnvBase
class MaskedEnv(EnvBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.action_spec=DiscreteTensorSpec(4)
        self.state_spec = CompositeSpec(action_mask=BinaryDiscreteTensorSpec(4,dtype=torch.bool))
        self.observation_spec=CompositeSpec(obs=UnboundedContinuousTensorSpec(3))
        self.reward_spec=UnboundedContinuousTensorSpec(1)

    def _reset(self, data):
        td = self.observation_spec.rand()
        td.update(torch.ones_like(self.state_spec.rand()))
        return td
    
    def _step(self, data):
        td = self.observation_spec.rand()
        mask = data.get("action_mask")
        print("old mask:", mask)
        action = data.get("action")
        mask = mask.scatter(-1, action.unsqueeze(-1), 0)
        print("new mask:", mask)
        
        td.set("action_mask", mask)
        td.set("reward", self.reward_spec.rand())
        td.set("done", ~mask.any().view(1))
        return td
    
    def _set_seed(self, seed):
        return seed
    
base_env = MaskedEnv()
env = TransformedEnv(base_env, ActionMask())
env.rollout(10)
env = TransformedEnv(base_env, ActionMask())
r = env.rollout(10)
r["action_mask"]

results in:

Traceback (most recent call last):
  File "<USER_PATH>\Projects\test.py", line 36, in <module>
    env.rollout(10)
    ^^^^^^^^^^^^^^^^
  File "<ENV_PATH>\Lib\site-packages\torchrl\envs\common.py", line 2410, in rollout
    tensordict = self.reset()
                 ^^^^^^^^^^^^
  File "<ENV_PATH>\Lib\site-packages\torchrl\envs\common.py", line 2068, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<ENV_PATH>\Lib\site-packages\torchrl\envs\transforms\transforms.py", line 768, in _reset
    tensordict_reset = self.base_env._reset(tensordict=tensordict, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: MaskedEnv._reset() got an unexpected keyword argument 'tensordict'

Expected behavior

The documentation indicates that the following output is expected:

tensor([[ True,  True,  True,  True],
        [ True,  True, False,  True],
        [ True,  True, False, False],
        [ True, False, False, False]])

Screenshots

From https://pytorch.org/rl/reference/generated/torchrl.envs.transforms.ActionMask.html:
image

System info

  • Installed with pip in a conda environment
  • Python 3.11.8

Reason and Possible fixes

Replacing the data argument in the _reset method with tensordict=None solves the problem.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions