In [18]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
from collections import defaultdict
import torch.optim as optim
from torch_geometric.nn import TransformerConv
from sklearn.metrics import accuracy_score
from torch import tensor
from torch_geometric.data import Data, DataLoader
from collections import defaultdict

import torch.utils.data as data

In [19]:



def load_data(dataset, device):
    data_file = f"./original_datasets/{dataset}/{dataset}_train"
    file = open(data_file, "r")
    node_types = set()
    label_types = set()
    tr_len = 0
    for line in file:
        tr_len += 1
        smiles = line.split("\t")[1]
        s = []
        mol = AllChem.MolFromSmiles(smiles)
        for atom in mol.GetAtoms():
            s.append(atom.GetAtomicNum())
        node_types |= set(s)
        label = line.split("\t")[2][:-1]
        label_types.add(label)
    file.close()

    te_len = 0
    data_file = f"./original_datasets/{dataset}/{dataset}_test"
    file = open(data_file, "r")
    for line in file:
        te_len += 1
        smiles = line.split("\t")[1]
        s = []
        mol = AllChem.MolFromSmiles(smiles)
        for atom in mol.GetAtoms():
            s.append(atom.GetAtomicNum())
        node_types |= set(s)
        label = line.split("\t")[2][:-1]
        label_types.add(label)
    file.close()

    #print(tr_len)
    #print(te_len)

    node2index = {n: i for i, n in enumerate(node_types)}
    label2index = {l: i for i, l in enumerate(label_types)}

    #print(node2index)
    #print(label2index)

    data_file = f"./original_datasets/{dataset}/{dataset}_train"
    file = open(data_file, "r")
    train_adjlists = []
    train_features = []
    train_sequence = []
    train_labels = torch.zeros(tr_len)
    for line in file:
        smiles = line.split("\t")[1]
        label = line.split("\t")[2][:-1]
        mol = AllChem.MolFromSmiles(smiles)
        feature = torch.zeros(len(mol.GetAtoms()), len(node_types))

        l = 0
        smiles_seq = []
        for atom in mol.GetAtoms():
            feature[l, node2index[atom.GetAtomicNum()]] = 1
            smiles_seq.append(node2index[atom.GetAtomicNum()])
            l += 1
        adj_list = defaultdict(list)
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            typ = bond.GetBondType()
            adj_list[i].append(j)
            adj_list[j].append(i)
            if typ == Chem.rdchem.BondType.DOUBLE:
                adj_list[i].append(j)
                adj_list[j].append(i)
            elif typ == Chem.rdchem.BondType.TRIPLE:
                adj_list[i].append(j)
                adj_list[j].append(i)
                adj_list[i].append(j)
                adj_list[j].append(i)

        train_labels[len(train_adjlists)]= int(label2index[label])
        #print(train_labels)
        train_adjlists.append(adj_list)
        train_features.append(torch.FloatTensor(feature).to(device))
        train_sequence.append(torch.tensor(smiles_seq))
    file.close()

    data_file = f"./original_datasets/{dataset}/{dataset}_test"
    file = open(data_file, "r")
    test_adjlists = []
    test_features = []
    test_sequence = []
    test_labels = np.zeros(te_len)
    for line in file:
        smiles = line.split("\t")[1]
        # print(smiles)
        label = line.split("\t")[2][:-1]
        mol = AllChem.MolFromSmiles(smiles)
        feature = torch.zeros(len(mol.GetAtoms()), len(node_types))
        l = 0
        smiles_seq = []
        for atom in mol.GetAtoms():
            feature[l, node2index[atom.GetAtomicNum()]] = 1
            smiles_seq.append(node2index[atom.GetAtomicNum()])
            l += 1
        adj_list = defaultdict(list)
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            typ = bond.GetBondType()
            adj_list[i].append(j)
            adj_list[j].append(i)
            if typ == Chem.rdchem.BondType.DOUBLE:
                adj_list[i].append(j)
                adj_list[j].append(i)
            elif typ == Chem.rdchem.BondType.TRIPLE:
                adj_list[i].append(j)
                adj_list[j].append(i)
                adj_list[i].append(j)
                adj_list[j].append(i)

        test_labels[len(test_adjlists)] = int(label2index[label])
        test_adjlists.append(adj_list)
        test_features.append(torch.FloatTensor(feature).to(device))
        test_sequence.append(torch.tensor(smiles_seq))
    file.close()
    train_data = {}
    train_data['adj_lists'] = train_adjlists
    train_data['features'] = train_features
    # Pad train_sequence to length 100
    padded_train_sequence = []
    for seq in train_sequence:
      padded_seq = torch.nn.functional.pad(seq, (0, 100 - len(seq)), 'constant', 0)
      padded_train_sequence.append(padded_seq)
      train_data['sequence'] = padded_train_sequence
    test_data = {}
    test_data['adj_lists'] = test_adjlists
    test_data['features'] = test_features
    padded_test_sequence = []
    for seq in test_sequence:
      padded_seq = torch.nn.functional.pad(seq, (0, 100 - len(seq)), 'constant', 0)
      padded_test_sequence.append(padded_seq)
      test_data['sequence'] = padded_test_sequence

    return train_data, train_labels, test_data, test_labels

In [20]:
train_data, train_labels, test_data, test_labels=load_data("logp",device='cpu')

In [21]:
train_data['sequence'][2]

tensor([1, 2, 1, 1, 1, 1, 1, 1, 3, 2, 2, 1, 1, 1, 1, 1, 1, 9, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0])

In [22]:
logp_input_dim_train = train_data['features'][0].size(-1)
logp_input_dim_test = test_data['features'][0].size(-1)
def adj_list_to_adj_matrix(adj_list):
    num_nodes = len(adj_list)
    adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
    for node, neighbors in adj_list.items():
        for neighbor in neighbors:
            adj_matrix[node][neighbor] = 1.0
            adj_matrix[neighbor][node] = 1.0
    return adj_matrix
logp_adj_matrices_train = [adj_list_to_adj_matrix(adj_list) for adj_list in train_data['adj_lists']]
logp_adj_matrices_test = [adj_list_to_adj_matrix(adj_list) for adj_list in test_data['adj_lists']]

from torch.utils.data import DataLoader, TensorDataset

# Combine data and sequence into one tensor for both train and test data
logp_data_sequence_train = torch.stack(train_data['sequence'])
logp_data_sequence_test = torch.stack(test_data['sequence'])

logp_data_list_train = [Data(x=torch.tensor(features, dtype=torch.float),
                  edge_index=torch.nonzero(adj_matrix, as_tuple=False).t().contiguous(),
                  y=torch.tensor(label, dtype=torch.float))
             for features, adj_matrix, label in zip(train_data['features'], logp_adj_matrices_train, train_labels)]
logp_data_list_test= [Data(x=torch.tensor(features, dtype=torch.float),
                  edge_index=torch.nonzero(adj_matrix, as_tuple=False).t().contiguous(),
                  y=torch.tensor(label, dtype=torch.float))
             for features, adj_matrix, label in zip(test_data['features'], logp_adj_matrices_test, train_labels)]


  logp_data_list_train = [Data(x=torch.tensor(features, dtype=torch.float),
  y=torch.tensor(label, dtype=torch.float))
  logp_data_list_test= [Data(x=torch.tensor(features, dtype=torch.float),
  y=torch.tensor(label, dtype=torch.float))


In [23]:
from torch.utils.data import DataLoader

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_list, sequence_list):
        self.data_list = data_list
        self.sequence_list = sequence_list


    def __getitem__(self, index):
        data = self.data_list[index]
        sequence = self.sequence_list[index]
        return data, sequence

    def __len__(self):
        return len(self.data_list)


def custom_collate(batch):
  if isinstance(batch[0], Data):
    # Convert each Data object in the batch to a dictionary.
    batch = [data.__dict__ for data in batch]
    # Return a new Data object with the dictionaries as its attributes.
    return Data(**batch[0])
  else:
    # Handle other types of data as needed.
    pass
# Create datasets
train_dataset = CustomDataset(logp_data_list_train, train_data['sequence'])
test_dataset = CustomDataset(logp_data_list_test, test_data['sequence'])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False,collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False,collate_fn=custom_collate)


In [24]:
train_dataset[1]

(Data(x=[34, 12], edge_index=[2, 72], y=1.0),
 tensor([1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 3, 2, 1, 1,
         1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]))

In [25]:
third_tensor = train_dataset[2][1]
third_tensor

tensor([1, 2, 1, 1, 1, 1, 1, 1, 3, 2, 2, 1, 1, 1, 1, 1, 1, 9, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0])

In [26]:
data=train_dataset[0][0]
data.y

tensor(1.)

In [27]:
class GCNConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, heads, edge_weight=None):
        super(GCNConvLayer, self).__init__()
        self.conv = TransformerConv(in_channels, out_channels, heads)
        self.edge_weight = edge_weight

    def forward(self, x, edge_index):
        return self.conv(x, edge_index)


class GCN_2l(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads, dropout_rate=0.5):
        super(GCN_2l, self).__init__()
        self.conv1 = GCNConvLayer(in_channels, hidden_channels, heads)
        self.conv2 = GCNConvLayer(hidden_channels*heads, out_channels, heads)
        self.linear=nn.Linear(heads,1)
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        #print("before conv1: ", x.shape)
        x = self.conv1(x, edge_index)
        #print("after conv1: ", x.shape)
        x = torch.relu(x)
        x = self.dropout(x)
        #print("before conv2: ", x.shape)
        x = self.conv2(x, edge_index)
        #print("after conv2: ", x.shape)
        x=self.linear(x)
        #print(x)
        #print(x.mean(dim=0, keepdim=True))
        return x.mean(dim=0, keepdim=True)  # Aggregating to a single output


In [28]:
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, max_length=100):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward), num_layers=num_encoder_layers)
        self.fc = nn.Linear(d_model, 1)

    def forward(self, x):
        x = self.embedding(x)
        x = torch.transpose(x, 0, 1)
        x = self.transformer_encoder(x)
        x = torch.mean(x, dim=0)
        x = self.fc(x)
        return  x.mean(dim=0, keepdim=True)


# Define hyperparameters
vocab_size = 100
d_model = 100
nhead = 4
num_encoder_layers = 3
dim_feedforward = 512
max_length = 100
batch_size = 1
num_epochs = 100



In [29]:
class FusionModel(nn.Module):
    def __init__(self, graph_model, sequence_model):
        super(FusionModel, self).__init__()
        self.graph_model = graph_model
        self.sequence_model = sequence_model
        # Attention layer
        self.attention = nn.Linear(graph_model.linear.out_features + sequence_model.fc.out_features, 1)
        # Fusion layer
        self.fusion_linear = nn.Linear(graph_model.linear.out_features + sequence_model.fc.out_features, 1)

    def forward(self, graph_data, sequence_data):
        graph_embedding = self.graph_model(graph_data)
        sequence_embedding = self.sequence_model(sequence_data)

        # Reshape sequence embedding to have two dimensions
        sequence_embedding = sequence_embedding.unsqueeze(0)


        # Concatenate embeddings
        combined_embedding = torch.cat((graph_embedding, sequence_embedding), dim=1)

        # Compute attention weights
        attention_weights = torch.sigmoid(self.attention(combined_embedding))

        # Apply attention to sequence embedding
        weighted_sequence_embedding = attention_weights * sequence_embedding

        # Fusion
        fused_embedding = torch.cat((graph_embedding, weighted_sequence_embedding), dim=1)

        # Apply fusion layer
        output = self.fusion_linear(fused_embedding)
        return output



In [30]:
graph_model = GCN_2l(in_channels=logp_input_dim_train, hidden_channels=64, out_channels=1, heads=4)
sequence_model = TransformerModel(vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward)
fusion_model = FusionModel(graph_model, sequence_model)



In [31]:
import torch.optim as optim
# Define the loss function
criterion = nn.BCEWithLogitsLoss()

# Define the optimizer
optimizer = optim.Adam(fusion_model.parameters(), lr=0.001)

# Training loop
for epoch in range(100):
    total_correct = 0
    total_samples = 0
    for data_batch in train_dataset:
        graph_data_batch = data_batch[0]
        sequence_inputs  = data_batch[1]
        sequence_targets=graph_data_batch.y

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        output = fusion_model(graph_data_batch, sequence_inputs)

        # Compute binary predictions
        binary_predictions = (output > 0.5).float()

        # Compute batch accuracy
        batch_correct = (binary_predictions == sequence_targets).sum().item()
        total_correct += batch_correct
        total_samples += 1

        sequence_targets = sequence_targets.view(-1, 1)
        # Compute loss
        loss = criterion(output, sequence_targets)

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()


    # Compute epoch accuracy
    epoch_accuracy = total_correct / total_samples
    print(f"Epoch {epoch+1}/{100}, Epoch Accuracy: {epoch_accuracy:.4f}")


Epoch 1/100, Epoch Accuracy: 0.7855
Epoch 2/100, Epoch Accuracy: 0.8538


KeyboardInterrupt: 