In [1]:
%load_ext autoreload
%autoreload 2


2021-08-23 14:19:57.602743: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/opt/cuda/extras/CUPTI/lib64/:/opt/intel/lib:/opt/intel/mkl/lib/intel64:/opt/intel:/opt/ibm/ILOG/CPLEX_Studio1210/cplex/bin/x86-64_linux:/opt/ibm/ILOG/CPLEX_Studio1210/cplex/python/3.7/x86-64_linux:/opt/intel/clck_latest/lib:/opt/intel/daal/lib:/opt/intel/intelpython3/lib:/opt/intel/ipp/lib:/opt/intel/itac_2019/lib:/opt/intel/itac_latest/lib:/opt/intel/mkl/lib:/opt/intel/mkl_/lib:/opt/intel/mpirt/lib:/opt/intel/tbb/lib:/opt/intel/clck/2019.0/lib:/opt/intel/compilers_and_libraries_2019/linux/lib:/opt/intel/compilers_and_libraries/linux/lib:/opt/intel/itac/2019.0.018/lib:/opt/intel/itac_2019/intel64/lib:/opt/intel/itac_latest/intel64/lib:/opt/intel/parallel_studio_xe_2019.0.045/clck_2019/lib:/opt/intel/parallel_studio_xe_2019.0.045/itac_2019/l

In [26]:
import deepchem as dc
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter
import numpy as np

import dfs_code


import rdkit
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import networkx as nx


types = {'C': 0,
 'O': 1,
 'N': 2,
 'Cl': 3,
 'S': 4,
 'F': 5,
 'P': 6,
 'Se': 7,
 'Br': 8,
 'I': 9,
 'Na': 10,
 'B': 11,
 'K': 12,
 'Li': 13,
 'H': 14,
 'Si': 15,
 'Ca': 16,
 'Rb': 17,
 'Te': 18,
 'Zn': 19,
 'Mg': 20,
 'As': 21,
 'Al': 22,
 'Ba': 23,
 'Be': 24,
 'Sr': 25,
 'Ag': 26,
 'Bi': 27,
 'Ra': 28,
 'Kr': 29,
 'Cs': 30,
 'Xe': 31,
 'He': 32}
bonds =  {rdkit.Chem.rdchem.BondType.SINGLE: 0,
 rdkit.Chem.rdchem.BondType.DOUBLE: 1,
 rdkit.Chem.rdchem.BondType.AROMATIC: 2,
 rdkit.Chem.rdchem.BondType.TRIPLE: 3}

class Deepchem2TorchGeometric(Dataset):
    def __init__(self, deepchem_smiles_dataset, useHs=False, precompute_min_dfs=True):
        self.deepchem = deepchem_smiles_dataset
        self.smiles = deepchem_smiles_dataset.X
        self.labels = deepchem_smiles_dataset.y
        self.w = deepchem_smiles_dataset.w
        self.useHs = useHs
        self.precompute_min_dfs=precompute_min_dfs
        self.data = []
        self.prepare()
  
    
    def prepare(self):
        for idx in range(len(self.smiles)):
            smiles = self.smiles[idx]
            mol = Chem.MolFromSmiles(smiles)
            if self.useHs:
                mol = Chem.rdmolops.AddHs(mol)        
            N = mol.GetNumAtoms()

            type_idx = []
            atomic_number = []
            aromatic = []
            sp = []
            sp2 = []
            sp3 = []
            num_hs = []
            for atom in mol.GetAtoms():
                type_idx.append(types[atom.GetSymbol()])
                atomic_number.append(atom.GetAtomicNum())
                aromatic.append(1 if atom.GetIsAromatic() else 0)
                hybridization = atom.GetHybridization()
                sp.append(1 if hybridization == HybridizationType.SP else 0)
                sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
                sp3.append(1 if hybridization == HybridizationType.SP3 else 0)

            z = torch.tensor(atomic_number, dtype=torch.long)

            type_idx = np.asarray(type_idx)
            atomic_number = np.asarray(atomic_number)
            aromatic = np.asarray(atomic_number)
            sp = np.asarray(sp)
            sp2 = np.asarray(sp2)
            sp3 = np.asarray(sp3)
            num_hs = np.asarray(num_hs)

            row, col, edge_type = [], [], []
            for bond in mol.GetBonds():
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                row += [start, end]
                col += [end, start]
                edge_type += 2 * [bonds[bond.GetBondType()]]

            edge_index = torch.tensor([row, col], dtype=torch.long)
            edge_type = torch.tensor(edge_type, dtype=torch.long)
            edge_attr = F.one_hot(edge_type,
                                  num_classes=len(bonds)).to(torch.float)

            perm = (edge_index[0] * N + edge_index[1]).argsort()
            edge_index = edge_index[:, perm]
            edge_type = edge_type[perm]
            edge_attr = edge_attr[perm]

            row, col = edge_index
            hs = (z == 1).to(torch.float)
            num_hs = scatter(hs[row], col, dim_size=N).tolist()

            x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types))
            x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs],
                              dtype=torch.float).t().contiguous()
            x = torch.cat([x1.to(torch.float), x2], dim=-1)

            # only keep largest connected component
            edges_coo = edge_index.detach().cpu().numpy().T
            g = nx.Graph()
            g.add_nodes_from(np.arange(len(z)))
            g.add_edges_from(edges_coo.tolist())

            ccs = list(nx.connected_components(g))
            largest_cc = ccs[np.argmax([len(cc) for cc in ccs])]
            node_ids = np.asarray(list(largest_cc))

            x = x[node_ids]
            z = z[node_ids]
            edges_cc = []
            edge_feats = []
            old2new = {old:new for new, old in enumerate(node_ids)}
            for idx, (u, v) in enumerate(edges_coo):
                if u in node_ids and v in node_ids:
                    edges_cc += [[old2new[u], old2new[v]]]
                    edge_feats += [edge_attr[idx].numpy().tolist()]
            edge_index = torch.tensor(edges_cc, dtype=torch.long)
            edge_attr = torch.tensor(edge_feats, dtype=torch.float)
            
            d = Data(x=x, z=z, pos=None, edge_index=edge_index.T,
                            edge_attr=edge_attr, y=torch.tensor(self.labels[idx]))
            min_code, min_index = dfs_code.min_dfs_code_from_torch_geometric(d, 
                                                                         d.z.numpy().tolist(), 
                                                                         np.argmax(d.edge_attr.numpy(), axis=1))
            self.data += [Data(x=x, z=z, pos=None, edge_index=edge_index.T,
                            edge_attr=edge_attr, y=torch.tensor(self.labels[idx]),
                            min_dfs_code=torch.tensor(min_code), min_dfs_index=torch.tensor(min_index))]
            
    

    def __len__(self):
        return len(self.smiles)
  
    def __getitem__(self, idx):
        return self.data[idx]


In [27]:
def collate_fn(dlist):
    x_batch = [] 
    z_batch = []
    edge_attr_batch = []
    rnd_code_batch = []
    min_code_batch = []
    for d in dlist:
        rnd_code, rnd_index = dfs_code.rnd_dfs_code_from_torch_geometric(d, 
                                                                         d.z.numpy().tolist(), 
                                                                         np.argmax(d.edge_attr.numpy(), axis=1))
        rnd_code_batch += [torch.tensor(rnd_code)]
        min_code_batch += [d.min_dfs_code]
    return rnd_code_batch, min_code_batch, d.y

In [28]:
tasks, datasets, transformers = dc.molnet.load_bbbp(reload=False, featurizer=dc.feat.RawFeaturizer(True))

Failed to featurize datapoint 59, None. Appending empty array
Exception message: Python argument types in
    rdkit.Chem.rdmolfiles.CanonicalRankAtoms(NoneType)
did not match C++ signature:
    CanonicalRankAtoms(RDKit::ROMol mol, bool breakTies=True, bool includeChirality=True, bool includeIsotopes=True)
Failed to featurize datapoint 61, None. Appending empty array
Exception message: Python argument types in
    rdkit.Chem.rdmolfiles.CanonicalRankAtoms(NoneType)
did not match C++ signature:
    CanonicalRankAtoms(RDKit::ROMol mol, bool breakTies=True, bool includeChirality=True, bool includeIsotopes=True)
Failed to featurize datapoint 391, None. Appending empty array
Exception message: Python argument types in
    rdkit.Chem.rdmolfiles.CanonicalRankAtoms(NoneType)
did not match C++ signature:
    CanonicalRankAtoms(RDKit::ROMol mol, bool breakTies=True, bool includeChirality=True, bool includeIsotopes=True)
Failed to featurize datapoint 614, None. Appending empty array
Exception messa

In [37]:
datasets

(<DiskDataset X.shape: (1631,), y.shape: (1631, 1), w.shape: (1631, 1), task_names: ['p_np']>,
 <DiskDataset X.shape: (204,), y.shape: (204, 1), w.shape: (204, 1), ids: ['[H+].C2=C1C(OC(=NC1=CC=C2Cl)NCC)(C3=CC=CC=C3)C.[Cl-]'
  'C1=CC=CC2=C1C(C3=C(N(C)C2=O)C=CC=C3)OCC'
  'C1=C(Cl)C=CC3=C1C(C2=CC=CC=C2)SC(=N3)NCC' ...
  'CC1(C)S[C@@H]2[C@H](NC(=O)C34C[C@H]5C[C@H](CC(N)(C5)C3)C4)C(=O)N2[C@H]1C(O)=O'
  'COc1cccc2C(=O)c3c(O)c4CC(O)(CC(O)c4c(O)c3C(=O)c12)C(=O)CO'
  'CC[C@@]1(O)C[C@H](OC2CC(C(OC3CC(O)C(OC4CCC(=O)C(C)O4)C(C)O3)C(C)O2)N(C)C)c5c(O)c6C(=O)c7c(O)cccc7C(=O)c6cc5[C@H]1C(=O)OC'], task_names: ['p_np']>,
 <DiskDataset X.shape: (204,), y.shape: (204, 1), w.shape: (204, 1), ids: ['CCOC(=O)[C@H](CCc1ccccc1)N[C@@H](C)C(=O)N2CCC[C@H]2C(O)=O'
  'FC(F)(F)[C@]1(OC(=O)Nc2ccc(Cl)cc12)C#CC3CC3'
  'CN1C[C@@H](C[C@H]2[C@H]1Cc3c[nH]c4cccc2c34)C(=O)N[C@]5(C)O[C@@]6(O)[C@@H]7CCCN7C(=O)[C@H](Cc8ccccc8)N6C5=O'
  ... 'Cn1c2CCC(Cn3ccnc3C)C(=O)c2c4ccccc14'
  'CN(C)[C@H]1[C@@H]2C[C@H]3C(=C(O)c4c(O)cccc4[C@@

In [30]:
trainset = Deepchem2TorchGeometric(datasets[0])

In [31]:
trainset[0]

Data(edge_attr=[46, 4], edge_index=[2, 46], min_dfs_code=[23, 8], min_dfs_index=[23], x=[23, 39], y=[1], z=[23])

In [32]:
n_nodes = []
n_edges = []
for d in trainset:
    n_nodes+=[len(d.z)]
    n_edges+=[len(d.edge_attr)]

In [33]:
max(n_edges)

136

In [34]:
loader = DataLoader(trainset, batch_size=256, shuffle=True, pin_memory=False, collate_fn=collate_fn)

In [35]:
import tqdm

In [36]:
for _ in range(1000):
    for d in tqdm.tqdm(loader):
        continue

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.64it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 12.01it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.68it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

KeyboardInterrupt: 