In [1]:
# import packages

# general tools
import numpy as np
import pandas as pd
from statistics import mean

# RDkit
from rdkit import Chem, RDLogger
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
from rdkit.Chem import PandasTools

# Pytorch and Pytorch Geometric
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
# from torch.utils.data import DataLoader
from torch_geometric.loader import DataLoader

from torch_geometric.nn import GAE, VGAE, GCNConv, global_mean_pool

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SDFFile = "../data/drugs/drugbank_5.1.10/structures.sdf"
struct_df = PandasTools.LoadSDF(SDFFile)

[15:48:37] Explicit valence for atom # 13 Cl, 5, is greater than permitted
[15:48:37] ERROR: Could not sanitize molecule ending on line 298551
[15:48:37] ERROR: Explicit valence for atom # 13 Cl, 5, is greater than permitted
[15:48:38] Explicit valence for atom # 19 O, 3, is greater than permitted
[15:48:38] ERROR: Could not sanitize molecule ending on line 412786
[15:48:38] ERROR: Explicit valence for atom # 19 O, 3, is greater than permitted
[15:48:40] Explicit valence for atom # 1 N, 4, is greater than permitted
[15:48:40] ERROR: Could not sanitize molecule ending on line 540739
[15:48:40] ERROR: Explicit valence for atom # 1 N, 4, is greater than permitted
[15:48:41] Explicit valence for atom # 1 N, 4, is greater than permitted
[15:48:41] ERROR: Could not sanitize molecule ending on line 598037
[15:48:41] ERROR: Explicit valence for atom # 1 N, 4, is greater than permitted
[15:48:42] Explicit valence for atom # 12 N, 4, is greater than permitted
[15:48:42] ERROR: Could not sanitize

In [3]:
df = struct_df[['DATABASE_ID','SMILES', 'JCHEM_POLAR_SURFACE_AREA']]
df = df.dropna()
df

Unnamed: 0,DATABASE_ID,SMILES,JCHEM_POLAR_SURFACE_AREA
0,DB00006,CC[C@H](C)[C@H](NC(=O)[C@H](CCC(O)=O)NC(=O)[C@...,901.5700000000002
1,DB00007,CCNC(=O)[C@@H]1CCCN1C(=O)[C@H](CCCNC(N)=N)NC(=...,429.04
2,DB00014,CC(C)C[C@H](NC(=O)[C@@H](COC(C)(C)C)NC(=O)[C@H...,495.8899999999997
3,DB00027,CC(C)C[C@@H](NC(=O)CNC(=O)[C@@H](NC=O)C(C)C)C(...,519.8899999999996
4,DB00035,NC(=O)CC[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=O)...,435.40999999999985
...,...,...,...
11577,DB17309,[Na+].[Na+].[H][C@@]12C[C@@H]3O[C@@]33[C@H](OC...,145.54
11578,DB17312,[212Pb++].[H][C@](C)(O)[C@]([H])(N=C(O)[C@]1([...,551.3800000000003
11579,DB17342,CCCCCCCCCCCCCC[N+](C)(C)C,0.0
11580,DB17352,CSCC[C@H](NC(=O)[C@@H](N)CC1=CC=C(O)C=C1)C(=O)...,362.09


In [4]:
x_map = {
    'atomic_num': list(range(0, 119)),
    'chirality': ['CHI_UNSPECIFIED','CHI_TETRAHEDRAL_CW','CHI_TETRAHEDRAL_CCW','CHI_OTHER','CHI_TETRAHEDRAL','CHI_ALLENE','CHI_SQUAREPLANAR','CHI_TRIGONALBIPYRAMIDAL','CHI_OCTAHEDRAL',],
    'degree':list(range(0, 11)),
    'formal_charge':list(range(-5, 7)),
    'num_hs':list(range(0, 9)),
    'num_radical_electrons':list(range(0, 5)),
    'hybridization': ['UNSPECIFIED','S','SP','SP2','SP3','SP3D','SP3D2','OTHER',],
    'is_aromatic': [False, True],
    'is_in_ring': [False, True],
}

e_map = {
    'bond_type': ['UNSPECIFIED','SINGLE','DOUBLE','TRIPLE','QUADRUPLE','QUINTUPLE','HEXTUPLE','ONEANDAHALF','TWOANDAHALF','THREEANDAHALF','FOURANDAHALF','FIVEANDAHALF','AROMATIC','IONIC','HYDROGEN','THREECENTER','DATIVEONE','DATIVE','DATIVEL','DATIVER','OTHER','ZERO',],
    'stereo': ['STEREONONE','STEREOANY','STEREOZ','STEREOE','STEREOCIS','STEREOTRANS',],
    'is_conjugated': [False, True],
}


def from_smiles(smiles: str, with_hydrogen: bool = False, kekulize: bool = False) -> 'Data':
    r"""Converts a SMILES string to a :class:`torch_geometric.data.Data`
    instance.

    Args:
        smiles (str): The SMILES string.
        with_hydrogen (bool, optional): If set to :obj:`True`, will store
            hydrogens in the molecule graph. (default: :obj:`False`)
        kekulize (bool, optional): If set to :obj:`True`, converts aromatic
            bonds to single/double bonds. (default: :obj:`False`)
    """
    
    RDLogger.DisableLog('rdApp.*')

    mol = Chem.MolFromSmiles(smiles)

    if mol is None:
        mol = Chem.MolFromSmiles('')
    if with_hydrogen:
        mol = Chem.AddHs(mol)
    if kekulize:
        Chem.Kekulize(mol)

    xs = []
    for atom in mol.GetAtoms():
        x = []
        x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
        x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
        x.append(x_map['degree'].index(atom.GetTotalDegree()))
        x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
        x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
        x.append(x_map['num_radical_electrons'].index(atom.GetNumRadicalElectrons()))
        x.append(x_map['hybridization'].index(str(atom.GetHybridization())))
        x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
        x.append(x_map['is_in_ring'].index(atom.IsInRing()))
        xs.append(x)

    x = torch.tensor(xs, dtype=torch.long).view(-1, 9)

    edge_indices, edge_attrs = [], []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()

        e = []
        e.append(e_map['bond_type'].index(str(bond.GetBondType())))
        e.append(e_map['stereo'].index(str(bond.GetStereo())))
        e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))

        edge_indices += [[i, j], [j, i]]
        edge_attrs += [e, e]

    edge_index = torch.tensor(edge_indices)
    edge_index = edge_index.t().to(torch.long).view(2, -1)
    edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)

    if edge_index.numel() > 0:  # Sort indices.
        perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
        edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)


In [8]:
def create_pytorch_geometric_graph_data_list_from_smiles(x_smiles):
    data = []
    for smile in x_smiles:
        data.append(from_smiles(smile, True))
    return data

In [33]:

data_list = create_pytorch_geometric_graph_data_list_from_smiles(df['SMILES'][:100])
dataloader = DataLoader(data_list, batch_size=1)

In [61]:
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(data_list[0].num_node_features, 64)
        self.conv2 = GCNConv(64, 256)
        # self.conv3 = GCNConv(512, 512)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        # x = x.relu()
        # x = self.conv3(x, edge_index)
        # x = global_mean_pool(x, batch)
        return x

encoder = GCN()
print(encoder)

GCN(
  (conv1): GCNConv(9, 64)
  (conv2): GCNConv(64, 256)
)


In [62]:
model = GAE(encoder)

In [63]:
# loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [64]:
g = from_smiles('CCCCCCCCCCCCCC[N+](C)(C)C')

In [65]:
for epoch in range(1, 100 + 1):
    model.train()
    optimizer.zero_grad()
    loss_list = []
    # for g in data_list:
    
    z = model.encode(g.x.to(torch.float32), g.edge_index, g.batch)
    loss = model.recon_loss(z, g.edge_index)
    loss.backward()
    optimizer.step()
    loss_list.append(float(loss))
        # break
    print(f'Epoch: {epoch:03d}, Loss: {mean(loss_list)}')
    # auc, ap = train(dataloader)
    # print(f'Epoch: {epoch:03d}, AUC: {auc:.4f}, AP: {ap:.4f}')

Epoch: 001, Loss: 34.53877639770508
Epoch: 002, Loss: 34.53877639770508
Epoch: 003, Loss: 34.53877639770508
Epoch: 004, Loss: 34.53877639770508
Epoch: 005, Loss: 34.53877639770508
Epoch: 006, Loss: 34.53877639770508
Epoch: 007, Loss: 34.53877639770508
Epoch: 008, Loss: 34.53877639770508
Epoch: 009, Loss: 34.53877639770508
Epoch: 010, Loss: 34.53877639770508
Epoch: 011, Loss: 34.53877639770508
Epoch: 012, Loss: 34.53877639770508
Epoch: 013, Loss: 34.53877639770508
Epoch: 014, Loss: 34.53877639770508
Epoch: 015, Loss: 34.53877639770508
Epoch: 016, Loss: 34.53877639770508
Epoch: 017, Loss: 34.53877639770508
Epoch: 018, Loss: 34.53877639770508
Epoch: 019, Loss: 34.53877639770508
Epoch: 020, Loss: 34.53877639770508
Epoch: 021, Loss: 34.53877639770508
Epoch: 022, Loss: 34.53877639770508
Epoch: 023, Loss: 34.53877639770508
Epoch: 024, Loss: 34.53877639770508
Epoch: 025, Loss: 34.53877639770508
Epoch: 026, Loss: 34.53877639770508
Epoch: 027, Loss: 34.53877639770508
Epoch: 028, Loss: 34.5387763

In [45]:
g.batch