Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 3 additions & 23 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,20 +855,10 @@ def _get_mock_input_td(
)
return td

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("n_agents", [3, 1])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize(
"batch",
[
(10,),
(
10,
3,
),
(),
],
)
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
def test_mlp(
self,
n_agents,
Expand Down Expand Up @@ -974,17 +964,7 @@ def test_cnn(
assert not torch.allclose(out[..., i, :], out[..., j, :])

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize(
"batch",
[
(10,),
(
10,
3,
),
(),
],
)
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
def test_vdn(self, n_agents, batch):
torch.manual_seed(0)
mixer = VDNMixer(n_agents=n_agents, device="cpu")
Expand Down
133 changes: 70 additions & 63 deletions torchrl/modules/models/multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import numpy as np

import torch

from tensordict import TensorDict
from tensordict.nn import make_functional, TensorDictParams
from torch import nn

from torchrl.data.utils import DEVICE_TYPING
Expand All @@ -18,9 +21,9 @@
class MultiAgentMLP(nn.Module):
"""Mult-agent MLP.

This is an MLP that can be used in multi-agent contexts.
For example, as a policy or as a value function.
See `examples/multiagent` for examples.
A MultiAgentMLP is an MLP that can be used in multi-agent contexts
(eg, as a a policy or as a value function).
See ``examples/multiagent`` for examples.

It expects inputs with shape (*B, n_agents, n_agent_inputs)
It returns outputs with shape (*B, n_agents, n_agent_outputs)
Expand All @@ -36,24 +39,29 @@ class MultiAgentMLP(nn.Module):
n_agent_inputs (int): number of inputs for each agent.
n_agent_outputs (int): number of outputs for each agent.
n_agents (int): number of agents.
centralised (bool): If `centralised` is True, each agent will use the inputs of all agents to compute its output
(n_agent_inputs * n_agents will be the number of inputs for one agent).
centralised (bool): If ``True``, each agent will use the inputs of
all agents to compute its output
(``n_agent_inputs * n_agents`` will be the number of inputs for one agent).
Otherwise, each agent will only use its data as input.
share_params (bool): If `share_params` is True, the same MLP will be used to make the forward pass
for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process
its input (heterogeneous policies).
share_params (bool): If ``True``, the same MLP will be used to make the forward pass
for all agents (homogeneous policies). Otherwise, each agent will
use a different MLP to process its input (heterogeneous policies).
device (str or toech.device, optional): device to create the module on.
depth (int, optional): depth of the network. A depth of 0 will produce a single linear layer network with the
desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated,
the depth information should be contained in the num_cells argument (see below). If num_cells is an
iterable and depth is indicated, both should match: len(num_cells) must be equal to depth.
default: 3.
num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If
an integer is provided, every layer will have the same number of cells. If an iterable is provided,
the linear layers out_features will match the content of num_cells.
default: 32.
activation_class (Type[nn.Module]): activation class to be used.
default: nn.Tanh.
depth (int, optional): depth of the network. A depth of ``0`` will produce
a single linear layer network with the
desired input and output size. A length of 1 will create 2 linear
layers etc. If no depth is indicated,
the depth information must be contained in the ``num_cells`` argument.
If ``num_cells`` is an iterable and depth is indicated, both should
match: ``len(num_cells)`` must be equal to ``depth``.
Defaults to ``3``.
num_cells (int or Sequence[int], optional): number of cells of every
layer in between the input and output. If an integer is provided,
every layer will have the same number of cells. If an iterable is provided,
the linear layers ``out_features`` will match the content of num_cells.
Defaults to ``32``.
activation_class (Type[nn.Module], optional): activation class to be used.
Defaults to :class:`torch.nn.Tanh`.
**kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs.

Examples:
Expand All @@ -63,8 +71,8 @@ class MultiAgentMLP(nn.Module):
>>> n_agent_inputs=3
>>> n_agent_outputs=2
>>> batch = 64
>>> obs = torch.zeros(batch, n_agents, n_agent_inputs
First let's instantiate a local network shared by all agents (e.g. a parameter-shared policy)
>>> obs = torch.zeros(batch, n_agents, n_agent_inputs)
>>> # First let's instantiate a local network shared by all agents (e.g. a parameter-shared policy)
>>> mlp = MultiAgentMLP(
... n_agent_inputs=n_agent_inputs,
... n_agent_outputs=n_agent_outputs,
Expand All @@ -86,7 +94,7 @@ class MultiAgentMLP(nn.Module):
)
)
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
Now let's instantiate a centralised network shared by all agents (e.g. a centalised value function)
>>> # centralised network shared by all agents (e.g. a centalised value function)
>>> mlp = MultiAgentMLP(
... n_agent_inputs=n_agent_inputs,
... n_agent_outputs=n_agent_outputs,
Expand All @@ -107,12 +115,9 @@ class MultiAgentMLP(nn.Module):
)
)
)
We can see that the input to the first layer is n_agents * n_agent_inputs,
this is because in the case the net acts as a centralised mlp (like a single huge agent)
>>> # The input to the first layer is `n_agents * n_agent_inputs`.
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
Outputs will be identical for all agents.
Now we can do both examples just shown but with an independent set of parameters for each agent
Let's show the centralised=False case.
>>> # share_params=False will create independent params for each sub-MLP
>>> mlp = MultiAgentMLP(
... n_agent_inputs=n_agent_inputs,
... n_agent_outputs=n_agent_outputs,
Expand All @@ -121,6 +126,7 @@ class MultiAgentMLP(nn.Module):
... share_params=False,
... depth=2,
... )
>>> # we now have 6 MLPs, one per agent
>>> print(mlp)
MultiAgentMLP(
(agent_networks): ModuleList(
Expand All @@ -133,7 +139,6 @@ class MultiAgentMLP(nn.Module):
)
)
)
We can see that this is the same as in the first example, but now we have 6 MLPs, one per agent!
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
"""

Expand All @@ -158,22 +163,42 @@ def __init__(
self.share_params = share_params
self.centralised = centralised

self.agent_networks = nn.ModuleList(
[
MLP(
in_features=n_agent_inputs
if not centralised
else n_agent_inputs * n_agents,
out_features=n_agent_outputs,
depth=depth,
num_cells=num_cells,
activation_class=activation_class,
device=device,
**kwargs,
)
for _ in range(self.n_agents if not self.share_params else 1)
]
)
def make_net():
return MLP(
in_features=n_agent_inputs
if not centralised
else n_agent_inputs * n_agents,
out_features=n_agent_outputs,
depth=depth,
num_cells=num_cells,
activation_class=activation_class,
device=device,
**kwargs,
)

if not self.share_params:
agent_networks = [make_net() for _ in range(self.n_agents)]
self.params = TensorDictParams(
torch.stack(
[TensorDict.from_module(mod) for mod in agent_networks], 0
).contiguous(),
no_convert=True,
)
net = agent_networks[0]
else:
net = make_net()
self.params = TensorDictParams(TensorDict.from_module(net), no_convert=True)
make_functional(net)
self.net = net
if self.share_params:
self.net_call = self.net
else:
if self.centralised:
self.net_call = torch.vmap(self.net, in_dims=(None, 0), out_dims=(-2,))
else:
self.net_call = torch.vmap(self.net, in_dims=(-2, 0), out_dims=(-2,))

print("self.net_call", self.net_call)

def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
if len(inputs) > 1:
Expand All @@ -186,31 +211,13 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
f"Multi-agent network expected input with last 2 dimensions {[self.n_agents, self.n_agent_inputs]},"
f" but got {inputs.shape}"
)

# If the model is centralized, agents have full observability
if self.centralised:
inputs = inputs.reshape(
*inputs.shape[:-2], self.n_agents * self.n_agent_inputs
)

# If parameters are not shared, each agent has its own network
if not self.share_params:
if self.centralised:
output = torch.stack(
[net(inputs) for i, net in enumerate(self.agent_networks)],
dim=-2,
)
else:
output = torch.stack(
[
net(inputs[..., i, :])
for i, net in enumerate(self.agent_networks)
],
dim=-2,
)
# If parameters are shared, agents use the same network
else:
output = self.agent_networks[0](inputs)
output = self.net_call(inputs, self.params)

if self.centralised:
# If the parameters are shared, and it is centralised, all agents will have the same output
Expand Down