In [1]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
from collections import defaultdict
import torch.optim as optim
from torch_geometric.nn import TransformerConv
from sklearn.metrics import accuracy_score
from torch import tensor
from torch_geometric.data import Data, DataLoader
from collections import defaultdict

import torch.utils.data as data

In [2]:



def load_data(dataset, device):
    data_file = f"./original_datasets/{dataset}/{dataset}_train"
    file = open(data_file, "r")
    node_types = set()
    label_types = set()
    tr_len = 0
    for line in file:
        tr_len += 1
        smiles = line.split("\t")[1]
        s = []
        mol = AllChem.MolFromSmiles(smiles)
        for atom in mol.GetAtoms():
            s.append(atom.GetAtomicNum())
        node_types |= set(s)
        label = line.split("\t")[2][:-1]
        label_types.add(label)
    file.close()

    te_len = 0
    data_file = f"./original_datasets/{dataset}/{dataset}_test"
    file = open(data_file, "r")
    for line in file:
        te_len += 1
        smiles = line.split("\t")[1]
        s = []
        mol = AllChem.MolFromSmiles(smiles)
        for atom in mol.GetAtoms():
            s.append(atom.GetAtomicNum())
        node_types |= set(s)
        label = line.split("\t")[2][:-1]
        label_types.add(label)
    file.close()

    #print(tr_len)
    #print(te_len)

    node2index = {n: i for i, n in enumerate(node_types)}
    label2index = {l: i for i, l in enumerate(label_types)}

    #print(node2index)
    #print(label2index)

    data_file = f"./original_datasets/{dataset}/{dataset}_train"
    file = open(data_file, "r")
    train_adjlists = []
    train_features = []
    train_sequence = []
    train_labels = torch.zeros(tr_len)
    for line in file:
        smiles = line.split("\t")[1]
        label = line.split("\t")[2][:-1]
        mol = AllChem.MolFromSmiles(smiles)
        feature = torch.zeros(len(mol.GetAtoms()), len(node_types))

        l = 0
        smiles_seq = []
        for atom in mol.GetAtoms():
            feature[l, node2index[atom.GetAtomicNum()]] = 1
            smiles_seq.append(node2index[atom.GetAtomicNum()])
            l += 1
        adj_list = defaultdict(list)
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            typ = bond.GetBondType()
            adj_list[i].append(j)
            adj_list[j].append(i)
            if typ == Chem.rdchem.BondType.DOUBLE:
                adj_list[i].append(j)
                adj_list[j].append(i)
            elif typ == Chem.rdchem.BondType.TRIPLE:
                adj_list[i].append(j)
                adj_list[j].append(i)
                adj_list[i].append(j)
                adj_list[j].append(i)

        train_labels[len(train_adjlists)]= int(label2index[label])
        #print(train_labels)
        train_adjlists.append(adj_list)
        train_features.append(torch.FloatTensor(feature).to(device))
        train_sequence.append(torch.tensor(smiles_seq))
    file.close()

    data_file = f"./original_datasets/{dataset}/{dataset}_test"
    file = open(data_file, "r")
    test_adjlists = []
    test_features = []
    test_sequence = []
    test_labels = np.zeros(te_len)
    for line in file:
        smiles = line.split("\t")[1]
        # print(smiles)
        label = line.split("\t")[2][:-1]
        mol = AllChem.MolFromSmiles(smiles)
        feature = torch.zeros(len(mol.GetAtoms()), len(node_types))
        l = 0
        smiles_seq = []
        for atom in mol.GetAtoms():
            feature[l, node2index[atom.GetAtomicNum()]] = 1
            smiles_seq.append(node2index[atom.GetAtomicNum()])
            l += 1
        adj_list = defaultdict(list)
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            typ = bond.GetBondType()
            adj_list[i].append(j)
            adj_list[j].append(i)
            if typ == Chem.rdchem.BondType.DOUBLE:
                adj_list[i].append(j)
                adj_list[j].append(i)
            elif typ == Chem.rdchem.BondType.TRIPLE:
                adj_list[i].append(j)
                adj_list[j].append(i)
                adj_list[i].append(j)
                adj_list[j].append(i)

        test_labels[len(test_adjlists)] = int(label2index[label])
        test_adjlists.append(adj_list)
        test_features.append(torch.FloatTensor(feature).to(device))
        test_sequence.append(torch.tensor(smiles_seq))
    file.close()
    train_data = {}
    train_data['adj_lists'] = train_adjlists
    train_data['features'] = train_features
    # Pad train_sequence to length 100
    padded_train_sequence = []
    for seq in train_sequence:
      padded_seq = torch.nn.functional.pad(seq, (0, 100 - len(seq)), 'constant', 0)
      padded_train_sequence.append(padded_seq)
      train_data['sequence'] = padded_train_sequence
    test_data = {}
    test_data['adj_lists'] = test_adjlists
    test_data['features'] = test_features
    padded_test_sequence = []
    for seq in test_sequence:
      padded_seq = torch.nn.functional.pad(seq, (0, 100 - len(seq)), 'constant', 0)
      padded_test_sequence.append(padded_seq)
      test_data['sequence'] = padded_test_sequence

    return train_data, train_labels, test_data, test_labels

In [3]:
train_data, train_labels, test_data, test_labels=load_data("bace",device='cpu')

In [4]:
logp_input_dim_train = train_data['features'][0].size(-1)
logp_input_dim_test = test_data['features'][0].size(-1)
def adj_list_to_adj_matrix(adj_list):
    num_nodes = len(adj_list)
    adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
    for node, neighbors in adj_list.items():
        for neighbor in neighbors:
            adj_matrix[node][neighbor] = 1.0
            adj_matrix[neighbor][node] = 1.0
    return adj_matrix
logp_adj_matrices_train = [adj_list_to_adj_matrix(adj_list) for adj_list in train_data['adj_lists']]
logp_adj_matrices_test = [adj_list_to_adj_matrix(adj_list) for adj_list in test_data['adj_lists']]

from torch.utils.data import DataLoader, TensorDataset

# Combine data and sequence into one tensor for both train and test data
logp_data_sequence_train = torch.stack(train_data['sequence'])
logp_data_sequence_test = torch.stack(test_data['sequence'])

logp_data_list_train = [Data(x=torch.tensor(features, dtype=torch.float),
                  edge_index=torch.nonzero(adj_matrix, as_tuple=False).t().contiguous(),
                  y=torch.tensor(label, dtype=torch.float))
             for features, adj_matrix, label in zip(train_data['features'], logp_adj_matrices_train, train_labels)]
logp_data_list_test= [Data(x=torch.tensor(features, dtype=torch.float),
                  edge_index=torch.nonzero(adj_matrix, as_tuple=False).t().contiguous(),
                  y=torch.tensor(label, dtype=torch.float))
             for features, adj_matrix, label in zip(test_data['features'], logp_adj_matrices_test, train_labels)]


  logp_data_list_train = [Data(x=torch.tensor(features, dtype=torch.float),
  y=torch.tensor(label, dtype=torch.float))
  logp_data_list_test= [Data(x=torch.tensor(features, dtype=torch.float),
  y=torch.tensor(label, dtype=torch.float))


In [5]:
from torch.utils.data import DataLoader

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_list, sequence_list):
        self.data_list = data_list
        self.sequence_list = sequence_list


    def __getitem__(self, index):
        data = self.data_list[index]
        sequence = self.sequence_list[index]
        return data, sequence

    def __len__(self):
        return len(self.data_list)


train_dataset = CustomDataset(logp_data_list_train, train_data['sequence'])
test_dataset = CustomDataset(logp_data_list_test, test_data['sequence'])




In [6]:
class TransformerDecoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TransformerDecoder, self).__init__()
        self.transformer = nn.TransformerDecoderLayer(d_model=input_dim, nhead=1)
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x, memory):
        x = self.transformer(x,memory)
        x = self.fc(x)
        return x

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, max_length=100):
        super(TransformerModel, self).__init__()
        self.device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward), num_layers=num_encoder_layers)
        self.fc = nn.Linear(d_model, 1)
        self.decoder = TransformerDecoder(d_model, d_model, vocab_size)
        
        
    def forward(self, x):
        x = self.embedding(x).to(self.device)
        x = torch.transpose(x, 0, 1).to(self.device)
        encoded = self.transformer_encoder(x).to(self.device)
        x = torch.mean(encoded, dim=0).to(self.device)
        x = self.fc(x).to(self.device)
        x = x.mean(dim=0, keepdim=True).to(self.device)
        reconstructed = self.decoder(encoded,encoded).to(self.device)
        reconstructed = torch.mean(reconstructed, dim=1).to(self.device)
        reconstructed = self.fc(reconstructed).to(self.device)
        
        return x, reconstructed

In [30]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
recon_losses = []
vocab_size = 100
d_model = 100
nhead = 4
num_encoder_layers = 3
dim_feedforward = 512
criterion = nn.MSELoss()
model = TransformerModel(vocab_size=logp_input_dim_train, d_model=100, nhead=4, num_encoder_layers=3, dim_feedforward=512)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
# Training loop
for epoch in range(4):
    recon_losses = []
    for index,data_batch in enumerate(test_dataset):
        graph_data_batch = data_batch[0]
        sequence_inputs = data_batch[1].to(device)
        sequence_targets = graph_data_batch.y.to(device)
        
        # Forward pass
        outputs, reconstructed = model(sequence_inputs)
        # reconstructed = reconstructed.to(device)
        # outputs = outputs.to(device)
        
        # Calculate reconstruction loss
        recon_loss = criterion(reconstructed, sequence_targets)
        recon_losses.append((index,sequence_inputs, recon_loss.item()))
        
        # Backward pass and optimization
        optimizer.zero_grad()
        recon_loss.backward()
        optimizer.step()




  return F.mse_loss(input, target, reduction=self.reduction)


In [31]:
# Output list of smiles and their loss
for index,smile, loss in recon_losses:
    print(f"Index: {index} Smile: {smile}, Recon Loss: {loss}")

Index: 0 Smile: tensor([4, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 3, 1, 1, 3, 1, 2, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0], device='cuda:0'), Recon Loss: 0.10280755907297134
Index: 1 Smile: tensor([5, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 3, 2, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 3, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 4, 4, 4, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0], device='cuda:0'), Recon Loss: 0.04914477840065956
Index: 2 Smile: tensor([5, 3, 3, 1, 1, 1, 1, 1, 1, 3, 1, 1, 4, 4, 4, 1, 4, 4, 4, 1, 2, 1, 4, 1,
        1, 3, 1, 2, 1, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0,

In [32]:
threshold = 0.35
list_of_smiles = []                             
for index, _, recon_loss in recon_losses:
    if recon_loss <=threshold:
        list_of_smiles.append(index)
print(f"Number of Smiles with loss less than {threshold}: {len(list_of_smiles)}")


Number of Smiles with loss less than 0.35: 109


In [33]:
len(list_of_smiles)

109

In [35]:
s=0

for data_batch in test_dataset:
        graph_data_batch = data_batch[0]
        sequence_inputs = data_batch[1].to(device)
        sequence_targets = graph_data_batch.y.to(device)
        s+=1
s

152

In [36]:
import os
# Specify the path to the input file
input_file_path = 'original_datasets/bace/bace_test'

output_file = "original_datasets/bace_recon/bace_test"


output_directory = os.path.dirname(output_file)
os.makedirs(output_directory, exist_ok=True)
# Open the input file in read mode
count=0
with open(input_file_path, 'r') as input_file:
    for index,line in enumerate(input_file.readlines()):
        
        if index in list_of_smiles:
            smille=line.split("\t")[1]
            label = line.split("\t")[2]
            
            data = f"{count}\t{smille}\t{label}"
            print(data)
            with open(output_file, 'a') as output_data:
                output_data.write(data)
            count+=1


0	Fc1cc(cc(F)c1)C[C@H](NC(=O)C)[C@H](O)C[NH2+]C1(CCCCC1)c1cc(-n2nccc2)ccc1	1

1	S(=O)(=O)(N(c1ccccc1)c1cc(cnc1)C(=O)NC(Cc1ccccc1)C(O)C[NH2+]Cc1cc(ccc1)C(F)(F)F)C	1

2	S1(=O)(=O)CC(Cc2cc(OC(C(F)(F)F)C(F)(F)F)c(N)c(F)c2)C(O)C([NH2+]Cc2noc(c2)CC(C)(C)C)C1	1

3	S1(=O)(=O)CC(Cc2cc(OC3CCC3)c(N)c(F)c2)C(O)C([NH2+]Cc2cc(ccc2)C(C)(C)C)C1	1

4	Clc1cc(-c2cc3c(Oc4c(cc(OC)cc4)[C@@]34N=C(OC4)N)cc2)c(F)cc1	0

5	S(=O)(=O)(N(c1cc(ncc1)C(=O)NC(Cc1ccccc1)C(O)C[NH2+]Cc1cc(ccc1)C(F)(F)F)c1ccccc1)C	0

6	Clc1cc2CC(N=C(NC(Cc3ccccc3)C=3NC(=O)c4c(N=3)cccc4)c2cc1)(C)C	0

7	S1(=O)(=O)C[C@@H](Cc2cc(C[C@@H]3N(CCC)C(OC3)=O)c(O)cc2)[C@H](O)[C@@H]([NH2+]Cc2cc(ccc2)C(C)C)C1	0

8	O=C1N[C@@H](C[C@@H](CCCCCCCC(=O)N[C@H]1C)C)[C@@H](O)C[C@H](C(=O)NCCCC)C	0

9	O1[C@@H]2COCC[C@@]2(N=C1N)c1cc(ccc1)-c1cncnc1	0

10	Fc1cc(cc(F)c1)CC(NC(=O)c1c2cccnc2n(c1)C(=O)N(CCCC)C)C(O)C[NH2+]C1(CC1)c1cc(ccc1)CC	1

11	FC(F)(F)c1cc(ccc1)C[NH2+]C[C@@H](O)[C@@H](NC(=O)c1cc(N2CCCC2=O)cc(NCC)c1)Cc1ccccc1	1

12	Fc1ncccc1-c1cc(ccc1)[C@]1([NH+]=C(N2C1=