# Serotonin 3D GNN Project


This project builds upon research done by Łapińska et al. (2024): https://doi.org/10.3390/pharmaceutics16030349

Data used: https://ftp.ebi.ac.uk/pub/databases/chembl/ChEMBLdb/releases/chembl_35/

Move the unpacked chembl_35_sqlite.tar.gz file into the data/ dir.

The research linked above presents two Quantitative Structure-Activity Relationship (QSAR) models to predict serotonergic binding affinity and selectivity, respectively, using Mordred molecular 2D descriptors. Specifically, one model classifies compounds binarily as "active" or "inactive", with a cutoff of pKi = 7. Another model does multiclass classification to predict the serotonergic selectivity of compounds previously classified as "active".

I am following a similar approach, but using 3D molecular graph representations instead of 2D molecular descriptors as input modality and using only the ChEMBL database, not ZINC.


## Google Colab Setup


### Configuration


In [1]:
from pathlib import Path

IN_COLAB = False

PATH_NOTEBOOK = (
    Path("/content/drive/MyDrive/Colab Notebooks/serotonin-3d-gnn.ipynb")
    if IN_COLAB
    else Path(
        "/Users/paul/Library/CloudStorage/GoogleDrive-unoutsch@gmail.com/My Drive/Colab Notebooks/serotonin-3d-gnn.ipynb"
    )
)
PATH_REPO = (
    Path("/content/drive/MyDrive/Repositories/serotonin-3d-gnn")
    if IN_COLAB
    else Path.cwd()
)
PATH_DATA = PATH_REPO / "data"

### Syncing Google Drive with Google Colab Content


In [2]:
if IN_COLAB:
    from google.colab import drive

    drive.mount("/content/drive")

### Installing Requirements


In [3]:
%pip install -r "$PATH_REPO/requirements.txt"

Note: you may need to restart the kernel to use updated packages.


## Imports


In [4]:
import os
import pandas as pd
import pickle
from rdkit import Chem
from rdkit.Chem import AllChem
import shutil
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

## Utils


### Syncing this file between Colab and local Git repo


Make sure the paths exist.


In [5]:
if IN_COLAB:
    shutil.copyfile(PATH_NOTEBOOK, PATH_REPO / "serotonin-3d-gnn.ipynb")
else:
    shutil.copyfile(PATH_REPO / "serotonin-3d-gnn.ipynb", PATH_NOTEBOOK)

## Data


### Note on Data Aquisition from chembl_35.db

In order to collect the desired data from the ChEMBL SQL database and transform it into a .csv file, I undertook the steps detailed in `data/README.md`.


### Loading the Data


If the pickled torch_data_list already exists, load it.


In [6]:
pickle_file_path = PATH_DATA / "torch_data_list.pkl"

if os.path.exists(pickle_file_path):
    torch_data_list = pickle.load(open(pickle_file_path, "rb"))
    print("Loaded torch_data_list from pickle file")
else:
    print("Creating torch_data_list from scratch")

Loaded torch_data_list from pickle file


#### Load dataframe from .csv file


In [7]:
# if df is None:
df = pd.read_csv(PATH_DATA / "serotonin_binding_summary.csv")

df.describe()

Unnamed: 0,molecule_id,Serotonin (5-HT) receptor,Serotonin 1 (5-HT1) receptor,Serotonin 1 receptors; 5-HT1B & 5-HT1D,Serotonin 1a (5-HT1a) receptor,Serotonin 1b (5-HT1b) receptor,Serotonin 1d (5-HT1d) receptor,Serotonin 1e (5-HT1e) receptor,Serotonin 1f (5-HT1f) receptor,Serotonin 2 (5-HT2) receptor,...,Serotonin 2b (5-HT2b) receptor,Serotonin 2c (5-HT2c) receptor,Serotonin 3 (5-HT3) receptor,Serotonin 3a (5-HT3a) receptor,Serotonin 3b (5-HT3b) receptor,Serotonin 4 (5-HT4) receptor,Serotonin 5a (5-HT5a) receptor,Serotonin 5b (5-HT5b) receptor,Serotonin 6 (5-HT6) receptor,Serotonin 7 (5-HT7) receptor
count,23456.0,90.0,252.0,1.0,9462.0,1492.0,1472.0,91.0,127.0,1469.0,...,2337.0,4343.0,939.0,1040.0,8.0,1009.0,422.0,1.0,4221.0,3100.0
mean,1003325.0,6.081759,6.683902,6.2,7.258523,6.952528,7.554968,5.791172,7.458423,7.053201,...,6.603829,6.81021,7.625768,7.04752,7.2035,7.645809,6.573801,7.17,7.311171,6.977487
std,898658.3,0.926906,1.104283,,1.152004,1.226482,1.36588,0.652239,0.85924,1.159567,...,0.981462,1.032874,1.225942,1.535413,1.735342,1.179482,1.089819,,1.143388,1.016128
min,97.0,4.39,4.1,6.2,4.0,4.0,4.0,4.8,5.14,4.03,...,4.19,4.0,4.01,4.0,5.46,5.0,4.07,7.17,4.12,4.0
25%,229157.0,5.385,5.8,6.2,6.48,6.05,6.47,5.36,7.0,6.24,...,5.9,6.05,6.8,5.7,5.49,6.81,5.8,7.17,6.47,6.285
50%,575761.5,5.925,6.6,6.2,7.28,6.85,7.64,5.73,7.72,6.92,...,6.523333,6.74,7.7,7.185,7.015,7.64,6.39,7.17,7.36,6.99
75%,1965967.0,6.6575,7.585,6.2,8.06,7.85,8.7,6.175,8.02,8.0,...,7.21,7.5125,8.58,8.41,8.3935,8.4,7.06875,7.17,8.11,7.7
max,2881244.0,9.8,9.3,6.2,11.0,10.0,10.7,8.2,8.8,10.3,...,10.1,10.7,10.42,10.4,9.604,10.8,9.17,7.17,10.4,10.0


In [8]:
df_targets = df.drop(columns=["molecule_id", "canonical_smiles"])
print(f"Number of targets: {len(df_targets.columns)}")

# compute non-nan counts per column and create a boolean mask to filter targets to contain at least n_threshold non-nan values
n_threshold = 9000
non_nan_counts = torch.tensor(df_targets.notna().sum().values, dtype=torch.long)
mask = non_nan_counts >= n_threshold
valid_column_indices = torch.nonzero(mask, as_tuple=True)[0]

print("Boolean mask:", mask)
print(f"Included targets ({len(valid_column_indices)}):", valid_column_indices)

Number of targets: 21
Boolean mask: tensor([False, False, False,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False])
Included targets (1): tensor([3])


In [None]:
def create_torch_data(smiles: str, targets: torch.Tensor) -> Data:
    # getting RDKit molecule object
    mol = Chem.MolFromSmiles(smiles)

    if mol is None:
        return None

    # add explicit hydrogen atoms to the molecule (are not included in the SMILES string) so that its 3D structure is complete
    mol = Chem.AddHs(mol)

    # EmbedMolecule positions atoms of mol in 3D space stochastically; if it fails (returning -1) return None
    if AllChem.EmbedMolecule(mol, randomSeed=42) == -1:
        return None

    # optimize the 3D structure using Universal Force Field (UFF) to lower mol's energy
    AllChem.UFFOptimizeMolecule(mol)

    # conformer contains 3D coordinates for mol's atoms
    conformer = mol.GetConformer()

    # atom-level features and 3D positions
    atom_features, positions = [], []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),  # atomic number: uid of element (e.g., 6 for carbon, 8 for oxygen)
            atom.GetDegree(),  # degree: number of bonds connecting the atom
            atom.GetFormalCharge(),  # formal charge: atom's electrical charge
            int(
                atom.GetHybridization()
            ),  # hybridization: type of atom's orbital hybridization (e.g., sp, sp2) as int
        ]
        atom_features.append(features)

        # 3D coordinates of atom from conformer
        pos = conformer.GetAtomPosition(atom.GetIdx())
        positions.append([pos.x, pos.y, pos.z])

    # transform to PyTorch tensors
    x = torch.tensor(atom_features, dtype=torch.float)
    pos = torch.tensor(positions, dtype=torch.float)

    # bonds between atoms – indices of connected atoms as well as types and conjugation
    edge_index, edge_attr = [], []
    for bond in mol.GetBonds():
        # indices of bonded atoms
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()

        bond_feat = [
            bond.GetBondTypeAsDouble(),  # bond type as float (e.g., 1.0 for single, 2.0 for double bonds)
            (
                1.0 if bond.GetIsConjugated() else 0.0
            ),  # conjugation flag indicating whether the bond is conjugated (1.0 if true, else 0.0) (conjugated means that electrons are delocalized, moving freely between multiple atoms)
        ]
        # for undirected graph, add bond in both directions
        edge_index += [[i, j], [j, i]]
        edge_attr += [bond_feat, bond_feat]

    # transform to PyTorch tensors
    # edge_index tensor is transposed to fit torch_geometric's expected shape (2, number_of_edges).
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    # graph as PyTorch Geometric Data object
    # x: atom features, [atomic number, degree, formal charge, hybridization]
    # pos: 3D positions of atoms, [x, y, z]
    # edge_index: connectivity indices between atoms, [[i, j], [j, i]]
    # edge_attr: features per bond, [[bond type, conjugation], [bond type, conjugation]]
    return Data(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, y=targets)

In [None]:
torch_data_list = [
    create_torch_data(
        row["canonical_smiles"],
        torch.tensor(df_targets.loc[row.name].values, dtype=torch.float),
    )
    for _, row in df.iterrows()
]

In [None]:
pickle_file_path = PATH_DATA / "torch_data_list.pkl"

with open(pickle_file_path, "wb") as f:
    pickle.dump(torch_data_list, f)

print(f"Saved torch_data_list to {pickle_file_path}")

Saved torch_data_list to /Users/paul/My Drive/Repositories/serotonin-3d-gnn/data/torch_data_list.pkl


Create training and test sets.


In [27]:
filtered_torch_data_list = [d.clone() for d in torch_data_list if d is not None]
print(
    f"Number of items in filtered_torch_data_list: {len(filtered_torch_data_list)} / {len(torch_data_list)}"
)  # still retaining original torch_data_list for reference to df later
print(f"filtered_torch_data_list[0]: {filtered_torch_data_list[0]}")

new_filtered_torch_data_list = []

# only include targets with at least n_threshold non-nan values
for d in filtered_torch_data_list:
    d.y = d.y[valid_column_indices]
    if not torch.isnan(d.y).all():
        new_filtered_torch_data_list.append(d)

print(
    f"Number of items in new_filtered_torch_data_list: {len(new_filtered_torch_data_list)} / {len(filtered_torch_data_list)}"
)
filtered_torch_data_list = new_filtered_torch_data_list

split_idx = int(0.8 * len(filtered_torch_data_list))

data_graph_train = DataLoader(
    filtered_torch_data_list[:split_idx], batch_size=32, shuffle=True
)
data_graph_test = DataLoader(
    filtered_torch_data_list[split_idx:], batch_size=32, shuffle=False
)

print(
    f"# training batches: {len(data_graph_train)}\n# test batches: {len(data_graph_test)}"
)
print(f"Example data point: {filtered_torch_data_list[190]}")

Number of items in filtered_torch_data_list: 23439 / 23456
filtered_torch_data_list[0]: Data(x=[49, 4], edge_index=[2, 104], edge_attr=[104, 2], y=[21], pos=[49, 3])
Number of items in new_filtered_torch_data_list: 9460 / 23439
# training batches: 237
# test batches: 60
Example data point: Data(x=[54, 4], edge_index=[2, 114], edge_attr=[114, 2], y=[1], pos=[54, 3])


## Models


### Model 1: PyTorch Implementation of a 3D GCN

In this section, a 3D graph convolutional network is created using PyTorch. The model takes as input a 3D molecular graph and outputs predictions of the serotonergic binding affinity of the molecule.

Information about the graph input the model will receive and process:

-   The feature matrix H contains the node (atom) features. Each row corresponds to a node, and each column corresponds to a feature.
-   The adjacency matrix A is built from the edge_index tensor, which contains the indices of the edges in the graph. The matrix A is built under the hood of the GCNConv class.


#### Model Architecture


In [10]:
from torch_geometric.nn import NNConv, GCNConv, global_mean_pool
from torch.nn import Linear
import torch.nn.functional as F


class SeroGCN(torch.nn.Module):
    def __init__(self, n_in, n_hidden, n_out, n_edge_attr):
        super(SeroGCN, self).__init__()

        edge_network = torch.nn.Sequential(
            torch.nn.Linear(n_edge_attr, n_hidden * n_in),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden * n_in, n_hidden * n_in),
        )

        self.pos_lin = Linear(3, n_in)  # simple positional encoding layer

        self.conv1 = NNConv(n_in, n_hidden, edge_network, aggr="mean")
        self.conv2 = GCNConv(n_hidden, n_hidden)

        self.fc = Linear(n_hidden, n_out)

    def forward(self, mol_batch) -> torch.Tensor:
        x, pos, edge_index, edge_attr = (
            mol_batch.x,
            mol_batch.pos,
            mol_batch.edge_index,
            mol_batch.edge_attr,
        )

        pos_feat = self.pos_lin(pos)
        x = x + pos_feat  # simple positional encoding

        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)

        x = self.conv2(x, edge_index)
        x = F.relu(x)

        x = global_mean_pool(
            x, mol_batch.batch
        )  # global mean pooling aggregates node features, returning a single graph-level vectorial representation

        x = self.fc(x)

        return x

### Model 2: Pretrained 3D GNN (...)


## Training


In [11]:
if torch.cuda.is_available():
    print("Using CUDA")
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    print("Using MPS")
    device = torch.device("mps")
else:
    print("Using CPU")
    device = torch.device("cpu")

Using MPS


In [12]:
epochs = 10
n_hidden = 32

n_in = data_graph_train.dataset[0].num_features
n_out = len(valid_column_indices)
n_edge_attr = data_graph_train.dataset[0].edge_attr.shape[1]

print(f"Node features: {n_in}, targets: {n_out}, edge attributes: {n_edge_attr}")

Node features: 4, targets: 1, edge attributes: 2


In [13]:
def masked_mse_loss(pred, target):
    # mask of non-nan targets
    mask = ~torch.isnan(target)
    if mask.sum() == 0:
        # return 0 loss, so that it doesn't affect the gradient
        return torch.tensor(0.0, requires_grad=True, device=target.device)
    # squared error for entries that are valid
    loss = (pred[mask] - target[mask]) ** 2
    return loss.mean()

In [21]:
import time


def train(
    model: torch.nn.Module, data_loader: DataLoader, optimizer: torch.optim.Optimizer
):
    model.train()
    for epoch in range(epochs):
        for data in data_loader:
            print(data.y)
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            # masked loss to ignore nans
            start_time = time.now()
            loss = masked_mse_loss(out, data.y.view(-1, n_out))
            total_time = time.now() - start_time
            print(f"Time taken for loss calculation: {total_time}")
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}: Loss = {loss.item()}")

In [22]:
sero_gcn = SeroGCN(
    n_in=n_in, n_hidden=n_hidden, n_out=n_out, n_edge_attr=n_edge_attr
).to(device)
sero_gcn_optimizer = torch.optim.Adam(sero_gcn.parameters(), lr=0.01)
# sero_gcn_criterion = torch.nn.MSELoss()

In [23]:
train(sero_gcn, data_graph_train, sero_gcn_optimizer)

tensor([   nan,    nan,    nan, 7.4000,    nan,    nan,    nan,    nan,    nan,
        6.3500,    nan, 8.7600,    nan, 5.0800, 8.4600,    nan,    nan,    nan,
           nan, 8.7700,    nan, 5.1800, 6.5100,    nan, 5.6600,    nan,    nan,
        5.7173,    nan, 8.7000,    nan,    nan])


AttributeError: module 'time' has no attribute 'now'