In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob

In [3]:
import dfs_code
from torch_geometric.data import InMemoryDataset, Data
import pickle
import torch
import tqdm
torch.multiprocessing.set_sharing_strategy('file_system') # this is important
# ulimit -n 500000
def set_worker_sharing_strategy(worker_id: int) -> None:
    torch.multiprocessing.set_sharing_strategy('file_system')

In [4]:
torch.multiprocessing.get_all_sharing_strategies()

{'file_descriptor', 'file_system'}

In [5]:
path = "/mnt/ssd/datasets/ChEMBL/ChEMBL100_noH/"

In [6]:
class ChEMBL100NoH(Dataset):
    """ChEMBL dataset of molecules and minimal DFS codes."""
    # create data structure that says which id is in which file...
    def __init__(self, path, transform=None):
        self.path = path
        self.fnames = glob.glob(path+"CHEMBL*.pkl")
        
    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, idx):
        with open(self.fnames[idx], 'rb') as f:
            d = pickle.load(f)
        data = Data(x=torch.tensor(d['x']),
                    z=torch.tensor(d['z']),
                    edge_attr=torch.tensor(d['edge_attr']),
                    edge_index=torch.tensor(d['edge_index'], dtype=torch.long),
                    name=d['name'],
                    min_dfs_code=torch.tensor(d['min_dfs_code']),
                    min_dfs_index=torch.tensor(d['min_dfs_index'], dtype=torch.long),
                    smiles=d['smiles'])
        return data

In [7]:
def collate_fn(dlist):
    x_batch = [] 
    z_batch = []
    edge_attr_batch = []
    rnd_code_batch = []
    min_code_batch = []
    for d in dlist:
        rnd_code, rnd_index = dfs_code.rnd_dfs_code_from_torch_geometric(d, 
                                                                         d.z.numpy().tolist(), 
                                                                         np.argmax(d.edge_attr.numpy(), axis=1))
        x_batch += [d.x]
        z_batch += [d.z]
        edge_attr_batch += [d.edge_attr]
        rnd_code_batch += [torch.tensor(rnd_code)]
        min_code_batch += [d.min_dfs_code]
    return rnd_code_batch, x_batch, z_batch, edge_attr_batch, min_code_batch

In [8]:
def collate_fn_(dlist):
    data_list = []
    for d in dlist:
        rnd_code, rnd_index = dfs_code.rnd_dfs_code_from_torch_geometric(d, 
                                                                         d.z.numpy().tolist(), 
                                                                         np.argmax(d.edge_attr.numpy(), axis=1))
        data = Data(x=d['x'],
                 z=d['z'],
                 edge_attr=d['edge_attr'],
                 edge_index=d['edge_index'],
                 name=d['name'],
                 min_dfs_code=d['min_dfs_code'],
                 min_dfs_index=d['min_dfs_index'],
                 smiles=d['smiles'],
                 rnd_dfs_code=torch.tensor(rnd_code),
                 rnd_dfs_index=torch.tensor(rnd_index, dtype=torch.long))
        data_list += [data]
    return data_list

In [9]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

In [10]:
dataset = ChEMBL100NoH(path)

In [11]:
loader = DataLoader(dataset, batch_size=128, shuffle=False, pin_memory=False, num_workers=4, collate_fn=collate_fn,
                   worker_init_fn=set_worker_sharing_strategy)
#shuffle False -> huge speedup

In [12]:
to_cuda = lambda T: map(lambda t: t.cuda(), T)

In [13]:
%prun next(iter(loader))

 

In [14]:
%prun next(iter(loader))

 

In [15]:
import copy

In [16]:
def f():
    try:
        atomic_number_max = 0
        for idx, data in tqdm.tqdm(enumerate(loader)):
            z_ = data[2]
            z = copy.deepcopy(z_)
            del z_
            del data
            z_tensor = torch.cat(z)
            atomic_number_max = max(z_tensor.max().item(), atomic_number_max)
    except KeyboardInterrupt:
        print('interrupt')
    print(atomic_number_max)
    

In [17]:
%prun f()

14908it [13:58, 17.78it/s]

53
 


