diff --git a/test/test_modules.py b/test/test_modules.py index 68917a10d16..1240f5dbfe2 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -855,20 +855,10 @@ def _get_mock_input_td( ) return td - @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("n_agents", [3, 1]) @pytest.mark.parametrize("share_params", [True, False]) @pytest.mark.parametrize("centralised", [True, False]) - @pytest.mark.parametrize( - "batch", - [ - (10,), - ( - 10, - 3, - ), - (), - ], - ) + @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) def test_mlp( self, n_agents, @@ -974,17 +964,7 @@ def test_cnn( assert not torch.allclose(out[..., i, :], out[..., j, :]) @pytest.mark.parametrize("n_agents", [1, 3]) - @pytest.mark.parametrize( - "batch", - [ - (10,), - ( - 10, - 3, - ), - (), - ], - ) + @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) def test_vdn(self, n_agents, batch): torch.manual_seed(0) mixer = VDNMixer(n_agents=n_agents, device="cpu") diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index f6b80ead12c..60395d76568 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -8,6 +8,9 @@ import numpy as np import torch + +from tensordict import TensorDict +from tensordict.nn import make_functional, TensorDictParams from torch import nn from torchrl.data.utils import DEVICE_TYPING @@ -18,9 +21,9 @@ class MultiAgentMLP(nn.Module): """Mult-agent MLP. - This is an MLP that can be used in multi-agent contexts. - For example, as a policy or as a value function. - See `examples/multiagent` for examples. + A MultiAgentMLP is an MLP that can be used in multi-agent contexts + (eg, as a a policy or as a value function). + See ``examples/multiagent`` for examples. It expects inputs with shape (*B, n_agents, n_agent_inputs) It returns outputs with shape (*B, n_agents, n_agent_outputs) @@ -36,24 +39,29 @@ class MultiAgentMLP(nn.Module): n_agent_inputs (int): number of inputs for each agent. n_agent_outputs (int): number of outputs for each agent. n_agents (int): number of agents. - centralised (bool): If `centralised` is True, each agent will use the inputs of all agents to compute its output - (n_agent_inputs * n_agents will be the number of inputs for one agent). + centralised (bool): If ``True``, each agent will use the inputs of + all agents to compute its output + (``n_agent_inputs * n_agents`` will be the number of inputs for one agent). Otherwise, each agent will only use its data as input. - share_params (bool): If `share_params` is True, the same MLP will be used to make the forward pass - for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process - its input (heterogeneous policies). + share_params (bool): If ``True``, the same MLP will be used to make the forward pass + for all agents (homogeneous policies). Otherwise, each agent will + use a different MLP to process its input (heterogeneous policies). device (str or toech.device, optional): device to create the module on. - depth (int, optional): depth of the network. A depth of 0 will produce a single linear layer network with the - desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated, - the depth information should be contained in the num_cells argument (see below). If num_cells is an - iterable and depth is indicated, both should match: len(num_cells) must be equal to depth. - default: 3. - num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If - an integer is provided, every layer will have the same number of cells. If an iterable is provided, - the linear layers out_features will match the content of num_cells. - default: 32. - activation_class (Type[nn.Module]): activation class to be used. - default: nn.Tanh. + depth (int, optional): depth of the network. A depth of ``0`` will produce + a single linear layer network with the + desired input and output size. A length of 1 will create 2 linear + layers etc. If no depth is indicated, + the depth information must be contained in the ``num_cells`` argument. + If ``num_cells`` is an iterable and depth is indicated, both should + match: ``len(num_cells)`` must be equal to ``depth``. + Defaults to ``3``. + num_cells (int or Sequence[int], optional): number of cells of every + layer in between the input and output. If an integer is provided, + every layer will have the same number of cells. If an iterable is provided, + the linear layers ``out_features`` will match the content of num_cells. + Defaults to ``32``. + activation_class (Type[nn.Module], optional): activation class to be used. + Defaults to :class:`torch.nn.Tanh`. **kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs. Examples: @@ -63,8 +71,8 @@ class MultiAgentMLP(nn.Module): >>> n_agent_inputs=3 >>> n_agent_outputs=2 >>> batch = 64 - >>> obs = torch.zeros(batch, n_agents, n_agent_inputs - First let's instantiate a local network shared by all agents (e.g. a parameter-shared policy) + >>> obs = torch.zeros(batch, n_agents, n_agent_inputs) + >>> # First let's instantiate a local network shared by all agents (e.g. a parameter-shared policy) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, @@ -86,7 +94,7 @@ class MultiAgentMLP(nn.Module): ) ) >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) - Now let's instantiate a centralised network shared by all agents (e.g. a centalised value function) + >>> # centralised network shared by all agents (e.g. a centalised value function) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, @@ -107,12 +115,9 @@ class MultiAgentMLP(nn.Module): ) ) ) - We can see that the input to the first layer is n_agents * n_agent_inputs, - this is because in the case the net acts as a centralised mlp (like a single huge agent) + >>> # The input to the first layer is `n_agents * n_agent_inputs`. >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) - Outputs will be identical for all agents. - Now we can do both examples just shown but with an independent set of parameters for each agent - Let's show the centralised=False case. + >>> # share_params=False will create independent params for each sub-MLP >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, @@ -121,6 +126,7 @@ class MultiAgentMLP(nn.Module): ... share_params=False, ... depth=2, ... ) + >>> # we now have 6 MLPs, one per agent >>> print(mlp) MultiAgentMLP( (agent_networks): ModuleList( @@ -133,7 +139,6 @@ class MultiAgentMLP(nn.Module): ) ) ) - We can see that this is the same as in the first example, but now we have 6 MLPs, one per agent! >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) """ @@ -158,22 +163,42 @@ def __init__( self.share_params = share_params self.centralised = centralised - self.agent_networks = nn.ModuleList( - [ - MLP( - in_features=n_agent_inputs - if not centralised - else n_agent_inputs * n_agents, - out_features=n_agent_outputs, - depth=depth, - num_cells=num_cells, - activation_class=activation_class, - device=device, - **kwargs, - ) - for _ in range(self.n_agents if not self.share_params else 1) - ] - ) + def make_net(): + return MLP( + in_features=n_agent_inputs + if not centralised + else n_agent_inputs * n_agents, + out_features=n_agent_outputs, + depth=depth, + num_cells=num_cells, + activation_class=activation_class, + device=device, + **kwargs, + ) + + if not self.share_params: + agent_networks = [make_net() for _ in range(self.n_agents)] + self.params = TensorDictParams( + torch.stack( + [TensorDict.from_module(mod) for mod in agent_networks], 0 + ).contiguous(), + no_convert=True, + ) + net = agent_networks[0] + else: + net = make_net() + self.params = TensorDictParams(TensorDict.from_module(net), no_convert=True) + make_functional(net) + self.net = net + if self.share_params: + self.net_call = self.net + else: + if self.centralised: + self.net_call = torch.vmap(self.net, in_dims=(None, 0), out_dims=(-2,)) + else: + self.net_call = torch.vmap(self.net, in_dims=(-2, 0), out_dims=(-2,)) + + print("self.net_call", self.net_call) def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: if len(inputs) > 1: @@ -186,31 +211,13 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: f"Multi-agent network expected input with last 2 dimensions {[self.n_agents, self.n_agent_inputs]}," f" but got {inputs.shape}" ) - # If the model is centralized, agents have full observability if self.centralised: inputs = inputs.reshape( *inputs.shape[:-2], self.n_agents * self.n_agent_inputs ) - # If parameters are not shared, each agent has its own network - if not self.share_params: - if self.centralised: - output = torch.stack( - [net(inputs) for i, net in enumerate(self.agent_networks)], - dim=-2, - ) - else: - output = torch.stack( - [ - net(inputs[..., i, :]) - for i, net in enumerate(self.agent_networks) - ], - dim=-2, - ) - # If parameters are shared, agents use the same network - else: - output = self.agent_networks[0](inputs) + output = self.net_call(inputs, self.params) if self.centralised: # If the parameters are shared, and it is centralised, all agents will have the same output