-
Notifications
You must be signed in to change notification settings - Fork 418
Open
Description
torch.export seems to not work on the multiagent models
I have distilled the issue from facebookresearch/BenchMARL#188 into this rminimal reproducing script
import torch
from tensordict import TensorDict
from tensordict.nn import NormalParamExtractor, TensorDictModule
from torch import nn
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal
from torchrl.envs.utils import ExplorationType, set_exploration_type
n_actions = 3
n_obs = 5
n_agents = 2
batch = 4
policy = TensorDictModule(
MultiAgentMLP(
n_agent_inputs=n_obs,
n_agent_outputs=2 * n_actions,
n_agents=n_agents,
centralised=False,
share_params=True,
device="cpu",
depth=2,
num_cells=256,
activation_class=nn.Tanh,
),
in_keys=[("agents", "observation")],
out_keys=[
("agents", "out"),
],
)
obs = TensorDict(
{
"agents": TensorDict(
{"observation": torch.randn((batch, n_agents, n_obs))},
batch_size=[batch, n_agents],
)
},
batch_size=[batch],
)
print(policy(obs)) # Success
with set_exploration_type(ExplorationType.DETERMINISTIC):
exported_policy = torch.export.export(
policy.select_out_keys(("agents", "out")),
args=(),
kwargs={"agents_observation": obs["agents", "observation"]},
strict=True,
) # Failtorch._dynamo.exc.Unsupported: isinstance(GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__), BuiltinVariable(dict)): can't determine type of GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__)Metadata
Metadata
Assignees
Labels
No labels