Skip to content

[BUG] Failing to export multi-agent models #2902

@matteobettini

Description

@matteobettini

torch.export seems to not work on the multiagent models

I have distilled the issue from facebookresearch/BenchMARL#188 into this rminimal reproducing script

import torch
from tensordict import TensorDict
from tensordict.nn import NormalParamExtractor, TensorDictModule
from torch import nn

from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal
from torchrl.envs.utils import ExplorationType, set_exploration_type

n_actions = 3
n_obs = 5
n_agents = 2
batch = 4


policy = TensorDictModule(
    MultiAgentMLP(
        n_agent_inputs=n_obs,
        n_agent_outputs=2 * n_actions,
        n_agents=n_agents,
        centralised=False,
        share_params=True,
        device="cpu",
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    ),
    in_keys=[("agents", "observation")],
    out_keys=[
        ("agents", "out"),
    ],
)

obs = TensorDict(
    {
        "agents": TensorDict(
            {"observation": torch.randn((batch, n_agents, n_obs))},
            batch_size=[batch, n_agents],
        )
    },
    batch_size=[batch],
)
print(policy(obs))  # Success
with set_exploration_type(ExplorationType.DETERMINISTIC):
    exported_policy = torch.export.export(
        policy.select_out_keys(("agents", "out")),
        args=(),
        kwargs={"agents_observation": obs["agents", "observation"]},
        strict=True,
    )  # Fail
torch._dynamo.exc.Unsupported: isinstance(GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__), BuiltinVariable(dict)): can't determine type of GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions