In [105]:
from rdkit import Chem
import pandas as pd
import numpy as np
import math
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor, as_completed

import torch
from torch import nn
import torch.nn.functional as F

import torch_geometric as tg
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_add_pool

from sklearn.metrics import mean_absolute_error, root_mean_squared_error
from sklearn.model_selection import train_test_split

In [106]:
def one_hot_encoding_unk(value, choices):
    # One hot encoding with unknown value handling
    # If the value is in choices, it puts a 1 at the corresponding index
    # Otherwise, it puts a 1 at the last index (unknown)
    encoding = [0] * (len(choices) + 1)
    index = choices.index(value) if value in choices else -1
    encoding[index] = 1
    return encoding


def get_atom_features(atom):
    # Returns a feature list for the atom
    # Concatenates the one-hot encodings into a single list
    features = [
        one_hot_encoding_unk(atom.GetSymbol(), ['B','Be','Br','C','Cl','F','I','N','Nb','O','P','S','Se','Si','V','W']),
        one_hot_encoding_unk(atom.GetTotalDegree(), [0, 1, 2, 3, 4, 5]),
        one_hot_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0]),
        one_hot_encoding_unk(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4]),
        one_hot_encoding_unk(int(atom.GetHybridization()),[
                                                        Chem.rdchem.HybridizationType.SP,
                                                        Chem.rdchem.HybridizationType.SP2,
                                                        Chem.rdchem.HybridizationType.SP3,
                                                        Chem.rdchem.HybridizationType.SP3D,
                                                        Chem.rdchem.HybridizationType.SP3D2
                                                        ]),
        [1 if atom.GetIsAromatic() else 0],
        [atom.GetMass() * 0.01]
    ]
    return sum(features, []) # Flatten the list into a single list


def get_bond_features(bond):
    # Returns a one-hot encoded feature list for the bond
    bond_fdim = 7

    if bond is None:
        fbond = [1] + [0] * (bond_fdim - 1)
    else:
        bt = bond.GetBondType()
        fbond = [
            0,  # Zeroth index indicates if bond is None
            bt == Chem.rdchem.BondType.SINGLE,
            bt == Chem.rdchem.BondType.DOUBLE,
            bt == Chem.rdchem.BondType.TRIPLE,
            bt == Chem.rdchem.BondType.AROMATIC,
            (bond.GetIsConjugated() if bt is not None else 0),
            (bond.IsInRing() if bt is not None else 0)
        ]
    return fbond

In [107]:
class MolGraph:
    # Returns a custom molecular graph for a given SMILES string
    # Contains atom, bond features and node connectivity
    def __init__(self, smiles):
        self.smiles = smiles
        self.f_atoms = []
        self.f_bonds = []
        self.edge_index = []

        mol = Chem.MolFromSmiles(self.smiles)
        n_atoms=mol.GetNumAtoms()

        for atom_1 in range(n_atoms):
            self.f_atoms.append(get_atom_features(mol.GetAtomWithIdx(atom_1)))

            for atom_2 in range(atom_1 + 1, n_atoms):
                bond = mol.GetBondBetweenAtoms(atom_1, atom_2)
                if bond is None:
                    continue
                f_bond = get_bond_features(bond)
                self.f_bonds.append(f_bond)
                self.f_bonds.append(f_bond) # Bond features are added twice for both directions
                self.edge_index.extend([(atom_1, atom_2), (atom_2, atom_1)]) # Edge index list with tuples of connected nodes instead of adjacency matrix

In [108]:
class ChemDataset(Dataset):
    def __init__(self, smiles, labels, noise_std=0.1, precompute=True):
        super(ChemDataset, self).__init__()
        self.smiles = smiles
        self.labels = labels
        self.cache = {}
        self.noise_std = noise_std
        self.precompute = precompute

        # Precomputing the dataset so the get method is faster, and the GPU doesn't have to wait for the CPU
        if precompute:
            print(f"Precomputing data...")
            with ThreadPoolExecutor(max_workers=8) as executor:
                futures = [
                    executor.submit(self.process_key , idx)
                    for idx in range(len(self.smiles))
                ]

                for future in as_completed(futures):
                    future.result()

            print(f"Precomputation finished. {len(self.cache)} molecules cached.")

    def process_key(self, key):
        smiles = self.smiles[key]
        if smiles in self.cache.keys():
            mol = self.cache[smiles]
        else:
            molgraph = MolGraph(smiles)
            mol = self.molgraph2data(molgraph, key)
            self.cache[smiles] = mol
        return mol

    def molgraph2data(self, molgraph, key):
        data = tg.data.Data()

        # Coverting all features and labels to tensors
        # And adding it to the data object
        data.x = torch.tensor(molgraph.f_atoms, dtype=torch.float)
        data.edge_index = torch.tensor(molgraph.edge_index, dtype=torch.long).t().contiguous()
        data.edge_attr = torch.tensor(molgraph.f_bonds, dtype=torch.float)
        data.y = torch.tensor([self.labels[key]], dtype=torch.float)
        data.smiles = self.smiles[key]

        if self.noise_std > 0:
            x_noisy = data.x + torch.randn_like(data.x) * self.noise_std
            data.x_noisy = x_noisy

        return data

    def get(self,key):
        return self.process_key(key)
    
    def __getitem__(self, key):
        return self.process_key(key)

    def len(self):
        return len(self.smiles)

    def __len__(self):
        return len(self.smiles)

In [109]:
def construct_loader(data_df, smi_column, target_column, shuffle=True, batch_size=64):    
    smiles = data_df[smi_column].values
    labels = data_df[target_column].values.astype(np.float32)  
    
    dataset = ChemDataset(smiles, labels, precompute=True)
    loader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=shuffle,
                            pin_memory=True
                       )
    return loader

In [110]:
class DMPNNConv(MessagePassing): # Extending the MessagePassing class from PyG
    def __init__(self, hidden_size):
        super(DMPNNConv, self).__init__(aggr='add') # Sum aggregation function
        self.lin = nn.Linear(hidden_size, hidden_size)

    def forward(self, edge_index, edge_attr):
        row, _ = edge_index
        # Since each edge is bidirectional, we do two message passings, one for each direction
        aggregated_message = self.propagate(edge_index, x=None, edge_attr=edge_attr)
        reversed_message = torch.flip(edge_attr.view(edge_attr.size(0) // 2, 2, -1), dims=[1]).view(edge_attr.size(0), -1)

        return aggregated_message, self.lin(aggregated_message[row] - reversed_message)

    def message(self, edge_attr):
        return edge_attr

In [111]:
class GNNEncoder(nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden_size, mode, depth, dropout):
        super().__init__()
        self.depth = depth
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.mode = mode

        # Encoder layers
        self.edge_init = nn.Linear(num_node_features + num_edge_features, hidden_size)
        self.convs = nn.ModuleList([DMPNNConv(hidden_size) for _ in range(depth)])
        self.edge_to_node = nn.Linear(num_node_features + hidden_size, hidden_size)
        self.pool = global_add_pool  # Not learnable

    def forward(self, data):
        edge_index, edge_attr, batch = data.edge_index, data.edge_attr, data.batch
        
        if self.mode == 'denoise':
            x = data.x_noisy
        elif self.mode == 'predict':
            x = data.x
        else:
            raise ValueError("Invalid mode. Choose 'denoise' or 'predict'.")

        # Edge initialization
        row, _ = edge_index
        h_0 = F.relu(self.edge_init(torch.cat([x[row], edge_attr], dim=1)))
        h = h_0

        # DMPNN Conv layers
        for layer in self.convs:
            _, h = layer(edge_index, h)
            h += h_0
            h = F.dropout(F.relu(h), self.dropout, training=self.training)

        # Edge to node aggregation
        # Re-using the last layer's results for s
        s, _ = self.convs[-1](edge_index, h)
        q = torch.cat([x, s], dim=1)
        h = F.relu(self.edge_to_node(q))

        # Global pooling for the final node embeddings
        return self.pool(h, batch) # Embedding is size (batch_size, hidden_size)

In [112]:
class GNNDecoder(nn.Module):
    # Decoder for self-supervised denoising task
    # Only reconstructing the node features from the embeddings
    def __init__(self, hidden_size, num_node_features):
        super().__init__()
        self.lin1 = nn.Linear(hidden_size, hidden_size)
        self.lin2 = nn.Linear(hidden_size, num_node_features)

    def forward(self, graph_embedding, batch):
        # The embedding size is (batch_size, hidden_size)
        # We need to expand it per node. First, get how many nodes per graph:
        batch_size = graph_embedding.size(0)
        node_counts = torch.bincount(batch)  # number of nodes in each graph

        # Expand each graph embedding "node_counts[g]" times
        expanded = []
        for g in range(batch_size):
            repeated = graph_embedding[g].unsqueeze(0).repeat(node_counts[g], 1)  # shape = (num_nodes_in_g, hidden_size)
            expanded.append(repeated)

        # Concatenate along node dimension
        expanded = torch.cat(expanded, dim=0)  # total_nodes x hidden_size

        # Decode the node features
        x_hat = F.relu(self.lin1(expanded))
        x_hat = self.lin2(x_hat)  # shape = (total_nodes, num_node_features)

        return x_hat

In [113]:
class GNNHead(nn.Module):
    def __init__(self, hidden_size, dropout):
        super().__init__()
        self.ffn1 = nn.Linear(hidden_size, hidden_size)
        self.ffn2 = nn.Linear(hidden_size, 1)
        self.dropout = dropout

    def forward(self, graph_embedding):
        x = F.relu(self.ffn1(graph_embedding))
        x = F.dropout(x, self.dropout, training=self.training)
        return self.ffn2(x).squeeze(-1)

In [None]:
class GNN(nn.Module):
    def __init__(self, num_node_features: int, num_edge_features: int, hidden_size: int=300, depth: int=3, mode: str='denoise', dropout: float=0.02):
        super().__init__()
        self.encoder = GNNEncoder(num_node_features, num_edge_features, hidden_size=hidden_size, mode=mode, depth=depth, dropout=dropout)
        self.head = GNNHead(hidden_size=hidden_size, dropout=dropout)
        self.decoder = GNNDecoder(hidden_size=hidden_size, num_node_features=num_node_features)
        self.mode = mode

    def set_mode(self, mode: str):
        # Update the mode in the encoder
        # So the encoder knows if it needs to read noisy or noise-free data
        self.encoder.mode = mode
        self.mode = mode

    def forward(self, data):
        graph_embedding = self.encoder(data)

        if self.encoder.mode == 'predict':
            prediction = self.head(graph_embedding)
            return prediction
        
        elif self.encoder.mode == 'denoise':
            node_features = self.decoder(graph_embedding, data.batch)
            return node_features
    
        else:
            raise ValueError("Invalid mode. Choose 'predict' or 'denoise'.")

In [115]:
class Standardizer:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, x, rev=False):
        if rev:
            return (x * self.std) + self.mean
        return (x - self.mean) / self.std

In [116]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [None]:
def train_epoch(model, loader, optimizer, loss, mode, stdzer=None):
    if mode == 'denoise':
        model.set_mode('denoise')
        model.train()
        loss_all = 0

        for data in loader:
            data = data.to(device)
            optimizer.zero_grad()

            out = model(data)
            result = loss(out, data.x)
            result.backward()

            optimizer.step()
            loss_all += loss(out, data.x)

        return math.sqrt(loss_all / len(loader.dataset))

    elif mode == 'predict':
        model.set_mode('predict')
        model.train()
        loss_all = 0

        for data in loader:
            data = data.to(device)
            optimizer.zero_grad()

            out = model(data)
            result = loss(out, stdzer(data.y))
            result.backward()

            optimizer.step()
            loss_all += loss(stdzer(out, rev=True), data.y)

        return math.sqrt(loss_all / len(loader.dataset))
    
    else:
        raise ValueError("Invalid mode. Choose 'denoise' or 'predict'.")


def pred(model, loader, mode, stdzer=None):
    if mode == 'denoise':
        model.eval()

        preds = []
        with torch.no_grad():
            for data in loader:
                data = data.to(device)
                out = model(data)
                preds.extend(out.cpu().detach().flatten().tolist())
                
        return preds

    elif mode == 'predict':
        model.set_mode('predict')
        model.eval()

        preds = []
        with torch.no_grad():
            for data in loader:
                data = data.to(device)
                out = model(data)
                pred = stdzer(out, rev=True)
                preds.extend(pred.cpu().detach().tolist())

        return preds
    
    else:
        raise ValueError("Invalid mode. Choose 'denoise' or 'predict'.")


In [118]:
torch.manual_seed(0)
data_df = pd.read_csv("AqSolDBc.csv")
# Drop single atoms    
idx_single = [i for i,s in enumerate(data_df['SmilesCurated']) if Chem.MolFromSmiles(s).GetNumAtoms()==1 or '.' in s]
data_df = data_df.drop(idx_single)
if len(idx_single) > 0:
    print(f"Removing {idx_single} due to single atoms")

test_df = pd.read_csv("OChemUnseen.csv")
# Drop single atoms
idx_nonetype = [i for i,s in enumerate(test_df['SMILES']) if Chem.MolFromSmiles(s) is None] # Got an error for a SMILES which was None
test_df = test_df.drop(idx_nonetype)
if len(idx_nonetype) > 0:
    print(f"Removing {idx_nonetype} due to Nonetypes")

idx_single = [i for i,s in enumerate(test_df['SMILES']) if Chem.MolFromSmiles(s).GetNumAtoms()==1 or '.' in s]
test_df = test_df.drop(idx_single)
if len(idx_single) > 0:
    print(f"Removing {idx_single} due to single atoms")

Removing [1263, 1444, 3605, 3702] due to single atoms


[18:04:08] Explicit valence for atom # 1 P, 6, is greater than permitted


Removing [667] due to Nonetypes
Removing [471, 503, 589, 591, 592, 593, 594, 610, 613, 641, 643, 647, 649, 652, 653, 654, 656, 658, 676, 681, 693, 744, 759, 763, 769, 773, 777, 807, 809, 811, 813, 869, 902, 969, 998] due to single atoms


In [119]:
batch_size = 256

train_df, val_df = train_test_split(data_df, test_size=0.1, random_state=0)
train_loader = construct_loader(train_df, 'SmilesCurated', 'ExperimentalLogS', shuffle=True, batch_size=batch_size)
val_loader = construct_loader(val_df, 'SmilesCurated', 'ExperimentalLogS', shuffle=False, batch_size=batch_size)
test_loader = construct_loader(test_df, 'SMILES', 'LogS', shuffle=False, batch_size=batch_size)
print(f"Train size: {len(train_loader.dataset)}, Val size: {len(val_loader.dataset)}, Test size: {len(test_loader.dataset)}")

# Normalizer for the labels
mean = np.mean(train_loader.dataset.labels)
std = np.std(train_loader.dataset.labels)
stdzer = Standardizer(mean, std)

Precomputing data...
Precomputation finished. 7238 molecules cached.
Precomputing data...
Precomputation finished. 805 molecules cached.
Precomputing data...
Precomputation finished. 2215 molecules cached.
Train size: 7238, Val size: 805, Test size: 2215


In [120]:
model = GNN(train_loader.dataset.num_node_features, train_loader.dataset.num_edge_features).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss = nn.MSELoss(reduction='sum').to(device)
print('\n', model, '\n')


 GNN(
  (encoder): GNNEncoder(
    (edge_init): Linear(in_features=51, out_features=300, bias=True)
    (convs): ModuleList(
      (0-2): 3 x DMPNNConv()
    )
    (edge_to_node): Linear(in_features=344, out_features=300, bias=True)
  )
  (head): GNNHead(
    (ffn1): Linear(in_features=300, out_features=300, bias=True)
    (ffn2): Linear(in_features=300, out_features=1, bias=True)
  )
  (decoder): GNNDecoder(
    (lin1): Linear(in_features=300, out_features=300, bias=True)
    (lin2): Linear(in_features=300, out_features=44, bias=True)
  )
) 



In [121]:
# Unfreeze the encoder
for param in model.encoder.parameters():
    param.requires_grad = True

# Unfreeze the decoder
for param in model.decoder.parameters():
    param.requires_grad = True

# Freeze the prediction head
for param in model.head.parameters():
    param.requires_grad = False


best_denoise_model = model
best_denoise_val_loss = 1000000

for epoch in range(0, 20):
    train_loss = train_epoch(model, train_loader, optimizer, loss, mode='denoise')
    preds = pred(model, val_loader, mode='denoise')
    node_feature_targets = [feature for batch in val_loader for feature in batch.x.cpu().flatten().tolist()]
    val_loss = root_mean_squared_error(preds, node_feature_targets)
    print(f"Epoch {epoch}  Train RMSE: {train_loss:.3g}   Val RMSE: {val_loss:.3g}")

    if val_loss < best_denoise_val_loss:
        best_denoise_model = deepcopy(model)
        best_denoise_val_loss = val_loss

Epoch 0  Train RMSE: 12.3   Val RMSE: 0.258
Epoch 1  Train RMSE: 7.03   Val RMSE: 0.245
Epoch 2  Train RMSE: 6.64   Val RMSE: 0.238
Epoch 3  Train RMSE: 6.39   Val RMSE: 0.234
Epoch 4  Train RMSE: 6.25   Val RMSE: 0.238
Epoch 5  Train RMSE: 6.21   Val RMSE: 0.225
Epoch 6  Train RMSE: 6.1   Val RMSE: 0.223
Epoch 7  Train RMSE: 6.07   Val RMSE: 0.22
Epoch 8  Train RMSE: 6.01   Val RMSE: 0.22
Epoch 9  Train RMSE: 5.99   Val RMSE: 0.219
Epoch 10  Train RMSE: 5.97   Val RMSE: 0.219
Epoch 11  Train RMSE: 5.97   Val RMSE: 0.22
Epoch 12  Train RMSE: 5.95   Val RMSE: 0.218
Epoch 13  Train RMSE: 5.93   Val RMSE: 0.217
Epoch 14  Train RMSE: 5.93   Val RMSE: 0.217
Epoch 15  Train RMSE: 5.92   Val RMSE: 0.217
Epoch 16  Train RMSE: 5.92   Val RMSE: 0.218
Epoch 17  Train RMSE: 5.92   Val RMSE: 0.217
Epoch 18  Train RMSE: 5.9   Val RMSE: 0.217
Epoch 19  Train RMSE: 5.9   Val RMSE: 0.216


In [122]:
# Freeze the encoder
for param in model.encoder.parameters():
    param.requires_grad = False

# Freeze the decoder
for param in model.decoder.parameters():
    param.requires_grad = False

# unfreeze the prediction head
for param in model.head.parameters():
    param.requires_grad = True

best_pred_model = best_denoise_model
best_val_loss = 1000000

for epoch in range(0, 30):
    train_loss = train_epoch(model, train_loader, optimizer, loss, mode='predict', stdzer=stdzer)
    preds = pred(model, val_loader, mode='predict', stdzer=stdzer)
    val_loss = root_mean_squared_error(preds, val_loader.dataset.labels)
    print(f"Epoch {epoch}  Train RMSE: {train_loss:.3g}   Val RMSE: {val_loss:.3g}")
    if val_loss < best_val_loss:
        best_pred_model = deepcopy(model)
        best_val_loss = val_loss

preds = pred(best_pred_model, test_loader, mode='predict', stdzer=stdzer)  
print(f"Test RMSE: {root_mean_squared_error(preds, test_loader.dataset.labels):.3g}")
print(f"Test MAE: {mean_absolute_error(preds, test_loader.dataset.labels):.3g}")

predict
Epoch 0  Train RMSE: 1.85   Val RMSE: 1.48
predict
Epoch 1  Train RMSE: 1.43   Val RMSE: 1.34
predict
Epoch 2  Train RMSE: 1.32   Val RMSE: 1.27
predict
Epoch 3  Train RMSE: 1.27   Val RMSE: 1.39
predict
Epoch 4  Train RMSE: 1.23   Val RMSE: 1.22
predict
Epoch 5  Train RMSE: 1.19   Val RMSE: 1.15
predict
Epoch 6  Train RMSE: 1.15   Val RMSE: 1.12
predict
Epoch 7  Train RMSE: 1.11   Val RMSE: 1.11
predict
Epoch 8  Train RMSE: 1.1   Val RMSE: 1.11
predict
Epoch 9  Train RMSE: 1.08   Val RMSE: 1.11
predict
Epoch 10  Train RMSE: 1.08   Val RMSE: 1.13
predict
Epoch 11  Train RMSE: 1.07   Val RMSE: 1.07
predict
Epoch 12  Train RMSE: 1.05   Val RMSE: 1.07
predict
Epoch 13  Train RMSE: 1.04   Val RMSE: 1.05
predict
Epoch 14  Train RMSE: 1.02   Val RMSE: 1.07
predict
Epoch 15  Train RMSE: 1.04   Val RMSE: 1.21
predict
Epoch 16  Train RMSE: 1.05   Val RMSE: 1.13
predict
Epoch 17  Train RMSE: 1.02   Val RMSE: 1.11
predict
Epoch 18  Train RMSE: 1   Val RMSE: 1.05
predict
Epoch 19  Train RM