<a href="https://colab.research.google.com/github/swansonk14/chemprop-intro/blob/master/lab3/message_passing_neural_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Message Passing Neural Network on Graph Structure

In [1]:
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!conda install -q -y --prefix /usr/local -c rdkit rdkit pytorch

import sys
sys.path.append('/usr/local/lib/python3.6/site-packages/')

!wget https://raw.githubusercontent.com/swansonk14/chemprop-intro/master/data/delaney_train.csv
!wget https://raw.githubusercontent.com/swansonk14/chemprop-intro/master/data/delaney_test.csv

--2018-12-28 00:31:44--  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
Resolving repo.anaconda.com (repo.anaconda.com)... 104.17.107.77, 104.17.109.77, 104.17.108.77, ...
Connecting to repo.anaconda.com (repo.anaconda.com)|104.17.107.77|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 62574861 (60M) [application/x-sh]
Saving to: ‘Miniconda3-latest-Linux-x86_64.sh’


2018-12-28 00:31:45 (111 MB/s) - ‘Miniconda3-latest-Linux-x86_64.sh’ saved [62574861/62574861]

PREFIX=/usr/local
reinstalling: python-3.7.0-hc3d631a_0 ...
Python 3.7.0
reinstalling: ca-certificates-2018.03.07-0 ...
reinstalling: conda-env-2.6.0-1 ...
reinstalling: libgcc-ng-8.2.0-hdf63c60_1 ...
reinstalling: libstdcxx-ng-8.2.0-hdf63c60_1 ...
reinstalling: libffi-3.2.1-hd88cf55_4 ...
reinstalling: ncurses-6.1-hf484d3e_0 ...
reinstalling: openssl-1.0.2p-h14c3975_0 ...
reinstalling: xz-5.2.4-h14c3975_4 ...
reinstalling: yaml-0.1.7-had09818_2 ...
reinstalling: zlib-1.2.11-ha838b

In [0]:
import math
import os
import random
from typing import Union, List, Dict

import numpy as np
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from sklearn.metrics import mean_squared_error
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

In [3]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
class MoleculeDatapoint:
  def __init__(self, smiles: str, targets: List[float]):
    self.smiles = smiles
    self.targets = targets
    
class MoleculeDataset:
  def __init__(self, data: List[MoleculeDatapoint]):
    self.data = data
    
  def smiles(self) -> List[str]:
    return [d.smiles for d in self.data]
  
  def targets(self) -> List[float]:
    return [d.targets for d in self.data]
  
  def shuffle(self, seed: int = None):
    if seed is not None:
      random.seed(seed)
    random.shuffle(self.data)
  
  def __len__(self) -> int:
    return len(self.data)
  
  def __getitem__(self, item) -> MoleculeDatapoint:
    return self.data[item]

In [0]:
def get_data(split: str) -> MoleculeDataset:
  data_path = 'delaney_{}.csv'.format(split)
  with open(data_path) as f:
    f.readline()
    data = []
    for line in f:
      line = line.strip().split(',')
      smiles, targets = line[0], line[1:]
      targets = [float(target) for target in targets]
      data.append(MoleculeDatapoint(smiles, targets))
      
  return MoleculeDataset(data)

In [0]:
train_data, test_data = get_data('train'), get_data('test')

In [0]:
ELEM_LIST = range(100)
HYBRID_LIST = [
    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
    Chem.rdchem.HybridizationType.SP3D2
]
ATOM_FDIM = 100 + len(HYBRID_LIST) + 6 + 5 + 4 + 7 + 5 + 3 + 1
BOND_FDIM = 6 + 6
MAX_NB = 12

def onek_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def atom_features(atom):
    return torch.Tensor(onek_encoding_unk(atom.GetAtomicNum() - 1, ELEM_LIST) 
            + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) 
            + onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
            + onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3])
            + onek_encoding_unk(int(atom.GetImplicitValence()), [0,1,2,3,4,5,6])
            + onek_encoding_unk(int(atom.GetTotalNumHs()), [0,1,2,3,4])
            + onek_encoding_unk(int(atom.GetHybridization()), HYBRID_LIST)
            + onek_encoding_unk(int(atom.GetNumRadicalElectrons()), [0,1,2])
            + [atom.GetIsAromatic()])

def bond_features(bond):
    bt = bond.GetBondType()
    stereo = int(bond.GetStereo())
    fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.GetIsConjugated(), bond.IsInRing()]
    fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5])
    return torch.Tensor(fbond + fstereo)

def mol2graph(mol_batch):
    padding = torch.zeros(ATOM_FDIM + BOND_FDIM)
    fatoms,fbonds = [],[padding] #Ensure bond is 1-indexed
    in_bonds,all_bonds = [],[(-1,-1)] #Ensure bond is 1-indexed
    scope = []
    total_atoms = 0

    for smiles in mol_batch:
        mol = Chem.MolFromSmiles(smiles)
        n_atoms = mol.GetNumAtoms()
        for atom in mol.GetAtoms():
            fatoms.append( atom_features(atom) )
            in_bonds.append([])

        for bond in mol.GetBonds():
            a1 = bond.GetBeginAtom()
            a2 = bond.GetEndAtom()
            x = a1.GetIdx() + total_atoms
            y = a2.GetIdx() + total_atoms

            b = len(all_bonds) 
            all_bonds.append((x,y))
            fbonds.append( torch.cat([fatoms[x], bond_features(bond)], 0) )
            in_bonds[y].append(b)

            b = len(all_bonds)
            all_bonds.append((y,x))
            fbonds.append( torch.cat([fatoms[y], bond_features(bond)], 0) )
            in_bonds[x].append(b)

        scope.append((total_atoms,n_atoms))
        total_atoms += n_atoms

    total_bonds = len(all_bonds)
    fatoms = torch.stack(fatoms, 0)
    fbonds = torch.stack(fbonds, 0)
    agraph = torch.zeros(total_atoms,MAX_NB).long()
    bgraph = torch.zeros(total_bonds,MAX_NB).long()

    for a in range(total_atoms):
        for i,b in enumerate(in_bonds[a]):
            agraph[a,i] = b

    for b1 in range(1, total_bonds):
        x,y = all_bonds[b1]
        for i,b2 in enumerate(in_bonds[x]):
            if all_bonds[b2][0] != y:
                bgraph[b1,i] = b2

    return fatoms, fbonds, agraph, bgraph, scope

def index_select_ND(source, dim, index):  # convenience method for selecting indices, used in MPN
    index_size = index.size()
    suffix_dim = source.size()[1:]
    final_size = index_size + suffix_dim
    target = source.index_select(dim, index.view(-1))
    return target.view(final_size)

In [0]:
num_epochs = 30
batch_size = 50
lr = .001

In [0]:
class MPN(nn.Module):
  def __init__(self):
    super(MPN, self).__init__()
    self.hidden_size = 300
    self.depth = 3
    self.dropout = 0

    self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, self.hidden_size, bias=False)
    self.W_h = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
    self.W_o = nn.Linear(ATOM_FDIM + self.hidden_size, self.hidden_size)
    self.W_pred = nn.Linear(self.hidden_size, 1)

  def forward(self, mol_graph) -> torch.FloatTensor:
    fatoms,fbonds,agraph,bgraph,scope = mol_graph

    binput = self.W_i(fbonds)
    message = F.relu(binput)

    for i in range(self.depth - 1):
        nei_message = index_select_ND(message, 0, bgraph)
        nei_message = nei_message.sum(dim=1)
        nei_message = self.W_h(nei_message)
        message = F.relu(binput + nei_message)
        if self.dropout > 0:
            message = F.dropout(message, self.dropout, self.training)

    nei_message = index_select_ND(message, 0, agraph)
    nei_message = nei_message.sum(dim=1)
    ainput = torch.cat([fatoms, nei_message], dim=1)
    atom_hiddens = F.relu(self.W_o(ainput))
    if self.dropout > 0:
        atom_hiddens = F.dropout(atom_hiddens, self.dropout, self.training)

    mol_vecs = []
    for st,le in scope:
        mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
        mol_vecs.append(mol_vec)

    mol_vecs = torch.stack(mol_vecs, dim=0)
    return self.W_pred(mol_vecs)
    

In [0]:
model = MPN()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [0]:
def param_count(model: nn.Module) -> int:
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

In [34]:
print(model)
print('Number of parameters = {:,}'.format(param_count(model)))

MPN(
  (W_i): Linear(in_features=148, out_features=300, bias=False)
  (W_h): Linear(in_features=300, out_features=300, bias=False)
  (W_o): Linear(in_features=436, out_features=300, bias=True)
  (W_pred): Linear(in_features=300, out_features=1, bias=True)
)
Number of parameters = 265,801


In [0]:
def train_epoch(model: nn.Module,
                optimizer: optim.Optimizer,
                data: MoleculeDataset,
                batch_size: int,
                epoch: int) -> float:
  model.train()
  data.shuffle(seed=epoch)
  
  total_loss = 0
  num_batches = 0
  
  data_size = len(data) // batch_size * batch_size  # drop final, incomplete batch
  for i in range(0, data_size, batch_size):
    batch = MoleculeDataset(data[i:i + batch_size])
    mol_graph, targets = mol2graph(batch.smiles()), batch.targets()
    
    targets = torch.FloatTensor(targets)
    
    optimizer.zero_grad()
    preds = model(mol_graph)
    loss = F.mse_loss(preds, targets)
    loss.backward()
    optimizer.step()
    
    total_loss += math.sqrt(loss.item())
    num_batches += 1
    
  avg_loss = total_loss / num_batches
  
  return avg_loss

In [42]:
num_epochs = 30
for epoch in range(num_epochs):
  train_loss = train_epoch(model, optimizer, train_data, batch_size, epoch)
  print('Epoch {}: Train loss = {:.4f}'.format(epoch, train_loss))

Epoch 0: Train loss = 2.2651
Epoch 1: Train loss = 1.8405
Epoch 2: Train loss = 1.6957
Epoch 3: Train loss = 1.5715
Epoch 4: Train loss = 1.4402
Epoch 5: Train loss = 1.3467
Epoch 6: Train loss = 1.2430
Epoch 7: Train loss = 1.1616
Epoch 8: Train loss = 1.0977
Epoch 9: Train loss = 1.0230
Epoch 10: Train loss = 0.9596
Epoch 11: Train loss = 0.9158
Epoch 12: Train loss = 0.8961
Epoch 13: Train loss = 0.9190
Epoch 14: Train loss = 0.9347
Epoch 15: Train loss = 0.8680
Epoch 16: Train loss = 0.8183
Epoch 17: Train loss = 0.8056
Epoch 18: Train loss = 0.8221
Epoch 19: Train loss = 0.8608
Epoch 20: Train loss = 0.8132
Epoch 21: Train loss = 0.7522
Epoch 22: Train loss = 0.7444
Epoch 23: Train loss = 0.7356
Epoch 24: Train loss = 0.7350
Epoch 25: Train loss = 0.7550
Epoch 26: Train loss = 0.6763
Epoch 27: Train loss = 0.6779
Epoch 28: Train loss = 0.6808
Epoch 29: Train loss = 0.7232


In [0]:
def rmse(targets: List[float], preds: List[float]) -> float:
    return math.sqrt(mean_squared_error(targets, preds))

In [0]:
def evaluate(model: nn.Module, data: MoleculeDataset, batch_size: int):
    model.eval()
    
    all_preds = []
    with torch.no_grad():
      for i in range(0, len(data), batch_size):
        batch = MoleculeDataset(data[i:i + batch_size])
        mol_graph = mol2graph(batch.smiles())
                
        preds = model(mol_graph)
        all_preds.extend(preds)
    
    return rmse(data.targets(), all_preds)

In [45]:
test_rmse = evaluate(model, test_data, batch_size)
print('Test rmse = {:.4f}'.format(test_rmse))

Test rmse = 0.8346
