In [1]:
from rdkit import Chem
from dgllife.utils import mol_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer

def SmilesToGraph(smi):
    mol = Chem.MolFromSmiles(smi)
    node_featurizer = CanonicalAtomFeaturizer() 
    edge_featurizer = CanonicalBondFeaturizer()
    graph = mol_to_bigraph(mol, node_featurizer = node_featurizer, edge_featurizer = edge_featurizer)
    return graph

In [7]:
import torch
import torch.nn as nn

import dgl
import dgllife
from dgllife.model import MPNNGNN, WeightedSumAndMax    

class MyGNN(nn.Module):
    def __init__(self,
                 node_in_feats=74,
                 edge_in_feats=12,
                 node_hidden_dim=64,
                 edge_hidden_feats=32,
                 num_step_message_passing=3,
                 n_classes = 100):
        
        super(MyGNN, self).__init__()
                
        self.mpnn = MPNNGNN(node_in_feats=node_in_feats,
                           node_out_feats=node_hidden_dim,
                           edge_in_feats=edge_in_feats,
                           edge_hidden_feats=edge_hidden_feats,
                           num_step_message_passing=num_step_message_passing)
        
        self.readout = WeightedSumAndMax(node_hidden_dim)
        
        self.ff = nn.Linear(node_hidden_dim*2, n_classes)
        
    def forward(self, graph):
        # get features from graph
        node_feats = graph.ndata['h']
        edge_feats = graph.edata['e']
        # message passing
        node_feats = self.mpnn(graph, node_feats, edge_feats)
        # readout
        readout = self.readout(graph, node_feats)
        # linear mlp
        output = self.ff(readout)
        return output

In [8]:
import rdkit
from rdkit import Chem
from dgllife.utils import mol_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer

smi = 'CC(O)CN'
graph = SmilesToGraph(smi)
graph

Graph(num_nodes=5, num_edges=8,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={'e': Scheme(shape=(12,), dtype=torch.float32)})

In [9]:
model = MyGNN()
output = model(graph)
output.shape

torch.Size([1, 100])

# Batch graph

In [10]:
smi1 =  'CC(O)CN'
smi2 =  'CC(O)CNOO'
batch_graph = dgl.batch([SmilesToGraph(smi1), SmilesToGraph(smi2)])
batch_graph

Graph(num_nodes=12, num_edges=20,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={'e': Scheme(shape=(12,), dtype=torch.float32)})

In [12]:
output = model(batch_graph)
output.shape

torch.Size([2, 100])

# Tasks
### 1. change the node_hidden_dim and add ReLU in the classification layer 
### 2. See the dimension of each tensor (node_feats, readout, output)
### 3. Make a models.py and import the model from the file