Skip to content

[BUG] Restoring multiagent nets  #1960

@matteobettini

Description

@matteobettini

Bug that prevents restoring multiagent networks after #1921

from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.modules.models.multiagent import MultiAgentMLP

if __name__ == "__main__":
    actor_net = MultiAgentMLP(
        n_agent_inputs=4,
        n_agent_outputs=6,
        n_agents=2,
        centralised=False,
        share_params=False,
        device="cpu",
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    )

    policy_module = TensorDictModule(
        actor_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "action")],
    )
    dict = policy_module.state_dict()
    policy_module.load_state_dict(dict)
Traceback (most recent call last):
  File "/Users/Matteo/PycharmProjects/torchrl/examples/multiagent/prova.py", line 25, in <module>
    policy_module.load_state_dict(dict)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2138, in load_state_dict
    load(self, state_dict)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2126, in load
    load(child, child_state_dict, child_prefix)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2126, in load
    load(child, child_state_dict, child_prefix)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2120, in load
    module._load_from_state_dict(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/params.py", line 994, in _load_from_state_dict
    TensorDict(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2455, in get
    return self._get_tuple(key, default=default)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1647, in _get_tuple
    first = self._get_str(key[0], default)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1643, in _get_str
    return self._default_get(first_key, default)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2433, in _default_get
    raise KeyError(
KeyError: 'key "module.params" not found in TensorDict with keys [\'module\']'

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions