In [1]:
import numpy as np
import pandas as pd
import duckdb
from rdkit import Chem
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import MessagePassing, global_max_pool
from torch.nn import BCEWithLogitsLoss
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from typing import List

In [2]:
train_path = 'train.parquet'
test_path = 'test.parquet'

con = duckdb.connect()

In [3]:
hsa_df = con.query(f"""(SELECT *
                        FROM parquet_scan('{train_path}')
                        WHERE protein_name = 'HSA' AND binds = 1)
                        UNION ALL
                        (SELECT *
                        FROM parquet_scan('{train_path}')
                        WHERE protein_name = 'HSA' AND binds = 0
                        ORDER BY random()
                        LIMIT 408410)
                        """).df()

brd4_df = con.query(f"""(SELECT *
                        FROM parquet_scan('{train_path}')
                        WHERE protein_name = 'BRD4' AND binds = 1)
                        UNION ALL
                        (SELECT *
                        FROM parquet_scan('{train_path}')
                        WHERE protein_name = 'BRD4' AND binds = 0
                        ORDER BY random()
                        LIMIT 456964)
                        """).df()

seh_df = con.query(f"""(SELECT *
                        FROM parquet_scan('{train_path}')
                        WHERE protein_name = 'sEH' AND binds = 1)
                        UNION ALL
                        (SELECT *
                        FROM parquet_scan('{train_path}')
                        WHERE protein_name = 'sEH' AND binds = 0
                        ORDER BY random()
                        LIMIT 724532)
                        """).df()


In [4]:
print(hsa_df['binds'].value_counts())
print(brd4_df['binds'].value_counts())
print(seh_df['binds'].value_counts())

binds
1    408410
0    408410
Name: count, dtype: int64
binds
1    456964
0    456964
Name: count, dtype: int64
binds
1    724532
0    724532
Name: count, dtype: int64


In [5]:
train_hsa, test_hsa = train_test_split(hsa_df, test_size=0.2, random_state=42, shuffle=True)
train_brd4, test_brd4 = train_test_split(brd4_df, test_size=0.2, random_state=42, shuffle=True)
train_seh, test_seh = train_test_split(seh_df, test_size=0.2, random_state=42, shuffle=True)
#remember to separate the labels from the test dfs after featurization

In [6]:
test_hsa_x = test_hsa.drop(columns= ['binds'])
test_hsa_y = test_hsa['binds']

test_brd4_x = test_brd4.drop(columns= ['binds'])
test_brd4_y = test_brd4['binds']

test_seh_x = test_seh.drop(columns= ['binds'])
test_seh_y = test_seh['binds']

In [7]:
hsa_data = {
    'train': train_hsa,
    'test_x': test_hsa_x,
}

brd4_data = {
    'train': train_brd4,
    'test_x': test_brd4_x,
}

seh_data = {
    'train': train_seh,
    'test_x': test_seh_x,
}

In [9]:
def get_torch_data_object(smiles, ids, labels=None):

    def _one_hot_encoding(element, permitted_elements):
        """
        Maps input elements element which are not in the permitted list to the last element of the permitted list
        """
        if element not in permitted_elements:
            element = permitted_elements[-1]

        binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: element==s , permitted_elements))]

        return binary_encoding


    def _get_atom_features(atom):
        
        #Define a simplified list of atom types
        permitted_atom_types = [
            'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg',
            'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl',
            'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H',
            'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr',
            'Pt', 'Hg', 'Pb', 'Dy', 'Unknown'
        ]
        atom_type = atom.GetSymbol() if atom.GetSymbol() in permitted_atom_types else 'Unknown'
        atom_type_enc = _one_hot_encoding(atom_type, permitted_atom_types)

        hybridization_type = [
            Chem.rdchem.HybridizationType.S,
            Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D
        ]
        atom_hybridization_type = _one_hot_encoding(atom.GetHybridization(), hybridization_type)

        atom_degree = _one_hot_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5])

        is_in_ring = [int(atom.IsInRing())]

        total_hs = _one_hot_encoding(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5])

        implicit_valence = _one_hot_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5])

        chirality = _one_hot_encoding(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"])
        
        atom_features = atom_type_enc + atom_degree + is_in_ring + total_hs + implicit_valence + atom_hybridization_type + chirality
        
        return np.array(atom_features, dtype=np.float32)


    def _get_bond_features(bond):

        bond_type = bond.GetBondType()

        features = [
            int(bond_type == Chem.rdchem.BondType.SINGLE),
            int(bond_type == Chem.rdchem.BondType.DOUBLE),
            int(bond_type == Chem.rdchem.BondType.TRIPLE),
            int(bond_type == Chem.rdchem.BondType.AROMATIC),
            int(bond.IsInRing()),
            int(bond.GetIsConjugated()),
        ]
        
        return np.array(features, dtype=np.float32)
    

    data_list = []
    
    for index, smile in enumerate(smiles):
        mol = Chem.MolFromSmiles(smile)
        
        if not mol:  # Skip invalid SMILES strings
            continue
        
        # Node features
        atom_features = [_get_atom_features(atom) for atom in mol.GetAtoms()]
        x = torch.tensor(atom_features, dtype=torch.float)
        
        # Edge features
        edge_index = []
        edge_features = []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_index += [(start, end), (end, start)]  # Undirected graph
            bond_feature = _get_bond_features(bond)
            edge_features += [bond_feature, bond_feature]  # Same features in both directions
        
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_features, dtype=torch.float)
        
        # Creating the Data object
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        data.molecule_id = ids[index]

        if labels is not None:
            data.y = torch.tensor([labels[index]], dtype=torch.float)
        
        data_list.append(data)
    
    return data_list


def featurize_data_in_batches(ids_array, smiles_array, labels_array, batch_size = 2**10):
    data_list = []
    # Define tqdm progress bar
    pbar = tqdm(total=len(smiles_array), desc="Featurizing data")
    for i in range(0, len(smiles_array), batch_size):
        smiles_batch = smiles_array[i:i+batch_size]
        ids_batch = ids_array[i:i+batch_size]
        labels_batch = labels_array[i:i+batch_size] if labels_array is not None else None
        batch_data_list = get_torch_data_object(smiles_batch, ids_batch, labels_batch)
        data_list.extend(batch_data_list)
        pbar.update(len(smiles_batch))
        
    pbar.close()
    return data_list


def get_featurized_data(protein_data: dict):

    featurized_data = {}

    for key in protein_data.keys():
        if 'molecule_smiles' in protein_data[key]:
            smiles_array = np.array(protein_data[key]['molecule_smiles'])
            ids_array = np.array(protein_data[key]['id'])

            labels_array = None
            
            if 'binds' in protein_data[key]:
                labels_array = np.array(protein_data[key]['binds'])
        
        featurized_data[key] = featurize_data_in_batches(ids_array, smiles_array, labels_array)
        
    return featurized_data

In [10]:
hsa_featurized_data = get_featurized_data(hsa_data)
brd4_featurized_data = get_featurized_data(brd4_data)
seh_featurized_data = get_featurized_data(seh_data)

  x = torch.tensor(atom_features, dtype=torch.float)
Featurizing data: 100%|██████████| 653456/653456 [14:02<00:00, 775.79it/s]
Featurizing data: 100%|██████████| 163364/163364 [03:28<00:00, 783.43it/s]
Featurizing data: 100%|██████████| 731142/731142 [15:50<00:00, 768.93it/s]
Featurizing data: 100%|██████████| 182786/182786 [03:59<00:00, 762.37it/s]
Featurizing data: 100%|██████████| 1159251/1159251 [24:50<00:00, 778.02it/s]
Featurizing data: 100%|██████████| 289813/289813 [06:05<00:00, 793.66it/s]


In [11]:
torch.save(hsa_featurized_data, 'hsa_featurized_data')
torch.save(brd4_featurized_data, 'brd4_featurized_data')
torch.save(seh_featurized_data, 'seh_featurized_data')

In [12]:
class CustomGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(CustomGNNLayer, self).__init__(aggr='max')
        self.lin = nn.Linear(in_channels + 6, out_channels)

    def forward(self, x, edge_index, edge_attr):
        # Start propagating messages
        return MessagePassing.propagate(self, edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        combined = torch.cat((x_j, edge_attr), dim=1)
        return combined

    def update(self, aggr_out):
        return self.lin(aggr_out)


class GNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout_rate):
        super(GNNModel, self).__init__()
        self.num_layers = num_layers
        self.convs = nn.ModuleList([CustomGNNLayer(input_dim if i == 0 else hidden_dim, hidden_dim) for i in range(num_layers)])
        self.dropout = nn.Dropout(dropout_rate)
        self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(num_layers)])
        self.lin = nn.Linear(hidden_dim, 1)
        
    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index, edge_attr)
            x = self.bns[i](x)
            x = F.relu(x)
            x = self.dropout(x)


        x = global_max_pool(x, data.batch) # Global pooling to get a graph-level representation
        x = self.lin(x)
        return x

In [13]:
def train_model(
        loader,
        num_epochs,
        input_dim,
        hidden_dim,
        num_layers,
        dropout_rate,
        lr,
        save_path
    ):
    model = GNNModel(input_dim, hidden_dim, num_layers, dropout_rate)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch in loader:
            optimizer.zero_grad()
            out = model(batch)
            loss = criterion(out, batch.y.view(-1,1).float())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(loader)}')
    
    torch.save(model, save_path)

def predict_with_model(model, test_loader):
    #model.eval()
    predictions = []
    molecule_ids = []

    with torch.no_grad():
        for data in test_loader:
            output = torch.sigmoid(model(data))
            predictions.extend(output.view(-1).tolist())
            molecule_ids.extend(data.molecule_id)
    
    return molecule_ids, predictions

In [14]:
# Create DataLoaders for the current protein
hsa_train_loader = DataLoader(hsa_featurized_data['train'], batch_size=32, shuffle=True)
hsa_test_loader = DataLoader(hsa_featurized_data['test_x'], batch_size=32, shuffle=False)

In [15]:
# Train model
input_dim = hsa_train_loader.dataset[0].num_node_features
hidden_dim = 64
num_epochs = 11
num_layers = 4 #can be modified
dropout_rate = 0.3
lr = 0.001
save_path = 'hsa_trained_model.pt'
train_model(hsa_train_loader,num_epochs, input_dim, hidden_dim,num_layers, dropout_rate, lr, save_path)

Epoch 1/11, Loss: 0.41445273210845784
Epoch 2/11, Loss: 0.37483675144753253
Epoch 3/11, Loss: 0.3640241125902013
Epoch 4/11, Loss: 0.35805150209816844
Epoch 5/11, Loss: 0.354398698350935
Epoch 6/11, Loss: 0.35250538794950337
Epoch 7/11, Loss: 0.3508440396946824
Epoch 8/11, Loss: 0.34916548866403363
Epoch 9/11, Loss: 0.34850780542956866
Epoch 10/11, Loss: 0.3480572686311512
Epoch 11/11, Loss: 0.34701293485138723


In [16]:
model = torch.load('hsa_trained_model.pt')
type(model)

  model = torch.load('hsa_trained_model.pt')


__main__.GNNModel

In [17]:
# Predict
molecule_ids, predictions = predict_with_model(model, hsa_test_loader)

# Collect predictions
hsa_predictions = pd.DataFrame({
    'id': molecule_ids,
    'binds': predictions,
})

In [21]:
#Evaluate predictions 
BINDING_THRESHOLD = 0.8
testhsax = hsa_predictions['binds'].apply(lambda x: 1 if x >= BINDING_THRESHOLD else 0)

print(classification_report(testhsax, test_hsa_y))

              precision    recall  f1-score   support

           0       0.95      0.75      0.84    103957
           1       0.68      0.94      0.79     59407

    accuracy                           0.82    163364
   macro avg       0.82      0.84      0.81    163364
weighted avg       0.85      0.82      0.82    163364



In [22]:
hsa_predictions['binds']

0         0.996935
1         0.484345
2         0.917875
3         0.780834
4         0.988841
            ...   
163359    0.552897
163360    0.774370
163361    0.936678
163362    0.976615
163363    0.318769
Name: binds, Length: 163364, dtype: float64

In [23]:
# Train model
input_dim = hsa_train_loader.dataset[0].num_node_features
hidden_dim = 64
num_epochs = 11
num_layers = 4 #can be modified
dropout_rate = 0.3
lr = 0.003
save_path = 'hsa_trained_model_2.pt'
train_model(hsa_train_loader,num_epochs, input_dim, hidden_dim,num_layers, dropout_rate, lr, save_path)

KeyboardInterrupt: 

In [68]:
model_2 = torch.load('hsa_trained_model_2.pt')
# Predict
molecule_ids, predictions = predict_with_model(model, hsa_test_loader)

# Collect predictions
hsa_predictions = pd.DataFrame({
    'id': molecule_ids,
    'binds': predictions,
})

#Evaluate predictions 
BINDING_THRESHOLD = 0.8
hsa_predictions['binds'] = hsa_predictions['binds'].apply(lambda x: 1 if x > BINDING_THRESHOLD else 0)

print(classification_report(hsa_predictions['binds'], test_hsa_y))

  model_2 = torch.load('hsa_trained_model_2.pt')


              precision    recall  f1-score   support

           0       0.96      0.72      0.82    109244
           1       0.63      0.94      0.75     54120

    accuracy                           0.79    163364
   macro avg       0.79      0.83      0.79    163364
weighted avg       0.85      0.79      0.80    163364



In [24]:
# Create DataLoaders for the current protein
brd4_train_loader = DataLoader(brd4_featurized_data['train'], batch_size=32, shuffle=True)
brd4_test_loader = DataLoader(brd4_featurized_data['test_x'], batch_size=32, shuffle=False)

In [25]:
# Train model
input_dim = brd4_train_loader.dataset[0].num_node_features
hidden_dim = 64
num_epochs = 11
num_layers = 4 #can be modified
dropout_rate = 0.3
lr = 0.001
save_path = 'brd4_trained_model_2.pt'
train_model(brd4_train_loader,num_epochs, input_dim, hidden_dim,num_layers, dropout_rate, lr, save_path)

Epoch 1/11, Loss: 0.3241531908028866
Epoch 2/11, Loss: 0.2774880223373698
Epoch 3/11, Loss: 0.2658897035027368
Epoch 4/11, Loss: 0.259332908411439
Epoch 5/11, Loss: 0.25706945616999566
Epoch 6/11, Loss: 0.2554396822979463
Epoch 7/11, Loss: 0.25418353400633076
Epoch 8/11, Loss: 0.2530095343909223
Epoch 9/11, Loss: 0.25223422029232023
Epoch 10/11, Loss: 0.2517589669030792
Epoch 11/11, Loss: 0.25102477654184735


In [31]:
model_6 = torch.load('brd4_trained_model_2.pt')
# Predict
molecule_ids, predictions = predict_with_model(model_6, brd4_test_loader)

# Collect predictions
brd4_predictions = pd.DataFrame({
    'id': molecule_ids,
    'binds': predictions,
})


  model_6 = torch.load('brd4_trained_model_2.pt')


In [32]:
#Evaluate predictions 
BINDING_THRESHOLD = 0.7
brd4_predictions_binds = brd4_predictions['binds'].apply(lambda x: 1 if x > BINDING_THRESHOLD else 0)

print(classification_report(brd4_predictions_binds, test_brd4_y))

              precision    recall  f1-score   support

           0       0.95      0.83      0.89    104358
           1       0.81      0.95      0.87     78428

    accuracy                           0.88    182786
   macro avg       0.88      0.89      0.88    182786
weighted avg       0.89      0.88      0.88    182786



In [33]:
# Create DataLoaders for the current protein
seh_train_loader = DataLoader(seh_featurized_data['train'], batch_size=32, shuffle=True)
seh_test_loader = DataLoader(seh_featurized_data['test_x'], batch_size=32, shuffle=False)

In [35]:
# Train model
input_dim = seh_train_loader.dataset[0].num_node_features
hidden_dim = 64
num_epochs = 11
num_layers = 4 #can be modified
dropout_rate = 0.3
lr = 0.001
save_path = 'seh_trained_model.pt'
train_model(seh_train_loader,num_epochs, input_dim, hidden_dim,num_layers, dropout_rate, lr, save_path)

Epoch 1/11, Loss: 0.19211339599879126
Epoch 2/11, Loss: 0.16134811337898122
Epoch 3/11, Loss: 0.15541458889830928
Epoch 4/11, Loss: 0.15288605839218689
Epoch 5/11, Loss: 0.15138296992422245
Epoch 6/11, Loss: 0.15018682266923408
Epoch 7/11, Loss: 0.1487006692093019
Epoch 8/11, Loss: 0.14760054968567723
Epoch 9/11, Loss: 0.14777189984626987
Epoch 10/11, Loss: 0.14661013126019756
Epoch 11/11, Loss: 0.146275563309343


In [38]:
model_5 = torch.load('seh_trained_model.pt')
# Predict
molecule_ids, predictions = predict_with_model(model_5, seh_test_loader)

# Collect predictions
seh_predictions = pd.DataFrame({
    'id': molecule_ids,
    'binds': predictions,
})

  model_5 = torch.load('seh_trained_model.pt')


In [41]:
#Evaluate predictions 
BINDING_THRESHOLD = 0.6
seh_predictions_binds = seh_predictions['binds'].apply(lambda x: 1 if x > BINDING_THRESHOLD else 0)

print(classification_report(seh_predictions_binds, test_seh_y))

              precision    recall  f1-score   support

           0       0.96      0.93      0.95    149851
           1       0.93      0.96      0.94    139962

    accuracy                           0.94    289813
   macro avg       0.94      0.94      0.94    289813
weighted avg       0.94      0.94      0.94    289813

