This notebook demonstrate the pipeline of training a Graph Neural Network to predict the permeability across Blood Brain Barrier, based on the molecular graphs of small drug molecules. 

We demonstrate how to run the modules, show some outputs and explain some design choices as markdown text nearby. See the imported modules for the detailed implementation. 

The 4 sections of the notebook are: 

1. Curate the process the molecule graphs
2. Define and initialize Message Passing Neral Network (MPNN)
3. Train the MPNN, inteprete results
4. Cluster analysis of the molecules based on training output
5. Discussion and Next Steps

In [None]:
import torch
import torch.nn as nn
from torch_geometric.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
import umap
import os

# Custom imports (my own modules)
from DataLoading import MoleculeGraphProcessor, MPNNDataset     # <-- dataset class
from MPNN import MPNNModel                                      # <-- MPNN architecture
from Training import MPNNTrainer                                # <-- trainer class
from ClusterAnalysis import MoleculeClusterAnalysis             # <-- t-SNE / UMAP class

1. Curate the process the molecule graphs

PAMPA assay is a reliable predictor of blood brain barrier (BBB) permeability for small molecules, (article https://doi.org/10.3389/fphar.2023.1291246). From PubChem database, we obtain a dataset of the PAMPA measurement of 438 small molecules and their SMILES descriptors (PubChem AID: 1845228). 

We use RDkit to convert the SMILES into molecular graphs, and record 7 node features for each atom, and 4 edge features for each atomi bond. The node features are: atomic number, hybridization, formal charge, aromaticity, number of hydrogen atoms attached and total valence. The edge features are: bond type (as double), conjugated, in ring and aromaticity. 

For each molecular graph, we also attach the PAMPA permeability measure, as the training target. We print a summary of the dataset after processing.

In [None]:
# Folder file name containing .csv files of PAMPA assays
folder_path = "PAMPAassays"  # This folder should be in the same directory as the script

# Initialize the processor and process the molecules
processor = MoleculeGraphProcessor(folder_path)
molecule_graphs, targets = processor.process_molecules()

# Create a dataset for training MPNN and save it
dataset = MPNNDataset(molecule_graphs, targets)
dataset.save("PAMPA_dataset.pt")

print(f"Processed {len(molecule_graphs)} molecules into molecular graphs, and saved PAMPA permiability metric.\n")
dataset.summary()

2. Define and initialize MPNN

Message Passing Neural Network is a special type of Graph Neural Networks that utilizes both node features AND edge features, which makes it suitable for differentiating and learning the various chemical bonds in molecular graphs. Message passing describes the activity of passing the imformation of all neighbors of a node to this node, which is an act of aggregation, and similar to convolution. 

The architecture of our MPNN is as follows:

> linear layers for node and edge embedding

> MPNN layer 1, containing one MLP(multi layer perceptron) for generating the messages from neighbors, another MLP for updating the node embedding

> MPNN layer 2, same weights as MPNN layer 1

> MPNN layer 3, same weights as MPNN layer 1

> graph level pooling (summarizing all nodes in a graph embedding, we used global_add_pool here)

> MLP (a small mlp to reduce graph embedding into 1 output)

An MPNN model is initialized in our MPNNModel class. The hidden_dim parameter is the dimension of node and edge embedding, which is a learned high-dimensional vector, describing all the properties of nodes and edges. The num_layers parameter describes how many times message passing happens, and the out_dim parameter is the dimension of output variable, which is the one predicted permeability measure.

In [None]:
# we first need to get the input dimensions of node and edges, i.e. the number of features
graph = dataset[0] # read the first graph from dataset 

node_in_dim = graph.x.shape[1]          # 7 in this case
edge_in_dim = graph.edge_attr.shape[1]  # 4 in this case

model = MPNNModel(node_in_dim, edge_in_dim, hidden_dim=64, num_layers=3, out_dim=1)

We do a test run of the initialised MPNN by passing one batch of graphs into the model, which generates one batch of permeability prediction.

In [None]:
loader = DataLoader(dataset, batch_size=4, shuffle=True)

model.eval() # run the model in evaluation mode, dropout disabled
for batch in loader:
    out = model(batch)
    print(out.shape)  # [batch_size, 1]
    print(out)
    break

3. Train the MPNN, inteprete results

A trainer class MPNNTrainer contains all training steps. Here are some technical details: 

We split the dataset with 70%/15%/15% for the train, validation and test dataset. We train on the train dataset, and check the loss against validation dataset to determine when overfitting occurs, which is the end of training. Test dataset is the held-out set and is used to examine the final model performance. 

We implement the Huber loss (also known as Smooth L1 Loss) as our loss function. It is designed to combine the best of both Mean Squared Error (MSE) and Mean Absolute Error (MAE). For small errors, it behaves like MSE, giving smooth gradients. For large errors, it behaves like MAE, making it robust to outliers. The Huber_beta parameter corresponds to the threshold of large and small errors. Apart from the Huber loss, we also report a few other metrics on the validation dataset as reference: MSE, RMSE (Root Mean Squared Error) and R^2 (Coefficient of Determination, measures how much of the variance in the true values can be explained by the modelâ€™s predictions).

We implement drop out and batch normalisation in the MLPs to improve the robustness and generality of the model. We implement exponential decaying learning rate (lr parameter) as the model approaches the optimal solution. We set early stopping with a patience of 30 epochs to prevent overfitting of the model to train dataset.

In the following, we perform training.

In [None]:
trainer = MPNNTrainer(model, dataset, epochs=150, batch_size=32, lr=5e-4, Huber_beta=2, toy_dataset=False)
trainer.train()

We plot the training historis for train and validation loss, as well as the validation metrics.

In [None]:
trainer.plot_training_history()

Finally, after adjusting the hyperparameters and being satisfied with the model, we test the model with held-out test data.

In [None]:
model.eval()
trainer.test()

4. Cluster analysis of the molecules based on training output

The trained model can return a graph embedding of size [1, hidden_dim], which is a descriptor of the molecular graph. We apply two dimension reduction techniques: tSNE and UMAP, to reduce the size of a graph embedding from a large number (hidden_dim) to 2, and plot that as the x and y coordinates of a point. We label each point with the BBB PAMPA permeability measure to check if the embedding is correlated with the permeability.   

We first perform clustering analysis on the whole dataset, and plot the tSNE and UMAP embeddings.

In [None]:
# `model` is trained MPNN and dataset is the test dataset from MPNNDataset
cluster_analysis = MoleculeClusterAnalysis(model, dataset)

# Extract graph embeddings for
embeddings = cluster_analysis.extract_embeddings()

# Compute 2D projection
cluster_analysis.compute_tsne()
cluster_analysis.compute_umap()

# Plot t-SNE colored by Permeability
cluster_analysis.plot_embeddings(color_by='Permeability')

From the cluster scatter plots, we can see a reasonably obvious permeability gradient over the space, indicating that the graph embedding is corelated with the permeability measure.

However, since the model knows the permeability of 70% of the whole dataset, the cluster analysis above is biased towards the training molecules. To assess the generality of the embedding from the trained model, we perform clustering analysis on the test dataset, which the model never saw. 

Below is an unbiased assessment of the graph embedding, where the high permeability molecules are less separated from the low ones. Mostly, there are many high permeabilties molecules located near the low ones, indicating that more training is needed to learn a better embedding that separates them further. 

In [None]:
cluster_analysis = MoleculeClusterAnalysis(model, dataset, test_data=True)
embeddings = cluster_analysis.extract_embeddings()
cluster_analysis.compute_tsne()
cluster_analysis.compute_umap()
cluster_analysis.plot_embeddings(color_by='Permeability')

5. Discussion and Next Steps

The current model has improved permeability prediction accuracy, evidenced by the reduction of Huber loss and MSE, but the model is not yet good enough as a functional BBB permeability predictor.

To improve the MPNN model, we could try incorporate the molecular level features into the computing of the graph embedding. Molecular features such as TPSA (Total Polar Surface Area), molecular weight could be very relevant to BBB permeability. Additionally, we could incorporate the molecular fingerprints from RDKit, which provide substructure information and could be useful if the presence of certain substructures is relevant to BBB Permeability.