In [None]:
# import useful packages and functions
from spektral_package import *

In [None]:
import networkx as nx
import numpy as np
from scipy import sparse as sp

try:
    from rdkit import Chem as rdc
    from rdkit.Chem import Draw
    from rdkit import rdBase as rdb

    rdb.DisableLog('rdApp.error')  # RDKit logging is disabled by default
    Draw.DrawingOptions.dblBondOffset = .1
    BOND_MAP = {0: rdc.rdchem.BondType.ZERO,
                1: rdc.rdchem.BondType.SINGLE,
                2: rdc.rdchem.BondType.DOUBLE,
                3: rdc.rdchem.BondType.TRIPLE,
                4: rdc.rdchem.BondType.AROMATIC}
except ImportError:
    rdc = None
    rdb = None

NUM_TO_SYMBOL = {1: 'H', 2: 'He', 3: 'Li', 4: 'Be', 5: 'B', 6: 'C', 7: 'N',
                 8: 'O', 9: 'F', 10: 'Ne', 11: 'Na', 12: 'Mg', 13: 'Al',
                 14: 'Si', 15: 'P', 16: 'S', 17: 'Cl', 18: 'Ar', 19: 'K',
                 20: 'Ca', 21: 'Sc', 22: 'Ti', 23: 'V', 24: 'Cr', 25: 'Mn',
                 26: 'Fe', 27: 'Co', 28: 'Ni', 29: 'Cu', 30: 'Zn', 31: 'Ga',
                 32: 'Ge', 33: 'As', 34: 'Se', 35: 'Br', 36: 'Kr', 37: 'Rb',
                 38: 'Sr', 39: 'Y', 40: 'Zr', 41: 'Nb', 42: 'Mo', 43: 'Tc',
                 44: 'Ru', 45: 'Rh', 46: 'Pd', 47: 'Ag', 48: 'Cd', 49: 'In',
                 50: 'Sn', 51: 'Sb', 52: 'Te', 53: 'I', 54: 'Xe', 55: 'Cs',
                 56: 'Ba', 57: 'La', 58: 'Ce', 59: 'Pr', 60: 'Nd', 61: 'Pm',
                 62: 'Sm', 63: 'Eu', 64: 'Gd', 65: 'Tb', 66: 'Dy', 67: 'Ho',
                 68: 'Er', 69: 'Tm', 70: 'Yb', 71: 'Lu', 72: 'Hf', 73: 'Ta',
                 74: 'W', 75: 'Re', 76: 'Os', 77: 'Ir', 78: 'Pt', 79: 'Au',
                 80: 'Hg', 81: 'Tl', 82: 'Pb', 83: 'Bi', 84: 'Po', 85: 'At',
                 86: 'Rn', 87: 'Fr', 88: 'Ra', 89: 'Ac', 90: 'Th', 91: 'Pa',
                 92: 'U', 93: 'Np', 94: 'Pu', 95: 'Am', 96: 'Cm', 97: 'Bk',
                 98: 'Cf', 99: 'Es', 100: 'Fm', 101: 'Md', 102: 'No', 103: 'Lr',
                 104: 'Rf', 105: 'Db', 106: 'Sg', 107: 'Bh', 108: 'Hs',
                 109: 'Mt', 110: 'Ds', 111: 'Rg', 112: 'Cn', 113: 'Nh',
                 114: 'Fl', 115: 'Mc', 116: 'Lv', 117: 'Ts', 118: 'Og'}
SYMBOL_TO_NUM = {v: k for k, v in NUM_TO_SYMBOL.items()}


def numpy_to_rdkit(adj, nf, ef, sanitize=False):
    """
    Converts a molecule from numpy to RDKit format.
    :param adj: binary numpy array of shape (N, N)
    :param nf: numpy array of shape (N, F)
    :param ef: numpy array of shape (N, N, S)
    :param sanitize: whether to sanitize the molecule after conversion
    :return: an RDKit molecule
    """
    if rdc is None:
        raise ImportError('`numpy_to_rdkit` requires RDKit.')
    mol = rdc.RWMol()
    for nf_ in nf:
        atomic_num = int(nf_)
        if atomic_num > 0:
            mol.AddAtom(rdc.Atom(atomic_num))

    for i, j in zip(*np.triu_indices(adj.shape[-1])):
        if i != j and adj[i, j] == adj[j, i] == 1 and not mol.GetBondBetweenAtoms(int(i), int(j)):
            bond_type_1 = BOND_MAP[int(ef[i, j, 0])]
            bond_type_2 = BOND_MAP[int(ef[j, i, 0])]
            if bond_type_1 == bond_type_2:
                mol.AddBond(int(i), int(j), bond_type_1)

    mol = mol.GetMol()
    if sanitize:
        rdc.SanitizeMol(mol)
    return mol


def numpy_to_smiles(adj, nf, ef):
    """
    Converts a molecule from numpy to SMILES format.
    :param adj: binary numpy array of shape (N, N)
    :param nf: numpy array of shape (N, F)
    :param ef: numpy array of shape (N, N, S)
    :return: the SMILES string of the molecule
    """
    if rdc is None:
        raise ImportError('`numpy_to_smiles` requires RDkit.')
    mol = numpy_to_rdkit(adj, nf, ef)
    return rdkit_to_smiles(mol)


def rdkit_to_smiles(mol):
    """
    Returns the SMILES string representing an RDKit molecule.
    :param mol: an RDKit molecule
    :return: the SMILES string of the molecule
    """
    if rdc is None:
        raise ImportError('`rdkit_to_smiles` requires RDkit.')
    return rdc.MolToSmiles(mol)


def sdf_to_nx(sdf, keep_hydrogen=False):
    """
    Converts molecules in SDF format to networkx Graphs.
    :param sdf: a list of molecules (or individual molecule) in SDF format.
    :param keep_hydrogen: whether to include hydrogen in the representation.
    :return: list of nx.Graphs.
    """
    if not isinstance(sdf, list):
        sdf = [sdf]

    output = []
    for sdf_ in sdf:
        g = nx.Graph()

        for atom in sdf_['atoms']:
            if atom['atomic_num'] > 1 or keep_hydrogen:
                g.add_node(atom['index'], **atom)
        for bond in sdf_['bonds']:
            start_atom_num = sdf_['atoms'][bond['start_atom']]['atomic_num']
            end_atom_num = sdf_['atoms'][bond['end_atom']]['atomic_num']
            if (start_atom_num > 1 and end_atom_num > 1) or keep_hydrogen:
                g.add_edge(bond['start_atom'], bond['end_atom'], **bond)
        output.append(g)

    if len(output) == 1:
        return output[0]
    else:
        return output


def nx_to_sdf(graphs):
    """
    Converts a list of nx.Graphs to the internal SDF format.
    :param graphs: list of nx.Graphs.
    :return: list of molecules in the internal SDF format.
    """
    if isinstance(graphs, nx.Graph):
        graphs = [graphs]
    output = []
    for g in graphs:
        sdf = {'atoms': [v for k, v in g.nodes.items()],
               'bonds': [v for k, v in g.edges.items()],
               'comment': '',
               'data': [''],
               'details': '',
               'n_atoms': -1,
               'n_bonds': -1,
               'name': '',
               'properties': []}
        output.append(sdf)
    return output


def validate_rdkit_mol(mol):
    """
    Sanitizes an RDKit molecules and returns True if the molecule is chemically
    valid.
    :param mol: an RDKit molecule
    :return: True if the molecule is chemically valid, False otherwise
    """
    if rdc is None:
        raise ImportError('`validate_rdkit_mol` requires RDkit.')
    if len(rdc.GetMolFrags(mol)) > 1:
        return False
    try:
        rdc.SanitizeMol(mol)
        return True
    except ValueError:
        return False


def validate_rdkit(mol):
    """
    Validates RDKit molecules (single or in a list).
    :param mol: an RDKit molecule or list/np.array thereof
    :return: boolean array, True if the molecules are chemically valid, False
    otherwise
    """
    if rdc is None:
        raise ImportError('`validate_rdkit` requires RDkit.')
    if isinstance(mol, list) or isinstance(mol, np.ndarray):
        return np.array([validate_rdkit_mol(m) for m in mol])
    else:
        return validate_rdkit_mol(mol)

In [None]:
import tensorflow as tf
import rdkit
import numpy as np
from rdkit import *
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem   

In [None]:
# read in the data 
import pandas as pd
df = pd.read_csv('top_20_MOAs.txt', sep = '\t')

In [None]:
# change 20 MoAs to classes 
MOA_class_dictionary = {'EGFR inhibitor': 8,
 'HDAC inhibitor': 16,
 'PI3K inhibitor': 13,
 'acetylcholine receptor agonist': 1,
 'acetylcholine receptor antagonist': 4,
 'adrenergic receptor agonist': 18,
 'adrenergic receptor antagonist': 15,
 'bacterial cell wall synthesis inhibitor': 14,
 'benzodiazepine receptor agonist': 10,
 'calcium channel blocker': 5,
 'cyclooxygenase inhibitor': 6,
 'dopamine receptor antagonist': 12,
 'glucocorticoid receptor agonist': 9,
 'glutamate receptor antagonist': 19,
 'histamine receptor antagonist': 17,
 'phosphodiesterase inhibitor': 3,
 'serotonin receptor agonist': 7,
 'serotonin receptor antagonist': 2,
 'sodium channel blocker': 11,
 'topoisomerase inhibitor': 0}

In [None]:
# add classes column 
df['classes'] = None
for i in range(df.shape[0]):
  df.iloc[i,2] = MOA_class_dictionary[df.iloc[i,1]]

In [None]:
structures = df.SMILES.tolist()

In [None]:
# create molecule block for each SMILES string    
molecule_blocks = []
for structure in structures:
  molecule = Chem.MolFromSmiles(structure)
  molecule_block = Chem.MolToMolBlock(molecule)
  molecule_blocks.append(molecule_block)

In [None]:
# write all the molecule blocks into a .txt file for loading afterwards
molecule_textfile = open("molecule_textfile.txt", "w")
for molecule_block in molecule_blocks:
  molecule_textfile.write(molecule_block + "\n" + '$$$$' + "\n")
molecule_textfile.close()

In [None]:
# load the .txt file    
from spektral.utils.io import load_sdf
molecule_sdf_loaded = load_sdf("molecule_textfile.txt", amount = None)

In [None]:
import networkx as nx
molecule_sdf_nx = sdf_to_nx(molecule_sdf_loaded, keep_hydrogen = True)
molecule_sdf_adj, molecule_sdf_node, _ = nx_to_numpy(molecule_sdf_nx, nf_keys=['atomic_num'],ef_keys=['type'] )
uniq_node = np.unique([v for x in molecule_sdf_node for v in np.unique(x)])
node = [label_to_one_hot(x, uniq_node) for x in molecule_sdf_node]

In [None]:
# get the labels 
y = [int(i) for i in df['classes'].tolist()]

In [None]:
# get the nodes number, features number and classes number, they are used for building GCN
n_nodes = node[0].shape[-2]           
n_features = node[0].shape[-1]         
n_classes = len(set(y))    

In [None]:
# Split out the test set    
X = list(df.index)
from sklearn.model_selection import train_test_split
x_train_valid, x_test, y_train_valid, y_test= train_test_split(X, df.classes.tolist(), test_size =10/100,
      stratify = df.classes.tolist(), shuffle=True, random_state = 1000)

In [None]:
# kfold 
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits = 9)
skf.get_n_splits(np.array(list(x_train_valid)), np.array(list(y_train_valid)))
train_index_list = []
valid_index_list = []
for train_index, valid_index in skf.split(np.array(list(x_train_valid)), np.array(list(y_train_valid))):
  train_index_list.append(train_index)
  valid_index_list.append(valid_index)

In [None]:
number_of_kfold = 0 # Change this from 0 - 8 to get 9 shuffles 
x_train = list(np.array(list(x_train_valid))[train_index_list[ number_of_kfold ]])
x_valid = list(np.array(list(x_train_valid))[valid_index_list[ number_of_kfold ]])
y_train = list(np.array(list(y_train_valid))[train_index_list[ number_of_kfold ]])
y_valid = list(np.array(list(y_train_valid))[valid_index_list[ number_of_kfold ]])
x_test = list(x_test)
y_test = list(y_test)

In [None]:
# Get adjacency matrix       
A_train = molecule_sdf_adj[x_train]
A_valid = molecule_sdf_adj[x_valid]
A_test = molecule_sdf_adj[x_test]

In [None]:
# get node matrix   
X_train = []
for i in  x_train:
  X_train.append(node[i])

In [None]:
X_valid = []
for i in  x_valid:
  X_valid.append(node[i]) 

In [None]:
X_test = []
for i in x_test:
  X_test.append(node[i])  

In [None]:
# change into numpy arrays 
X_train = np.asarray(X_train)
X_valid = np.asarray(X_valid)
X_test = np.asarray(X_test)

In [None]:
# build GCN
from tensorflow.keras.regularizers import l2
from spektral.layers import GCSConv, GlobalAttentionPool
from keras.layers import Dropout, Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

units = 128
drop = 0.5
batch_size = 8
l2_reg = 5e-4

X_in = Input((n_nodes, n_features))
A_in = Input((n_nodes, n_nodes))
layer = GCSConv(units, activation='relu', kernel_regularizer=l2(l2_reg))([X_in, A_in])
layer = Dropout(drop)(layer)
layer = GCSConv(units, activation='relu', kernel_regularizer=l2(l2_reg))([layer, A_in])
layer = Dropout(drop)(layer)
layer = GCSConv(units, activation='relu', kernel_regularizer=l2(l2_reg))([layer, A_in])
layer = Dropout(drop)(layer) 
layer = GlobalAttentionPool(units*2)(layer)
output = Dense(n_classes, activation='softmax')(layer)
model0 = Model(inputs = [X_in, A_in], outputs = output)

# compile the model       
optimizer = Adam(lr = 1e-3)
model0.compile(optimizer = optimizer,loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
        metrics = ['accuracy'])

In [None]:
# Create class weights
from sklearn.utils import class_weight
y_unique = np.unique(np.array(y_train))
class_weights = class_weight.compute_class_weight(class_weight = 'balanced', classes = y_unique,
                y = np.array(y_train)) 
class_weights_dict45 = dict(enumerate(class_weights))

In [None]:
# set the checkpoint   
from keras.callbacks import ModelCheckpoint
filepath_gcn = './content/GCN_top_'+str(n_classes)+'_MOA_weights.hdf5'
checkpoint_gcn = ModelCheckpoint(filepath_gcn, save_weights_only = True, monitor = 'val_accuracy',
         verbose = 0, save_best_only = True, mode = 'max')

In [None]:
# train the model 
from tensorflow.keras.callbacks import EarlyStopping
earlyStopping = EarlyStopping(monitor = 'val_loss', patience = 80, verbose = 1, mode = 'min')
reduce_lr_loss = tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_loss', 
          factor = 0.98, patience = 5, verbose = 0, min_delta = 1e-54, mode = 'min')
history = model0.fit([[X_train, A_train]], np.array(y_train), batch_size = batch_size,
          validation_data = [[X_valid, A_valid], np.array(y_valid)], 
          shuffle = True, class_weight = class_weights_dict45, epochs = 300, 
          verbose = 2, callbacks = [earlyStopping, checkpoint_gcn, reduce_lr_loss,])

In [None]:
# load the weights 
model0.load_weights(filepath_gcn)

In [None]:
# evaluate the model 
from sklearn.metrics import classification_report
assert list(y_test)[0:5] == [14, 12, 6, 13, 14]
print(classification_report(y_test, np.array(model0.predict([X_test, A_test]).argmax(-1)),))

In [None]:
# Training curves
import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'valid'], loc = 'upper left')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'valid'], loc = 'upper left')
plt.show()

In [None]:
# References 
# https://codesuche.com/view-source/python/danielegrattarola/spektral/ 
# https://github.com/Discngine/dng_dl_speknn 
# https://github.com/danielegrattarola/spektral