From b54f73d885737d61e72a0c1c6dd0270bcd00620c Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Feb 2024 21:01:02 +0000 Subject: [PATCH 1/5] init --- test/test_modules.py | 38 ++-- torchrl/modules/models/multiagent.py | 285 +++++++++++++++------------ 2 files changed, 182 insertions(+), 141 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index 68917a10d16..896ead4f7fb 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -858,24 +858,15 @@ def _get_mock_input_td( @pytest.mark.parametrize("n_agents", [1, 3]) @pytest.mark.parametrize("share_params", [True, False]) @pytest.mark.parametrize("centralised", [True, False]) - @pytest.mark.parametrize( - "batch", - [ - (10,), - ( - 10, - 3, - ), - (), - ], - ) - def test_mlp( + @pytest.mark.parametrize("n_agent_inputs", [6, None]) + @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) + def test_multiagent_mlp( self, n_agents, centralised, share_params, batch, - n_agent_inputs=6, + n_agent_inputs, n_agent_outputs=2, ): torch.manual_seed(0) @@ -887,6 +878,8 @@ def test_mlp( share_params=share_params, depth=2, ) + if n_agent_inputs is None: + n_agent_inputs = 6 td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch) obs = td.get(("agents", "observation")) @@ -924,14 +917,27 @@ def test_mlp( @pytest.mark.parametrize("n_agents", [1, 3]) @pytest.mark.parametrize("share_params", [True, False]) @pytest.mark.parametrize("centralised", [True, False]) + @pytest.mark.parametrize("channels", [3, None]) @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) - def test_cnn( - self, n_agents, centralised, share_params, batch, x=50, y=50, channels=3 + def test_multiagent_cnn( + self, + n_agents, + centralised, + share_params, + batch, + channels, + x=50, + y=50, ): torch.manual_seed(0) cnn = MultiAgentConvNet( - n_agents=n_agents, centralised=centralised, share_params=share_params + n_agents=n_agents, + centralised=centralised, + share_params=share_params, + in_features=channels, ) + if channels is None: + channels = 3 td = TensorDict( { "agents": TensorDict( diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index f6b80ead12c..a1508d4d858 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -2,20 +2,113 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations +import abc from typing import Optional, Sequence, Tuple, Type, Union import numpy as np import torch -from torch import nn +from tensordict import TensorDict +from torch import nn from torchrl.data.utils import DEVICE_TYPING from torchrl.modules.models import ConvNet, MLP -class MultiAgentMLP(nn.Module): +class MultiAgentNetBase(nn.Module): + """A base class for multi-agent networks.""" + _empty_net: nn.Module + + def __init__( + self, + *, + n_agents: int, + centralised: bool, + share_params: bool, + agent_dim: int, + **kwargs, + ): + super().__init__() + + self.n_agents = n_agents + self.share_params = share_params + self.centralised = centralised + self.agent_dim = agent_dim + + agent_networks = [ + self._build_single_net(**kwargs) + for _ in range(self.n_agents if not self.share_params else 1) + ] + kwargs["device"] = "meta" + self.__dict__["_empty_net"] = self._build_single_net(**kwargs) + if self.share_params: + self.params = TensorDict.from_module(agent_networks[0], as_module=True) + else: + self.params = TensorDict.from_modules(*agent_networks, as_module=True) + + @abc.abstractmethod + def _build_single_net(self, *, device, **kwargs): + ... + + @abc.abstractmethod + def _pre_forward_check(self, inputs): + ... + + @staticmethod + def vmap_func_module(module, *args, **kwargs): + def exec_module(params, *input): + with params.to_module(module): + return module(*input) + + return torch.vmap(exec_module, *args, **kwargs) + + def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + if len(inputs) > 1: + inputs = torch.cat([*inputs], -1) + else: + inputs = inputs[0] + + inputs = self._pre_forward_check(inputs) + + # If parameters are not shared, each agent has its own network + if not self.share_params: + if self.centralised: + output = self.vmap_func_module(self._empty_net, (0, None), (-2,))( + self.params, inputs + ) + else: + output = self.vmap_func_module( + self._empty_net, (0, self.agent_dim), (-2,) + )(self.params, inputs) + + # If parameters are shared, agents use the same network + else: + with self.params.to_module(self._empty_net): + output = self._empty_net(inputs) + + if self.centralised: + # If the parameters are shared, and it is centralised, all agents will have the same output + # We expand it to maintain the agent dimension, but values will be the same for all agents + n_agent_outputs = output.shape[-1] + output = output.view(*output.shape[:-1], n_agent_outputs) + output = output.unsqueeze(-2) + output = output.expand( + *output.shape[:-2], self.n_agents, n_agent_outputs + ) + + if output.shape[-2] != (self.n_agents): + raise ValueError( + f"Multi-agent network expected output with shape[-2]={self.n_agents}" + f" but got {output.shape}" + ) + + return output + + +class MultiAgentMLP(MultiAgentNetBase): """Mult-agent MLP. This is an MLP that can be used in multi-agent contexts. @@ -150,87 +243,52 @@ def __init__( activation_class: Optional[Type[nn.Module]] = nn.Tanh, **kwargs, ): - super().__init__() self.n_agents = n_agents self.n_agent_inputs = n_agent_inputs self.n_agent_outputs = n_agent_outputs self.share_params = share_params self.centralised = centralised + self.num_cells = num_cells + self.activation_class = activation_class + self.depth = depth - self.agent_networks = nn.ModuleList( - [ - MLP( - in_features=n_agent_inputs - if not centralised - else n_agent_inputs * n_agents, - out_features=n_agent_outputs, - depth=depth, - num_cells=num_cells, - activation_class=activation_class, - device=device, - **kwargs, - ) - for _ in range(self.n_agents if not self.share_params else 1) - ] + super().__init__( + n_agents=n_agents, + centralised=centralised, + share_params=share_params, + device=device, + agent_dim=-2, + **kwargs, ) - def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: - if len(inputs) > 1: - inputs = torch.cat([*inputs], -1) - else: - inputs = inputs[0] - - if inputs.shape[-2:] != (self.n_agents, self.n_agent_inputs): + def _pre_forward_check(self, inputs): + if inputs.shape[-2] != self.n_agents: raise ValueError( - f"Multi-agent network expected input with last 2 dimensions {[self.n_agents, self.n_agent_inputs]}," + f"Multi-agent network expected input with shape[-2]={self.n_agents}," f" but got {inputs.shape}" ) - # If the model is centralized, agents have full observability if self.centralised: - inputs = inputs.reshape( - *inputs.shape[:-2], self.n_agents * self.n_agent_inputs - ) - - # If parameters are not shared, each agent has its own network - if not self.share_params: - if self.centralised: - output = torch.stack( - [net(inputs) for i, net in enumerate(self.agent_networks)], - dim=-2, - ) - else: - output = torch.stack( - [ - net(inputs[..., i, :]) - for i, net in enumerate(self.agent_networks) - ], - dim=-2, - ) - # If parameters are shared, agents use the same network - else: - output = self.agent_networks[0](inputs) - - if self.centralised: - # If the parameters are shared, and it is centralised, all agents will have the same output - # We expand it to maintain the agent dimension, but values will be the same for all agents - output = output.view(*output.shape[:-1], self.n_agent_outputs) - output = output.unsqueeze(-2) - output = output.expand( - *output.shape[:-2], self.n_agents, self.n_agent_outputs - ) - - if output.shape[-2:] != (self.n_agents, self.n_agent_outputs): - raise ValueError( - f"Multi-agent network expected output with last 2 dimensions {[self.n_agents, self.n_agent_outputs]}," - f" but got {output.shape}" - ) - - return output + inputs = inputs.flatten(-2, -1) + return inputs + + def _build_single_net(self, *, device, **kwargs): + n_agent_inputs = self.n_agent_inputs + if self.centralised and n_agent_inputs is not None: + n_agent_inputs = self.n_agent_inputs * self.n_agents + return MLP( + in_features=n_agent_inputs, + out_features=self.n_agent_outputs, + depth=self.depth, + num_cells=self.num_cells, + activation_class=self.activation_class, + device=device, + **kwargs, + ) -class MultiAgentConvNet(nn.Module): +class MultiAgentConvNet(MultiAgentNetBase): """Multi-agent CNN. In MARL settings, agents may or may not share the same policy for their actions: we say that the parameters can be shared or not. Similarly, a network may take the entire observation space (across agents) or on a per-agent basis to compute its output, which we refer to as "centralized" and "non-centralized", respectively. @@ -243,6 +301,10 @@ class MultiAgentConvNet(nn.Module): share_params (bool): If ``True``, the same :class:`~torchrl.modules.ConvNet` will be used to make the forward pass for all agents (homogeneous policies). Otherwise, each agent will use a different :class:`~torchrl.modules.ConvNet` to process its input (heterogeneous policies). + + Keyword Args: + in_features (int, optional): the input feature dimension. If left to ``None``, + a lazy module is used. device (str or torch.device, optional): device to create the module on. num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an iterable is provided, @@ -374,36 +436,43 @@ def __init__( n_agents: int, centralised: bool, share_params: bool, - device: Optional[DEVICE_TYPING] = None, - num_cells: Optional[Sequence[int]] = None, + *, + in_features: int | None = None, + device: DEVICE_TYPING | None = None, + num_cells: Sequence[int] | None = None, kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 5, strides: Union[Sequence, int] = 2, paddings: Union[Sequence, int] = 0, activation_class: Type[nn.Module] = nn.ELU, **kwargs, ): - super().__init__() - - self.n_agents = n_agents - self.centralised = centralised - self.share_params = share_params + self.in_features = in_features + self.num_cells = num_cells + self.strides = strides + self.kernel_sizes = kernel_sizes + self.paddings = paddings + self.activation_class = activation_class + super().__init__( + n_agents=n_agents, + centralised=centralised, + share_params=share_params, + device=device, + agent_dim=-4, + ) - self.agent_networks = nn.ModuleList( - [ - ConvNet( - num_cells=num_cells, - kernel_sizes=kernel_sizes, - strides=strides, - paddings=paddings, - activation_class=activation_class, - device=device, - **kwargs, - ) - for _ in range(self.n_agents if not self.share_params else 1) - ] + def _build_single_net(self, *, device, **kwargs): + return ConvNet( + in_features=self.in_features, + num_cells=self.num_cells, + kernel_sizes=self.kernel_sizes, + strides=self.strides, + paddings=self.paddings, + activation_class=self.activation_class, + device=device, + **kwargs, ) - def forward(self, inputs: torch.Tensor): + def _pre_forward_check(self, inputs): if len(inputs.shape) < 4: raise ValueError( """Multi-agent network expects (*batch_size, agent_index, x, y, channels)""" @@ -412,44 +481,10 @@ def forward(self, inputs: torch.Tensor): raise ValueError( f"""Multi-agent network expects {self.n_agents} but got {inputs.shape[-4]}""" ) - # If the model is centralized, agents have full observability if self.centralised: - shape = ( - *inputs.shape[:-4], - self.n_agents * inputs.shape[-3], - inputs.shape[-2], - inputs.shape[-1], - ) - inputs = torch.reshape(inputs, shape) - - # If the parameters are not shared, each agent has its own network - if not self.share_params: - if self.centralised: - output = torch.stack( - [net(inputs) for net in self.agent_networks], dim=-2 - ) - else: - output = torch.stack( - [ - net(inp) - for i, (net, inp) in enumerate( - zip(self.agent_networks, inputs.unbind(-4)) - ) - ], - dim=-2, - ) - else: - output = self.agent_networks[0](inputs) - if self.centralised: - # If the parameters are shared, and it is centralised all agents will have the same output. - # We expand it to maintain the agent dimension, but values will be the same for all agents - n_agent_outputs = output.shape[-1] - output = output.view(*output.shape[:-1], n_agent_outputs) - output = output.unsqueeze(-2) - output = output.expand( - *output.shape[:-2], self.n_agents, n_agent_outputs - ) - return output + # If the model is centralized, agents have full observability + inputs = torch.flatten(inputs, -4, -3) + return inputs class Mixer(nn.Module): From 5b6ee6e46de1de4451d68ad5cdcdbce2e6508354 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 17 Feb 2024 20:46:26 +0000 Subject: [PATCH 2/5] amend --- torchrl/modules/models/multiagent.py | 40 ++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index a1508d4d858..f6c12672c60 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -20,6 +20,7 @@ class MultiAgentNetBase(nn.Module): """A base class for multi-agent networks.""" + _empty_net: nn.Module def __init__( @@ -42,8 +43,21 @@ def __init__( self._build_single_net(**kwargs) for _ in range(self.n_agents if not self.share_params else 1) ] + initialized = True + for p in agent_networks[0].parameters(): + if isinstance(p, torch.nn.UninitializedParameter): + initialized = False + break + self.initialized = initialized + if not self.initialized: + self._agents_nets = agent_networks + else: + self._agents_nets = None + self._make_params(agent_networks) kwargs["device"] = "meta" self.__dict__["_empty_net"] = self._build_single_net(**kwargs) + + def _make_params(self, agent_networks): if self.share_params: self.params = TensorDict.from_module(agent_networks[0], as_module=True) else: @@ -53,6 +67,25 @@ def __init__( def _build_single_net(self, *, device, **kwargs): ... + def _check_init(self, inputs): + if self.initialized: + return + if not self.share_params: + if self.centralised: + for model in self._agents_nets: + model(inputs) + else: + for input, model in zip( + inputs.unbind(self.agent_dim), self._agents_nets + ): + model(input) + # If parameters are shared, agents use the same network + else: + self._agents_nets[0](inputs) + self._make_params(self._agents_nets) + del self._agents_nets + self.initialized = True + @abc.abstractmethod def _pre_forward_check(self, inputs): ... @@ -72,7 +105,7 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: inputs = inputs[0] inputs = self._pre_forward_check(inputs) - + self._check_init(inputs) # If parameters are not shared, each agent has its own network if not self.share_params: if self.centralised: @@ -461,8 +494,11 @@ def __init__( ) def _build_single_net(self, *, device, **kwargs): + in_features = self.in_features + if self.centralised and in_features is not None: + in_features = in_features * self.n_agents return ConvNet( - in_features=self.in_features, + in_features=in_features, num_cells=self.num_cells, kernel_sizes=self.kernel_sizes, strides=self.strides, From 66343ee5e82d2150fbbc63dbb1fc62adc2629e8f Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 17 Feb 2024 20:58:34 +0000 Subject: [PATCH 3/5] amend --- torchrl/modules/models/multiagent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index f6c12672c60..aa2d5f35320 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -159,7 +159,8 @@ class MultiAgentMLP(MultiAgentNetBase): Otherwise, each agent will only use its data as input. Args: - n_agent_inputs (int): number of inputs for each agent. + n_agent_inputs (int): number of inputs for each agent. If left to ``None``, + the number of inputs is lazily instantiated during the first call. n_agent_outputs (int): number of outputs for each agent. n_agents (int): number of agents. centralised (bool): If `centralised` is True, each agent will use the inputs of all agents to compute its output From 5c8834f2a0b652dc9038ef83bad5b03a738b3a8b Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 19 Feb 2024 18:34:24 +0000 Subject: [PATCH 4/5] amend --- torchrl/modules/models/multiagent.py | 43 +++++++++------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index aa2d5f35320..bcb3ce067ee 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -49,14 +49,16 @@ def __init__( initialized = False break self.initialized = initialized - if not self.initialized: - self._agents_nets = agent_networks - else: - self._agents_nets = None - self._make_params(agent_networks) + self._make_params(agent_networks) kwargs["device"] = "meta" self.__dict__["_empty_net"] = self._build_single_net(**kwargs) + @property + def _vmap_randomness(self): + if self.initialized: + return "error" + return "same" + def _make_params(self, agent_networks): if self.share_params: self.params = TensorDict.from_module(agent_networks[0], as_module=True) @@ -67,25 +69,6 @@ def _make_params(self, agent_networks): def _build_single_net(self, *, device, **kwargs): ... - def _check_init(self, inputs): - if self.initialized: - return - if not self.share_params: - if self.centralised: - for model in self._agents_nets: - model(inputs) - else: - for input, model in zip( - inputs.unbind(self.agent_dim), self._agents_nets - ): - model(input) - # If parameters are shared, agents use the same network - else: - self._agents_nets[0](inputs) - self._make_params(self._agents_nets) - del self._agents_nets - self.initialized = True - @abc.abstractmethod def _pre_forward_check(self, inputs): ... @@ -105,16 +88,18 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: inputs = inputs[0] inputs = self._pre_forward_check(inputs) - self._check_init(inputs) # If parameters are not shared, each agent has its own network if not self.share_params: if self.centralised: - output = self.vmap_func_module(self._empty_net, (0, None), (-2,))( - self.params, inputs - ) + output = self.vmap_func_module( + self._empty_net, (0, None), (-2,), randomness=self._vmap_randomness + )(self.params, inputs) else: output = self.vmap_func_module( - self._empty_net, (0, self.agent_dim), (-2,) + self._empty_net, + (0, self.agent_dim), + (-2,), + randomness=self._vmap_randomness, )(self.params, inputs) # If parameters are shared, agents use the same network From 030d0dd11e277d35131b554826e9095b2160cb22 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 20 Feb 2024 13:12:15 -0800 Subject: [PATCH 5/5] amend --- test/test_helpers.py | 2 +- test/test_modules.py | 72 ++++++++++++++++++++++++++++ torchrl/modules/models/multiagent.py | 4 +- 3 files changed, 75 insertions(+), 3 deletions(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index eb9620001c7..e39e6cc6082 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -512,7 +512,7 @@ def test_initialize_stats_from_observation_norms(device, keys, composed, initial with pytest.raises( ValueError, match="Attempted to use an uninitialized parameter" ): - pre_init_state_dict = t_env.transform.state_dict() + t_env.transform.state_dict() return pre_init_state_dict = t_env.transform.state_dict() initialize_observation_norm_transforms( diff --git a/test/test_modules.py b/test/test_modules.py index 896ead4f7fb..3d01fd04768 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -914,6 +914,39 @@ def test_multiagent_mlp( # same input different output assert not torch.allclose(out[..., i, :], out[..., j, :]) + def test_multiagent_mlp_lazy(self): + mlp = MultiAgentMLP( + n_agent_inputs=None, + n_agent_outputs=6, + n_agents=3, + centralised=True, + share_params=False, + depth=2, + ) + optim = torch.optim.Adam(mlp.parameters()) + for p in mlp.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for _ in range(2): + td = self._get_mock_input_td(3, 4, batch=(10,)) + obs = td.get(("agents", "observation")) + out = mlp(obs) + out.mean().backward() + optim.step() + for p in mlp.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + @pytest.mark.parametrize("n_agents", [1, 3]) @pytest.mark.parametrize("share_params", [True, False]) @pytest.mark.parametrize("centralised", [True, False]) @@ -979,6 +1012,45 @@ def test_multiagent_cnn( # same input different output assert not torch.allclose(out[..., i, :], out[..., j, :]) + def test_multiagent_cnn_lazy(self): + cnn = MultiAgentConvNet( + n_agents=5, + centralised=False, + share_params=False, + in_features=None, + ) + optim = torch.optim.Adam(cnn.parameters()) + for p in cnn.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for _ in range(2): + td = TensorDict( + { + "agents": TensorDict( + {"observation": torch.randn(10, 5, 3, 50, 50)}, + [10, 5], + ) + }, + batch_size=[10], + ) + obs = td[("agents", "observation")] + out = cnn(obs) + out.mean().backward() + optim.step() + for p in cnn.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + @pytest.mark.parametrize("n_agents", [1, 3]) @pytest.mark.parametrize( "batch", diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index bcb3ce067ee..6229aa30fe3 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -144,7 +144,7 @@ class MultiAgentMLP(MultiAgentNetBase): Otherwise, each agent will only use its data as input. Args: - n_agent_inputs (int): number of inputs for each agent. If left to ``None``, + n_agent_inputs (int or None): number of inputs for each agent. If ``None``, the number of inputs is lazily instantiated during the first call. n_agent_outputs (int): number of outputs for each agent. n_agents (int): number of agents. @@ -251,7 +251,7 @@ class MultiAgentMLP(MultiAgentNetBase): def __init__( self, - n_agent_inputs: int, + n_agent_inputs: int | None, n_agent_outputs: int, n_agents: int, centralised: bool,