In [17]:
# import packages
# RDkit
from rdkit import Chem
from rdkit.Chem import GraphDescriptors
from rdkit.Chem.rdmolops import GetAdjacencyMatrix

import pandas as pd, numpy as np

import torch
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
df_data = pd.read_csv('data/kinase_JAK.csv')
df_data.head()

Unnamed: 0,SMILES,measurement_type,measurement_value,Kinase_name
0,C#CCCOC(=O)N1CCC(n2cc(C(N)=O)c(Nc3ccc(F)cc3)n2...,pIC50,6.81,JAK2
1,C#CCCOC(=O)N1CCC(n2cc(C(N)=O)c(Nc3ccc(F)cc3)n2...,pIC50,8.05,JAK1
2,C#CCN(c1ccc(C#N)cn1)C1CCN(c2ncnc3[nH]ccc23)C1,pIC50,10.26,JAK2
3,C#CCN(c1ccc(C#N)cn1)C1CCN(c2ncnc3[nH]ccc23)C1,pIC50,10.26,JAK1
4,C#CCNCC1CCC(c2nnn3cnc4[nH]ccc4c23)CC1,pIC50,7.36,JAK2


In [4]:
def onehot_encode(x, features:list):
    """
    Maps input elements x not in features to the last element
    """
    if x not in features: x = features[-1]
    binary_encoding = [int(bool_val) for bool_val in list(
       map(lambda s: x == s, features))]
    return binary_encoding
  
class onehot_encodings:
  ''' encoding class for one hot features'''
  def __init__(self, atom_info_func, features):
    self.atom_info_func = atom_info_func
    self.features = features
  
  #@property
  def onehot_encodings(self, atom):
    return onehot_encode(self.atom_info_func(atom), self.features)
  
  def __len__(self): return len(self.features)
                            
class FeaturesArgs:
  # encodings information
  available_atoms = [
        'C','N','O','S','F','Si','P','Cl','Br','Mg','Na','Ca','Fe','As','Al','I',
        'B','V','K','Tl','Yb','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn', 'Li','Ge',
        'Cu','Au','Ni','Cd','In','Mn','Zr','Cr','Pt','Hg','Pb','Unknown'
  ]
  chirality = ["CHI_UNSPECIFIED","CHI_TETRAHEDRAL_CW","CHI_TETRAHEDRAL_CCW","CHI_OTHER"]
  num_hydrogens = [0, 1, 2, 3, 4, "MoreThanFour"]
  n_heavy_atoms = num_hydrogens
  formal_charges = [-3, -2, -1, 0, 1, 2, 3, "Extreme"]
  hybridisation_type = ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"]
  # Atoms
    # atom encodings
  atom_encoding_lambdas = {
    'available_atoms': onehot_encodings(
      lambda atom: str(atom.GetSymbol()), available_atoms),
    'chirality_type_enc': onehot_encodings(
      lambda atom: str(atom.GetChiralTag()), chirality),
    'hydrogens_implicit': onehot_encodings(
      lambda atom: int(atom.GetTotalNumHs()), num_hydrogens),
    'n_heavy_atoms': onehot_encodings(
      lambda atom: int(atom.GetDegree()), n_heavy_atoms),
    'formal_charge': onehot_encodings(
      lambda atom: int(atom.GetFormalCharge()), formal_charges),
    'hybridisation_type': onehot_encodings(
      lambda atom: str(atom.GetHybridization()), hybridisation_type)
  }
    # atom info
  atom_info_lambdas = {
    'is_in_a_ring_enc': lambda atom: [int(atom.IsInRing())],
    'is_aromatic_enc': lambda atom: [int(atom.GetIsAromatic())],
    'atomic_mass_scaled': lambda atom: [float((atom.GetMass() - 10.812)/116.092)],
    'vdw_radius_scaled': lambda atom: [float(
      (Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5)/0.6)],
    'covalent_radius_scaled': lambda atom:[float(
      (Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64)/0.76)]
  }
    # compute node feature length
  n_node_features = sum(map(len, atom_encoding_lambdas.values()))
  n_node_features += len(atom_info_lambdas)
  
  # Bonds encoding info
  bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
  stereo_types = ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"]
    # bond encodings
  bond_encoding_lambdas = {
    'bond_types': onehot_encodings(
      lambda bond: bond.GetBondType(), bond_types),
    'stereo_types': onehot_encodings(
      lambda bond: str(bond.GetStereo()), stereo_types)
  }
    # bond quantity
  bond_info_lambas = {
    'bond_is_conj_enc': lambda bond: [int(bond.GetIsConjugated())],
    'bond_is_in_ring_enc': lambda bond: [int(bond.IsInRing())]
  }
  n_edge_features = sum(map(len, bond_encoding_lambdas.values()))
  n_edge_features += len(bond_info_lambas)
  # Molecule
  lambda mol: GraphDescriptors.BalabanJ(mol)

In [5]:
def get_atom_features(atom, 
                      available_atoms:list = FeaturesArgs.available_atoms,
                      atom_encode_lambdas: dict = FeaturesArgs.atom_encoding_lambdas,
                      atom_info_lambdas: dict = FeaturesArgs.atom_info_lambdas,
                      debug=False
                     ):
  """
  Takes an RDKit atom object as input and gives a 1d-numpy array of atom features as output.
  """  
  if 'hydrogens_implicit' in atom_encode_lambdas:
    available_atoms = ['H'] + available_atoms
  atom_feature_vector = []
  # compute atom features
  for name, atom_encoding in atom_encode_lambdas.items():
    encoding = atom_encoding.onehot_encodings(atom)
    atom_feature_vector += encoding
    if debug: print(f"atom encoding length ({name}): {len(encoding)}")
    # boolean features
  for name, info_func in atom_info_lambdas.items():
    atom_feature_vector += info_func(atom)
    if debug: print(f"atom info ({name}): {info_func(atom)}")
  if debug: print(f"full atom feature:{len(atom_feature_vector)}")
  return torch.Tensor(atom_feature_vector)

def get_bond_features(bond, 
                      bond_encoding_lambdas: dict = FeaturesArgs.bond_encoding_lambdas,
                      bond_info_lambdas: dict = FeaturesArgs.bond_info_lambas,
                      debug=False):
  """
  Takes an RDKit bond object as input and gives a 1d-numpy array of bond features as output.
  """
  bond_feature_vector = []
    # compute bond features
  for name, bond_encoding in bond_encoding_lambdas.items():
    encoding = bond_encoding.onehot_encodings(bond)
    bond_feature_vector += encoding
    if debug: print(f"bond encoding length ({name}): {len(bond_feature_vector)}")
    # boolean features
  for name, info_func in bond_info_lambdas.items():
    bond_feature_vector += info_func(bond)
    if debug: print(f"bond info ({name}): {info_func(bond)}")
    return torch.Tensor(bond_feature_vector)

In [18]:
mol = Chem.MolFromSmiles(df_data['SMILES'][0])
atom = mol.GetAtoms()[0]
atom_features = get_atom_features(atom, debug=True)
print(len(atom_features))

atom encoding length (available_atoms): 43
atom encoding length (chirality_type_enc): 4
atom encoding length (hydrogens_implicit): 6
atom encoding length (n_heavy_atoms): 6
atom encoding length (formal_charge): 8
atom encoding length (hybridisation_type): 7
atom info (is_in_a_ring_enc): [0]
atom info (is_aromatic_enc): [0]
atom info (atomic_mass_scaled): [0.010328015711676944]
atom info (vdw_radius_scaled): [0.33333333333333326]
atom info (covalent_radius_scaled): [0.05263157894736847]
full atom feature:79
79


In [19]:
bond = mol.GetBondBetweenAtoms(0, 1)
bond_features = get_bond_features(bond, debug=True)
print(bond_features)

bond encoding length (bond_types): 4
bond encoding length (stereo_types): 8
bond info (bond_is_conj_enc): [0]
tensor([0., 0., 1., 0., 0., 0., 0., 1., 0.])


In [20]:
def smile_to_data(smiles, y_val):
  # convert SMILES to RDKit mol object
  mol = Chem.MolFromSmiles(smiles)
  # get feature dimensions
  n_nodes = mol.GetNumAtoms()
  n_edges = 2*mol.GetNumBonds()

  # construct node feature matrix X of shape (n_nodes, n_node_features)
  for n, atom in enumerate(mol.GetAtoms()):
    atom_features = get_atom_features(atom)
    if n==0: X = torch.zeros((n_nodes, len(atom_features)), dtype=torch.float)
    X[atom.GetIdx(), :] = atom_features
    
  # construct edge index array E of shape (2, n_edges)
  E_ij = torch.stack(
    list(map(
      lambda arr: torch.Tensor(arr).to(torch.long),
      np.nonzero(GetAdjacencyMatrix(mol))))
  )
  # construct edge feature array EF of shape (n_edges, n_edge_features)
  EF = torch.stack([
    get_bond_features(mol.GetBondBetweenAtoms(i.item(), j.item())) for i, j in zip(
      E_ij[0], E_ij[1])]
  )
  # construct label tensor
  y_tensor = torch.tensor(np.array([y_val]), dtype = torch.float)
  return X, E_ij, EF, y_tensor
        
def graph_datalist_from_smiles_and_labels(x_smiles, y):
    """
    Inputs:
      x_smiles [list]: SMILES strings
      y [list]: numerial labels for the SMILES strings
    Outputs:
      data_list [list]: torch_geometric.data.Data objects which represent labeled molecular graphs that can readily be used for machine learning
    
    """
    data_list = []
    for (smiles, y_val) in zip(x_smiles, y):
        X, E, EF, y_tensor = smile_to_data(smiles, y_val)
        #print(list(map(type, [X, E, EF, y_tensor])))
        # construct Pytorch Geometric data object and append to data list
        data_list.append(
          Data(x = X, edge_index = E, edge_attr = EF, y = y_tensor))
    return data_list

In [26]:
# canonical training loop for a Pytorch Geometric GNN model gnn_model
x_smiles, y = df_data.to_numpy()[:,0], df_data.to_numpy()[:,-2]
# create list of molecular graph objects from list of SMILES x_smiles and list of labels y
data_list = graph_datalist_from_smiles_and_labels(x_smiles, y)

PyGeometric dataset

In [33]:
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, InMemoryDataset, download_url

In [34]:
class Molecule_pKa(InMemoryDataset):
    def __init__(self, root, data_list, transform=None):
        self.data_list = data_list
        super().__init__(root, transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return 'data.pt'

    def process(self):
        torch.save(self.collate(self.data_list), self.processed_paths[0])

In [35]:
dataset = Molecule_pKa('.', data_list)

Model

In [36]:
from postera_gnn.model.models import *

In [382]:
model = GIN
gnn_model = model(FeaturesArgs.n_edge_features + FeaturesArgs.n_node_features, 1)

### Training

In [409]:
class TrainArgs:
  batch_size=2**7

In [410]:
# create dataloader for training
dataloader = DataLoader(dataset = data_list, batch_size = TrainArgs.batch_size)
# define loss function
loss_function = nn.MSELoss()
# define optimiser
optimiser = torch.optim.Adam(gnn_model.parameters(), lr = 1e-3)

In [405]:
# loop over 10 training epochs
for epoch in range(10):
    # set model to training mode
    gnn_model.train()
    # loop over minibatches for training
    for (k, batch) in enumerate(dataloader):
        # compute current value of loss function via forward pass
        output = gnn_model(batch)
        loss_function_value = loss_function(output[:,0], torch.tensor(batch.y, dtype = torch.float32))
        # set past gradient to zero
        optimiser.zero_grad()
        # compute current gradient via backward pass
        loss_function_value.backward()
        # update model weights using gradient and optimisation method
        optimiser.step()

TypeError: forward() missing 1 required positional argument: 'edge_index'