In [1]:
import dgl
import torch
import pandas as pd
from dgllife.data import MoleculeCSVDataset
from dgllife.utils import SMILESToBigraph
from dgllife.utils import CanonicalAtomFeaturizer

In [2]:
def load_dataset(df, smi_title, label_title, out_graph_path):
    smiles_to_g = SMILESToBigraph(add_self_loop=True,
                                  node_featurizer=CanonicalAtomFeaturizer(),
                                  edge_featurizer=None)
    dataset = MoleculeCSVDataset(df=df,
                                 smiles_to_graph=smiles_to_g,
                                 smiles_column=smi_title,
                                 cache_file_path=out_graph_path,
                                 task_names=[label_title],
                                 load=True,
                                 n_jobs=1)
    return dataset

In [3]:
df = pd.read_csv('data/caco2ab_curated1805_trtecv.csv')
dataset = load_dataset(df, 'Flatten_SMILES', 'target_binary', 'data/caco2ab_curated1805_graphs.bin')

Loading previously saved dgl graphs...


In [4]:
dataset

<dgllife.data.csv_dataset.MoleculeCSVDataset at 0x7fa13942aca0>

In [5]:
from torch.utils.data import DataLoader

In [6]:
len(dataset)

1805

In [7]:
g = dataset[0]
g

('COCCOc1cc(NC(=O)C2(NC(=O)c3ccc4c(C5CCCC5)c(-c5ncc(Cl)cn5)n(C)c4c3)CCC2)ccc1C=CC(=O)O',
 Graph(num_nodes=48, num_edges=154,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={}),
 tensor([1.]),
 tensor([1.]))

In [8]:
def collate_molgraphs(data):
    # uppacked the dataset to list
    smiles, graphs, labels, masks = map(list, zip(*data))

    # batch graphs
    bg = dgl.batch(graphs)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)
    labels = torch.stack(labels, dim=0)

    return smiles, bg, labels, masks

In [9]:
data_loader = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=True,
                         collate_fn=collate_molgraphs, num_workers=0)

In [10]:
data_loader

<torch.utils.data.dataloader.DataLoader at 0x7fa1594ebbb0>

In [11]:
len(data_loader)

225

In [12]:
for data in data_loader:
    smiles, graphs, labels, masks = data
    
    break

In [13]:
smiles

['CN1Cc2c(Cl)cc(Cl)cc2C(c2ccc(S(=O)(=O)Nc3ccc(CP(=O)(O)O)cc3)cc2)C1',
 'C=CC(=O)NC1CCN(c2ccc(C(N)=O)c(Nc3ccc(C(=O)N4CCOCC4)cc3)n2)C1',
 'NCCCC(C#Cc1cccc(C(N)=O)c1)CC1OC(n2cnc3c(N)ncnc32)C(O)C1O',
 'Cc1ccc(NS(C)(=O)=O)c(C(=O)N2CCCCC2c2cc3nc(C4CC4)ccn3n2)c1',
 'COc1cc(C(=O)N2CCC(O)C(C)C2)cc2nc(NCc3cncc(Cl)c3)oc12',
 'CN(C)CC1CN(c2ccc(Nc3nc(-c4ccc5cn[nH]c5c4)cn4ccnc34)cc2)CCO1',
 'COc1ccc(-c2cc3c(C)nc(N)nc3n(C3CCC(OCCO)CC3)c2=O)cn1',
 'N#Cc1ccnc(N2C(=O)CCC2CC(=O)N(c2cc(F)cc(F)c2)C(C(=O)NC2CC(F)(F)C2)c2ccccc2Cl)c1']

In [14]:
graphs

Graph(num_nodes=273, num_edges=875,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={})

In [15]:
labels

tensor([[0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.]])

In [16]:
masks

[tensor([1.]),
 tensor([1.]),
 tensor([1.]),
 tensor([1.]),
 tensor([1.]),
 tensor([1.]),
 tensor([1.]),
 tensor([1.])]