In [11]:
import torch as th
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning import LightningModule

import torch_geometric as thg
import torch_geometric.nn as gnn
import torch_geometric.nn.functional as gf
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import from_scipy_sparse_matrix as keys

import os 
import numpy as np 
import pandas as pd 
import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw

import warnings 
warnings.filterwarnings("ignore")

In [12]:
PATH = r'/Users/suyashsachdeva/Desktop/GyanBhandar/toxic/'
SMILE = os.listdir(PATH)
smiles = []
for smile in SMILE:
    if smile[-6:]=="smiles":
        smiles.append(smile)

In [13]:
my_elements = {6: "C", 8: "O", 1: "H", 17: "Cl", 7: "N", 9: "F", 16: "S",
               35: "Br", 14: "Si", 11: "Na", 53: "I", 80: "Hg", 5: "B", 
               19: "K", 15: "P", 79: "Au", 24: "Cr", 50: "Sn", 20: "Ca",
               48: "Cd", 30: "Zn", 23: "V", 33: "As", 3: "Li", 29: "Cu",
               27:"Co", 47: "Ag", 34: "Se", 78: "Pt", 83: "Bi", 26: "Fe", 
               51: "Sb", 12: "Mg", 13: "Al", 81: "Tl", 56: "Ba", 22: "V", 
               40: "Zr", 38: "Sr", 28: "Ni", 49: "In", 32: "Ge"}

In [14]:
def smiles2graph(sml):
    """Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    """
    m = rdkit.Chem.MolFromSmiles(sml)
    m = rdkit.Chem.AddHs(m)
    order_string = {
        rdkit.Chem.rdchem.BondType.SINGLE: 1,
        rdkit.Chem.rdchem.BondType.DOUBLE: 2,
        rdkit.Chem.rdchem.BondType.TRIPLE: 3,
        rdkit.Chem.rdchem.BondType.AROMATIC: 4,
    }
    N = len(list(m.GetAtoms()))
    nodes = np.zeros((N, len(my_elements)))
    lookup = list(my_elements.keys())
    for i in m.GetAtoms():
        nodes[i.GetIdx(), lookup.index(i.GetAtomicNum())] = 1

    adj = np.zeros((N, N, 5))
    for j in m.GetBonds():
        u = min(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        v = max(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        order = j.GetBondType()
        if order in order_string:
            order = order_string[order]
        else:
            raise Warning("Ignoring bond order" + order)
        adj[u, v, order] = 1
        adj[v, u, order] = 1
    return nodes, adj

In [15]:
X = []
Y = []
for smile in smiles[:1]:   
    file = open(PATH+smile, 'r')
    smile = file.read().split('\n')
    xtrain = []
    ytrain = []
    c=0
    for x in smile:
        if x!='':
            x = x.split("\t")
            if x[-1]=='1':
                c = c+1
    factor = int(len(smile)/float(c))-1
    val1 = c
    for c, x in enumerate(smile):
        try:
            x = x.split('\t')
            if x[-1]=='1' or c%factor==0:
                nodes, adj= smiles2graph(x[0])
                adj_mat = np.sum(adj, axis=-1) + np.eye(adj.shape[0])
                degree = np.sum(adj_mat, axis=-1)
                new_nodes = np.einsum("i,ij,jk->ik", 1 / degree, adj_mat, nodes)
                # print(new_nodes.shape, adj_mat.shape)
                xtrain.append([Data(x=th.tensor(th.from_numpy(new_nodes), dtype=th.float32), edge_index=th.tensor(adj_mat, dtype=th.long).nonzero().t().contiguous()), th.from_numpy(np.array(x[2], dtype="float32"))])
        except:
            pass
    print(len(xtrain), val1)
    dataset = DataLoader(xtrain, batch_size=1)
    X.append(dataset)
    Y.append(ytrain)

[16:51:50] Explicit valence for atom # 2 Cl, 2, is greater than permitted
[16:51:52] Explicit valence for atom # 3 Si, 8, is greater than permitted


2333 1098


In [16]:
class GraphAtt(pl.LightningModule):
    def __init__(self, traindata, trainbatch=1, valbatch=1, lr=1e-4, num=5, indim=42, hidden_dim=64, dense = 256, factor=2):
        super(GraphAtt, self).__init__()
        self.gatconv = nn.ModuleList()
        self.norm = nn.ModuleList()
        for _ in range(num):
            self.gatconv.append(gnn.GATConv(indim, hidden_dim, heads=1, dropout=0.2))
            self.norm.append(gnn.BatchNorm(hidden_dim, momentum=0.99))
            indim = hidden_dim
            hidden_dim = hidden_dim*factor

        self.act = nn.LeakyReLU(0.1)
        self.pool = gnn.MeanAggregation()
        self.dense = nn.Linear(hidden_dim//2, dense)
        self.classify = nn.Linear(dense, 1)
        self.drop = nn.Dropout(0.2)

        self.traindata = traindata
        self.trainbatch = trainbatch
        self.valbatch = valbatch
        self.lr = lr
        self.loss = nn.BCELoss()

    def forward(self, g, h):
        for gat, norm in zip(self.gatconv, self.norm):
            h = self.act(norm(gat(h, g)))
        h = self.drop(self.pool(h))
        out = F.sigmoid(self.classify(self.act(self.drop(self.dense(h)))))
        return out

    def train_dataloader(self):
        return self.traindata

    # def val_dataloader(self):
    #     return DataLoader(self.valdata, batch_size=self.valbatch, shuffle=True)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr = 1e-4)
        sch = optim.lr_scheduler.StepLR(
        optimizer, step_size  = 10 , gamma = 0.1)
        return {
            "optimizer":optimizer,
            "lr_scheduler" : {
                "scheduler" : sch,
                "monitor" : "train_loss",
            }
        }

    def training_step(self, batch, batch_idx):
        x = batch[0]
        y = batch[1]
        h = x.x
        g = th.tensor(x.edge_index, dtype=th.long)
        y_pred = self.forward(g, h)
        loss = self.loss(y_pred, y.reshape(-1, 1))
        accuracy = self.get_accuracy(y.reshape(-1,1), y_pred)
        self.log_dict({"traning_loss": loss, "Accuracy": accuracy}, on_step=True, on_epoch=True, prog_bar=True, logger=False)
        return loss

    # def validation_step(self, batch, batch_idx):
    #     xtr, ytr = batch
    #     h = xtr.x
    #     g = th.tensor(xtr.edge_index, dtype=th.long)
    #     y_pred = self.forward(g, h)
    #     loss = self.loss(y_pred, ytr.reshape(-1,1))
    #     accuracy = self.get_accuracy(ytr, y_pred)
    #     self.log_dict({"traning_loss": loss, "Accuracy": accuracy}, on_step=True, on_epoch=True, prog_bar=True)
    #     return loss

    def get_accuracy(self, y_true, y_prob):
        assert y_true.size() == y_prob.size()
        y_prob = y_prob > 0.5
        return (y_true == y_prob).sum().item() / y_true.size(0)

In [17]:
trainer = pl.Trainer(min_epochs=10, max_epochs=30, enable_progress_bar=True)
model = GraphAtt(traindata=dataset)
trainer.fit(model)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name     | Type            | Params
---------------------------------------------
0 | gatconv  | ModuleList      | 704 K 
1 | norm     | ModuleList      | 4.0 K 
2 | act      | LeakyReLU       | 0     
3 | pool     | MeanAggregation | 0     
4 | dense    | Linear          | 262 K 
5 | classify | Linear          | 257   
6 | drop     | Dropout         | 0     
7 | loss     | BCELoss         | 0     
---------------------------------------------
971 K     Trainable params
0         Non-trainable params
971 K     Total params
3.886     Total estimated model params size (MB)


Epoch 0:  10%|▉         | 224/2333 [00:06<01:00, 34.91it/s, loss=0.718, v_num=5, traning_loss_step=0.799, Accuracy_step=0.000]

In [None]:
from tqdm.auto import trange
for _ in trange(1000):
    pass 

100%|██████████| 1000/1000 [00:00<00:00, 3666349.65it/s]
