In [51]:
%matplotlib inline

from collections import defaultdict as ddict, OrderedDict as odict
from typing import Any, Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from xgboost import XGBRegressor
from sklearn.neural_network import MLPRegressor
from rdkit.ML.Descriptors.MoleculeDescriptors import MolecularDescriptorCalculator
from rdkit.Chem import PandasTools, AllChem as Chem, Descriptors
from rdkit.Chem.Descriptors import MolWt
from sklearn.ensemble import RandomForestRegressor
import sklearn
from rdkit.Chem.rdmolops import GetFormalCharge
import torch
import torch.nn as nn
import deepchem as dc
import copy
from sklearn.model_selection import train_test_split
from hyperopt import hp
import imp

pd.set_option('display.float_format', lambda x: '%.3f' % x)  # Display floats without scientific notation

In [38]:
from modules.data import data_maker, Dataset
from modules.RNN import double_RNN
from modules.fit import Model, fit
from modules.myhyperopt import hyperopt_func
from modules.MPNN import double_MPNN
from modules.MP_utils import MolGraph, BatchMolGraph

In [29]:
import modules

In [3]:
data = pd.read_csv('data/full_pka_data.csv')
solute = data['Solute SMILES'].tolist()
solvent = data['Solvent SMILES'].tolist()
pka = data['pKa (avg)'].tolist()
data_size = len(solute)

In [4]:
indices = list(range(data_size))
CV_ids, holdout_ids, _, _ = train_test_split(indices, solvent, test_size=0.2, random_state=1, stratify=solvent)
CV_datasets = data_maker(solute, solvent, pka, CV_ids)
datasets = data_maker(solute, solvent, pka)

In [5]:
MPNN = Model(name='MPNN',
                model=double_MPNN(atom_messages=True),
                model_type='torch',
                data_type='SMILES')

In [64]:
def collate_double(batch):
    '''
    Collates double input batches for a torch loader.
        
    Parameters
    ----------
    batch: List = [(X,y)]
        List of (solute,solvent) pairs with their target value.
    
    Returns
    -------
    [sol_batch, solv_batch, targets]: List
        Type of output depends on if the original dataset contains SMILES or sentences.
        Each component is a list / torch.Tensor.
    '''
    if type(batch[0][0][0]) == MolGraph:
        sol_batch = BatchMolGraph([t[0][0] for t in batch])
        solv_batch = BatchMolGraph([t[0][1] for t in batch])
    elif type(batch[0][0][0]) == str:
        sol_batch = [t[0][0] for t in batch]
        solv_batch = [t[0][1] for t in batch]
    else:
        sol_batch = [torch.Tensor(t[0][0]) for t in batch]
        sol_batch = nn.utils.rnn.pad_sequence(sol_batch)
        solv_batch = [torch.Tensor(t[0][1]) for t in batch]
        solv_batch = nn.utils.rnn.pad_sequence(solv_batch)
    targets = torch.Tensor([t[1].item() for t in batch])
    
    return [sol_batch, solv_batch, targets]

def collate_double_s(batch):
    '''
    Collates double input batches for a torch loader.
        
    Parameters
    ----------
    batch: List = [(X,y)]
        List of (solute,solvent) pairs with their target value.
    
    Returns
    -------
    [sol_batch, solv_batch, targets]: List
        Type of output depends on if the original dataset contains SMILES or sentences.
        Each component is a list / torch.Tensor.
    '''
    sol_batch = BatchMolGraph([t[0][0] for t in batch])
    solv_batch = BatchMolGraph([t[0][1] for t in batch])
    targets = torch.Tensor([t[1].item() for t in batch])
    
    return [sol_batch, solv_batch, targets]

def double_loader(data, indices, batch_size=64):
    '''
    torch loader for double inputs.
        
    Parameters
    ----------
    indices : list, np.array
        Indices for selected samples.
    data : List = [(sol,solv),pka]
        Training data of (solute,solvent) pairs and target values.
    batch_size : int
        Size of selected batches
    
    Returns
    -------
    loader : torch.utils.data.DataLoader
        Batched dataloader for torch regressors
    '''
    dataset = Dataset(indices, data[0], data[1])
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_double)
    return loader

In [65]:
data = datasets['graphs']
print(type(data[0][0][0])==MolGraph)

True


In [67]:
def one_run():
    data = datasets['graphs']
    ids = list(range(len(data[0])))
    loader = double_loader(data, ids)
    for (sol,solv,pka) in loader:
        pass
one_run()

In [14]:
import cProfile

In [36]:
#graphs
cProfile.run('one_run()', filename=None, sort='cumtime')

         526261 function calls (525290 primitive calls) in 1.447 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    1.447    1.447 {built-in method builtins.exec}
        1    0.000    0.000    1.447    1.447 <string>:1(<module>)
        1    0.059    0.059    1.447    1.447 <ipython-input-34-a9af0a8a3f7f>:1(one_run)
  1020/51    0.012    0.000    0.702    0.014 module.py:1045(_call_impl)
       51    0.001    0.000    0.702    0.014 MPNN.py:154(forward)
       52    0.001    0.000    0.685    0.013 dataloader.py:517(__next__)
       52    0.001    0.000    0.682    0.013 dataloader.py:559(_next_data)
       51    0.000    0.000    0.679    0.013 fetch.py:42(fetch)
      102    0.082    0.001    0.678    0.007 MPNN.py:239(forward)
       51    0.005    0.000    0.675    0.013 <ipython-input-33-53a2bb97cc48>:31(collate_double)
      102    0.581    0.006    0.667    0.007 MP_utils.py:235(__init_

In [18]:
#SMILES
cProfile.run('one_run()', filename=None, sort='cumtime')

         2155539 function calls (2154568 primitive calls) in 3.834 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       51    0.001    0.000    0.003    0.000 <ipython-input-12-2e09b19f250d>:1(collate_double)
       51    0.000    0.000    0.000    0.000 <ipython-input-12-2e09b19f250d>:20(<listcomp>)
       51    0.000    0.000    0.000    0.000 <ipython-input-12-2e09b19f250d>:21(<listcomp>)
       51    0.001    0.000    0.002    0.000 <ipython-input-12-2e09b19f250d>:27(<listcomp>)
        1    0.000    0.000    0.000    0.000 <ipython-input-12-2e09b19f250d>:31(double_loader)
        1    0.060    0.060    3.834    3.834 <ipython-input-17-11f0fa9c3269>:1(one_run)
        1    0.000    0.000    3.834    3.834 <string>:1(<module>)
       51    0.001    0.000    3.760    0.074 MPNN.py:154(forward)
      102    0.082    0.001    3.736    0.037 MPNN.py:239(forward)
     6444    0.906    0.000    2.265    0.000 MP_utils.py:130