In [92]:
import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from rdkit import Chem
from dgllife.utils import smiles_to_complete_graph, smiles_to_bigraph
from dgllife.utils import AttentiveFPAtomFeaturizer, AttentiveFPBondFeaturizer
from functools import partial
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader

In [37]:

df = pd.read_csv("../data/drugs/drugbank_5.1.10/structure links.csv")
df = df[["DrugBank ID","SMILES"]]
df = df.dropna()
df

Unnamed: 0,DrugBank ID,SMILES
0,DB00006,CC[C@H](C)[C@H](NC(=O)[C@H](CCC(O)=O)NC(=O)[C@...
1,DB00007,CCNC(=O)[C@@H]1CCCN1C(=O)[C@H](CCCNC(N)=N)NC(=...
2,DB00014,CC(C)C[C@H](NC(=O)[C@@H](COC(C)(C)C)NC(=O)[C@H...
3,DB00027,CC(C)C[C@@H](NC(=O)CNC(=O)[C@@H](NC=O)C(C)C)C(...
4,DB00035,NC(=O)CC[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=O)...
...,...,...
12221,DB17379,CC(C)C1=C(O)C(O)=C(C=O)C2=C(O)C(=C(C)C=C12)C1=...
12223,DB17383,CN1CCN(CC2=CC=C(NC(=O)C3=NNC=C3NC3=C4C=CNC4=NC...
12224,DB17384,CC1=C2N=C(C3=CC=CC=C3Cl)C3=C(NC2=NN1)C=C(N=C3)...
12225,DB17385,CC[C@@]1(OC(=O)C(C)ON=C2C3=C(C4=C2C=C(C=C4[N+]...


In [57]:
# read more here: https://lifesci.dgl.ai/generated/dgllife.utils.AttentiveFPAtomFeaturizer.html

def featurize_atoms(mol):
    feats = []
    atom_featurizer = AttentiveFPAtomFeaturizer(atom_data_field='feat')
    return(atom_featurizer(mol))

def featurize_edges(mol, add_self_loop=False):
    bond_featurizer = AttentiveFPBondFeaturizer(bond_data_field='feat', self_loop=add_self_loop)    
    return(bond_featurizer(mol))

In [56]:
#test with a simple SMILES mol
add_self_loop = True
g = smiles_to_bigraph(
        'CCO', add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
        edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop),explicit_hydrogens=True)

In [69]:
g_list = []

for idx, row in df.iterrows():
    mol = row["SMILES"]
    try:
        g = smiles_to_bigraph(
            mol, add_self_loop=add_self_loop, node_featurizer=featurize_atoms,
            edge_featurizer=partial(featurize_edges, add_self_loop=add_self_loop),explicit_hydrogens=True)
        g_list.append(g)
    except:
        print(row["DrugBank ID"])

print(len(g_list),"/",len(df),"done.")
    

Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
11583 / 11583 done.


In [82]:

class AEDataset(DGLDataset):
    def __init__(self, graph_list):
        super().__init__(name="drugmols")
        self.graphs = graph_list
        self.labels = graph_list

    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)


In [83]:
dataset = AEDataset(g_list)

In [88]:
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=5, drop_last=False
)
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=5, drop_last=False
)

In [89]:
it = iter(train_dataloader)
batch = next(it)
print(batch)

[Graph(num_nodes=360, num_edges=1106,
      ndata_schemes={'feat': Scheme(shape=(39,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(11,), dtype=torch.float32)}), Graph(num_nodes=360, num_edges=1106,
      ndata_schemes={'feat': Scheme(shape=(39,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(11,), dtype=torch.float32)})]


In [90]:
batched_graph, labels = batch
print(
    "Number of nodes for each graph element in the batch:",
    batched_graph.batch_num_nodes(),
)
print(
    "Number of edges for each graph element in the batch:",
    batched_graph.batch_num_edges(),
)

# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
print("The original graphs in the minibatch:")
print(graphs)

Number of nodes for each graph element in the batch: tensor([ 45,  56,  29, 181,  49])
Number of edges for each graph element in the batch: tensor([139, 174,  89, 557, 147])
The original graphs in the minibatch:
[Graph(num_nodes=45, num_edges=139,
      ndata_schemes={'feat': Scheme(shape=(39,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(11,), dtype=torch.float32)}), Graph(num_nodes=56, num_edges=174,
      ndata_schemes={'feat': Scheme(shape=(39,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(11,), dtype=torch.float32)}), Graph(num_nodes=29, num_edges=89,
      ndata_schemes={'feat': Scheme(shape=(39,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(11,), dtype=torch.float32)}), Graph(num_nodes=181, num_edges=557,
      ndata_schemes={'feat': Scheme(shape=(39,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(11,), dtype=torch.float32)}), Graph(num_nodes=49, num_edges=147,
      ndata_schemes={'feat': Scheme(

In [93]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata["h"] = h
        return dgl.mean_nodes(g, "h")

In [None]:
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata["attr"].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata["attr"].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print("Test accuracy:", num_correct / num_tests)