-
Notifications
You must be signed in to change notification settings - Fork 418
Closed
pytorch/tensordict
#689Labels
bugSomething isn't workingSomething isn't working
Description
Bug that prevents restoring multiagent networks after #1921
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.modules.models.multiagent import MultiAgentMLP
if __name__ == "__main__":
actor_net = MultiAgentMLP(
n_agent_inputs=4,
n_agent_outputs=6,
n_agents=2,
centralised=False,
share_params=False,
device="cpu",
depth=2,
num_cells=256,
activation_class=nn.Tanh,
)
policy_module = TensorDictModule(
actor_net,
in_keys=[("agents", "observation")],
out_keys=[("agents", "action")],
)
dict = policy_module.state_dict()
policy_module.load_state_dict(dict)Traceback (most recent call last):
File "/Users/Matteo/PycharmProjects/torchrl/examples/multiagent/prova.py", line 25, in <module>
policy_module.load_state_dict(dict)
File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2138, in load_state_dict
load(self, state_dict)
File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2126, in load
load(child, child_state_dict, child_prefix)
File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2126, in load
load(child, child_state_dict, child_prefix)
File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2120, in load
module._load_from_state_dict(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/params.py", line 994, in _load_from_state_dict
TensorDict(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2455, in get
return self._get_tuple(key, default=default)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1647, in _get_tuple
first = self._get_str(key[0], default)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1643, in _get_str
return self._default_get(first_key, default)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2433, in _default_get
raise KeyError(
KeyError: 'key "module.params" not found in TensorDict with keys [\'module\']'Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working