In [45]:
import pandas as pd
from rdkit import Chem
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data

In [46]:
df = pd.read_csv('D:/brpr/dataset/brenda_rm.csv')

# Convert SMILES to RDKit Mol objects
df['substrate_mol'] = df['smiles'].apply(Chem.MolFromSmiles)
df['product_mol'] = df['product_smile'].apply(Chem.MolFromSmiles)
unique_ec_numbers = df['EC Number'].unique()
ec_to_index = {ec: idx for idx, ec in enumerate(unique_ec_numbers)}
df['EC Index'] = df['EC Number'].map(ec_to_index)

In [47]:
def atom_feature(atom):
    return [
        atom.GetAtomicNum(),
        atom.GetTotalValence(),
        atom.GetDegree(),
        atom.IsInRing(),
        atom.GetFormalCharge(),
        atom.GetNumRadicalElectrons()
    ]

def bond_feature(bond):
    bond_type = bond.GetBondType()
    return [
        bond_type == Chem.rdchem.BondType.SINGLE,
        bond_type == Chem.rdchem.BondType.DOUBLE,
        bond_type == Chem.rdchem.BondType.TRIPLE,
        bond_type == Chem.rdchem.BondType.AROMATIC,
        bond.IsInRing()
    ]

def molecule_to_graph(mol):
    atom_features = [atom_feature(atom) for atom in mol.GetAtoms()]
    bond_indices = []
    bond_features = []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bond_indices.append((start, end))
        bond_indices.append((end, start))
        bond_features.extend([bond_feature(bond)] * 2)

    x = torch.tensor(atom_features, dtype=torch.float)
    edge_index = torch.tensor(bond_indices, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(bond_features, dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

df['substrate_graph'] = df['substrate_mol'].apply(molecule_to_graph)
df['product_graph'] = df['product_mol'].apply(molecule_to_graph)

In [63]:
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class MolecularDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        substrate_graph = self.data.iloc[idx]['substrate_graph']
        enzyme_index = self.data.iloc[idx]['EC Index']
        product_graph = self.data.iloc[idx]['product_graph']
        
        return substrate_graph, enzyme_index, product_graph

# Splitting the data into training and test sets
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

train_dataset = MolecularDataset(train_df)
test_dataset = MolecularDataset(test_df)

In [64]:
from torch_geometric.data import Batch

def collate_fn(batch):
    substrate_data, enzyme_data, true_product_data = zip(*batch)
    
    # Collate substrate and product data using Batch.from_data_list
    substrate_data = Batch.from_data_list(list(substrate_data))
    true_product_data = Batch.from_data_list(list(true_product_data))
    
    # Convert enzyme_data to tensor
    enzyme_data = torch.tensor(enzyme_data, dtype=torch.long)
    
    return substrate_data, enzyme_data, true_product_data

In [67]:
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

## Model

In [121]:
class TransformationEncoderDecoder(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels, transformation_dim, heads, num_enzymes, enzyme_embedding_dim):
        super(TransformationEncoderDecoder, self).__init__()

        # Enzyme embedding
        self.enzyme_embedding = torch.nn.Embedding(num_enzymes, enzyme_embedding_dim)

        # Encoder
        self.encoder_conv1 = GATConv(num_node_features, hidden_channels, heads=heads, dropout=0.6)
        self.encoder_conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=0.6)
        self.encoder_fc = torch.nn.Linear(hidden_channels * heads + enzyme_embedding_dim, 256)  # Adjusted for enzyme embedding

        # Decoder
        self.decoder_fc = torch.nn.Linear(256, 256)
        
    def encode_molecule(self, molecule_data, enzyme_data):
        x, edge_index = molecule_data.x, molecule_data.edge_index
        x = F.elu(self.encoder_conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = F.elu(self.encoder_conv2(x, edge_index))
        global_representation = torch.mean(x, dim=0, keepdim=True)
        enzyme_embed = self.enzyme_embedding(enzyme_data)

        # Repeat the global_representation to match the batch size of enzyme_embed
        global_representation = global_representation.repeat(enzyme_embed.size(0), 1)

        combined_representation = torch.cat([global_representation, enzyme_embed], dim=1)
        transformation_representation = self.encoder_fc(combined_representation)
        return transformation_representation


    def forward(self, substrate_data, enzyme_data):
        # Encoder
        x, edge_index = substrate_data.x, substrate_data.edge_index
        x = F.elu(self.encoder_conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = F.elu(self.encoder_conv2(x, edge_index))
        global_representation = torch.mean(x, dim=0, keepdim=True)  # Global pooling
        #enzyme embed
        enzyme_embed = self.enzyme_embedding(enzyme_data)
        global_representation = global_representation.repeat(enzyme_embed.size(0), 1)
        # Concatenate enzyme embedding with substrate representation
        combined_representation = torch.cat([global_representation, enzyme_embed], dim=1)

        transformation_representation = self.encoder_fc(combined_representation)

        # Decoder
        predicted_product = self.decoder_fc(transformation_representation)

        return predicted_product

In [122]:
num_enzymes = len(unique_ec_numbers)
enzyme_embedding_dim = 32

## Train

In [123]:
# Hyperparameters
learning_rate = 0.001
epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Initialize the model, criterion, and optimizer
model = TransformationEncoderDecoder(num_node_features=6,  # Assuming 6 features per node
                                    hidden_channels=64,
                                    transformation_dim=128,
                                    heads=4,
                                    num_enzymes=num_enzymes,
                                    enzyme_embedding_dim=32).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.MSELoss()

# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for substrate_data, enzyme_data, true_product_data in train_loader:
        substrate_data, enzyme_data, true_product_data = substrate_data.to(device), enzyme_data.to(device), true_product_data.to(device)
        
        optimizer.zero_grad()
        predicted_product = model(substrate_data, enzyme_data)
        true_product_representation = model.encode_molecule(true_product_data, enzyme_data)
        loss = criterion(predicted_product, true_product_representation)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader)}")

Epoch 1/100, Loss: 0.0006558131486973142
Epoch 2/100, Loss: 1.829552584037256e-05
Epoch 3/100, Loss: 1.4227179618572512e-05
Epoch 4/100, Loss: 8.075902338517595e-06
Epoch 5/100, Loss: 3.844186295026517e-06
Epoch 6/100, Loss: 1.7774302896262657e-06
Epoch 7/100, Loss: 6.79058255036624e-07
Epoch 8/100, Loss: 2.6039528443067773e-07
Epoch 9/100, Loss: 4.9518285040296775e-05
Epoch 10/100, Loss: 1.2193840172204423e-06
Epoch 11/100, Loss: 4.5149615713912974e-07
Epoch 12/100, Loss: 2.965127977951899e-07
Epoch 13/100, Loss: 1.5811131291847205e-07
Epoch 14/100, Loss: 8.789290360117238e-08
Epoch 15/100, Loss: 4.3865475555180425e-08
Epoch 16/100, Loss: 1.5095450043501728e-05
Epoch 17/100, Loss: 1.6041215626228587e-07
Epoch 18/100, Loss: 1.1528275722328844e-07
Epoch 19/100, Loss: 1.0832541134067416e-07
Epoch 20/100, Loss: 2.4614259034426577e-08
Epoch 21/100, Loss: 4.177422168117011e-05
Epoch 22/100, Loss: 3.7546256061768175e-07
Epoch 23/100, Loss: 2.61807917587423e-07
Epoch 24/100, Loss: 6.330181521

# Evaluation

In [125]:
model.eval()
total_loss = 0
with torch.no_grad():
    for substrate_data, enzyme_data, true_product_data in test_loader:
        substrate_data, enzyme_data, true_product_data = substrate_data.to(device), enzyme_data.to(device), true_product_data.to(device)
        
        predicted_product = model(substrate_data, enzyme_data)
        true_product_representation = model.encode_molecule(true_product_data, enzyme_data)
        loss = criterion(predicted_product, true_product_representation)
        
        total_loss += loss.item()

print(f"Test Loss: {total_loss/len(test_loader)}")

Test Loss: 2.5817973505802655e-08
