In [1]:
import pandas as pd
import numpy as np
import rdkit
from rdkit import Chem
import torch
from torch.nn import functional as F
from torch import nn
from torch_geometric.nn import GCNConv,TopKPooling,global_mean_pool
from torch_geometric.nn import global_mean_pool as gmp, global_max_pool as gap

import sklearn
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader,random_split
import tqdm
from torch_geometric.data import DataLoader, Data

In [2]:
df = pd.read_csv("drug.csv")

In [3]:
drug_id = df["ID"]

In [4]:
smiles_df = pd.read_csv("smiles.csv")

In [5]:
smiles = smiles_df["Smiles"]

In [6]:
drug_cleaned=[]
smiles_cleaned=[]
for i in range(len(smiles)):
       if isinstance(smiles[i], str) and not pd.isnull(smiles[i]):
        mol_obj = Chem.MolFromSmiles(smiles[i])
        if mol_obj is not None:
            # If successful, add drug_id and smiles to the cleaned lists
            drug_cleaned.append(drug_id[i])
            smiles_cleaned.append(smiles[i])



In [7]:
print(len(smiles))
print(len(drug_cleaned))

1048575
1036742


In [7]:
 def get_node_features(mol):
        """
        This will return a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        all_node_feats = []

        for atom in mol.GetAtoms():
            node_feats = []
            # Feature 1: Atomic number
            node_feats.append(atom.GetAtomicNum())
            # Feature 2: Atom degree
            node_feats.append(atom.GetDegree())
            # Feature 3: Formal charge
            node_feats.append(atom.GetFormalCharge())
            # Feature 4: Hybridization
            node_feats.append(atom.GetHybridization())
            # Feature 5: Aromaticity
            node_feats.append(atom.GetIsAromatic())
            # Feature 6: Total Num Hs
            node_feats.append(atom.GetTotalNumHs())
            # Feature 7: Radical Electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            # Feature 8: In Ring
            node_feats.append(atom.IsInRing())
            # Feature 9: Chirality
            node_feats.append(atom.GetChiralTag())

            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)


In [8]:
def get_edge_features(mol):
        """
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = []

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)


In [9]:
def get_adjacency_info(mol):
        """
        We could also use rdmolops.GetAdjacencyMatrix(mol)
        but we want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

In [10]:
node_features = []
edge_features = []
edge_index = []
for i in range(len(smiles_cleaned)):
    mol_obj = Chem.MolFromSmiles(smiles_cleaned[i])
    node_feature = get_node_features(mol_obj)
    node_features.append(node_feature)
    edge_feature = get_edge_features(mol_obj)
    edge_features.append(edge_feature)
    edge_adjacency = get_adjacency_info(mol_obj)
    edge_index.append(edge_adjacency)



In [11]:
from ProtFlash.pretrain import load_prot_flash_base
from ProtFlash.utils import batchConverter
data = [
    ("protein1", "KIYIVLRRRRKRVNT"),
]
ids, batch_token, lengths = batchConverter(data)
model = load_prot_flash_base()
with torch.no_grad():
    token_embedding = model(batch_token, lengths)
# Generate per-sequence representations via averaging
sequence_representations = []
for i, (_, seq) in enumerate(data):
    sequence_representations.append(token_embedding[i, 0: len(seq) + 1].mean(0))

In [12]:
seq_sq = sequence_representations[0]
seq_sq = seq_sq.unsqueeze(0)
seq_sq = seq_sq.unsqueeze(0)

sequence_list=[]
for i in range(len(node_features)):
  sequence_list.append(seq_sq)


In [13]:
batch_size = 16
num_nodes = len(node_features)
batch_indices = [i // batch_size for i in range(num_nodes)]

In [14]:
data_list = []
for i in range(num_nodes):
    node_features_data = node_features[i]
    edge_index_data = torch.tensor(edge_index[i])
    batch_idx = torch.tensor(batch_indices[i])
    edge_features_data = torch.tensor(edge_features[i])
    pr_seq = sequence_list[i]

    data = Data(x=node_features_data, batch= batch_idx, edge_index = edge_index_data, sequence = pr_seq)
    data_list.append(data)

  edge_index_data = torch.tensor(edge_index[i])
  edge_features_data = torch.tensor(edge_features[i])


In [15]:
test_loader =  DataLoader(data_list, batch_size = batch_size, shuffle=False, drop_last=True)



In [16]:
class channel_attention(nn.Module):
    def __init__(self):
        super(channel_attention,self).__init__()
        self.maxpooling = nn.AdaptiveMaxPool1d(1)
        self.avgpooling = nn.AdaptiveAvgPool1d(1)
        self.mlp = nn.Sequential(
            nn.Linear(in_features = 32, out_features = 16, bias = False),
            nn.Linear(in_features = 16, out_features = 32, bias = False),
            nn.ReLU(inplace = True)
            




        )
            
        self.activation = nn.Sigmoid()


    
    def forward(self,x):
        x1 = self.maxpooling(x)
        #print(x1.shape)
        x2 = self.avgpooling(x)
        #print(x2.shape)
        x1_mlp = self.mlp(x1.squeeze(-1))
        x2_mlp = self.mlp(x2.squeeze(-1))
        feats = x1_mlp + x2_mlp
        feats = self.activation(feats)
        #print(feats.shape)
        channel_refined_feats = x * feats.unsqueeze(-1)
        return(channel_refined_feats)
        

In [17]:
class SAM(nn.Module):
    def __init__(self):
        super(SAM,self).__init__()

        
        self.convlayer = nn.Sequential(
            nn.Conv1d(in_channels = 2, out_channels = 1, kernel_size = 3, padding = 1),
        )
        self.activation = nn.Sigmoid()

        
    def forward(self,x):
        x_mean = torch.mean(x, dim = 1, keepdim = True)
        x_max,_ = torch.max(x,dim = 1, keepdim = True)
        x_cat = torch.cat((x_mean,x_max), dim = 1)
        x_conv = self.convlayer(x_cat)
        spatial_feats = self.activation(x_conv)
        refined_spatial_feats = x * spatial_feats
        return(refined_spatial_feats)
        
        

In [18]:
class CBAM(nn.Module):
    def __init__(self):
        super(CBAM,self).__init__()
        self.channel_attention = channel_attention()
        self.spatial_attention = SAM()
    def forward(self,x):
        channel_attention_layer = self.channel_attention(x)
        spatial_attention_layer = self.spatial_attention(channel_attention_layer)
        return(spatial_attention_layer)
        

In [19]:
embedding_size = 64
cnn_hidden_layer = 64
linear_embeddings = 32
output_embeddings = 16
out_features_cnn = 16

class GCN(nn.Module):
    def __init__(self):
        super(GCN,self).__init__()

        #GCN_layers
        self.initial_GCN = GCNConv(9,embedding_size)
        self.GCN1 = GCNConv(embedding_size,embedding_size)
        self.GCN2 = GCNConv(embedding_size,embedding_size)
        self.GCN_output_layer = nn.Linear(in_features = embedding_size*2, out_features = 768)

        self.bn1 = nn.LayerNorm(embedding_size)
        self.bn2 = nn.LayerNorm(embedding_size)



        #CNN_layer
        self.CNN_layers = nn.Sequential(
            nn.Conv1d(in_channels=2, out_channels=32, kernel_size = 3, stride=1,padding=1),
            nn.Conv1d(in_channels=32, out_channels=32, kernel_size = 3, stride=1,padding=1),

            nn.Conv1d(in_channels=32, out_channels=32, kernel_size = 3, stride=1,padding=1),







            
        )
        self.pooling = nn.MaxPool1d(kernel_size = 2)
        self.flatten = nn.Flatten()
        self.cbam = CBAM()

        self.CNN_linear = nn.Sequential(
            nn.Linear(in_features= 12288, out_features = linear_embeddings),
            nn.Softplus(),







            nn.Linear(in_features= linear_embeddings, out_features = linear_embeddings),
            nn.Softplus(),






            nn.Linear(in_features= linear_embeddings,out_features = 1),
            nn.Softplus()



            )
        self.dropout = nn.Dropout(0.6)


    
        

      

     

      

    def forward(self,drug, edge_index, batch_index, protein):
        GCN_layer_1 = self.initial_GCN(drug, edge_index)
        GCN_layer_1 = F.tanh(GCN_layer_1)

        GCN_layer_2 = self.GCN2(GCN_layer_1, edge_index)
        GCN_layer_2 = F.tanh(GCN_layer_2)


        GCN_layer_3 = self.GCN2(GCN_layer_2,edge_index)


        hidden = torch.cat([gmp(GCN_layer_3,batch_index), gap(GCN_layer_3,batch_index)],dim = 1)


        GCN_output_layer = self.GCN_output_layer(hidden)
        GCN_output_layer = GCN_output_layer.unsqueeze(0)


        GCN_output_layer = GCN_output_layer.permute(1,0,2)
        #print(f"GCN_output : {GCN_output_layer.shape}")

        combined = torch.cat((GCN_output_layer, protein), dim = 1)
        #print(f"Combined_shape : {combined.shape}")

        CNN_layer = self.CNN_layers(combined)
        CNN_layer = self.dropout(CNN_layer)

        #print(f" CNN_layer : {CNN_layer.shape}")

        attention_layer = self.cbam(x = CNN_layer)
        pooled_attention_layer = self.pooling(attention_layer)
        flattened_attention_layer = self.flatten(pooled_attention_layer)
        


        #print(flattened_attention_layer.shape)

        CNN_output = self.CNN_linear(flattened_attention_layer)
        #print(f"CNN_output.shape : {CNN_output.shape}")
        ###CNN_output = CNN_output.unsqueeze(0)
        #print(f"CNN_output : {CNN_output}")
        
        #print(f" CNN_output shape : {CNN_output.shape}")

        ###combined_layer = torch.cat((GCN_output_layer, CNN_output), dim =0)
        #print(f" Combined_layer_shape: {combined_layer.shape}")
        ###combined_layer = combined_layer.permute(1,0,2)


        ###combined_layer = torch.mean(combined_layer, dim = 1)
        ###combined_layer = combined_layer.unsqueeze(1)
        #print(f" Combined_layer_shape: {combined_layer.shape}")


        ###combined_CNN_layer = self.combined_CNN_layers (combined_layer)
        ###output_layer = self.combined_CNN_linear(combined_CNN_layer)


        #print(combined_CNN.shape)
        #print(model_layer.shape)
        return(CNN_output)

In [20]:
cnn = GCN()

In [21]:
cnn.load_state_dict(torch.load("attention_model.pth"))

<All keys matched successfully>

In [22]:
device = "cuda"

In [23]:
cnn.to(device)

GCN(
  (initial_GCN): GCNConv(9, 64)
  (GCN1): GCNConv(64, 64)
  (GCN2): GCNConv(64, 64)
  (GCN_output_layer): Linear(in_features=128, out_features=768, bias=True)
  (bn1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (bn2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (CNN_layers): Sequential(
    (0): Conv1d(2, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (2): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
  )
  (pooling): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (cbam): CBAM(
    (channel_attention): channel_attention(
      (maxpooling): AdaptiveMaxPool1d(output_size=1)
      (avgpooling): AdaptiveAvgPool1d(output_size=1)
      (mlp): Sequential(
        (0): Linear(in_features=32, out_features=16, bias=False)
        (1): Linear(in_features=16, out_features=32, bias=False)
        (2): ReLU(i

In [24]:
cnn.eval()

GCN(
  (initial_GCN): GCNConv(9, 64)
  (GCN1): GCNConv(64, 64)
  (GCN2): GCNConv(64, 64)
  (GCN_output_layer): Linear(in_features=128, out_features=768, bias=True)
  (bn1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (bn2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (CNN_layers): Sequential(
    (0): Conv1d(2, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (2): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
  )
  (pooling): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (cbam): CBAM(
    (channel_attention): channel_attention(
      (maxpooling): AdaptiveMaxPool1d(output_size=1)
      (avgpooling): AdaptiveAvgPool1d(output_size=1)
      (mlp): Sequential(
        (0): Linear(in_features=32, out_features=16, bias=False)
        (1): Linear(in_features=16, out_features=32, bias=False)
        (2): ReLU(i

In [25]:
output_values_test = []

with torch.no_grad():
    for idx in test_loader:
        idx.x = idx.x.to(device)
        idx.batch = idx.batch.to(device)
        idx.edge_index = idx.edge_index.to(device)
        idx.embeddings_batch = idx.sequence.to(device)

        idx.edge_index = idx.edge_index.to(torch.int64)

        output_test = cnn(drug=idx.x, edge_index=idx.edge_index, batch_index=idx.batch, protein= idx.embeddings_batch)
        output_values_test.append(output_test.cpu().numpy())

# Flatten the output and ground truth arrays
output_value_model_test = np.concatenate(output_values_test).flatten()

In [55]:
output_model = []
for i in range(len(output_value_model_test)):
    output = torch.tensor(output_value_model_test[i])
    output_model.append(output)

In [57]:
len(output_log)

1036736

In [58]:
log_sort = np.argsort(output_log)

In [66]:
len(output_value_model_test)

1036736

In [67]:
output_log[0]

tensor(1.5324)

In [41]:
output_value_model_test

array([4.629501 , 7.072581 , 3.2077932, ..., 5.571663 , 7.2749705,
       3.0805588], dtype=float32)

In [26]:
output_sort = np.argsort(output_value_model_test)

In [27]:
output_sort

array([857613, 408800, 917404, ..., 876393, 108486, 166486])

In [29]:
sorted_ki = output_sort[:20]

In [59]:
sorted_log = log_sort[:30000]

In [30]:
values = []
smiles_sorted = []
drug_ids = []
for i in range(len(sorted_ki)):
    index = sorted_ki[i]
    values.append(output_value_model_test[index])
    smiles_sorted.append(smiles_cleaned[index])
    drug_ids.append(drug_cleaned[index])

In [31]:
dataframe = pd.DataFrame({"CHEMBL ID" : drug_ids, "Smiles" : smiles_sorted, "Ki_values(predicted)" : values})

In [32]:
dataframe.to_csv("attention_model_predictions.csv")

In [68]:
values_log = []
smiles_sorted = []
drug_ids = []
for i in range(len(sorted_log)):
    index = sorted_log[i]
    values_log.append(output_log[index])
    smiles_sorted.append(smiles_cleaned[index])
    drug_ids.append(drug_cleaned[index])

In [69]:
dataframe = pd.DataFrame({"CHEMBL ID" : drug_ids, "Smiles" : smiles_sorted, "Ki_values(predicted)" : values_log})

In [70]:
dataframe.to_csv("dopamine_drug_predictions_log.csv")