Skip to content

Commit

Permalink
Allow to configure the depth of the MLP in output modules (#314)
Browse files Browse the repository at this point in the history
* Added an MLP module.
Allow number of hidden MLP layers in Scalar OutputModel to be
configured from the yaml input.

* Make state dictionary compatible with previous checkpoints

* Change name to num_hidden_layers

* Update doc
  • Loading branch information
RaulPPelaez committed Apr 5, 2024
1 parent 74702da commit 6694816
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 33 deletions.
8 changes: 8 additions & 0 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
activation=args["activation"],
reduce_op=args["reduce_op"],
dtype=dtype,
num_hidden_layers=args.get("output_mlp_num_layers", 0),
)

# combine representation and output network
Expand Down Expand Up @@ -232,6 +233,13 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
model.prior_model[-1].enable = True

state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()}
# In ET, before we had output_model.output_network.{0,1}.update_net.[0-9].{weight,bias}
# Now we have output_model.output_network.{0,1}.update_net.layers.[0-9].{weight,bias}
# This change was introduced in https://github.com/torchmd/torchmd-net/pull/314
state_dict = {
re.sub(r"update_net\.(\d+)\.", r"update_net.layers.\1.", k): v
for k, v in state_dict.items()
}
model.load_state_dict(state_dict)
return model.to(device)

Expand Down
78 changes: 54 additions & 24 deletions torchmdnet/models/output_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from typing import Optional
import torch
from torch import nn
from torchmdnet.models.utils import act_class_mapping, GatedEquivariantBlock, scatter
from torchmdnet.models.utils import (
act_class_mapping,
GatedEquivariantBlock,
scatter,
MLP,
)
from torchmdnet.utils import atomic_masses
from torchmdnet.extensions import is_current_stream_capturing
from warnings import warn
Expand All @@ -20,6 +25,7 @@ class OutputModel(nn.Module, metaclass=ABCMeta):
Derive this class to make custom output models.
As an example, have a look at the :py:mod:`torchmdnet.output_modules.Scalar` output model.
"""

def __init__(self, allow_prior_model, reduce_op):
super(OutputModel, self).__init__()
self.allow_prior_model = allow_prior_model
Expand Down Expand Up @@ -60,24 +66,23 @@ def __init__(
allow_prior_model=True,
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(Scalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
act_class(),
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
self.output_network = MLP(
in_channels=hidden_channels,
out_channels=1,
hidden_channels=hidden_channels // 2,
activation=activation,
num_hidden_layers=kwargs.get("num_layers", 0),
dtype=dtype,
)

self.reset_parameters()

def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)
self.output_network.reset_parameters()

def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch):
return self.output_network(x)
Expand All @@ -91,10 +96,13 @@ def __init__(
allow_prior_model=True,
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(EquivariantScalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
if kwargs.get("num_layers", 0) > 0:
warn("num_layers is not used in EquivariantScalar")
self.output_network = nn.ModuleList(
[
GatedEquivariantBlock(
Expand Down Expand Up @@ -125,14 +133,20 @@ def pre_reduce(self, x, v, z, pos, batch):

class DipoleMoment(Scalar):
def __init__(
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
self,
hidden_channels,
activation="silu",
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(DipoleMoment, self).__init__(
hidden_channels,
activation,
allow_prior_model=False,
reduce_op=reduce_op,
dtype=dtype,
**kwargs
)
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
self.register_buffer("atomic_mass", atomic_mass)
Expand All @@ -152,14 +166,20 @@ def post_reduce(self, x):

class EquivariantDipoleMoment(EquivariantScalar):
def __init__(
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
self,
hidden_channels,
activation="silu",
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(EquivariantDipoleMoment, self).__init__(
hidden_channels,
activation,
allow_prior_model=False,
reduce_op=reduce_op,
dtype=dtype,
**kwargs
)
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
self.register_buffer("atomic_mass", atomic_mass)
Expand All @@ -180,27 +200,31 @@ def post_reduce(self, x):

class ElectronicSpatialExtent(OutputModel):
def __init__(
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
self,
hidden_channels,
activation="silu",
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(ElectronicSpatialExtent, self).__init__(
allow_prior_model=False, reduce_op=reduce_op
)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
act_class(),
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
self.output_network = MLP(
in_channels=hidden_channels,
out_channels=1,
hidden_channels=hidden_channels // 2,
activation=activation,
num_hidden_layers=kwargs.get("num_layers", 0),
dtype=dtype,
)
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
self.register_buffer("atomic_mass", atomic_mass)

self.reset_parameters()

def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)
self.output_network.reset_parameters()

def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch):
x = self.output_network(x)
Expand All @@ -219,14 +243,20 @@ class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent):

class EquivariantVectorOutput(EquivariantScalar):
def __init__(
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
self,
hidden_channels,
activation="silu",
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(EquivariantVectorOutput, self).__init__(
hidden_channels,
activation,
allow_prior_model=False,
reduce_op="sum",
dtype=dtype,
**kwargs
)

def pre_reduce(self, x, v, z, pos, batch):
Expand Down
69 changes: 60 additions & 9 deletions torchmdnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,58 @@ def forward(self, distances: Tensor) -> Tensor:
return cutoffs


class MLP(nn.Module):
"""A simple multi-layer perceptron with a given number of layers and hidden channels.
The simplest MLP has no hidden layers and is composed of two linear layers with a non-linear activation function in between:
.. math::
\text{MLP}(x) = \text{Linear}_o(\text{act}(\text{Linear}_i(x)))
Where :math:`\text{Linear}_i` has input size :math:`\text{in_channels}` and output size :math:`\text{hidden_channels}` and :math:`\text{Linear}_o` has input size :math:`\text{hidden_channels}` and output size :math:`\text{out_channels}`.
Args:
in_channels (int): Number of input features.
out_channels (int): Number of output features.
hidden_channels (int): Number of hidden features.
activation (str): Activation function to use.
num_hidden_layers (int, optional): Number of hidden layers. Defaults to 0.
dtype (torch.dtype, optional): Data type to use. Defaults to torch.float32.
"""

def __init__(
self,
in_channels,
out_channels,
hidden_channels,
activation,
num_hidden_layers=0,
dtype=torch.float32,
):
super(MLP, self).__init__()
act_class = act_class_mapping[activation]
self.act = act_class()
self.layers = nn.Sequential()
self.layers.append(nn.Linear(in_channels, hidden_channels, dtype=dtype))
self.layers.append(self.act)
for _ in range(num_hidden_layers):
self.layers.append(nn.Linear(hidden_channels, hidden_channels, dtype=dtype))
self.layers.append(self.act)
self.layers.append(nn.Linear(hidden_channels, out_channels, dtype=dtype))

def reset_parameters(self):
for layer in self.layers:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
layer.bias.data.fill_(0)

def forward(self, x):
x = self.layers(x)
return x


class GatedEquivariantBlock(nn.Module):
"""Gated Equivariant Block as defined in Schütt et al. (2021):
Equivariant message passing for the prediction of tensorial properties and molecular spectra
Expand Down Expand Up @@ -462,21 +514,20 @@ def __init__(
)

act_class = act_class_mapping[activation]
self.update_net = nn.Sequential(
nn.Linear(hidden_channels * 2, intermediate_channels, dtype=dtype),
act_class(),
nn.Linear(intermediate_channels, out_channels * 2, dtype=dtype),
self.update_net = MLP(
in_channels=hidden_channels * 2,
out_channels=out_channels * 2,
hidden_channels=intermediate_channels,
activation=activation,
num_hidden_layers=0,
dtype=dtype,
)

self.act = act_class() if scalar_activation else None

def reset_parameters(self):
nn.init.xavier_uniform_(self.vec1_proj.weight)
nn.init.xavier_uniform_(self.vec2_proj.weight)
nn.init.xavier_uniform_(self.update_net[0].weight)
self.update_net[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.update_net[2].weight)
self.update_net[2].bias.data.fill_(0)
self.update_net.reset_parameters()

def forward(self, x, v):
vec1_buffer = self.vec1_proj(v)
Expand Down
1 change: 1 addition & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_argparse():
# model architecture
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model')
parser.add_argument('--output-mlp-num-layers', type=int, default=0, help='If the output model uses an MLP this will be the number of hidden layers, excluding the input and output layers.')
parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "lower_switch_distance"=4, "upper_switch_distance"=8}', action="extend", nargs="*")

# architectural args
Expand Down

0 comments on commit 6694816

Please sign in to comment.