In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob

In [3]:
import dfs_code
from torch_geometric.data import InMemoryDataset, Data
import pickle
import torch

In [4]:
path = "/home/chrisw/Documents/projects/2021/graph-transformer/datasets/ChEMBL100_noH/"

In [5]:
fnames = glob.glob(path+"CHEMBL*.pkl")

In [6]:
class ChEMBL100NoH(Dataset):
    """ChEMBL dataset of molecules and minimal DFS codes."""
    # create data structure that says which id is in which file...
    def __init__(self, path, transform=None):
        self.path = path
        self.fnames = glob.glob(path+"CHEMBL*.pkl")[:10]
        
    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, idx):
        with open(self.fnames[idx], 'rb') as f:
            d = pickle.load(f)
        data = Data(x=torch.tensor(d['x']),
                    z=torch.tensor(d['z']),
                    edge_attr=torch.tensor(d['edge_attr']),
                    edge_index=torch.tensor(d['edge_index'], dtype=torch.long),
                    name=d['name'],
                    min_dfs_code=torch.tensor(d['min_dfs_code']),
                    min_dfs_index=torch.tensor(d['min_dfs_index'], dtype=torch.long),
                    smiles=d['smiles'])
        return data

In [13]:
def collate_fn(dlist):
    x_batch = [] 
    z_batch = []
    edge_attr_batch = []
    rnd_code_batch = []
    min_code_batch = []
    for d in dlist:
        rnd_code, rnd_index = dfs_code.rnd_dfs_code_from_torch_geometric(d, 
                                                                         d.z.numpy().tolist(), 
                                                                         np.argmax(d.edge_attr.numpy(), axis=1))
        x_batch += [d.x]
        z_batch += [d.z]
        edge_attr_batch += [d.edge_attr]
        rnd_code_batch += [torch.tensor(rnd_code)]
        min_code_batch += [d.min_dfs_code]
    return rnd_code_batch, x_batch, z_batch, edge_attr_batch, min_code_batch

In [14]:
def collate_fn_(dlist):
    data_list = []
    for d in dlist:
        rnd_code, rnd_index = dfs_code.rnd_dfs_code_from_torch_geometric(d, 
                                                                         d.z.numpy().tolist(), 
                                                                         np.argmax(d.edge_attr.numpy(), axis=1))
        data = Data(x=d['x'],
                 z=d['z'],
                 edge_attr=d['edge_attr'],
                 edge_index=d['edge_index'],
                 name=d['name'],
                 min_dfs_code=d['min_dfs_code'],
                 min_dfs_index=d['min_dfs_index'],
                 smiles=d['smiles'],
                 rnd_dfs_code=torch.tensor(rnd_code),
                 rnd_dfs_index=torch.tensor(rnd_index, dtype=torch.long))
        data_list += [data]
    return data_list

In [8]:
dataset = ChEMBL100NoH(path)

In [9]:
loader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True, num_workers=4, collate_fn=collate_fn)

In [11]:
for epoch in range(2):
    for idx, data in enumerate(loader):
        print(data)
        if idx == 0:
            print(data[0].name)
            print(data[0].rnd_dfs_code)



[Data(edge_attr=[58, 4], edge_index=[2, 58], min_dfs_code=[29, 8], min_dfs_index=[25], name="CHEMBL3310257", rnd_dfs_code=[29, 8], rnd_dfs_index=[25], smiles="O=C(CN1CCC[C@@H]2Cc3cc4c(cc3[C@H]21)OCO4)N1CCCCC1", x=[25, 39], z=[25])]
CHEMBL3310257
tensor([[ 0,  1,  8,  1,  6,  0,  0,  1],
        [ 1,  2,  6,  0,  6,  1,  2,  2],
        [ 2,  3,  6,  0,  7,  2,  5,  3],
        [ 3,  4,  7,  0,  6,  3,  8, 15],
        [ 4,  5,  6,  0,  6, 15, 38, 14],
        [ 5,  6,  6,  2,  6, 14, 34, 13],
        [ 6,  7,  6,  2,  6, 13, 31, 12],
        [ 7,  8,  6,  2,  6, 12, 28, 11],
        [ 8,  9,  6,  2,  6, 11, 25, 10],
        [ 9, 10,  6,  2,  6, 10, 23,  9],
        [10,  5,  6,  2,  6,  9, 22, 14],
        [10, 11,  6,  0,  6,  9, 20,  8],
        [11, 12,  6,  0,  6,  8, 18,  7],
        [12,  4,  6,  0,  6,  7, 17, 15],
        [12, 13,  6,  0,  6,  7, 15,  6],
        [13, 14,  6,  0,  6,  6, 13,  5],
        [ 3, 15,  7,  0,  6,  3,  7,  4],
        [15, 14,  6,  0,  6,  4, 10,  5]



[Data(edge_attr=[38, 4], edge_index=[2, 38], min_dfs_code=[19, 8], min_dfs_index=[19], name="CHEMBL3683979", rnd_dfs_code=[19, 8], rnd_dfs_index=[19], smiles="O=C(CCCCCOc1ccccc1O)C(F)(F)F", x=[19, 39], z=[19])]
CHEMBL3683979
tensor([[ 0,  1,  6,  0,  8,  6, 13,  7],
        [ 1,  2,  8,  0,  6,  7, 15,  8],
        [ 2,  3,  6,  2,  6,  8, 17,  9],
        [ 3,  4,  6,  2,  6,  9, 20, 10],
        [ 4,  5,  6,  2,  6, 10, 22, 11],
        [ 5,  6,  6,  2,  6, 11, 24, 12],
        [ 2,  7,  6,  2,  6,  8, 18, 13],
        [ 7,  6,  6,  2,  6, 13, 28, 12],
        [ 7,  8,  6,  0,  8, 13, 29, 14],
        [ 0,  9,  6,  0,  6,  6, 12,  5],
        [ 9, 10,  6,  0,  6,  5, 10,  4],
        [10, 11,  6,  0,  6,  4,  8,  3],
        [11, 12,  6,  0,  6,  3,  6,  2],
        [12, 13,  6,  0,  6,  2,  4,  1],
        [13, 14,  6,  0,  6,  1,  3, 15],
        [14, 15,  6,  0,  9, 15, 32, 16],
        [14, 16,  6,  0,  9, 15, 34, 18],
        [14, 17,  6,  0,  9, 15, 33, 17],
        [13, 18,  6

