<a href="https://colab.research.google.com/github/vinayak2019/ml_for_molecules/blob/main/Pre_trained_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Using pre-trained models

If the model parameters and optimized weights are available, the model can be used to make predictions or train further (transfer learning)

Here, we will use the graph neural network from the published [article](https://pubs.rsc.org/en/Content/ArticleLanding/2022/SC/D2SC04676H). We will the the molecular property lowest singlet excitation which could correspond to absorption maxima.

First, we will download the model.



In [None]:
# download the parameters and weights
! wget https://data.materialsdatafacility.org/mdf_open/ocelotml_2d_v1.2/s0t1_3gen/best_r2.pt
! wget https://data.materialsdatafacility.org/mdf_open/ocelotml_2d_v1.2/s0t1_3gen/params.json

In [None]:
# install the packages
! pip install dgl
! pip install dgllife
! pip install rdkit

This is the code for the MPNN used in the article

In [None]:
import torch
import torch.nn as nn
from dgl.nn.pytorch import Set2Set
from dgllife.model.gnn import MPNNGNN


class MPNN_readout(nn.Module):

    def __init__(self,
                 node_in_feats,
                 edge_in_feats,
                 node_out_feats=64,
                 edge_hidden_feats=128,
                 n_tasks=1,
                 num_step_message_passing=6,
                 num_step_set2set=6,
                 dropout=0,
                 num_layer_set2set=3, descriptor_feats=0):
        super(MPNN_readout, self).__init__()

        self.gnn = MPNNGNN(node_in_feats=node_in_feats,
                           node_out_feats=node_out_feats,
                           edge_in_feats=edge_in_feats,
                           edge_hidden_feats=edge_hidden_feats,
                           num_step_message_passing=num_step_message_passing)
        self.readout = Set2Set(input_dim=node_out_feats,
                               n_iters=num_step_set2set,
                               n_layers=num_layer_set2set)
        self.predict = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(2 * node_out_feats + descriptor_feats, node_out_feats),
            nn.ReLU(),
            nn.BatchNorm1d(node_out_feats),
            nn.Linear(node_out_feats, n_tasks)
        )

    def forward(self, g, node_feats, edge_feats, concat_feats=None):
        node_feats = self.gnn(g, node_feats, edge_feats)
        graph_feats = self.readout(g, node_feats)
        if concat_feats != None:
            final_feats = torch.cat((graph_feats, concat_feats), dim=1)
        else:
            final_feats = graph_feats
        return self.predict(final_feats)

Getting the model parameters from the params.json file

In [None]:
import json

with open("params.json") as f:
  params = json.load(f)

params

We need to add the node_in_feats which is the length of CanonicalAtomFeatures , 74 and edge_in_feats is 12 from the CanonicalBondFeaturizer.

In [None]:
params.update(
    {
      "node_in_feats" : 74,
      "edge_in_feats": 12
    }
)

Let's create the model object

In [None]:
model = MPNN_readout(**params)
model

We have the model parameters set. But we don't have the optimized weights. Let's load the weights from the pre-trained model.

In [None]:
model.load_state_dict(torch.load("best_r2.pt", map_location=torch.device('cpu')))

Create the graphs for input to model

In [None]:
# import from rdkit and dgl-lifesci
from rdkit import Chem
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer, \
mol_to_bigraph

# create the atom and bond featurizer object
atom_featurizer = CanonicalAtomFeaturizer(atom_data_field="hv")
bond_featurizer = CanonicalBondFeaturizer(bond_data_field="he")

# example smiles - ethane
smiles = "CC"

# mol_to_graph requires the RDKit molecule and featurizers
mol = Chem.MolFromSmiles(smiles)
graph = mol_to_bigraph(mol, node_featurizer=atom_featurizer, 
                     edge_featurizer=bond_featurizer)

# display the graph object
graph

Make predictions

In [None]:
model.eval()
node_feats = graph.ndata["hv"]
edge_feats = graph.edata["he"]
model(graph, node_feats, edge_feats)

## Saving a trained pytorch model

Use the `torch.save` function and pass in the model state_dict and name

In [None]:
torch.save(model.state_dict(), "my_model.pt")

## Transfer learning

When data is limited for prediction one task, a model that is trained on another task with large data can be used to produce models with higher accuracy than starting model training from scratch.

Let's say there is a model to predict HOMO-LUMO gap trained on the QM9 dataset. If your task is to now predict the HOMO energies, there is no need to start the model training from scratch. You can use the optimized weights from the HOMO-LUMO gap predictor and not change them for inner layers. Only the weights for the penultimate layers could be optimized for the HOMO energy prediction model. This process is called transfer learning.

Here, we will freeze the `gnn` and `readout` layers and allow the weights on the `predict` layer to be trainable.

In [None]:
model

Let's look at the parameter (weights)

In [None]:
for param in model.parameters():
  print(param)

We see all have gradients. To freeze the weights we need to remove the gradients 

In [None]:
for param in model.parameters():
  param.requires_grad = False

All weights are frozen. If the model is used for training, the model weights will not change. This implies not learning.

In [None]:
for param in model.parameters():
  print(param)

We want the predict layer weights to be trainable. Let's not freeze those

In [None]:
for param in model.predict.parameters():
  param.requires_grad = True

Check the weights again

In [None]:
for param in model.parameters():
  print(param)