<a href="https://colab.research.google.com/github/williamtbarker/ML4Molecules/blob/main/08_Pre_trained_models_complete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Using pre-trained models

If the model parameters and optimized weights are available, the model can be used to make predictions or train further (transfer learning)

Here, we will use the graph neural network from the published [article](https://pubs.rsc.org/en/Content/ArticleLanding/2022/SC/D2SC04676H). We will the the molecular property lowest singlet excitation which could correspond to absorption maxima.

First, we will download the model.



In [1]:
# download the parameters and weights
! wget https://data.materialsdatafacility.org/mdf_open/ocelotml_2d_v1.2/s0t1_3gen/best_r2.pt
! wget https://data.materialsdatafacility.org/mdf_open/ocelotml_2d_v1.2/s0t1_3gen/params.json

--2024-01-09 17:14:33--  https://data.materialsdatafacility.org/mdf_open/ocelotml_2d_v1.2/s0t1_3gen/best_r2.pt
Resolving data.materialsdatafacility.org (data.materialsdatafacility.org)... 141.142.218.119
Connecting to data.materialsdatafacility.org (data.materialsdatafacility.org)|141.142.218.119|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified
Saving to: ‘best_r2.pt’

best_r2.pt              [           <=>      ]  45.22M  21.5MB/s    in 2.1s    

2024-01-09 17:14:36 (21.5 MB/s) - ‘best_r2.pt’ saved [47415343]

--2024-01-09 17:14:36--  https://data.materialsdatafacility.org/mdf_open/ocelotml_2d_v1.2/s0t1_3gen/params.json
Resolving data.materialsdatafacility.org (data.materialsdatafacility.org)... 141.142.218.119
Connecting to data.materialsdatafacility.org (data.materialsdatafacility.org)|141.142.218.119|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/json]
Saving to: ‘params.json’

params.json    

In [2]:
# install the packages
! pip install dgl
! pip install dgllife
! pip install rdkit

Collecting dgl
  Downloading dgl-1.1.3-cp310-cp310-manylinux1_x86_64.whl (6.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.5/6.5 MB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dgl
Successfully installed dgl-1.1.3
Collecting dgllife
  Downloading dgllife-0.3.2-py3-none-any.whl (226 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.1/226.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dgllife
Successfully installed dgllife-0.3.2
Collecting rdkit
  Downloading rdkit-2023.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.4/34.4 MB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit
Successfully installed rdkit-2023.9.4


This is the code for the MPNN used in the article

In [3]:
import torch
import torch.nn as nn
from dgl.nn.pytorch import Set2Set
from dgllife.model.gnn import MPNNGNN


class MPNN_readout(nn.Module):

    def __init__(self,
                 node_in_feats,
                 edge_in_feats,
                 node_out_feats=64,
                 edge_hidden_feats=128,
                 n_tasks=1,
                 num_step_message_passing=6,
                 num_step_set2set=6,
                 dropout=0,
                 num_layer_set2set=3, descriptor_feats=0):
        super(MPNN_readout, self).__init__()

        self.gnn = MPNNGNN(node_in_feats=node_in_feats,
                           node_out_feats=node_out_feats,
                           edge_in_feats=edge_in_feats,
                           edge_hidden_feats=edge_hidden_feats,
                           num_step_message_passing=num_step_message_passing)
        self.readout = Set2Set(input_dim=node_out_feats,
                               n_iters=num_step_set2set,
                               n_layers=num_layer_set2set)
        self.predict = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(2 * node_out_feats + descriptor_feats, node_out_feats),
            nn.ReLU(),
            nn.BatchNorm1d(node_out_feats),
            nn.Linear(node_out_feats, n_tasks)
        )

    def forward(self, g, node_feats, edge_feats, concat_feats=None):
        node_feats = self.gnn(g, node_feats, edge_feats)
        graph_feats = self.readout(g, node_feats)
        if concat_feats != None:
            final_feats = torch.cat((graph_feats, concat_feats), dim=1)
        else:
            final_feats = graph_feats
        return self.predict(final_feats)

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


Getting the model parameters from the params.json file

In [4]:
import json

with open("params.json") as f:
  params = json.load(f)

params

{'dropout': 0.7512145577910623,
 'edge_hidden_feats': 195,
 'node_out_feats': 234,
 'num_layer_set2set': 1,
 'num_step_message_passing': 7,
 'num_step_set2set': 4}

We need to add the node_in_feats which is the length of CanonicalAtomFeatures , 74 and edge_in_feats is 12 from the CanonicalBondFeaturizer.

In [5]:
params.update(
    {
      "node_in_feats" : 74,
      "edge_in_feats": 12
    }
)

Let's create the model object

In [6]:
model = MPNN_readout(**params)
model

MPNN_readout(
  (gnn): MPNNGNN(
    (project_node_feats): Sequential(
      (0): Linear(in_features=74, out_features=234, bias=True)
      (1): ReLU()
    )
    (gnn_layer): NNConv(
      (edge_func): Sequential(
        (0): Linear(in_features=12, out_features=195, bias=True)
        (1): ReLU()
        (2): Linear(in_features=195, out_features=54756, bias=True)
      )
    )
    (gru): GRU(234, 234)
  )
  (readout): Set2Set(
    n_iters=4
    (lstm): LSTM(468, 234)
  )
  (predict): Sequential(
    (0): Dropout(p=0.7512145577910623, inplace=False)
    (1): Linear(in_features=468, out_features=234, bias=True)
    (2): ReLU()
    (3): BatchNorm1d(234, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Linear(in_features=234, out_features=1, bias=True)
  )
)

We have the model parameters set. But we don't have the optimized weights. Let's load the weights from the pre-trained model.

In [7]:
model.load_state_dict(torch.load("best_r2.pt", map_location=torch.device('cpu')))

<All keys matched successfully>

Create the graphs for input to model

In [8]:
# import from rdkit and dgl-lifesci
from rdkit import Chem
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer, \
mol_to_bigraph

# create the atom and bond featurizer object
atom_featurizer = CanonicalAtomFeaturizer(atom_data_field="hv")
bond_featurizer = CanonicalBondFeaturizer(bond_data_field="he")

# example smiles - ethane
smiles = "CC"

# mol_to_graph requires the RDKit molecule and featurizers
mol = Chem.MolFromSmiles(smiles)
graph = mol_to_bigraph(mol, node_featurizer=atom_featurizer,
                     edge_featurizer=bond_featurizer)

# display the graph object
graph

Graph(num_nodes=2, num_edges=2,
      ndata_schemes={'hv': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={'he': Scheme(shape=(12,), dtype=torch.float32)})

Make predictions

In [9]:
model.eval()
node_feats = graph.ndata["hv"]
edge_feats = graph.edata["he"]
model(graph, node_feats, edge_feats)

tensor([[4.2528]], grad_fn=<AddmmBackward0>)

## Saving a trained pytorch model

Use the `torch.save` function and pass in the model state_dict and name

In [10]:
torch.save(model.state_dict(), "my_model.pt")

## Transfer learning

When data is limited for prediction one task, a model that is trained on another task with large data can be used to produce models with higher accuracy than starting model training from scratch.

Let's say there is a model to predict HOMO-LUMO gap trained on the QM9 dataset. If your task is to now predict the HOMO energies, there is no need to start the model training from scratch. You can use the optimized weights from the HOMO-LUMO gap predictor and not change them for inner layers. Only the weights for the penultimate layers could be optimized for the HOMO energy prediction model. This process is called transfer learning.

Here, we will freeze the `gnn` and `readout` layers and allow the weights on the `predict` layer to be trainable.

In [11]:
model

MPNN_readout(
  (gnn): MPNNGNN(
    (project_node_feats): Sequential(
      (0): Linear(in_features=74, out_features=234, bias=True)
      (1): ReLU()
    )
    (gnn_layer): NNConv(
      (edge_func): Sequential(
        (0): Linear(in_features=12, out_features=195, bias=True)
        (1): ReLU()
        (2): Linear(in_features=195, out_features=54756, bias=True)
      )
    )
    (gru): GRU(234, 234)
  )
  (readout): Set2Set(
    n_iters=4
    (lstm): LSTM(468, 234)
  )
  (predict): Sequential(
    (0): Dropout(p=0.7512145577910623, inplace=False)
    (1): Linear(in_features=468, out_features=234, bias=True)
    (2): ReLU()
    (3): BatchNorm1d(234, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Linear(in_features=234, out_features=1, bias=True)
  )
)

Let's look at the parameter (weights)

In [12]:
for param in model.parameters():
  print(param)

Parameter containing:
tensor([[-0.0500, -0.0636,  0.1109,  ..., -0.0038, -0.1115,  0.0940],
        [-0.0169, -0.0582,  0.0980,  ..., -0.0561,  0.0750, -0.0510],
        [ 0.0817, -0.1027, -0.0433,  ..., -0.1036, -0.0646,  0.0198],
        ...,
        [ 0.0446,  0.0874, -0.0721,  ..., -0.1003, -0.0632, -0.0235],
        [-0.0862, -0.1119, -0.0681,  ...,  0.0903,  0.0373,  0.0556],
        [-0.1281, -0.1137, -0.1008,  ..., -0.0791,  0.0059,  0.0099]],
       requires_grad=True)
Parameter containing:
tensor([-0.0618,  0.0876,  0.0888, -0.0437,  0.0954,  0.0804,  0.0605,  0.0722,
        -0.0998,  0.0821,  0.0808, -0.0087,  0.0784,  0.0878, -0.0134, -0.0404,
         0.0734, -0.1135,  0.0478, -0.0202,  0.0908, -0.0439,  0.1129, -0.0337,
         0.0517,  0.1183,  0.0234, -0.0877,  0.0410, -0.0884, -0.0137,  0.0063,
        -0.0712, -0.0483, -0.0751, -0.0818, -0.0586, -0.0901, -0.0355,  0.0873,
        -0.0924, -0.0659, -0.0328, -0.0723,  0.1102, -0.0954,  0.0691, -0.0870,
        -0.0829

We see all have gradients. To freeze the weights we need to remove the gradients

In [13]:
for param in model.parameters():
  param.requires_grad = False

All weights are frozen. If the model is used for training, the model weights will not change. This implies not learning.

In [14]:
for param in model.parameters():
  print(param)

Parameter containing:
tensor([[-0.0500, -0.0636,  0.1109,  ..., -0.0038, -0.1115,  0.0940],
        [-0.0169, -0.0582,  0.0980,  ..., -0.0561,  0.0750, -0.0510],
        [ 0.0817, -0.1027, -0.0433,  ..., -0.1036, -0.0646,  0.0198],
        ...,
        [ 0.0446,  0.0874, -0.0721,  ..., -0.1003, -0.0632, -0.0235],
        [-0.0862, -0.1119, -0.0681,  ...,  0.0903,  0.0373,  0.0556],
        [-0.1281, -0.1137, -0.1008,  ..., -0.0791,  0.0059,  0.0099]])
Parameter containing:
tensor([-0.0618,  0.0876,  0.0888, -0.0437,  0.0954,  0.0804,  0.0605,  0.0722,
        -0.0998,  0.0821,  0.0808, -0.0087,  0.0784,  0.0878, -0.0134, -0.0404,
         0.0734, -0.1135,  0.0478, -0.0202,  0.0908, -0.0439,  0.1129, -0.0337,
         0.0517,  0.1183,  0.0234, -0.0877,  0.0410, -0.0884, -0.0137,  0.0063,
        -0.0712, -0.0483, -0.0751, -0.0818, -0.0586, -0.0901, -0.0355,  0.0873,
        -0.0924, -0.0659, -0.0328, -0.0723,  0.1102, -0.0954,  0.0691, -0.0870,
        -0.0829, -0.0715,  0.0151, -0.0989

We want the predict layer weights to be trainable. Let's not freeze those

In [15]:
for param in model.predict.parameters():
  param.requires_grad = True

Check the weights again

In [16]:
for param in model.parameters():
  print(param)

Parameter containing:
tensor([[-0.0500, -0.0636,  0.1109,  ..., -0.0038, -0.1115,  0.0940],
        [-0.0169, -0.0582,  0.0980,  ..., -0.0561,  0.0750, -0.0510],
        [ 0.0817, -0.1027, -0.0433,  ..., -0.1036, -0.0646,  0.0198],
        ...,
        [ 0.0446,  0.0874, -0.0721,  ..., -0.1003, -0.0632, -0.0235],
        [-0.0862, -0.1119, -0.0681,  ...,  0.0903,  0.0373,  0.0556],
        [-0.1281, -0.1137, -0.1008,  ..., -0.0791,  0.0059,  0.0099]])
Parameter containing:
tensor([-0.0618,  0.0876,  0.0888, -0.0437,  0.0954,  0.0804,  0.0605,  0.0722,
        -0.0998,  0.0821,  0.0808, -0.0087,  0.0784,  0.0878, -0.0134, -0.0404,
         0.0734, -0.1135,  0.0478, -0.0202,  0.0908, -0.0439,  0.1129, -0.0337,
         0.0517,  0.1183,  0.0234, -0.0877,  0.0410, -0.0884, -0.0137,  0.0063,
        -0.0712, -0.0483, -0.0751, -0.0818, -0.0586, -0.0901, -0.0355,  0.0873,
        -0.0924, -0.0659, -0.0328, -0.0723,  0.1102, -0.0954,  0.0691, -0.0870,
        -0.0829, -0.0715,  0.0151, -0.0989