<a href="https://colab.research.google.com/github/yala/deeplearning_bootcamp/blob/master/lab4/bonus_property_prediction_solution_message_passing_neural_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to Message Passing Neural Networks

In this tutorial, we'll take you through developing a message passing neural network (MPNN), which is a type of neural network that operates on graphs. We'll then show you how to use an MPNN to predict the properties of molecules.

Let's get started!

# Preliminaries

The next few sections will set up the necessary components of the tutorial, including:


1.   Installing PyTorch and RDKit
2.   Importing dependencies
3.   Downloading and processing data
4.   Defining training and evaluation procedures

## Download PyTorch

In [0]:
# http://pytorch.org/
from os import path
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())

accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'

!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.0-{platform}-linux_x86_64.whl torchvision==0.2.0
import torch
print(torch.__version__)
print(torch.cuda.is_available())

## Download RDKit

RDKit is a Python cheminformatics package which makes it easy to work with molecules.

In [0]:
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!conda install -q -y --prefix /usr/local -c conda-forge rdkit rdkit scikit-learn

## Imports

In [0]:
import math
import os
import pickle
import random
import re
import sys
from typing import List, Tuple

sys.path.append('/usr/local/lib/python3.7/site-packages/')

import numpy as np
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from sklearn.metrics import mean_squared_error
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import trange

## Download and Process Data

In [0]:
# Download data
!wget https://raw.githubusercontent.com/yala/introML_chem/master/lab1/data/chem/delaney_train.csv
!wget https://raw.githubusercontent.com/yala/introML_chem/master/lab1/data/chem/delaney_val.csv
!wget https://raw.githubusercontent.com/yala/introML_chem/master/lab1/data/chem/delaney_test.csv

In [0]:
# Define Datapoint and Dataset classes
class MoleculeDatapoint:
  def __init__(self, smiles: str, targets: List[float]):
    self.smiles = smiles
    self.targets = targets
    
  def __str__(self):
    return f'{self.smiles},' + ','.join(str(target) for target in self.targets)
    
class MoleculeDataset:
  def __init__(self, data: List[MoleculeDatapoint]):
    self.data = data
    
  def smiles(self) -> List[str]:
    return [d.smiles for d in self.data]
  
  def targets(self) -> List[float]:
    return [d.targets for d in self.data]
  
  def shuffle(self, seed: int = None):
    if seed is not None:
      random.seed(seed)
    random.shuffle(self.data)
  
  def __len__(self) -> int:
    return len(self.data)
  
  def __getitem__(self, item) -> MoleculeDatapoint:
    return self.data[item]
  
def get_data(split: str) -> MoleculeDataset:
  data_path = 'delaney_{}.csv'.format(split)
  with open(data_path) as f:
    f.readline()
    data = []
    for line in f:
      line = line.strip().split(',')
      smiles, targets = line[0], line[1:]
      targets = [float(target) for target in targets]
      data.append(MoleculeDatapoint(smiles, targets))
      
  return MoleculeDataset(data)

In [0]:
# Load data
train_data, val_data, test_data = get_data('train'), get_data('val'), get_data('test')

print(f'Num train = {len(train_data):,}')
print(f'Num val = {len(val_data):,}')
print(f'Num test = {len(test_data):,}')
print()
print('Example data point')
print(train_data[0])

## Model and Training Settings

In [0]:
batch_size = 50
num_epochs = 10
lr = 1e-3
weight_decay = 1e-4
hidden_size = 300
depth = 3
output_size = 1  # do not modify
dropout = 0.0

## Utility Functions

In [0]:
def rmse(targets: List[float], preds: List[float]) -> float:
    return math.sqrt(mean_squared_error(targets, preds))
  
def param_count(model: nn.Module) -> int:
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

## Training Procedure

In [0]:
def train_epoch(model: nn.Module,
                optimizer: optim.Optimizer,
                data: MoleculeDataset,
                batch_size: int,
                epoch: int) -> float:
  model.train()
  data.shuffle(seed=epoch)
  
  total_loss = 0
  num_batches = 0
  
  data_size = len(data) // batch_size * batch_size  # drop final, incomplete batch
  
  for i in trange(0, data_size, batch_size):
    # 1) get batch
    batch = MoleculeDataset(data[i:i + batch_size])
    
    # Convert SMILES strings to molecular graphs
    mol_graph, targets = mol2graph(batch.smiles()), batch.targets()
    
    # Cast targets to FloatTenso
    targets = torch.FloatTensor(targets)
    
    # Reset gradient data to 0
    optimizer.zero_grad()
    
    # Get prediction for batch
    preds = model(mol_graph)
    
    # 2) Compute loss
    loss = F.mse_loss(preds, targets)
    
    # 3) Do backprop
    loss.backward()
    
    # 4) Update model
    optimizer.step()
    
    # Do book-keeping to track loss
    total_loss += math.sqrt(loss.item())
    num_batches += 1
    
  avg_loss = total_loss / num_batches
  
  return avg_loss

## Evaluation Procedure

In [0]:
def evaluate(model: nn.Module, data: MoleculeDataset, batch_size: int) -> float:
    model.eval()
    
    all_preds = []
    with torch.no_grad():
      for i in range(0, len(data), batch_size):
        batch = MoleculeDataset(data[i:i + batch_size])
        mol_graph = mol2graph(batch.smiles())
                
        preds = model(mol_graph)
        all_preds.extend(preds)
    
    return rmse(data.targets(), all_preds)

# Message Passing Neural Networks (MPNNs)

Message passing neural networks (MPNNs) are neural networks that are defined to operate on graph input. This makes them ideal for working with molecules, which can be represented as a graph where atoms are nodes and bonds are edges.

<img src="https://github.com/yala/introML_chem/raw/master/lab4/message_passing.png">

MPNNs work in two phases: a *message passing phase* and a *readout phase*.

**Message Passing Phase**

During the message passing phase, each bond's representation (or "message") is updated based on the output of a neural network applied to the sum of the neighboring bond's representations. This process is repeated a number of times, causing information to flow across the graph and allowing each bond to become aware of the surrounding bonds and its local chemistry.

<img src="https://github.com/yala/introML_chem/raw/master/lab4/bond_message_passing.png">

**Readout Phase**

During the readout phase, the final bond representations are combined to produce a single representation for the entire molecule. This molecular representation is then passed through a feed-forward neural network to make the final property prediction.

## Molecular Featurization

In order to run an MPNN on a molecule, we first need to featurize the individual atoms and bonds. The code below loops through each atom and bond and extracts features such as atomic number or bond type.

Then, it collects all the atom and bond features and combines them into PyTorch tensors. The features for different molecules are then concatenated to create a tensor with a batch of atom and bond features across multiple molecules

In [0]:
# Define enumerations of possible features
ELEM_LIST = range(100)
HYBRID_LIST = [
    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
    Chem.rdchem.HybridizationType.SP3D2
]
ATOM_FDIM = 100 + len(HYBRID_LIST) + 6 + 5 + 4 + 7 + 5 + 3 + 1
BOND_FDIM = 6 + 6
MAX_NB = 12

# Creates a one-hot vector with an extra element for uncommon features
def onek_encoding_unk(x: int, allowable_set: List[int]):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

# Creates a feature vector for an atom
def atom_features(atom: Chem.rdchem.Atom) -> torch.Tensor:
    return torch.Tensor(
        onek_encoding_unk(atom.GetAtomicNum() - 1, ELEM_LIST)
        + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) 
        + onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
        + onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3])
        + onek_encoding_unk(int(atom.GetImplicitValence()), [0,1,2,3,4,5,6])
        + onek_encoding_unk(int(atom.GetTotalNumHs()), [0,1,2,3,4])
        + onek_encoding_unk(int(atom.GetHybridization()), HYBRID_LIST)
        + onek_encoding_unk(int(atom.GetNumRadicalElectrons()), [0,1,2])
        + [atom.GetIsAromatic()]
    )

# Creates a feature vector for a bond
def bond_features(bond: Chem.rdchem.Bond) -> torch.Tensor:
    bt = bond.GetBondType()
    stereo = int(bond.GetStereo())
    fbond = [
        bt == Chem.rdchem.BondType.SINGLE,
        bt == Chem.rdchem.BondType.DOUBLE,
        bt == Chem.rdchem.BondType.TRIPLE,
        bt == Chem.rdchem.BondType.AROMATIC,
        bond.GetIsConjugated(),
        bond.IsInRing()
    ]
    fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5])
    return torch.Tensor(fbond + fstereo)

# Creates feature vectors and adjacency matrices for a batch of molecules
def mol2graph(mol_batch: List[str]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[Tuple[int, int]]]:
    padding = torch.zeros(ATOM_FDIM + BOND_FDIM)
    fatoms,fbonds = [],[padding]  # Ensure bond is 1-indexed
    in_bonds,all_bonds = [],[(-1,-1)]  # Ensure bond is 1-indexed
    scope = []
    total_atoms = 0

    for smiles in mol_batch:
        mol = Chem.MolFromSmiles(smiles)
        n_atoms = mol.GetNumAtoms()
        for atom in mol.GetAtoms():
            fatoms.append( atom_features(atom) )
            in_bonds.append([])

        for bond in mol.GetBonds():
            a1 = bond.GetBeginAtom()
            a2 = bond.GetEndAtom()
            x = a1.GetIdx() + total_atoms
            y = a2.GetIdx() + total_atoms

            b = len(all_bonds) 
            all_bonds.append((x,y))
            fbonds.append( torch.cat([fatoms[x], bond_features(bond)], 0) )
            in_bonds[y].append(b)

            b = len(all_bonds)
            all_bonds.append((y,x))
            fbonds.append( torch.cat([fatoms[y], bond_features(bond)], 0) )
            in_bonds[x].append(b)

        scope.append((total_atoms,n_atoms))
        total_atoms += n_atoms

    total_bonds = len(all_bonds)
    fatoms = torch.stack(fatoms, 0)
    fbonds = torch.stack(fbonds, 0)
    agraph = torch.zeros(total_atoms,MAX_NB).long()
    bgraph = torch.zeros(total_bonds,MAX_NB).long()

    for a in range(total_atoms):
        for i,b in enumerate(in_bonds[a]):
            agraph[a,i] = b

    for b1 in range(1, total_bonds):
        x,y = all_bonds[b1]
        for i,b2 in enumerate(in_bonds[x]):
            if all_bonds[b2][0] != y:
                bgraph[b1,i] = b2

    return fatoms, fbonds, agraph, bgraph, scope

## Define MPNN

In [0]:
class MPNN(nn.Module):
  def __init__(self, hidden_size, depth, output_size, dropout):
    super(MPNN, self).__init__()
    self.depth = depth
    
    # Initial embedding of bond features
    self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
    
    # Linear layer used during message passing on neighboring bonds
    self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
    
    # Output linear layer during beginning of readout phase
    self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
    
    # Final prediction linear layer
    self.W_pred = nn.Linear(hidden_size, output_size)
    
    # Activation
    self.relu = nn.ReLU()
    
    # Dropout
    self.dropout = nn.Dropout(dropout)
    
  # Selects indices from a PyTorch Tensor
  def index_select_ND(self, source: torch.Tensor, dim: int, index: torch.Tensor):
      index_size = index.size()
      suffix_dim = source.size()[1:]
      final_size = index_size + suffix_dim
      target = source.index_select(dim, index.view(-1))
      target = target.view(final_size)
      return target
    
  def forward(self, mol_graph) -> torch.FloatTensor:
    # Extract components of molecular graph
    fatoms,fbonds,agraph,bgraph,scope = mol_graph

    # Embed the bond features
    binput = self.W_i(fbonds)
    message = self.relu(binput)

    # --- Message passing phase ---
    
    # Loop through number of message passing steps
    for i in range(self.depth - 1):
        # Select messages from bonds that neighbor each bond
        nei_message = self.index_select_ND(message, 0, bgraph)
        
        # Sum messages from neighboring bonds
        nei_message = nei_message.sum(dim=1)
        
        # Apply linear layer to the sum
        nei_message = self.W_h(nei_message)
        
        # Apply relu to initial bond featurization + neighboring messages
        message = self.relu(binput + nei_message)
        
        # Apply dropout
        message = self.dropout(message)

    # --- Readout phase ---
    
    # Select messages from bonds that neighbor each atom
    nei_message = self.index_select_ND(message, 0, agraph)
    
    # Sum messages from neighboring bonds
    nei_message = nei_message.sum(dim=1)
    
    # Concatenate atom featurization with sum of bond featurizations
    ainput = torch.cat([fatoms, nei_message], dim=1)
    
    # Apply linear layer
    atom_hiddens = self.relu(self.W_o(ainput))
    
    # Apply dropout
    atom_hiddens = self.dropout(atom_hiddens)

    # Separate out each graph in the batch and sum atom vectors
    mol_vecs = []
    for st,le in scope:
        mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
        mol_vecs.append(mol_vec)

    # Stack molecule vectors into one tensor
    mol_vecs = torch.stack(mol_vecs, dim=0)
    
    # Apply output linear layer
    mol_vecs = self.W_pred(mol_vecs)
    
    return mol_vecs

## Build MPNN

In [0]:
model = MPNN(hidden_size, depth, output_size, dropout)

print(model)
print(f'Number of parameters = {param_count(model):,}')

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

## Train MPNN

In [0]:
for epoch in range(1, num_epochs + 1):
  train_loss = train_epoch(model, optimizer, train_data, batch_size, epoch)
  val_rmse = evaluate(model, val_data, batch_size)
  print(f'Epoch {epoch}: Train loss = {train_loss:.4f}, Val rmse = {val_rmse:.4f}')

## Test MPNN

In [0]:
test_rmse = evaluate(model, test_data, batch_size)
print(f'Test rmse = {test_rmse:.4f}')