In [1]:
import numpy as np
import pandas as pd
from copy import deepcopy
import networkx as nx

import torch
import dgl

from qtaim_embed.utils.grapher import get_grapher
from qtaim_embed.data.molwrapper import mol_wrappers_from_df
from qtaim_embed.utils.tests import get_data
from qtaim_embed.core.dataset import HeteroGraphNodeLabelDataset

In [2]:
file = "/home/santiagovargas/dev/qtaim_embed/data/qm8/molecules_full.pkl"
df = pd.read_pickle(file)

In [3]:
df_subset = df.iloc[:10]

In [4]:
atom_keys = [
    "extra_feat_atom_esp_total",
]
bond_keys = [
    "extra_feat_bond_esp_total",
    "extra_feat_bond_esp_nuc",
    "bond_length",
]

mol_wrappers, element_set = mol_wrappers_from_df(df, atom_keys, bond_keys)

grapher = get_grapher(
    element_set,
    atom_keys=atom_keys,
    bond_keys=bond_keys,
    global_keys=[],
    allowed_ring_size=[3, 4, 5, 6, 7],
    allowed_charges=None,
    self_loop=True,
)

graph_list = []
print("... Building graphs and featurizing")
for mol in tqdm(mol_wrappers):
    graph = grapher.build_graph(mol)
    graph, names = grapher.featurize(graph, mol, ret_feat_names=True)
    graph_list.append(graph)

100%|██████████| 21786/21786 [00:02<00:00, 10442.07it/s]


element set {'O', 'N', 'C', 'H', 'F'}
selected keys ['extra_feat_atom_esp_total']


In [8]:
# TODO: build dataloader class


target_dict = {
    "atom": ["extra_feat_atom_esp_total"],
    "bond": ["extra_feat_bond_esp_total"],
}

graph_list_temp = deepcopy(graph_list)
extra_info = {
    "allowed_ring_size": [3, 4, 5, 6, 7],
    "element_set": element_set,
}
train_dataset = HeteroGraphNodeLabelDataset(
    mol_wrappers,
    graph_list_temp,
    names,
    target_dict=target_dict,
    extra_dataset_info=extra_info,
)

included in labels
{'atom': ['extra_feat_atom_esp_total'], 'bond': ['extra_feat_bond_esp_total'], 'global': []}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_O', 'chemical_symbol_N', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_F'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length', 'extra_feat_bond_esp_nuc'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
... > loaded dataset


In [12]:
from torch.utils.data import DataLoader
import itertools


class DataLoaderMoleculeNodeTask(DataLoader):
    """ """

    def __init__(self, dataset, **kwargs):
        if "collate_fn" in kwargs:
            raise ValueError(
                "'collate_fn' provided internally by 'bondnet.data', you need not to "
                "provide one"
            )

        def collate(samples):
            graphs = samples

            count_label_atom = 0
            for i in graphs:
                count_label_atom = count_label_atom + i.ndata["labels"]["bond"].shape[0]
            batched_graphs = dgl.batch(graphs)
            batched_labels = batched_graphs.ndata["labels"]
            return batched_graphs, batched_labels

        super(DataLoaderMoleculeNodeTask, self).__init__(
            dataset, collate_fn=collate, **kwargs
        )

In [100]:
dataloader = DataLoaderMoleculeNodeTask(train_dataset, batch_size=20, shuffle=True)

In [101]:
batch_graph, batch_label = next(iter(dataloader))

In [102]:
print(graph.ndata["feat"]["atom"].shape)
print(graph.ndata["feat"]["bond"].shape)
print(graph.ndata["feat"]["global"].shape)
print(len(grapher.atom_featurizer._feature_name))
print(len(grapher.bond_featurizer._feature_name))
print(len(grapher.global_featurizer._feature_name))

torch.Size([13, 13])
torch.Size([12, 9])
torch.Size([1, 3])
14
10
3


In [94]:
import dgl.nn.pytorch as dglnn

atom_input_size = int(len(grapher.atom_featurizer._feature_name)) - 1
bond_input_size = int(len(grapher.bond_featurizer._feature_name)) - 1
global_input_size = 3


class testmodel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = dglnn.HeteroGraphConv(
            {
                "a2b": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=bond_input_size,
                ),
                "b2a": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=atom_input_size,
                ),
                "a2g": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=global_input_size,
                ),
                "g2a": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=atom_input_size,
                ),
                "b2g": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=global_input_size,
                ),
                "g2b": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=bond_input_size,
                ),
                "a2a": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=atom_input_size,
                ),
                "b2b": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=bond_input_size,
                ),
                "g2g": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=global_input_size,
                ),
            },
            aggregate="sum",
        )

        self.conv2 = dglnn.HeteroGraphConv(
            {
                "a2b": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=bond_input_size,
                ),
                "b2a": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=atom_input_size,
                ),
                "a2g": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=global_input_size,
                ),
                "g2a": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=atom_input_size,
                ),
                "b2g": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=global_input_size,
                ),
                "g2b": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=bond_input_size,
                ),
                "a2a": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=atom_input_size,
                ),
                "b2b": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=bond_input_size,
                ),
                "g2g": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=global_input_size,
                ),
            },
            aggregate="sum",
        )

        self.conv3 = dglnn.HeteroGraphConv(
            {
                "b2a": dglnn.GraphConv(in_feats=bond_input_size, out_feats=1),
                "g2a": dglnn.GraphConv(in_feats=global_input_size, out_feats=1),
                "a2a": dglnn.GraphConv(in_feats=atom_input_size, out_feats=1),
                "b2b": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=bond_input_size,
                ),
                "g2g": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=global_input_size,
                ),
                "b2g": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=global_input_size,
                ),
                "g2b": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=bond_input_size,
                ),
                "a2b": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=bond_input_size,
                ),
                "a2g": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=global_input_size,
                ),
            },
            aggregate="sum",
        )

    def forward(self, graph, inputs):
        feats = self.conv(graph, inputs)
        feats = self.conv2(graph, feats)
        feats = self.conv3(graph, feats)
        return feats


testmodel = testmodel()

In [95]:
forward_out = testmodel(graph, graph.ndata["feat"])

In [139]:
from torch.nn import functional as F
from sklearn.metrics import r2_score

# from tqdm import tqdm
import tqdm.notebook as tq

opt = torch.optim.Adam(testmodel.parameters(), lr=0.01)


for epoch in range(1000):
    with tqdm(dataloader) as tq:
        testmodel.train()
        r2_list = []
        tq.set_description(f"Epoch {epoch+1}")
        training_loss = 0
        for step, (batch_graph, batch_label) in enumerate(tq):
            # forward propagation by using all nodes and extracting the user embeddings
            batch_graph, batch_label = next(iter(dataloader))
            labels = batch_label["atom"]
            logits = testmodel(batch_graph, batch_graph.ndata["feat"])["atom"]

            # compute loss
            loss = F.mse_loss(logits, labels)
            # loss_mae = F.l1_loss(logits, labels)
            # compute r2 score
            r2 = r2_score(logits.detach().numpy(), labels.detach().numpy())
            r2_list.append(r2)
            # Compute validation accuracy.  Omitted in this example.
            # backward propagation
            opt.zero_grad()
            loss.backward()
            opt.step()
            training_loss += loss.item()
            # tq.set_postfix({"Step": step, "MSE": loss.item()})

        r2_mean = np.mean(r2_list)
        tq.set_postfix({"final_t_loss": training_loss, "R_2": r2_mean})
        print(r2_mean)

        # tq.update()
        tq.close()

Epoch 1: 100%|██████████| 1085/1085 [00:13<00:00, 81.05it/s]


-464.0406035106687


Epoch 2: 100%|██████████| 1085/1085 [00:13<00:00, 80.99it/s]


-3549.6973953827423


Epoch 3: 100%|██████████| 1085/1085 [00:13<00:00, 81.32it/s]


-1779.1195338412172


Epoch 4: 100%|██████████| 1085/1085 [00:13<00:00, 80.73it/s]


-1557.6191949635447


Epoch 5: 100%|██████████| 1085/1085 [00:13<00:00, 80.44it/s]


-744.0677690230102


Epoch 6: 100%|██████████| 1085/1085 [00:13<00:00, 82.13it/s]


-910.7186760316185


Epoch 7: 100%|██████████| 1085/1085 [00:13<00:00, 81.71it/s]


-780.6937679869957


Epoch 8: 100%|██████████| 1085/1085 [00:13<00:00, 80.58it/s]


-1862.3083102921316


Epoch 9: 100%|██████████| 1085/1085 [00:13<00:00, 81.33it/s]


-421.9139219835718


Epoch 10: 100%|██████████| 1085/1085 [00:13<00:00, 81.10it/s]


-344.26083965853707


Epoch 11: 100%|██████████| 1085/1085 [00:13<00:00, 82.18it/s]


-180.26670257604817


Epoch 12:  84%|████████▍ | 914/1085 [00:11<00:02, 81.08it/s]


KeyboardInterrupt: 

In [123]:
label_list = []
predictions_list = []

with tqdm(dataloader) as tq, torch.no_grad():
    for step, (batch_graph, batch_label) in enumerate(tq):
        batch_graph, batch_label = next(iter(dataloader))
        labels = batch_label["atom"]
        logits = testmodel(batch_graph, batch_graph.ndata["feat"])["atom"]

        label_list.append(labels.cpu().numpy())
        predictions_list.append(logits.cpu().numpy())


cat_labels = np.concatenate(label_list)
cat_preds = np.concatenate(predictions_list)

100%|██████████| 1085/1085 [00:09<00:00, 109.67it/s]


In [124]:
r2 = r2_score(cat_labels, cat_preds)
print(r2)

0.0006468550770712955
