<a href="https://colab.research.google.com/github/rsaran-BioAI/AGILE/blob/main/SMILES_GraphRep_AGILE_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Connecting the drive with Colab

from google.colab import drive
drive.mount('/content/drive')

## This is the code from AGILE for preparing data (dataset.py)

In [None]:
import codecs
import os
import csv
import math
import time
import random
import networkx as nx
import numpy as np
from copy import deepcopy

import torch
import torch.nn.functional as F

# from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms

In [None]:
!pip install torch_scatter

In [None]:
!pip install torch_geometric

In [None]:
from torch_scatter import scatter
from torch_geometric.data import Data, Dataset, DataLoader
from torch.utils.data import Dataset

import rdkit
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import AllChem

In [None]:
ATOM_LIST = list(range(1, 119))
CHIRALITY_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER,
]
BOND_LIST = [BT.SINGLE, BT.DOUBLE, BT.TRIPLE, BT.AROMATIC]
BONDDIR_LIST = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT,
]

In [None]:
def read_smiles(data_path, smiles_col="smiles"): # defines a function named read_smiles with two parameters:
                       # data_path (a string representing the path to a data file) and smiles_col (a string representing the column name in the data file that contains SMILES strings).
                       # The default value of smiles_col is set to "smiles".
    smiles_data = [] # An empty list smiles_data is initialized. This list will be used to store the SMILES strings extracted from the file.
    with open(data_path, "r") as f: # This line opens the file located at the path data_path in read mode ("r"). The file is referred to as f within the with block.
        reader = csv.reader(f) # A CSV reader object is created using the file object f. This reader will be used to iterate over the rows in the CSV file.
        headers = next(reader) # The first line of the CSV file, typically containing the column headers, is read. This line is stored in the variable headers.
        smiles_idx = headers.index(smiles_col) # The index of the column containing SMILES strings is determined by finding the position of smiles_col in the headers list.
        for row in reader: # This loop iterates over each row in the CSV file, starting from the second line (since the first line containing headers has already been read).
            smiles = row[smiles_idx].strip() # The SMILES string is extracted from the row using the column index smiles_idx. Any leading or trailing whitespace is removed using strip().
            smiles = codecs.decode(smiles, "unicode_escape") # The extracted SMILES string is decoded.
                                                   # This is necessary because SMILES strings might contain escaped unicode characters, and this step ensures they are correctly interpreted.
            smiles_data.append(smiles) # The decoded SMILES string is appended to the smiles_data list.
    return smiles_data # After all rows in the file have been processed, the function returns the list of SMILES strings.

In [None]:
class MoleculeDataset(Dataset):
    def __init__(self, data_path):
        super(Dataset, self).__init__()
        self.smiles_data = read_smiles(data_path)
        self.smiles_data = self.validate_smiles()
        print(f"Number of valid smiles: {len(self.smiles_data)}")
        print(f"Valid example smiles: '{self.smiles_data[0]}'")

    def validate_smiles(self):
        valid_smiles = []
        for i, smiles in enumerate(self.smiles_data):
            mol = Chem.MolFromSmiles(smiles)
            if mol is not None:
                valid_smiles.append(smiles)
            else:
                print(f"Invalid smiles: at index {i} '{smiles}'")
        return valid_smiles

    def __getitem__(self, index):
        mol = Chem.MolFromSmiles(self.smiles_data[index])
        # mol = Chem.AddHs(mol)

        try:
            N = mol.GetNumAtoms()
            M = mol.GetNumBonds()
        except:
            print(f"Invalid SMILES: at index {index}")
            raise ValueError

        type_idx = []
        chirality_idx = []
        atomic_number = []
        # aromatic = []
        # sp, sp2, sp3, sp3d = [], [], [], []
        # num_hs = []
        for atom in mol.GetAtoms():
            type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
            chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
            atomic_number.append(atom.GetAtomicNum())
            # aromatic.append(1 if atom.GetIsAromatic() else 0)
            # hybridization = atom.GetHybridization()
            # sp.append(1 if hybridization == HybridizationType.SP else 0)
            # sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
            # sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
            # sp3d.append(1 if hybridization == HybridizationType.SP3D else 0)

        # z = torch.tensor(atomic_number, dtype=torch.long)
        x1 = torch.tensor(type_idx, dtype=torch.long).view(-1, 1)
        x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1, 1)
        x = torch.cat([x1, x2], dim=-1)
        # x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, sp3d, num_hs],
        #                     dtype=torch.float).t().contiguous()
        # x = torch.cat([x1.to(torch.float), x2], dim=-1)

        row, col, edge_feat = [], [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            # edge_type += 2 * [MOL_BONDS[bond.GetBondType()]]
            edge_feat.append(
                [
                    BOND_LIST.index(bond.GetBondType()),
                    BONDDIR_LIST.index(bond.GetBondDir()),
                ]
            )
            edge_feat.append(
                [
                    BOND_LIST.index(bond.GetBondType()),
                    BONDDIR_LIST.index(bond.GetBondDir()),
                ]
            )

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)

        # random mask a subgraph of the molecule
        num_mask_nodes = max([1, math.floor(0.25 * N)])
        num_mask_edges = max([0, math.floor(0.25 * M)])
        mask_nodes_i = random.sample(list(range(N)), num_mask_nodes)
        mask_nodes_j = random.sample(list(range(N)), num_mask_nodes)
        mask_edges_i_single = random.sample(list(range(M)), num_mask_edges)
        mask_edges_j_single = random.sample(list(range(M)), num_mask_edges)
        mask_edges_i = [2 * i for i in mask_edges_i_single] + [
            2 * i + 1 for i in mask_edges_i_single
        ]
        mask_edges_j = [2 * i for i in mask_edges_j_single] + [
            2 * i + 1 for i in mask_edges_j_single
        ]

        x_i = deepcopy(x)
        for atom_idx in mask_nodes_i:
            x_i[atom_idx, :] = torch.tensor([len(ATOM_LIST), 0])
        edge_index_i = torch.zeros((2, 2 * (M - num_mask_edges)), dtype=torch.long)
        edge_attr_i = torch.zeros((2 * (M - num_mask_edges), 2), dtype=torch.long)
        count = 0
        for bond_idx in range(2 * M):
            if bond_idx not in mask_edges_i:
                edge_index_i[:, count] = edge_index[:, bond_idx]
                edge_attr_i[count, :] = edge_attr[bond_idx, :]
                count += 1
        data_i = Data(x=x_i, edge_index=edge_index_i, edge_attr=edge_attr_i)

        x_j = deepcopy(x)
        for atom_idx in mask_nodes_j:
            x_j[atom_idx, :] = torch.tensor([len(ATOM_LIST), 0])
        edge_index_j = torch.zeros((2, 2 * (M - num_mask_edges)), dtype=torch.long)
        edge_attr_j = torch.zeros((2 * (M - num_mask_edges), 2), dtype=torch.long)
        count = 0
        for bond_idx in range(2 * M):
            if bond_idx not in mask_edges_j:
                edge_index_j[:, count] = edge_index[:, bond_idx]
                edge_attr_j[count, :] = edge_attr[bond_idx, :]
                count += 1
        data_j = Data(x=x_j, edge_index=edge_index_j, edge_attr=edge_attr_j)

        return data_i, data_j

    def __len__(self):
        return len(self.smiles_data)
class MoleculeDatasetWrapper(object):
    def __init__(self, batch_size, num_workers, valid_size, data_path):
        super(object, self).__init__()
        self.data_path = data_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.valid_size = valid_size

    def get_data_loaders(self):
        train_dataset = MoleculeDataset(data_path=self.data_path)
        train_loader, valid_loader = self.get_train_validation_data_loaders(
            train_dataset
        )
        return train_loader, valid_loader

    def get_train_validation_data_loaders(self, train_dataset):
        # obtain training indices that will be used for validation
        num_train = len(train_dataset)
        indices = list(range(num_train))
        np.random.shuffle(indices)

        split = int(np.floor(self.valid_size * num_train))
        train_idx, valid_idx = indices[split:], indices[:split]

        # define samplers for obtaining training and validation batches
        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)

        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            sampler=train_sampler,
            num_workers=self.num_workers,
            drop_last=True,
        )

        valid_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            sampler=valid_sampler,
            num_workers=self.num_workers,
            drop_last=True,
        )

        return train_loader, valid_loader

In [None]:
data_path = '/content/drive/MyDrive/AGILE2/SMILES_features.csv'

In [None]:
batch_size = 32  # Example size, adjust as needed
num_workers = 2  # Adjust based on your machine's capabilities
valid_size = 0.2  # 20% of the data for validation, adjust as needed

In [None]:
dataset_wrapper = MoleculeDatasetWrapper(batch_size, num_workers, valid_size, data_path)

In [None]:
train_loader, valid_loader = dataset_wrapper.get_data_loaders()

# This code is from AGILE for bilding the VAE

In [None]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool

In [None]:
num_atom_type = 119  # including the extra mask tokens
num_chirality_tag = 3

num_bond_type = 5  # including aromatic and self-loop edge
num_bond_direction = 3


class GINEConv(MessagePassing):
    def __init__(self, emb_dim):
        super(GINEConv, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, 2 * emb_dim), nn.ReLU(), nn.Linear(2 * emb_dim, emb_dim)
        )
        self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim)
        nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        nn.init.xavier_uniform_(self.edge_embedding2.weight.data)

    def forward(self, x, edge_index, edge_attr):
        # add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0]

        # add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:, 0] = 4  # bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)

        edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(
            edge_attr[:, 1]
        )

        return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return self.mlp(aggr_out)


In [None]:
class AGILE_Encoder(nn.Module):
    """
    Modified AGILE class to serve as an encoder in a VAE.

    Args:
        num_layer (int): Number of GNN layers.
        emb_dim (int): Dimensionality of embeddings.
        feat_dim (int): Dimensionality of feature vector before producing mu and logvar.
        latent_dim (int): Dimensionality of the latent space (output of the encoder).
        drop_ratio (float): Dropout rate.
    Output:
        mu (torch.Tensor): Mean of the latent space distribution.
        logvar (torch.Tensor): Log variance of the latent space distribution.
    """

    def __init__(self, num_layer=5, emb_dim=300, feat_dim=256, latent_dim=128, drop_ratio=0):
        super(AGILE_Encoder, self).__init__()
        self.num_layer = num_layer
        self.emb_dim = emb_dim
        self.feat_dim = feat_dim
        self.latent_dim = latent_dim
        self.drop_ratio = drop_ratio

        self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
        self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
        nn.init.xavier_uniform_(self.x_embedding1.weight.data)
        nn.init.xavier_uniform_(self.x_embedding2.weight.data)

        # List of MLPs
        self.gnns = nn.ModuleList()
        for layer in range(num_layer):
            self.gnns.append(GINEConv(emb_dim))

        # List of batchnorms
        self.batch_norms = nn.ModuleList()
        for layer in range(num_layer):
            self.batch_norms.append(nn.BatchNorm1d(emb_dim))

        if pool == "mean":
            self.pool = global_mean_pool
        elif pool == "max":
            self.pool = global_max_pool
        elif pool == "add":
            self.pool = global_add_pool

        # Replacing the output layer with two separate layers for mu and logvar
        self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)
        self.mu_lin = nn.Linear(self.feat_dim, self.latent_dim)
        self.logvar_lin = nn.Linear(self.feat_dim, self.latent_dim)

In [None]:
class GNNVAE(nn.Module):
    def __init__(self):
        super(GNNVAE, self).__init__()

        # Encoder: Using AGILE_Encoder
        self.encoder = AGILE_Encoder(num_layer=num_layer, emb_dim=emb_dim,
                                     feat_dim=feat_dim, latent_dim=latent_dim,
                                     drop_ratio=drop_ratio, pool=pool) # This should return two values mu and logvar

        # Decoder
        self.decoder = torch.nn.Linear(latent_dim, num_classes) # Here the data should be regenerated in same dimension as encoded

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        return torch.sigmoid(self.decoder(z))

    def forward(self, data):
        mu, logvar = self.encoder(data) # Here data should be the dataset prepared
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

class GNNVAE(torch.nn.Module):
    def __init__(self, num_features, num_classes, latent_dim):
        super(GNNVAE, self).__init__()
        # Encoder
        self.conv1 = GCNConv(num_features, 2 * latent_dim)
        self.conv2 = GCNConv(2 * latent_dim, latent_dim)
        
        # Latent space
        self.fc_mu = torch.nn.Linear(latent_dim, latent_dim)
        self.fc_logvar = torch.nn.Linear(latent_dim, latent_dim)

        # Decoder
        self.decoder = torch.nn.Linear(latent_dim, num_classes)

    def encode(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return self.fc_mu(x), self.fc_logvar(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        return torch.sigmoid(self.decoder(z))

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        mu, logvar = self.encode(x, edge_index)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
model = GNNVAE()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define a Training Loop here
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for data in train_loader:  # Assuming train_loader is your DataLoader
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data.x, mu, logvar)
        loss.backward()
        total_loss += loss.item()
        optimizer.step()

    print('Epoch: {}, Loss: {:.4f}'.format(epoch, total_loss / len(train_loader.dataset)))

NameError: ignored