<a href="https://colab.research.google.com/github/swansonk14/chemprop-intro/blob/master/lab3/message_passing_neural_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Message Passing Neural Network on Graph Structure

In [0]:
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!conda install -q -y --prefix /usr/local -c rdkit rdkit pytorch

import sys
sys.path.append('/usr/local/lib/python3.6/site-packages/')

!wget https://raw.githubusercontent.com/swansonk14/chemprop-intro/master/data/delaney_train.csv
!wget https://raw.githubusercontent.com/swansonk14/chemprop-intro/master/data/delaney_test.csv

In [0]:
import math
import os
import random
from typing import Union, List, Dict

import numpy as np
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from sklearn.metrics import mean_squared_error
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
class MoleculeDatapoint:
  def __init__(self, smiles: str, targets: List[float]):
    self.smiles = smiles
    self.targets = targets
    
class MoleculeDataset:
  def __init__(self, data: List[MoleculeDatapoint]):
    self.data = data
    
  def smiles(self) -> List[str]:
    return [d.smiles for d in self.data]
  
  def targets(self) -> List[float]:
    return [d.targets for d in self.data]
  
  def shuffle(self, seed: int = None):
    if seed is not None:
      random.seed(seed)
    random.shuffle(self.data)
  
  def __len__(self) -> int:
    return len(self.data)
  
  def __getitem__(self, item) -> MoleculeDatapoint:
    return self.data[item]

In [0]:
def get_data(split: str) -> MoleculeDataset:
  data_path = 'delaney_{}.csv'.format(split)
  with open(data_path) as f:
    f.readline()
    data = []
    for line in f:
      line = line.strip().split(',')
      smiles, targets = line[0], line[1:]
      targets = [float(target) for target in targets]
      data.append(MoleculeDatapoint(smiles, targets))
      
  return MoleculeDataset(data)

In [0]:
train_data, test_data = get_data('train'), get_data('test')

In [0]:
# TODO: convert data to graph representation

In [0]:
num_epochs = 30
batch_size = 50
lr = .01

In [0]:
class MPN(nn.Module):
  def __init__(self):
    super(FFN, self).__init__()
    # TODO

  def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
    # TODO
    
    return x

In [0]:
model = FFN()
optimizer = optim.SGD(model.parameters(), lr=lr)

In [0]:
def param_count(model: nn.Module) -> int:
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

In [0]:
print(model)
print('Number of parameters = {:,}'.format(param_count(model)))

In [0]:
def train_epoch(model: nn.Module,
                optimizer: optim.Optimizer,
                data: MoleculeDataset,
                batch_size: int,
                epoch: int) -> float:
  model.train()
  data.shuffle(seed=epoch)
  
  total_loss = 0
  num_batches = 0
  
  data_size = len(data) // batch_size * batch_size  # drop final, incomplete batch
  for i in range(0, data_size, batch_size):
    batch = MoleculeDataset(data[i:i + batch_size])
    morgans, targets = batch.morgans(), batch.targets()
    
    morgans, targets = torch.FloatTensor(morgans), torch.FloatTensor(targets)
    
    optimizer.zero_grad()
    preds = model(morgans)
    loss = F.mse_loss(preds, targets)
    loss.backward()
    optimizer.step()
    
    total_loss += math.sqrt(loss.item())
    num_batches += 1
    
  avg_loss = total_loss / num_batches
  
  return avg_loss

In [0]:
num_epochs = 30
for epoch in range(num_epochs):
  train_loss = train_epoch(model, optimizer, train_data, batch_size, epoch)
  print('Epoch {}: Train loss = {:.4f}'.format(epoch, train_loss))

In [0]:
def rmse(targets: List[float], preds: List[float]) -> float:
    return math.sqrt(mean_squared_error(targets, preds))

In [0]:
def evaluate(model: nn.Module, data: MoleculeDataset, batch_size: int):
    model.eval()
    
    all_preds = []
    with torch.no_grad():
      for i in range(0, len(data), batch_size):
        batch = MoleculeDataset(data[i:i + batch_size])
        morgans = batch.morgans()
        
        morgans = torch.FloatTensor(morgans)
        
        preds = model(morgans)
        all_preds.extend(preds)
    
    return rmse(data.targets(), all_preds)

In [0]:
test_rmse = evaluate(model, test_data, batch_size)
print('Test rmse = {:.4f}'.format(test_rmse))