In [1]:
import numpy as np
import pandas as pd
from copy import deepcopy
import networkx as nx

import torch
import dgl

from qtaim_embed.utils.grapher import get_grapher
from qtaim_embed.data.molwrapper import mol_wrappers_from_df
from qtaim_embed.utils.tests import get_data
from qtaim_embed.core.dataset import HeteroGraphNodeLabelDataset

In [2]:
file = "/home/santiagovargas/dev/qtaim_embed/data/qm8/molecules_full.pkl"
df = pd.read_pickle(file)

In [3]:
df_subset = df.iloc[:10]

In [4]:
atom_keys = [
    "extra_feat_atom_esp_total",
]
bond_keys = [
    "extra_feat_bond_esp_total",
    "extra_feat_bond_esp_nuc",
    "bond_length",
]

mol_wrappers, element_set = mol_wrappers_from_df(df, atom_keys, bond_keys)

grapher = get_grapher(
    element_set,
    atom_keys=atom_keys,
    bond_keys=bond_keys,
    global_keys=[],
    allowed_ring_size=[3, 4, 5, 6, 7],
    allowed_charges=None,
    self_loop=True,
)

graph_list = []
print("... Building graphs and featurizing")
for mol in tqdm(mol_wrappers):
    graph = grapher.build_graph(mol)
    graph, names = grapher.featurize(graph, mol, ret_feat_names=True)
    graph_list.append(graph)

100%|██████████| 21786/21786 [00:02<00:00, 10442.07it/s]


element set {'O', 'N', 'C', 'H', 'F'}
selected keys ['extra_feat_atom_esp_total']


In [7]:
print(graph_list[2].ndata["feat"]["atom"].shape[0])
print(graph_list[2].num_nodes("atom"))

18
18


In [8]:
# TODO: build dataloader class


target_dict = {
    "atom": ["extra_feat_atom_esp_total"],
    "bond": ["extra_feat_bond_esp_total"],
}

graph_list_temp = deepcopy(graph_list)
extra_info = {
    "allowed_ring_size": [3, 4, 5, 6, 7],
    "element_set": element_set,
}
train_dataset = HeteroGraphNodeLabelDataset(
    mol_wrappers,
    graph_list_temp,
    names,
    target_dict=target_dict,
    extra_dataset_info=extra_info,
)

included in labels
{'atom': ['extra_feat_atom_esp_total'], 'bond': ['extra_feat_bond_esp_total'], 'global': []}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_O', 'chemical_symbol_N', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_F'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length', 'extra_feat_bond_esp_nuc'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
... > loaded dataset


In [9]:
graph = train_dataset.__getitem__(0)

In [10]:
graph.ndata["feat"]

{'atom': tensor([[4., 3., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [2., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [2., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [3., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]]),
 'bond': tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.5053,
          18.3283],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.000

In [11]:
graph.ndata["labels"]

{'atom': tensor([[5.5136e+06],
         [1.3306e+06],
         [1.4497e+06],
         [3.2892e+01],
         [1.4663e+06],
         [1.4663e+06],
         [1.4497e+06],
         [1.3306e+06],
         [6.1929e+06],
         [1.0323e+06],
         [5.5136e+06],
         [2.6634e+01],
         [3.2892e+01]]),
 'bond': tensor([[0.7334],
         [0.8547],
         [0.8573],
         [0.8570],
         [2.3834],
         [0.7249],
         [1.8967],
         [1.7703],
         [1.2688],
         [1.6803],
         [2.4584],
         [0.8952]]),
 'global': tensor([], size=(1, 0))}

In [12]:
from torch.utils.data import DataLoader
import itertools


class DataLoaderMoleculeNodeTask(DataLoader):
    """ """

    def __init__(self, dataset, **kwargs):
        if "collate_fn" in kwargs:
            raise ValueError(
                "'collate_fn' provided internally by 'bondnet.data', you need not to "
                "provide one"
            )

        def collate(samples):
            graphs = samples

            count_label_atom = 0
            for i in graphs:
                count_label_atom = count_label_atom + i.ndata["labels"]["bond"].shape[0]
            batched_graphs = dgl.batch(graphs)
            batched_labels = batched_graphs.ndata["labels"]
            return batched_graphs, batched_labels

        super(DataLoaderMoleculeNodeTask, self).__init__(
            dataset, collate_fn=collate, **kwargs
        )

In [13]:
dataloader = DataLoaderMoleculeNodeTask(train_dataset, batch_size=10, shuffle=True)

In [14]:
batch_graph, batch_label = next(iter(dataloader))

In [29]:
print(graph.ndata["feat"]["atom"].shape)
print(graph.ndata["feat"]["bond"].shape)
print(graph.ndata["feat"]["global"].shape)
print(len(grapher.atom_featurizer._feature_name))
print(len(grapher.bond_featurizer._feature_name))
print(len(grapher.global_featurizer._feature_name))

torch.Size([13, 13])
torch.Size([12, 9])
torch.Size([1, 3])
14
10
3


In [32]:
import dgl.nn.pytorch as dglnn

atom_input_size = int(len(grapher.atom_featurizer._feature_name)) - 1
bond_input_size = int(len(grapher.bond_featurizer._feature_name)) - 1
global_input_size = 3
conv = dglnn.HeteroGraphConv(
    {
        "a2b": dglnn.GraphConv(in_feats=atom_input_size, out_feats=bond_input_size),
        "b2a": dglnn.GraphConv(in_feats=bond_input_size, out_feats=atom_input_size),
        "a2g": dglnn.GraphConv(in_feats=atom_input_size, out_feats=global_input_size),
        "g2a": dglnn.GraphConv(in_feats=global_input_size, out_feats=atom_input_size),
        "b2g": dglnn.GraphConv(in_feats=bond_input_size, out_feats=global_input_size),
        "g2b": dglnn.GraphConv(in_feats=global_input_size, out_feats=bond_input_size),
        "a2a": dglnn.GraphConv(in_feats=atom_input_size, out_feats=atom_input_size),
        "b2b": dglnn.GraphConv(in_feats=bond_input_size, out_feats=bond_input_size),
        "g2g": dglnn.GraphConv(in_feats=global_input_size, out_feats=global_input_size),
    },
    aggregate="sum",
)