From 6694816860598ad787a74215c23af509330783f6 Mon Sep 17 00:00:00 2001 From: Raul Date: Fri, 5 Apr 2024 15:01:47 +0200 Subject: [PATCH] Allow to configure the depth of the MLP in output modules (#314) * Added an MLP module. Allow number of hidden MLP layers in Scalar OutputModel to be configured from the yaml input. * Make state dictionary compatible with previous checkpoints * Change name to num_hidden_layers * Update doc --- torchmdnet/models/model.py | 8 +++ torchmdnet/models/output_modules.py | 78 ++++++++++++++++++++--------- torchmdnet/models/utils.py | 69 +++++++++++++++++++++---- torchmdnet/scripts/train.py | 1 + 4 files changed, 123 insertions(+), 33 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 913693043..bfb218fee 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -127,6 +127,7 @@ def create_model(args, prior_model=None, mean=None, std=None): activation=args["activation"], reduce_op=args["reduce_op"], dtype=dtype, + num_hidden_layers=args.get("output_mlp_num_layers", 0), ) # combine representation and output network @@ -232,6 +233,13 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): model.prior_model[-1].enable = True state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()} + # In ET, before we had output_model.output_network.{0,1}.update_net.[0-9].{weight,bias} + # Now we have output_model.output_network.{0,1}.update_net.layers.[0-9].{weight,bias} + # This change was introduced in https://github.com/torchmd/torchmd-net/pull/314 + state_dict = { + re.sub(r"update_net\.(\d+)\.", r"update_net.layers.\1.", k): v + for k, v in state_dict.items() + } model.load_state_dict(state_dict) return model.to(device) diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index 07271d132..bf408aa36 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -6,7 +6,12 @@ from typing import Optional import torch from torch import nn -from torchmdnet.models.utils import act_class_mapping, GatedEquivariantBlock, scatter +from torchmdnet.models.utils import ( + act_class_mapping, + GatedEquivariantBlock, + scatter, + MLP, +) from torchmdnet.utils import atomic_masses from torchmdnet.extensions import is_current_stream_capturing from warnings import warn @@ -20,6 +25,7 @@ class OutputModel(nn.Module, metaclass=ABCMeta): Derive this class to make custom output models. As an example, have a look at the :py:mod:`torchmdnet.output_modules.Scalar` output model. """ + def __init__(self, allow_prior_model, reduce_op): super(OutputModel, self).__init__() self.allow_prior_model = allow_prior_model @@ -60,24 +66,23 @@ def __init__( allow_prior_model=True, reduce_op="sum", dtype=torch.float, + **kwargs ): super(Scalar, self).__init__( allow_prior_model=allow_prior_model, reduce_op=reduce_op ) - act_class = act_class_mapping[activation] - self.output_network = nn.Sequential( - nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype), - act_class(), - nn.Linear(hidden_channels // 2, 1, dtype=dtype), + self.output_network = MLP( + in_channels=hidden_channels, + out_channels=1, + hidden_channels=hidden_channels // 2, + activation=activation, + num_hidden_layers=kwargs.get("num_layers", 0), + dtype=dtype, ) - self.reset_parameters() def reset_parameters(self): - nn.init.xavier_uniform_(self.output_network[0].weight) - self.output_network[0].bias.data.fill_(0) - nn.init.xavier_uniform_(self.output_network[2].weight) - self.output_network[2].bias.data.fill_(0) + self.output_network.reset_parameters() def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch): return self.output_network(x) @@ -91,10 +96,13 @@ def __init__( allow_prior_model=True, reduce_op="sum", dtype=torch.float, + **kwargs ): super(EquivariantScalar, self).__init__( allow_prior_model=allow_prior_model, reduce_op=reduce_op ) + if kwargs.get("num_layers", 0) > 0: + warn("num_layers is not used in EquivariantScalar") self.output_network = nn.ModuleList( [ GatedEquivariantBlock( @@ -125,7 +133,12 @@ def pre_reduce(self, x, v, z, pos, batch): class DipoleMoment(Scalar): def __init__( - self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float + self, + hidden_channels, + activation="silu", + reduce_op="sum", + dtype=torch.float, + **kwargs ): super(DipoleMoment, self).__init__( hidden_channels, @@ -133,6 +146,7 @@ def __init__( allow_prior_model=False, reduce_op=reduce_op, dtype=dtype, + **kwargs ) atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) @@ -152,7 +166,12 @@ def post_reduce(self, x): class EquivariantDipoleMoment(EquivariantScalar): def __init__( - self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float + self, + hidden_channels, + activation="silu", + reduce_op="sum", + dtype=torch.float, + **kwargs ): super(EquivariantDipoleMoment, self).__init__( hidden_channels, @@ -160,6 +179,7 @@ def __init__( allow_prior_model=False, reduce_op=reduce_op, dtype=dtype, + **kwargs ) atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) @@ -180,16 +200,23 @@ def post_reduce(self, x): class ElectronicSpatialExtent(OutputModel): def __init__( - self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float + self, + hidden_channels, + activation="silu", + reduce_op="sum", + dtype=torch.float, + **kwargs ): super(ElectronicSpatialExtent, self).__init__( allow_prior_model=False, reduce_op=reduce_op ) - act_class = act_class_mapping[activation] - self.output_network = nn.Sequential( - nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype), - act_class(), - nn.Linear(hidden_channels // 2, 1, dtype=dtype), + self.output_network = MLP( + in_channels=hidden_channels, + out_channels=1, + hidden_channels=hidden_channels // 2, + activation=activation, + num_hidden_layers=kwargs.get("num_layers", 0), + dtype=dtype, ) atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) @@ -197,10 +224,7 @@ def __init__( self.reset_parameters() def reset_parameters(self): - nn.init.xavier_uniform_(self.output_network[0].weight) - self.output_network[0].bias.data.fill_(0) - nn.init.xavier_uniform_(self.output_network[2].weight) - self.output_network[2].bias.data.fill_(0) + self.output_network.reset_parameters() def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch): x = self.output_network(x) @@ -219,7 +243,12 @@ class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent): class EquivariantVectorOutput(EquivariantScalar): def __init__( - self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float + self, + hidden_channels, + activation="silu", + reduce_op="sum", + dtype=torch.float, + **kwargs ): super(EquivariantVectorOutput, self).__init__( hidden_channels, @@ -227,6 +256,7 @@ def __init__( allow_prior_model=False, reduce_op="sum", dtype=dtype, + **kwargs ) def pre_reduce(self, x, v, z, pos, batch): diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index fde3d3098..c94d4188f 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -434,6 +434,58 @@ def forward(self, distances: Tensor) -> Tensor: return cutoffs +class MLP(nn.Module): + """A simple multi-layer perceptron with a given number of layers and hidden channels. + + The simplest MLP has no hidden layers and is composed of two linear layers with a non-linear activation function in between: + + .. math:: + + \text{MLP}(x) = \text{Linear}_o(\text{act}(\text{Linear}_i(x))) + + Where :math:`\text{Linear}_i` has input size :math:`\text{in_channels}` and output size :math:`\text{hidden_channels}` and :math:`\text{Linear}_o` has input size :math:`\text{hidden_channels}` and output size :math:`\text{out_channels}`. + + + Args: + in_channels (int): Number of input features. + out_channels (int): Number of output features. + hidden_channels (int): Number of hidden features. + activation (str): Activation function to use. + num_hidden_layers (int, optional): Number of hidden layers. Defaults to 0. + dtype (torch.dtype, optional): Data type to use. Defaults to torch.float32. + """ + + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + activation, + num_hidden_layers=0, + dtype=torch.float32, + ): + super(MLP, self).__init__() + act_class = act_class_mapping[activation] + self.act = act_class() + self.layers = nn.Sequential() + self.layers.append(nn.Linear(in_channels, hidden_channels, dtype=dtype)) + self.layers.append(self.act) + for _ in range(num_hidden_layers): + self.layers.append(nn.Linear(hidden_channels, hidden_channels, dtype=dtype)) + self.layers.append(self.act) + self.layers.append(nn.Linear(hidden_channels, out_channels, dtype=dtype)) + + def reset_parameters(self): + for layer in self.layers: + if isinstance(layer, nn.Linear): + nn.init.xavier_uniform_(layer.weight) + layer.bias.data.fill_(0) + + def forward(self, x): + x = self.layers(x) + return x + + class GatedEquivariantBlock(nn.Module): """Gated Equivariant Block as defined in Schütt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra @@ -462,21 +514,20 @@ def __init__( ) act_class = act_class_mapping[activation] - self.update_net = nn.Sequential( - nn.Linear(hidden_channels * 2, intermediate_channels, dtype=dtype), - act_class(), - nn.Linear(intermediate_channels, out_channels * 2, dtype=dtype), + self.update_net = MLP( + in_channels=hidden_channels * 2, + out_channels=out_channels * 2, + hidden_channels=intermediate_channels, + activation=activation, + num_hidden_layers=0, + dtype=dtype, ) - self.act = act_class() if scalar_activation else None def reset_parameters(self): nn.init.xavier_uniform_(self.vec1_proj.weight) nn.init.xavier_uniform_(self.vec2_proj.weight) - nn.init.xavier_uniform_(self.update_net[0].weight) - self.update_net[0].bias.data.fill_(0) - nn.init.xavier_uniform_(self.update_net[2].weight) - self.update_net[2].bias.data.fill_(0) + self.update_net.reset_parameters() def forward(self, x, v): vec1_buffer = self.vec1_proj(v) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index a51cfe45f..2e69212b4 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -74,6 +74,7 @@ def get_argparse(): # model architecture parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train') parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model') + parser.add_argument('--output-mlp-num-layers', type=int, default=0, help='If the output model uses an MLP this will be the number of hidden layers, excluding the input and output layers.') parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "lower_switch_distance"=4, "upper_switch_distance"=8}', action="extend", nargs="*") # architectural args