In [11]:
import numpy as np
import pandas as pd
from copy import deepcopy
import networkx as nx

import torch
import dgl
from tqdm import tqdm
from qtaim_embed.utils.grapher import get_grapher
from qtaim_embed.data.molwrapper import mol_wrappers_from_df
from qtaim_embed.utils.tests import get_data
from qtaim_embed.core.dataset import HeteroGraphNodeLabelDataset

In [12]:
train_dataset = HeteroGraphNodeLabelDataset(
    # file="/home/santiagovargas/dev/qtaim_embed/data/qm8/molecules_full.pkl",
    file="/home/santiagovargas/dev/qtaim_generator/data/xyz_qm8/molecules_qtaim.pkl",
    allowed_ring_size=[3, 4, 5, 6, 7],
    allowed_charges=None,
    self_loop=True,
    extra_keys={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": [
            "extra_feat_bond_esp_total",
            "bond_length",
        ],
        "global": [],
    },
    target_dict={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": ["extra_feat_bond_esp_total"],
    },
    extra_dataset_info={},
    debug=False,
    log_scale_targets=False,
    standard_scale_targets=True,
)

... > creating MoleculeWrapper objects


100%|██████████| 21786/21786 [00:02<00:00, 10888.08it/s]


element set {'H', 'C', 'O', 'F', 'N'}
selected atomic keys ['extra_feat_atom_esp_total']
selected bond keys ['extra_feat_bond_esp_total', 'bond_length']
selected global keys []
... > Building graphs and featurizing


100%|██████████| 21785/21785 [01:03<00:00, 344.05it/s]


included in labels
{'atom': ['extra_feat_atom_esp_total'], 'bond': ['extra_feat_bond_esp_total'], 'global': []}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_H', 'chemical_symbol_C', 'chemical_symbol_O', 'chemical_symbol_F', 'chemical_symbol_N'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
... > parsing labels and features in graphs


100%|██████████| 21785/21785 [00:05<00:00, 3676.94it/s]


... > Scaling features
Standard deviation for feature 0 is 0.0, smaller than 0.001. You may want to exclude this feature.
... > Scaling features complete
... > mean: 
 {'atom': tensor([2.0525e+00, 5.2508e-01, 2.5562e-01, 5.0526e-02, 5.3274e-02, 9.2300e-02,
        5.0783e-02, 8.7325e-03, 5.1693e-01, 3.4349e-01, 7.8644e-02, 1.4606e-03,
        5.9473e-02]), 'bond': tensor([0.0000, 0.2584, 0.0523, 0.0571, 0.0956, 0.0543, 0.0101, 1.2687]), 'global': tensor([ 16.0904,  16.5128, 108.8643])}
... > std:  
 {'atom': tensor([1.2716, 0.8735, 0.4362, 0.2190, 0.2246, 0.2894, 0.2196, 0.0930, 0.4997,
        0.4749, 0.2692, 0.0382, 0.2365]), 'bond': tensor([0.0010, 0.4377, 0.2226, 0.2320, 0.2941, 0.2267, 0.1002, 0.2146]), 'global': tensor([2.9101, 3.1555, 8.0615])}
... > Scaling targets
... > Scaling targets complete
... > mean: 
 {'atom': tensor([2260655.2500]), 'bond': tensor([0.9732])}
... > std:  
 {'atom': tensor([63750496.]), 'bond': tensor([0.4235])}
... > loaded dataset


In [13]:
# TODO: build dataloader class
from qtaim_embed.core.dataset import HeteroGraphNodeLabelDataset, Subset

test_subset = Subset(train_dataset, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
test_subset.feature_names()

{'atom': ['total_degree',
  'total_H',
  'is_in_ring',
  'ring_size_3',
  'ring_size_4',
  'ring_size_5',
  'ring_size_6',
  'ring_size_7',
  'chemical_symbol_H',
  'chemical_symbol_C',
  'chemical_symbol_O',
  'chemical_symbol_F',
  'chemical_symbol_N'],
 'bond': ['metal bond',
  'ring inclusion',
  'ring size_3',
  'ring size_4',
  'ring size_5',
  'ring size_6',
  'ring size_7',
  'bond_length'],
 'global': ['num atoms', 'num bonds', 'molecule weight']}

In [14]:
len_dict = {}
for key, value in test_subset.dataset.exclude_names.items():
    len_dict[key] = len(value)
len_dict

{'atom': 13, 'bond': 8, 'global': 3}

In [15]:
from torch.utils.data import DataLoader
import itertools


class DataLoaderMoleculeNodeTask(DataLoader):
    """ """

    def __init__(self, dataset, **kwargs):
        if "collate_fn" in kwargs:
            raise ValueError(
                "'collate_fn' provided internally by 'bondnet.data', you need not to "
                "provide one"
            )

        def collate(samples):
            graphs = samples

            # count_label_atom = 0
            # for i in graphs:
            #    count_label_atom = count_label_atom + i.ndata["labels"]["bond"].shape[0]
            batched_graphs = dgl.batch(graphs)
            batched_labels = batched_graphs.ndata["labels"]
            return batched_graphs, batched_labels

        super(DataLoaderMoleculeNodeTask, self).__init__(
            dataset, collate_fn=collate, **kwargs
        )

In [16]:
dataloader = DataLoaderMoleculeNodeTask(train_dataset, batch_size=20, shuffle=True)

In [17]:
batch_graph, batch_label = next(iter(dataloader))

In [18]:
import dgl.nn.pytorch as dglnn

len_dict = train_dataset.featuze_size()
atom_input_size = len_dict["atom"]
bond_input_size = len_dict["bond"]
global_input_size = len_dict["global"]

print(atom_input_size, bond_input_size, global_input_size)


class testmodel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        """
        self.conv = dglnn.HeteroGraphConv(
            {
                "a2b": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=bond_input_size,
                ),
                "b2a": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=atom_input_size,
                ),
                "a2g": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=global_input_size,
                ),
                "g2a": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=atom_input_size,
                ),
                "b2g": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=global_input_size,
                ),
                "g2b": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=bond_input_size,
                ),
                "a2a": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=atom_input_size,
                ),
                "b2b": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=bond_input_size,
                ),
                "g2g": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=global_input_size,
                ),
            },
            aggregate="sum",
        )
        
        self.conv2 = dglnn.HeteroGraphConv(
            {
                "a2b": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=bond_input_size,
                ),
                "b2a": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=atom_input_size,
                ),
                "a2g": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=global_input_size,
                ),
                "g2a": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=atom_input_size,
                ),
                "b2g": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=global_input_size,
                ),
                "g2b": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=bond_input_size,
                ),
                "a2a": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=atom_input_size,
                ),
                "b2b": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=bond_input_size,
                ),
                "g2g": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=global_input_size,
                ),
            },
            aggregate="sum",
        )
        """
        self.conv3 = dglnn.HeteroGraphConv(
            {
                "b2a": dglnn.GraphConv(
                    in_feats=bond_input_size, out_feats=atom_input_size
                ),
                "g2a": dglnn.GraphConv(
                    in_feats=global_input_size, out_feats=atom_input_size
                ),
                "a2a": dglnn.GraphConv(
                    in_feats=atom_input_size, out_feats=atom_input_size
                ),
                "b2b": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=1,
                ),
                "g2g": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=global_input_size,
                ),
                "b2g": dglnn.GraphConv(
                    in_feats=bond_input_size,
                    out_feats=global_input_size,
                ),
                "g2b": dglnn.GraphConv(
                    in_feats=global_input_size,
                    out_feats=1,
                ),
                "a2b": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=1,
                ),
                "a2g": dglnn.GraphConv(
                    in_feats=atom_input_size,
                    out_feats=global_input_size,
                ),
            },
            aggregate="sum",
        )

    def forward(self, graph, inputs):
        # feats = self.conv(graph, inputs)
        # feats = self.conv2(graph, feats)
        feats = self.conv3(graph, inputs)
        return feats


testmodel = testmodel()

13 8 3


In [19]:
# forward_out = testmodel(graph, graph.ndata["feat"])

In [21]:
from torch.nn import functional as F
from sklearn.metrics import r2_score

# from tqdm import tqdm
import tqdm.notebook as tq

opt = torch.optim.Adam(testmodel.parameters(), lr=0.02)


for epoch in range(1000):
    with tqdm(dataloader) as tq:
        testmodel.train()
        r2_list = []
        tq.set_description(f"Epoch {epoch+1}")
        training_loss = 0
        target_type = "bond"
        for step, (batch_graph, batch_label) in enumerate(tq):
            # forward propagation by using all nodes and extracting the user embeddings
            batch_graph, batch_label = next(iter(dataloader))
            labels = batch_label[target_type]
            logits = testmodel(batch_graph, batch_graph.ndata["feat"])[target_type]
            # print(logits.shape)
            # print(labels.shape)
            # compute loss
            loss = F.mse_loss(logits, labels)
            # loss_mae = F.l1_loss(logits, labels)
            # compute r2 score
            r2 = r2_score(logits.detach().numpy(), labels.detach().numpy())
            r2_list.append(r2)
            # Compute validation accuracy.  Omitted in this example.
            # backward propagation
            opt.zero_grad()
            loss.backward()
            opt.step()
            training_loss += loss.item()
            # tq.set_postfix({"Step": step, "MSE": loss.item()})

        r2_mean = np.mean(r2_list)
        tq.set_postfix({"final_t_loss": training_loss, "R_2": r2_mean})
        print(r2_mean)

        # tq.update()
        tq.close()

Epoch 1: 100%|██████████| 1090/1090 [00:05<00:00, 185.59it/s]


0.7427236132021471


Epoch 2: 100%|██████████| 1090/1090 [00:05<00:00, 186.64it/s]


0.853391413574738


Epoch 3: 100%|██████████| 1090/1090 [00:05<00:00, 186.33it/s]


0.8561014642033615


Epoch 4: 100%|██████████| 1090/1090 [00:05<00:00, 184.49it/s]


0.7964085067857632


Epoch 5: 100%|██████████| 1090/1090 [00:05<00:00, 187.31it/s]


0.8648924246156215


Epoch 6: 100%|██████████| 1090/1090 [00:05<00:00, 184.54it/s]


0.8560766812986726


Epoch 7: 100%|██████████| 1090/1090 [00:05<00:00, 186.25it/s]


0.7489454548787985


Epoch 8: 100%|██████████| 1090/1090 [00:05<00:00, 186.36it/s]


0.7221625750138655


Epoch 9: 100%|██████████| 1090/1090 [00:05<00:00, 185.97it/s]


0.8042726870907742


Epoch 10: 100%|██████████| 1090/1090 [00:05<00:00, 182.86it/s]


0.8696785244373213


Epoch 11: 100%|██████████| 1090/1090 [00:05<00:00, 184.55it/s]


0.8590227890760046


Epoch 12: 100%|██████████| 1090/1090 [00:05<00:00, 186.23it/s]


0.8557021661209665


Epoch 13: 100%|██████████| 1090/1090 [00:05<00:00, 184.78it/s]


0.7939983025859828


Epoch 14: 100%|██████████| 1090/1090 [00:05<00:00, 185.05it/s]


0.749473668718486


Epoch 15: 100%|██████████| 1090/1090 [00:05<00:00, 190.90it/s]


0.7289840153706296


Epoch 16: 100%|██████████| 1090/1090 [00:05<00:00, 190.22it/s]


0.8167637425104097


Epoch 17: 100%|██████████| 1090/1090 [00:05<00:00, 190.94it/s]


0.8082830181251383


Epoch 18: 100%|██████████| 1090/1090 [00:05<00:00, 191.44it/s]


0.8017221977719801


Epoch 19: 100%|██████████| 1090/1090 [00:05<00:00, 189.55it/s]


0.8601777549926136


Epoch 20: 100%|██████████| 1090/1090 [00:05<00:00, 189.59it/s]


0.8666194436998685


Epoch 21: 100%|██████████| 1090/1090 [00:05<00:00, 190.89it/s]


0.8073748379334775


Epoch 22: 100%|██████████| 1090/1090 [00:05<00:00, 181.84it/s]


0.8186835782749826


Epoch 23: 100%|██████████| 1090/1090 [00:05<00:00, 184.76it/s]


0.8467949922671845


Epoch 24: 100%|██████████| 1090/1090 [00:05<00:00, 189.66it/s]


0.8724163208836805


Epoch 25:  50%|█████     | 546/1090 [00:02<00:02, 189.65it/s]


KeyboardInterrupt: 

In [22]:
label_list = []
predictions_list = []

with tqdm(dataloader) as tq, torch.no_grad():
    for step, (batch_graph, batch_label) in enumerate(tq):
        batch_graph, batch_label = next(iter(dataloader))
        labels = batch_label[target_type]
        logits = testmodel(batch_graph, batch_graph.ndata["feat"])[target_type]
        label_list.append(labels.cpu().numpy())
        predictions_list.append(logits.cpu().numpy())


cat_labels = np.concatenate(label_list)
cat_preds = np.concatenate(predictions_list)

100%|██████████| 1090/1090 [00:04<00:00, 223.05it/s]


In [23]:
print(cat_labels.shape)
print(cat_preds.shape)

(360755, 1)
(360755, 1)


In [24]:
r2 = r2_score(cat_labels, cat_preds)
print(r2)

0.7884027468940493
