# Message Passing network
Code is based on https://keras.io/examples/graph/mpnn-molecular-graphs/.
It was refactored and updated.

implement an MPNN based on the original paper [Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) and [DeepChem's MPNNModel](https://deepchem.readthedocs.io/en/latest/api_reference/models.html#mpnnmodel).

In [None]:
import sys
import os
import numpy as np
import warnings
from rdkit import RDLogger
from rdkit.Chem.Draw import IPythonConsole

# Temporary suppress warnings and RDKit logs
warnings.filterwarnings("ignore")
RDLogger.DisableLog("rdApp.*")
# Temporary suppress tf logs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

np.random.seed(42)
sys.path.append("../")

In [None]:
from src.data import load_data
from src.features import smiles_bio
from src.models.mpnn import MPNNModel
from src.models.mpnn_trainer import MPNNTrainer
from src.visualisation.smiles_vis import visualise_molecule
from src.validation.mpnn_val import look_up

# Data

In [None]:
df = load_data.load_csv_bbbp()

In [None]:
df.info()

In [None]:
df.iloc[55:60]

# Features

## SMILES to graph

In [None]:
atom_featurizer = smiles_bio.AtomFeaturizer(
    allowable_sets={
        "symbol": {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"},
        "n_valence": {0, 1, 2, 3, 4, 5, 6},
        "n_hydrogens": {0, 1, 2, 3, 4},
        "hybridization": {"s", "sp", "sp2", "sp3"},
    }
)

bond_featurizer = smiles_bio.BondFeaturizer(
    allowable_sets={
        "bond_type": {"single", "double", "triple", "aromatic"},
        "conjugated": {True, False},
    }
)

In [None]:
# Shuffle array of indices ranging from 0 to 2049
permuted_indices = np.random.permutation(np.arange(df.shape[0]))

In [None]:
# Train set: 80 % of data
train_index = permuted_indices[: int(df.shape[0] * 0.8)]
x_train = smiles_bio.graphs_from_smiles(df.iloc[train_index].smiles, atom_featurizer, bond_featurizer)
y_train = df.iloc[train_index].p_np

In [None]:
# Valid set: 19 % of data
valid_index = permuted_indices[int(df.shape[0] * 0.8) : int(df.shape[0] * 0.99)]
x_valid = smiles_bio.graphs_from_smiles(df.iloc[valid_index].smiles, atom_featurizer, bond_featurizer)
y_valid = df.iloc[valid_index].p_np

In [None]:
# Test set: 1 % of data
test_index = permuted_indices[int(df.shape[0] * 0.99) :]
x_test = smiles_bio.graphs_from_smiles(df.iloc[test_index].smiles, atom_featurizer, bond_featurizer)
y_test = df.iloc[test_index].p_np

In [None]:
molecule = visualise_molecule(df, 100)
molecule

In [None]:
graph = smiles_bio.graph_from_molecule(molecule, atom_featurizer, bond_featurizer)
print("Graph (including self-loops):")
print("\tatom features\t", graph[0].shape)
print("\tbond features\t", graph[1].shape)
print("\tpair indices\t", graph[2].shape)

## tf.data.Dataset

In [None]:
train_dataset = smiles_bio.MPNNDataset(x_train, y_train)
valid_dataset = smiles_bio.MPNNDataset(x_valid, y_valid)
test_dataset = smiles_bio.MPNNDataset(x_test, y_test)

# Model

In [None]:
trainer = MPNNTrainer(x_train)

In [None]:
trainer.model.summary()

In [None]:
history = trainer.train(train_dataset, valid_dataset, epochs=40)

# Validation

In [None]:
molecules, legends, grid = look_up(trainer, df, test_dataset, test_index)
grid

In [None]:
grid.save("molecules.png")