In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.datasets.qm9 as qm9
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn
from torch_geometric.data import Data
from torch_scatter import scatter
import torch_geometric
import tqdm
import numpy as np
import wandb
import random
import pandas as pd

import rdkit
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

from matplotlib import pyplot as plt
import networkx as nx

manualSeed = 43
random.seed(manualSeed)
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)
print("Random Seed: ", manualSeed)

Random Seed:  43


In [2]:
df = pd.read_csv('../../datasets/ChEMBL/small_molecules.csv', delimiter=';')

  exec(code_obj, self.user_global_ns, self.user_ns)


In [3]:
df['Smiles']

0          CCC(=O)O[C@H]1[C@H](C)O[C@@H](O[C@@H]2[C@@H](C...
1                  COc1ccc(Cl)cc1-c1cc(Nc2ccc(Cl)cc2)nc(N)n1
2          COc1ccc(/C=C/C2=NN(CCCN)C(=O)N(Cc3ccc4ccccc4c3...
3                    Cc1ccc(CCN2CC(C(=O)NCc3cccnc3)CC2=O)cc1
4                                               CCCCC(=O)O.N
                                 ...                        
1920134    COc1ccnc(C[S+]([O-])c2nc3cc(OC(F)F)ccc3[n-]2)c...
1920135      CN1c2c(oc(=O)n(-c3ccccn3)c2=O)-c2ccccc2S1(=O)=O
1920136                                                  NaN
1920137                                                  NaN
1920138    COc1cc([C@@H]2c3cc4c(cc3[C@H](O)[C@H]3COC(=O)[...
Name: Smiles, Length: 1920139, dtype: object

In [4]:
df.dropna(subset = ["Smiles"], inplace=True)

In [5]:
df['Smiles']

0          CCC(=O)O[C@H]1[C@H](C)O[C@@H](O[C@@H]2[C@@H](C...
1                  COc1ccc(Cl)cc1-c1cc(Nc2ccc(Cl)cc2)nc(N)n1
2          COc1ccc(/C=C/C2=NN(CCCN)C(=O)N(Cc3ccc4ccccc4c3...
3                    Cc1ccc(CCN2CC(C(=O)NCc3cccnc3)CC2=O)cc1
4                                               CCCCC(=O)O.N
                                 ...                        
1920132    Oc1ccc2c(c1)[C@@]13CCCC[C@@]1(O)[C@@H](C2)N(CC...
1920133                       Oc1ccc(/C=C/c2cc(O)cc(O)c2)cc1
1920134    COc1ccnc(C[S+]([O-])c2nc3cc(OC(F)F)ccc3[n-]2)c...
1920135      CN1c2c(oc(=O)n(-c3ccccn3)c2=O)-c2ccccc2S1(=O)=O
1920138    COc1cc([C@@H]2c3cc4c(cc3[C@H](O)[C@H]3COC(=O)[...
Name: Smiles, Length: 1914648, dtype: object

In [6]:
df.keys()

Index(['ChEMBL ID', 'Name', 'Synonyms', 'Type', 'Max Phase',
       'Molecular Weight', 'Targets', 'Bioactivities', 'AlogP', 'PSA', 'HBA',
       'HBD', '#RO5 Violations', '#Rotatable Bonds', 'Passes Ro3',
       'QED Weighted', 'CX ApKa', 'CX BpKa', 'CX LogP', 'CX LogD',
       'Aromatic Rings', 'Structure Type', 'Inorganic Flag', 'Heavy Atoms',
       'HBA Lipinski', 'HBD Lipinski', '#RO5 Violations (Lipinski)',
       'Molecular Weight (Monoisotopic)', 'Molecular Species',
       'Molecular Formula', 'Smiles'],
      dtype='object')

In [7]:
smiles = df['Smiles'][111]
mol = Chem.MolFromSmiles(smiles)

In [8]:
smiles = df['Smiles'][df['ChEMBL ID'] == 'CHEMBL4099974']

In [9]:
smiles

2    COc1ccc(/C=C/C2=NN(CCCN)C(=O)N(Cc3ccc4ccccc4c3...
Name: Smiles, dtype: object

In [10]:
mol = Chem.MolFromSmiles(smiles[2]) 

In [11]:
df['Smiles'].to_csv('../../datasets/ChEMBL/smiles.csv', index=False)

In [12]:
supplier = Chem.rdmolfiles.SmilesMolSupplier('../../datasets/ChEMBL/smiles.csv', delimiter='', sanitize=False)

# Determine all atom and bond types that occur

In [13]:
atom_types = {}
bond_types = {}
n_atoms_list = []
n_bonds_list = []

In [None]:
curr_atom_type = 0
curr_bond_type = 0
max_atoms = 0
max_bonds = 0
for mol in tqdm.tqdm(supplier):
    mol = Chem.rdmolops.AddHs(mol)
    if mol is None:
        print(smiles, 'yields None')
        continue
    for atom in mol.GetAtoms():
        symbol = atom.GetSymbol()
        if symbol not in atom_types:
            atom_types[symbol] = curr_atom_type
            curr_atom_type += 1
    for bond in mol.GetBonds():
        btype = bond.GetBondType()
        if btype not in bond_types:
            bond_types[btype] = curr_bond_type 
            curr_bond_type += 1
    n_atoms = len(mol.GetAtoms())
    n_bonds = len(mol.GetBonds())
    if max_atoms < n_atoms:
        max_atoms = n_atoms
    if max_bonds < n_bonds:
        max_bonds = n_bonds
    n_atoms_list += [n_atoms]
    n_bonds_list += [n_bonds]

 40%|████████████▉                   | 773674/1914648 [03:54<05:42, 3327.25it/s]

In [None]:
plt.hist(n_atoms_list, bins='rice')

In [None]:
plt.hist(n_bonds_list, bins='rice')

In [None]:
atom_types

In [None]:
bond_types

In [None]:
supplier = Chem.rdmolfiles.SmilesMolSupplier('../../datasets/ChEMBL/smiles.csv', delimiter='', sanitize=False)

In [None]:
types = atom_types
bonds = bond_types
data_list = []

for i,(mol_id, mol) in tqdm.tqdm(enumerate(zip(df['ChEMBL ID'], supplier))):
    if mol is None:
        print(smiles, 'yields None')
        continue
        
    mol = Chem.rdmolops.AddHs(mol)        
    N = mol.GetNumAtoms()
    
    type_idx = []
    atomic_number = []
    aromatic = []
    sp = []
    sp2 = []
    sp3 = []
    num_hs = []
    for atom in mol.GetAtoms():
        type_idx.append(types[atom.GetSymbol()])
        atomic_number.append(atom.GetAtomicNum())
        aromatic.append(1 if atom.GetIsAromatic() else 0)
        hybridization = atom.GetHybridization()
        sp.append(1 if hybridization == HybridizationType.SP else 0)
        sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
        sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
    
    z = torch.tensor(atomic_number, dtype=torch.long)

    type_idx = np.asarray(type_idx)
    atomic_number = np.asarray(atomic_number)
    aromatic = np.asarray(atomic_number)
    sp = np.asarray(sp)
    sp2 = np.asarray(sp2)
    sp3 = np.asarray(sp3)
    num_hs = np.asarray(num_hs)
    
    row, col, edge_type = [], [], []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        row += [start, end]
        col += [end, start]
        edge_type += 2 * [bonds[bond.GetBondType()]]

    edge_index = torch.tensor([row, col], dtype=torch.long)
    edge_type = torch.tensor(edge_type, dtype=torch.long)
    edge_attr = F.one_hot(edge_type,
                          num_classes=len(bonds)).to(torch.float)

    perm = (edge_index[0] * N + edge_index[1]).argsort()
    edge_index = edge_index[:, perm]
    edge_type = edge_type[perm]
    edge_attr = edge_attr[perm]

    row, col = edge_index
    hs = (z == 1).to(torch.float)
    num_hs = scatter(hs[row], col, dim_size=N).tolist()

    x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types))
    x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs],
                      dtype=torch.float).t().contiguous()
    x = torch.cat([x1.to(torch.float), x2], dim=-1)

    # only keep largest connected component
    edges_coo = edge_index.detach().cpu().numpy().T
    g = nx.Graph()
    g.add_nodes_from(np.arange(len(z)))
    g.add_edges_from(edges_coo.tolist())

    ccs = list(nx.connected_components(g))
    largest_cc = ccs[np.argmax([len(cc) for cc in ccs])]
    node_ids = np.asarray(list(largest_cc))

    x = x[node_ids]
    z = z[node_ids]
    edges_cc = []
    edge_feats = []
    old2new = {old:new for new, old in enumerate(node_ids)}
    for idx, (u, v) in enumerate(edges_coo):
        if u in node_ids and v in node_ids:
            edges_cc += [[old2new[u], old2new[v]]]
            edge_feats += [edge_attr[idx].numpy().tolist()]
    edge_index = torch.tensor(edges_cc, dtype=torch.long)
    edge_attr = torch.tensor(edge_feats, dtype=torch.float)

    name = mol_id

    data = Data(x=x, z=z, pos=None, edge_index=edge_index.T,
                edge_attr=edge_attr, y=None, name=name, idx=i)



    data_list.append(data)



In [None]:
n_splits = 64

In [None]:
splits = np.array_split(np.arange(len(data_list)), n_splits)

In [None]:
for idx, split in tqdm.tqdm(enumerate(splits)):
    dlist = [data_list[idx] for idx in split]
    torch.save(torch_geometric.data.InMemoryDataset.collate(dlist), '../../datasets/ChEMBL/preprocessedPlusHs_split%d.pt'%idx)