In [1]:
from rdkit import Chem
import pandas as pd
import numpy as np
import math
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor, as_completed

import torch
from torch import nn
import torch.nn.functional as F

import torch_geometric as tg
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_add_pool

from sklearn.metrics import mean_absolute_error, root_mean_squared_error
from sklearn.model_selection import train_test_split

In [None]:
def one_hot_encoding_unk(value, choices):
    # One hot encoding with unknown value handling
    # If the value is in choices, it puts a 1 at the corresponding index
    # Otherwise, it puts a 1 at the last index (unknown)
    encoding = [0] * (len(choices) + 1)
    index = choices.index(value) if value in choices else -1
    encoding[index] = 1
    return encoding


def get_atom_features(atom):
    # Returns a feature list for the atom
    # Concatenates the one-hot encodings into a single list
    features = [
        one_hot_encoding_unk(atom.GetSymbol(), ['B','Be','Br','C','Cl','F','I','N','Nb','O','P','S','Se','Si','V','W']),
        one_hot_encoding_unk(atom.GetTotalDegree(), [0, 1, 2, 3, 4, 5]),
        one_hot_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0]),
        one_hot_encoding_unk(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4]),
        one_hot_encoding_unk(int(atom.GetHybridization()),[
                                                        Chem.rdchem.HybridizationType.SP,
                                                        Chem.rdchem.HybridizationType.SP2,
                                                        Chem.rdchem.HybridizationType.SP3,
                                                        Chem.rdchem.HybridizationType.SP3D,
                                                        Chem.rdchem.HybridizationType.SP3D2
                                                        ]),
        [1 if atom.GetIsAromatic() else 0],
        [atom.GetMass() * 0.01]
    ]
    return sum(features, []) # Flatten the list into a single list


def get_bond_features(bond):
    # Returns a one-hot encoded feature list for the bond
    bond_fdim = 7

    if bond is None:
        fbond = [1] + [0] * (bond_fdim - 1)
    else:
        bt = bond.GetBondType()
        fbond = [
            0,  # Zeroth index indicates if bond is None
            bt == Chem.rdchem.BondType.SINGLE,
            bt == Chem.rdchem.BondType.DOUBLE,
            bt == Chem.rdchem.BondType.TRIPLE,
            bt == Chem.rdchem.BondType.AROMATIC,
            (bond.GetIsConjugated() if bt is not None else 0),
            (bond.IsInRing() if bt is not None else 0)
        ]
    return fbond

In [None]:
class MolGraph:
    # Returns a custom molecular graph for a given SMILES string
    # Contains atom, bond features and node connectivity
    def __init__(self, smiles):
        self.smiles = smiles
        self.f_atoms = []
        self.f_bonds = []
        self.edge_index = []

        mol = Chem.MolFromSmiles(self.smiles)
        n_atoms=mol.GetNumAtoms()

        for atom_1 in range(n_atoms):
            self.f_atoms.append(get_atom_features(mol.GetAtomWithIdx(atom_1)))

            for atom_2 in range(atom_1 + 1, n_atoms):
                bond = mol.GetBondBetweenAtoms(atom_1, atom_2)
                if bond is None:
                    continue
                f_bond = get_bond_features(bond)
                self.f_bonds.append(f_bond)
                self.f_bonds.append(f_bond) # Bond features are added twice for both directions
                self.edge_index.extend([(atom_1, atom_2), (atom_2, atom_1)]) # Edge index list with tuples of connected nodes instead of adjacency matrix

In [None]:
class ChemDataset(Dataset):
    def __init__(self, smiles, labels, precompute=False):
        super(ChemDataset, self).__init__()
        self.smiles = smiles
        self.labels = labels
        self.cache = {}

        # Precomputing the dataset so the get method is faster, and the GPU doesn't have to wait for the CPU
        if precompute:
            print(f"Precomputing the dataset in parellel...")
            with ThreadPoolExecutor(max_workers=8) as executor:
                futures = [
                    executor.submit(self.process_key , idx)
                    for idx in range(len(self.smiles))
                ]

                for future in as_completed(futures):
                    future.result()

            print(f"Precomputation finished. {len(self.cache)} molecules cached.")

    def process_key(self, key):
        smiles = self.smiles[key]
        if smiles in self.cache.keys():
            mol = self.cache[smiles]
        else:
            molgraph = MolGraph(smiles)
            mol = self.molgraph2data(molgraph, key)
            self.cache[smiles] = mol
        return mol

    def molgraph2data(self, molgraph, key):
        data = tg.data.Data()

        # Coverting all features and labels to tensors
        # And adding it to the data object
        data.x = torch.tensor(molgraph.f_atoms, dtype=torch.float)
        data.edge_index = torch.tensor(molgraph.edge_index, dtype=torch.long).t().contiguous()
        data.edge_attr = torch.tensor(molgraph.f_bonds, dtype=torch.float)
        data.y = torch.tensor([self.labels[key]], dtype=torch.float)

        data.smiles = self.smiles[key]

        return data

    def get(self,key):
        return self.process_key(key)
    
    def __getitem__(self, key):
        return self.process_key(key)

    def len(self):
        return len(self.smiles)

    def __len__(self):
        return len(self.smiles)

In [None]:
def construct_loader(data_df, smi_column, target_column, shuffle=True, batch_size=64):    
    smiles = data_df[smi_column].values
    labels = data_df[target_column].values.astype(np.float32)  
    
    dataset = ChemDataset(smiles, labels, precompute=True)
    loader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=shuffle,
                            pin_memory=True
                       )
    return loader

In [None]:
class DMPNNConv(MessagePassing): # Extending the MessagePassing class from PyG
    def __init__(self, hidden_size):
        super(DMPNNConv, self).__init__(aggr='add') # Sum aggregation function
        self.lin = nn.Linear(hidden_size, hidden_size)

    def forward(self, edge_index, edge_attr):
        row, _ = edge_index
        # Since each edge is bidirectional, we do two message passings, one for each direction
        aggregated_message = self.propagate(edge_index, x=None, edge_attr=edge_attr)
        reversed_message = torch.flip(edge_attr.view(edge_attr.size(0) // 2, 2, -1), dims=[1]).view(edge_attr.size(0), -1)

        return aggregated_message, self.lin(aggregated_message[row] - reversed_message)

    def message(self, edge_attr):
        return edge_attr

In [None]:
class GNN(nn.Module):
    def __init__(self, num_node_features, num_edge_features):
        super(GNN, self).__init__()

        self.depth = 3
        self.hidden_size = 300
        self.dropout = 0.02

        # Encoder
        self.edge_init = nn.Linear(num_node_features + num_edge_features, self.hidden_size)
        self.convs = torch.nn.ModuleList()

        for _ in range(self.depth):
            self.convs.append(DMPNNConv(self.hidden_size))
            
        self.edge_to_node = nn.Linear(num_node_features + self.hidden_size, self.hidden_size)
        self.pool = global_add_pool

        # Prediction head
        self.ffn1 = nn.Linear(self.hidden_size, self.hidden_size)
        self.ffn2 = nn.Linear(self.hidden_size, 1)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # initial edge features
        row, _ = edge_index
        h_0 = F.relu(self.edge_init(torch.cat([x[row], edge_attr], dim=1)))
        h = h_0

        # convolutions
        for l in range(self.depth):
            _, h = self.convs[l](edge_index, h)
            h += h_0
            h = F.dropout(F.relu(h), self.dropout, training=self.training)

        # dmpnn edge -> node aggregation
        s, _ = self.convs[l](edge_index, h) #only use for summing
        try:
            q  = torch.cat([x,s], dim=1)
        except:
            print(data)
            print(data.smiles)
            print(data.x)
            q  = torch.cat([x,s], dim=1)
        h = F.relu(self.edge_to_node(q))

        return self.ffn2(F.dropout(F.relu(self.ffn1(self.pool(h, batch))))).squeeze(-1)

In [8]:
class Standardizer:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, x, rev=False):
        if rev:
            return (x * self.std) + self.mean
        return (x - self.mean) / self.std

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [10]:
def train_epoch(model, loader, optimizer, loss, stdzer):
    model.train()
    loss_all = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        out = model(data)
        result = loss(out, stdzer(data.y))
        result.backward()

        optimizer.step()
        loss_all += loss(stdzer(out, rev=True), data.y)

    return math.sqrt(loss_all / len(loader.dataset))


def pred(model, loader, loss, stdzer):
    model.eval()

    preds, ys = [], []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)
            pred = stdzer(out, rev=True)
            preds.extend(pred.cpu().detach().tolist())

    return preds


def train():
    torch.manual_seed(0)
    data_df = pd.read_csv("AqSolDBc.csv")
    #Drop single atoms    
    idx_single = [i for i,s in enumerate(data_df['SmilesCurated']) if Chem.MolFromSmiles(s).GetNumAtoms()==1 or '.' in s]
    data_df = data_df.drop(idx_single)
    if len(idx_single)>0:
        print(f"Removing {idx_single} due to single atoms")

        
    test_df = pd.read_csv("OChemUnseen.csv")
    #Drop single atoms
    idx_nonetype = [i for i,s in enumerate(test_df['SMILES']) if Chem.MolFromSmiles(s) is None] # Got an error for a SMILES which was None
    test_df = test_df.drop(idx_nonetype)
    if len(idx_nonetype)>0:
        print(f"Removing {idx_nonetype} due to Nonetypes")

    idx_single = [i for i,s in enumerate(test_df['SMILES']) if Chem.MolFromSmiles(s).GetNumAtoms()==1 or '.' in s]
    test_df = test_df.drop(idx_single)
    if len(idx_single)>0:
        print(f"Removing {idx_single} due to single atoms")

        
    train_df, val_df = train_test_split(data_df, test_size=0.2, random_state=0)
    train_loader = construct_loader(train_df, 'SmilesCurated', 'ExperimentalLogS', shuffle=True)
    val_loader = construct_loader(val_df, 'SmilesCurated', 'ExperimentalLogS', shuffle=False)
    test_loader = construct_loader(test_df, 'SMILES', 'LogS', shuffle=False)
    mean = np.mean(train_loader.dataset.labels)
    std = np.std(train_loader.dataset.labels)
    stdzer = Standardizer(mean, std)

    model = GNN(train_loader.dataset.num_node_features, train_loader.dataset.num_edge_features).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss = nn.MSELoss(reduction='sum').to(device)
    print(model)
    best_model = model
    best_val_loss = 1000000
    for epoch in range(0, 30):
        train_loss = train_epoch(model, train_loader, optimizer, loss, stdzer)
        preds = pred(model, val_loader, loss, stdzer)
        val_loss = root_mean_squared_error(preds,val_loader.dataset.labels)
        print("Epoch",epoch,"  Train RMSE", train_loss,"   Val RMSE", val_loss)
        if val_loss < best_val_loss:
            best_model = deepcopy(model)
            best_val_loss = val_loss

    preds = pred(best_model, test_loader, loss, stdzer)
    print("Test RMSE", root_mean_squared_error(preds,test_loader.dataset.labels))
    print("Test MAE", mean_absolute_error(preds,test_loader.dataset.labels))

In [None]:
train()

Removing [1263, 1444, 3605, 3702] due to single atoms


[14:24:10] Explicit valence for atom # 1 P, 6, is greater than permitted


Removing [667] due to Nonetypes
Removing [471, 503, 589, 591, 592, 593, 594, 610, 613, 641, 643, 647, 649, 652, 653, 654, 656, 658, 676, 681, 693, 744, 759, 763, 769, 773, 777, 807, 809, 811, 813, 869, 902, 969, 998] due to single atoms
Precomputing the dataset in parellel...
Precomputation finished. 6434 molecules cached.
Precomputing the dataset in parellel...
Precomputation finished. 1609 molecules cached.
Precomputing the dataset in parellel...
Precomputation finished. 2215 molecules cached.
GNN(
  (edge_init): Linear(in_features=51, out_features=300, bias=True)
  (convs): ModuleList(
    (0-2): 3 x DMPNNConv()
  )
  (edge_to_node): Linear(in_features=344, out_features=300, bias=True)
  (ffn1): Linear(in_features=300, out_features=300, bias=True)
  (ffn2): Linear(in_features=300, out_features=1, bias=True)
)
