In [1]:
import pandas as pd
import dgl
import networkx as nx
import numpy as np
import torch

from qtaim_embed.core.molwrapper import MoleculeWrapper
from qtaim_embed.utils.descriptors import get_atom_feats, get_bond_features
from qtaim_embed.data.grapher import HeteroCompleteGraphFromMolWrapper
from qtaim_embed.data.featurizer import (
    BondAsNodeGraphFeaturizerGeneral,
    AtomFeaturizerGraphGeneral,
    GlobalFeaturizerGraph,
)

In [2]:
file = "/home/santiagovargas/dev/qtaim_embed/data/qm8/molecules_full.pkl"
df = pd.read_pickle(file)

In [3]:
def clean(input):
    return "".join([i for i in input if not i.isdigit()])


atom_keys = [
    "extra_feat_atom_esp_total",
]

bond_keys = [
    "extra_feat_bond_esp_total",
]

element_set = set()
mol_wrappers = []
for index, row in df.iterrows():
    charge = 0
    free_energy = 0
    bonds = row.bonds
    id_combined = str(row.ids) + "_" + row.names
    bonds = {tuple(sorted(b)): None for b in bonds}
    global_features = {}
    atom_feats = get_atom_feats(row, atom_keys)

    bond_feats = get_bond_features(
        row,
        "extra_feat_bond_indices_qtaim",
        keys=bond_keys,
    )
    mol_graph = row.molecule_graph
    pmg_mol = row.molecule
    formula = pmg_mol.composition.formula.split()
    elements = [clean(x) for x in formula]
    element_set.update(elements)

    # filter for bond_feats = -1
    if bond_feats != -1 and atom_feats != 1:
        mol_wrapper = MoleculeWrapper(
            mol_graph,
            functional_group=None,
            free_energy=None,
            id=None,
            non_metal_bonds=None,
            atom_features=atom_feats,
            bond_features=bond_feats,
            global_features=global_features,
            original_atom_ind=None,
            original_bond_mapping=None,
        )
        mol_wrappers.append(mol_wrapper)

In [4]:
atom_featurizer = AtomFeaturizerGraphGeneral(
    selected_keys=["extra_feat_atom_esp_total"],
    element_set=element_set,
)
bond_featurizer = BondAsNodeGraphFeaturizerGeneral(
    selected_keys=["extra_feat_bond_esp_total"],
    allowed_ring_size=[3, 4, 5, 6, 7],
)
global_featurizer = GlobalFeaturizerGraph()

grapher = HeteroCompleteGraphFromMolWrapper(
    atom_featurizer=atom_featurizer,
    bond_featurizer=bond_featurizer,
    global_featurizer=global_featurizer,
    self_loop=True,
)

graph_list = []
for mol in mol_wrappers:
    graph = grapher.build_graph(mol)
    graph, names = grapher.featurize(graph, mol, ret_feat_names=True)
    graph_list.append(graph)

In [5]:
# TODO: build dataloader class


class Dataset(torch.utils.data.Dataset):
    def __init__(self, molecule_wrappers, graphs, feature_names, target_dict):
        """
        Args:
            molecule_wrappers (list): list of MoleculeWrapper objects
            feature_names (list): list of feature names
            graphs (list): list of dgl graphs
            target_dict (dict): dict with node type as keys and target names as value
        """
        self.data = molecule_wrappers
        self.feature_names = feature_names
        self.graphs = graphs
        self.target_dict = target_dict

        # if self.exclude_dict != []:
        #    # get indices of features to exclude
        #    self.exclude_indices = [
        #        self.feature_names.index(name) for name in self.exclude_names
        #    ]
        # self.target_index = self.feature_names.index(self.target_name)
        # self.target = np.array([mol.global_features[self.target_index] for mol in self.data])
        # get target from feature names

        self.load()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.graphs[idx], self.labels[idx]

    def get_include_exclude_indices(self):
        target_locs = {}
        # get locations of target features
        for node_type, value_list in self.target_dict.items():
            if node_type not in target_locs:
                target_locs[node_type] = []

            for value in value_list:
                target_locs[node_type].append(names[node_type].index(value))

        # now partition features into feats in target_locs and feats not in target_locs
        include_locs = {}
        exclude_locs = {}
        include_names = {}
        exclude_names = {}

        for node_type, value_list in names.items():
            if node_type not in include_locs:
                include_locs[node_type] = []
                exclude_locs[node_type] = []
                include_names[node_type] = []
                exclude_names[node_type] = []

            for i, value in enumerate(value_list):
                if node_type in target_locs.keys():
                    if i in target_locs[node_type]:
                        include_locs[node_type].append(i)
                        include_names[node_type].append(value)
                    else:
                        exclude_locs[node_type].append(i)
                        exclude_names[node_type].append(value)
                else:
                    exclude_locs[node_type].append(i)
                    exclude_names[node_type].append(value)

        self.include_locs = include_locs
        self.exclude_locs = exclude_locs
        self.include_names = include_names
        self.exclude_names = exclude_names
        print("included in labels")
        print(self.include_locs)
        print(self.include_names)
        print("included in graph features")
        print(self.exclude_locs)
        print(self.exclude_names)

    def load(self):
        self.get_include_exclude_indices()

        label_list = []

        for graph in self.graphs:
            labels = {}
            features_new = {}
            for key, value in graph.ndata["feat"].items():
                if key in self.include_names.keys():
                    graph_features = {}

                    graph_features[key] = graph.ndata["feat"][key][
                        :, self.exclude_locs[key]
                    ]

                    features_new.update(graph_features)
                    if key == "global":
                        labels[key] = graph.ndata["feat"][key][
                            :, self.include_locs[key]
                        ]
                    else:
                        labels[key] = graph.ndata["feat"][key][
                            :, self.include_locs[key]
                        ]
                graph.ndata["feat"] = features_new
            label_list.append(labels)

        self.labels = label_list


target_dict = {
    "atom": ["extra_feat_atom_esp_total"],
    "bond": ["extra_feat_bond_esp_total"],
    "global": ["molecule weight"],
}
from copy import deepcopy

graph_list_temp = deepcopy(graph_list)
train_dataset = Dataset(mol_wrappers, graph_list_temp, names, target_dict=target_dict)

included in labels
{'atom': [3], 'bond': [7], 'global': [2]}
{'atom': ['extra_feat_atom_esp_total'], 'bond': ['extra_feat_bond_esp_total'], 'global': ['molecule weight']}
included in graph features
{'atom': [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], 'bond': [0, 1, 2, 3, 4, 5, 6], 'global': [0, 1]}
{'atom': ['total degree', 'is in ring', 'total H', 'chemical symbol', 'chemical symbol', 'chemical symbol', 'chemical symbol', 'chemical symbol', 'ring size', 'ring size', 'ring size', 'ring size', 'ring size'], 'bond': ['metal bond', 'ring inclusion', 'ring size', 'ring size', 'ring size', 'ring size', 'ring size'], 'global': ['num atoms', 'num bonds']}


In [13]:
[print(v.shape) for k, v in train_dataset.labels[2].items()]

torch.Size([18, 1])
torch.Size([18, 1])
torch.Size([1, 1])


[None, None, None]

In [14]:
train_dataset.graphs[2].ndata["feat"]["atom"].shape

torch.Size([18, 13])

In [21]:
print(type(train_dataset.__getitem__(2)[0]))
print(type(train_dataset.__getitem__(2)[1]))
print(train_dataset.__getitem__(2)[0].ndata["feat"]["atom"].shape)
print(train_dataset.__getitem__(2)[1])

<class 'dgl.heterograph.DGLGraph'>
<class 'dict'>
torch.Size([18, 13])
{'atom': tensor([[1.4713e+06],
        [3.5727e+01],
        [3.4055e+01],
        [3.5410e+01],
        [3.4057e+01],
        [6.2590e+06],
        [3.5420e+01],
        [2.0817e+06],
        [2.0817e+06],
        [3.2031e+06],
        [3.3250e+01],
        [3.5410e+01],
        [3.5727e+01],
        [3.5420e+01],
        [3.4055e+01],
        [3.3250e+01],
        [3.3248e+01],
        [3.4057e+01]]), 'bond': tensor([[1.4942],
        [0.8365],
        [0.8364],
        [0.8456],
        [1.7912],
        [2.3291],
        [0.7243],
        [0.6597],
        [0.7020],
        [0.7020],
        [0.8158],
        [0.8209],
        [0.8155],
        [0.7307],
        [0.8452],
        [0.8417],
        [0.8419],
        [0.8451]]), 'global': tensor([[114.1440]])}
