In [1]:
import numpy as np
import pandas as pd
from copy import deepcopy
import networkx as nx
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch import nn

import dgl
import dgl.nn.pytorch as dglnn
from dgl import apply_each


from qtaim_embed.utils.grapher import get_grapher
from qtaim_embed.data.molwrapper import mol_wrappers_from_df
from qtaim_embed.utils.tests import get_data
from qtaim_embed.core.dataset import HeteroGraphNodeLabelDataset
from qtaim_embed.models.layers import GraphConvDropoutBatch
from qtaim_embed.core.dataset import HeteroGraphNodeLabelDataset, Subset

In [2]:
# df = pd.read_pickle(
#    "/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim.pkl"
# )
# df.keys()

In [3]:
from qtaim_embed.utils.data import get_default_node_level_config

config = get_default_node_level_config()
config[
    "train_dataset_loc"
] = "/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim.pkl"
config["log_scale_features"] = True
config["log_scale_targets"] = True
config["debug"] = True

In [4]:
train_dataset = HeteroGraphNodeLabelDataset(
    # file="/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_full.pkl",
    file="/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim.pkl",
    allowed_ring_size=[3, 4, 5, 6, 7],
    allowed_charges=None,
    self_loop=True,
    extra_keys={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": [
            "extra_feat_bond_esp_total",
            "extra_feat_bond_ellip_e_dens",
            "extra_feat_bond_eta",
            "bond_length",
        ],
        "global": [],
    },
    target_dict={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": [
            "extra_feat_bond_esp_total",
            "extra_feat_bond_ellip_e_dens",
            "extra_feat_bond_eta",
        ],
    },
    extra_dataset_info={},
    debug=True,
    log_scale_features=True,
    log_scale_targets=True,
    standard_scale_features=True,
    standard_scale_targets=True,
)

... > running in debug mode
... > creating MoleculeWrapper objects


100%|██████████| 100/100 [00:00<00:00, 8435.00it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'N', 'H', 'C', 'O'}
selected atomic keys ['extra_feat_atom_esp_total']
selected bond keys ['extra_feat_bond_esp_total', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta', 'bond_length']
selected global keys []
... > Building graphs and featurizing


100%|██████████| 100/100 [00:00<00:00, 316.92it/s]


included in labels
{'atom': ['extra_feat_atom_esp_total'], 'bond': ['extra_feat_bond_esp_total', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_C', 'chemical_symbol_O'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['atom', 'bond'])
... > parsing labels and features in graphs


100%|██████████| 100/100 [00:00<00:00, 13402.90it/s]


original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['atom', 'bond'])
... > Log scaling features
... > Log scaling features complete
... > Scaling features
Standard deviation for feature 0 is 0.0, smaller than 0.001. You may want to exclude this feature.
... > Scaling features complete
... > feature mean(s): 
 {'atom': tensor([1.0335, 0.2958, 0.1931, 0.0230, 0.0399, 0.0681, 0.0430, 0.0191, 0.0456,
        0.3555, 0.2365, 0.0556]), 'bond': tensor([0.0000, 0.1958, 0.0242, 0.0442, 0.0701, 0.0459, 0.0208, 0.8152]), 'global': tensor([2.8185, 2.8375, 4.6976])}
... > feature std(s):  
 {'atom': tensor([0.3915, 0.4605, 0.3108, 0.1242, 0.1615, 0.2064, 0.1671, 0.1135, 0.1718,
        0.3465, 0.3286, 0.1882]), 'bond': tensor([0.0010, 0.3121, 0.1273, 0.1693, 0.2090, 0.1723, 0.1183, 0.0937]), 'global': tensor([0.1632, 0.1720, 0.0635])}
... > Log scaling targets
... > Log scaling targets complete
... > Scaling targets
... > Scaling targets complet

In [5]:
# TODO: build dataloader class

test_subset = Subset(train_dataset, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
test_subset.feature_size()

{'atom': 12, 'bond': 8, 'global': 3}

In [10]:
len_dict = {}
for key, value in test_subset.dataset.exclude_names.items():
    len_dict[key] = len(value)
len_dict

{'atom': 12, 'bond': 8, 'global': 3}

In [6]:
from qtaim_embed.data.dataloader import DataLoaderMoleculeNodeTask

In [7]:
dataloader = DataLoaderMoleculeNodeTask(train_dataset, batch_size=100, shuffle=True)

In [8]:
batch_graph, batch_label = next(iter(dataloader))

In [9]:
len_dict = train_dataset.feature_size()
atom_input_size = len_dict["atom"]
bond_input_size = len_dict["bond"]
global_input_size = len_dict["global"]
bond_output_size = 3
print(atom_input_size, bond_input_size, global_input_size)


n_heads = 1
hidden_feats = 64


class testmodel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = dglnn.HeteroGraphConv(
            {
                "a2b": GraphConvDropoutBatch(
                    in_feats=atom_input_size, out_feats=bond_input_size, dropout=0.2
                ),
                "b2a": GraphConvDropoutBatch(
                    in_feats=bond_input_size, out_feats=atom_input_size, dropout=0.2
                ),
                "a2g": GraphConvDropoutBatch(
                    in_feats=atom_input_size, out_feats=global_input_size, dropout=0.2
                ),
                "g2a": GraphConvDropoutBatch(
                    in_feats=global_input_size, out_feats=atom_input_size, dropout=0.2
                ),
                "b2g": GraphConvDropoutBatch(
                    in_feats=bond_input_size, out_feats=global_input_size, dropout=0.2
                ),
                "g2b": GraphConvDropoutBatch(
                    in_feats=global_input_size, out_feats=bond_input_size, dropout=0.2
                ),
                "a2a": GraphConvDropoutBatch(
                    in_feats=atom_input_size, out_feats=atom_input_size, dropout=0.2
                ),
                "b2b": GraphConvDropoutBatch(
                    in_feats=bond_input_size, out_feats=bond_input_size, dropout=0.2
                ),
                "g2g": GraphConvDropoutBatch(
                    in_feats=global_input_size, out_feats=global_input_size, dropout=0.2
                ),
            },
            aggregate="sum",
        )

        self.conv2 = dglnn.HeteroGraphConv(
            {
                "a2b": GraphConvDropoutBatch(
                    in_feats=atom_input_size,
                    out_feats=bond_output_size,
                ),
                "b2a": GraphConvDropoutBatch(
                    in_feats=bond_input_size,
                    out_feats=atom_input_size,
                ),
                "a2g": GraphConvDropoutBatch(
                    in_feats=atom_input_size,
                    out_feats=global_input_size,
                ),
                "g2a": GraphConvDropoutBatch(
                    in_feats=global_input_size,
                    out_feats=atom_input_size,
                ),
                "b2g": GraphConvDropoutBatch(
                    in_feats=bond_input_size,
                    out_feats=global_input_size,
                ),
                "g2b": GraphConvDropoutBatch(
                    in_feats=global_input_size,
                    out_feats=bond_output_size,
                ),
                "a2a": GraphConvDropoutBatch(
                    in_feats=atom_input_size,
                    out_feats=atom_input_size,
                ),
                "b2b": GraphConvDropoutBatch(
                    in_feats=bond_input_size,
                    out_feats=bond_output_size,
                ),
                "g2g": GraphConvDropoutBatch(
                    in_feats=global_input_size,
                    out_feats=global_input_size,
                ),
            },
            aggregate="sum",
        )
        """
        # get max feature length 
        self.max_feature_len = max([atom_input_size, bond_input_size, global_input_size])
        
        self.conv3 = dglnn.HeteroGraphConv(
            {
                "b2a": dglnn.GATConv(
                    in_feats=self.max_feature_len, 
                    out_feats=hidden_feats,
                    num_heads = n_heads, 
                    #aggregator_type="lstm",
                ),
                "g2a": dglnn.GATConv(
                    in_feats=self.max_feature_len, 
                    out_feats=hidden_feats,
                    num_heads = n_heads, 
                    #aggregator_type="lstm",
                ),
                "a2a": dglnn.GATConv(
                    in_feats=self.max_feature_len, 
                    out_feats=hidden_feats,
                    num_heads = n_heads, 
                    #aggregator_type="lstm",
                ),
                "g2b": dglnn.GATConv(
                    in_feats=self.max_feature_len,
                    out_feats=hidden_feats,
                    num_heads = n_heads, 
                    #aggregator_type="lstm",
                ),
                "a2b": dglnn.GATConv(
                    in_feats=self.max_feature_len,
                    out_feats=hidden_feats ,
                    num_heads = n_heads, 
                    #aggregator_type="lstm",
                ),
                "b2b": dglnn.GATConv(
                    in_feats=self.max_feature_len,
                    out_feats=hidden_feats,
                    num_heads = n_heads, 
                    #aggregator_type="lstm",
                ),

                "b2g": dglnn.GATConv(
                    in_feats=self.max_feature_len,
                    out_feats=hidden_feats,
                    num_heads = n_heads, 
                    #aggregator_type="lstm",
                ),
                "g2g": dglnn.GATConv(
                    in_feats=self.max_feature_len,
                    out_feats=hidden_feats,
                    num_heads = n_heads, 
                    #aggregator_type="lstm",
                ),
                "a2g": dglnn.GATConv(
                    in_feats=self.max_feature_len,
                    out_feats= hidden_feats,
                    num_heads = n_heads, 
                    #aggregator_type="lstm",
                ),
            },
            aggregate="sum",
        )

        self.conv4 = dglnn.HeteroGraphConv(
            {
                "b2a": dglnn.GraphConv(
                    in_feats=hidden_feats * n_heads, 
                    out_feats=atom_input_size,
                    #aggregator_type="lstm",
                ),
                "g2a": dglnn.GraphConv(
                    in_feats=hidden_feats * n_heads, 
                    out_feats=atom_input_size,
                    #aggregator_type="lstm",
                ),
                "a2a": dglnn.GraphConv(
                    in_feats=hidden_feats * n_heads, 
                    out_feats=atom_input_size,
                    #aggregator_type="lstm",
                ),
                "g2b": dglnn.GraphConv(
                    in_feats=hidden_feats * n_heads,
                    out_feats=bond_output_size,
                    #aggregator_type="lstm",
                ),
                "a2b": dglnn.GraphConv(
                    in_feats=hidden_feats * n_heads,
                    out_feats=bond_output_size,
                    #aggregator_type="lstm",
                ),
                "b2b": dglnn.GraphConv(
                    in_feats=hidden_feats * n_heads,
                    out_feats=bond_output_size,
                    #aggregator_type="lstm",
                ),

                "b2g": dglnn.GraphConv(
                    in_feats=hidden_feats * n_heads,
                    out_feats=global_input_size,
                    #aggregator_type="lstm",
                ),
                "g2g": dglnn.GraphConv(
                    in_feats=hidden_feats * n_heads,
                    out_feats=global_input_size,
                    #aggregator_type="lstm",
                ),
                "a2g": dglnn.GraphConv(
                    in_feats=hidden_feats * n_heads,
                    out_feats= global_input_size,
                ),
            },
            aggregate="sum",
        )
        """

    def forward(self, graph, inputs):
        """
        # feats = self.conv(graph, inputs)
        # feats = self.conv2(graph, feats)
        #print(inputs.keys())
        #for i in inputs.keys():
        #    print(i)
        #    print(inputs[i].shape)
        # zero pad with typed linear to get max feature length for each feature
        inputs = {k: F.pad(v, (0, self.max_feature_len - v.shape[1])) for k, v in inputs.items()}
        feats = self.conv3(graph, inputs)
        #print("pass" * 30)
        feats = {k: torch.reshape(F.relu(v), (v.shape[0], -1)) for k, v in feats.items()}
        feats = self.conv4(graph, feats)
        return feats
        """

        # print(inputs)
        feats = self.conv(graph, inputs)
        # print(feats)
        feats = self.conv2(graph, feats)
        return feats


model = testmodel()

12 8 3


In [10]:
graph = batch_graph
forward_out = model(graph, graph.ndata["feat"])

In [11]:
from qtaim_embed.models.node_level.base_gnn import GCNNodePred

model_imported = GCNNodePred(
    atom_input_size=atom_input_size,
    bond_input_size=bond_input_size,
    global_input_size=global_input_size,
    n_conv_layers=2,
    target_dict={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": [
            "extra_feat_bond_esp_total",
            "extra_feat_bond_ellip_e_dens",
            "extra_feat_bond_eta",
        ],
    },
    dropout=0.2,
)

number of output dims 4


In [14]:
dataloader = DataLoaderMoleculeNodeTask(train_dataset, batch_size=100, shuffle=True)
batch_graph, batch_label = next(iter(dataloader))
model_imported.forward(batch_graph, batch_label)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1597x1 and 12x12)

In [16]:
from torch.nn import functional as F
from sklearn.metrics import r2_score

# from tqdm import tqdm
import tqdm.notebook as tq

opt = torch.optim.Adam(model.parameters(), lr=0.01)


for epoch in range(50):
    training_loss_list = []
    with tqdm(dataloader) as tq:
        model_imported.train()
        r2_list = []
        tq.set_description(f"Epoch {epoch+1}")
        training_loss = 0
        target_type = "bond"
        for step, (batch_graph, batch_label) in enumerate(tq):
            # forward propagation by using all nodes and extracting the user embeddings
            batch_graph, batch_label = next(iter(dataloader))
            labels = batch_label[target_type]
            logits = model_imported(batch_graph, batch_graph.ndata["feat"])[target_type]
            # print(logits.shape)
            # print(labels.shape)
            # compute loss
            loss = F.mse_loss(logits, labels)
            training_loss_list.append(loss.item())
            # loss_mae = F.l1_loss(logits, labels)
            # compute r2 score
            r2 = r2_score(logits.detach().numpy(), labels.detach().numpy())
            r2_list.append(r2)
            # Compute validation accuracy.  Omitted in this example.
            # backward propagation
            opt.zero_grad()
            loss.backward()
            opt.step()
            training_loss += loss.item()
            # tq.set_postfix({"Step": step, "MSE": loss.item()})

        r2_mean = np.mean(r2_list)
        loss = np.mean(training_loss_list)
        tq.set_postfix({"final_t_loss": training_loss, "R_2": r2_mean})
        print(r2_mean, loss)

        # tq.update()
        tq.close()

Epoch 1: 100%|██████████| 1/1 [00:00<00:00, 45.06it/s]


-18.491557544946957 64.81822204589844


Epoch 2: 100%|██████████| 1/1 [00:00<00:00, 53.65it/s]


-19.652633210662124 64.9208755493164


Epoch 3: 100%|██████████| 1/1 [00:00<00:00, 51.62it/s]


-18.196400837495926 64.71204376220703


Epoch 4: 100%|██████████| 1/1 [00:00<00:00, 50.30it/s]


-19.481222898194748 65.23275756835938


Epoch 5: 100%|██████████| 1/1 [00:00<00:00, 45.16it/s]


-18.454386870128573 65.0399398803711


Epoch 6: 100%|██████████| 1/1 [00:00<00:00, 48.93it/s]


-18.000030814527026 64.60310363769531


Epoch 7: 100%|██████████| 1/1 [00:00<00:00, 55.58it/s]


-18.619782404310477 65.09677124023438


Epoch 8: 100%|██████████| 1/1 [00:00<00:00, 57.58it/s]


-17.948384887704083 64.5608901977539


Epoch 9: 100%|██████████| 1/1 [00:00<00:00, 58.70it/s]


-18.35725478182663 64.4888687133789


Epoch 10: 100%|██████████| 1/1 [00:00<00:00, 59.72it/s]


-19.125196694766217 65.395263671875


Epoch 11: 100%|██████████| 1/1 [00:00<00:00, 56.30it/s]


-18.822992729676475 64.20320129394531


Epoch 12: 100%|██████████| 1/1 [00:00<00:00, 58.67it/s]


-18.468047646042773 64.5150146484375


Epoch 13: 100%|██████████| 1/1 [00:00<00:00, 57.85it/s]


-18.838857720977654 64.35306549072266


Epoch 14: 100%|██████████| 1/1 [00:00<00:00, 54.94it/s]


-18.93165797709132 64.61466217041016


Epoch 15: 100%|██████████| 1/1 [00:00<00:00, 56.64it/s]


-19.11850721946044 64.75078582763672


Epoch 16: 100%|██████████| 1/1 [00:00<00:00, 56.09it/s]


-19.164263035646393 64.55276489257812


Epoch 17: 100%|██████████| 1/1 [00:00<00:00, 54.47it/s]


-18.726980210173092 64.60494232177734


Epoch 18: 100%|██████████| 1/1 [00:00<00:00, 54.13it/s]


-18.866725932893157 64.32341766357422


Epoch 19: 100%|██████████| 1/1 [00:00<00:00, 55.63it/s]


-18.885121214486833 64.25521850585938


Epoch 20: 100%|██████████| 1/1 [00:00<00:00, 53.47it/s]


-18.734664372676836 64.48884582519531


Epoch 21: 100%|██████████| 1/1 [00:00<00:00, 53.69it/s]


-18.566862350797052 64.1678237915039


Epoch 22: 100%|██████████| 1/1 [00:00<00:00, 53.84it/s]


-19.078504733061454 64.2514877319336


Epoch 23: 100%|██████████| 1/1 [00:00<00:00, 44.06it/s]


-18.63568252431443 65.0409927368164


Epoch 24: 100%|██████████| 1/1 [00:00<00:00, 45.45it/s]


-18.200540997031805 65.32804870605469


Epoch 25: 100%|██████████| 1/1 [00:00<00:00, 51.38it/s]


-18.049153704118282 64.82474517822266


Epoch 26: 100%|██████████| 1/1 [00:00<00:00, 49.32it/s]


-19.686220342164813 64.52757263183594


Epoch 27: 100%|██████████| 1/1 [00:00<00:00, 46.97it/s]


-18.57613378938301 64.51982116699219


Epoch 28: 100%|██████████| 1/1 [00:00<00:00, 47.90it/s]


-19.124851959287994 64.7703628540039


Epoch 29: 100%|██████████| 1/1 [00:00<00:00, 47.15it/s]


-18.227306757037834 64.6273422241211


Epoch 30: 100%|██████████| 1/1 [00:00<00:00, 47.34it/s]


-18.494942144287858 64.94515991210938


Epoch 31: 100%|██████████| 1/1 [00:00<00:00, 52.62it/s]


-18.669486245843185 64.63151550292969


Epoch 32: 100%|██████████| 1/1 [00:00<00:00, 52.13it/s]


-19.602800871321477 64.83536529541016


Epoch 33: 100%|██████████| 1/1 [00:00<00:00, 47.28it/s]


-19.173129154976145 64.57640838623047


Epoch 34: 100%|██████████| 1/1 [00:00<00:00, 45.80it/s]


-18.156369739786083 64.44548034667969


Epoch 35: 100%|██████████| 1/1 [00:00<00:00, 48.47it/s]


-18.89881080634547 64.65231323242188


Epoch 36: 100%|██████████| 1/1 [00:00<00:00, 53.00it/s]


-19.237002269548878 64.60893249511719


Epoch 37: 100%|██████████| 1/1 [00:00<00:00, 57.50it/s]


-18.618874592243944 64.29723358154297


Epoch 38: 100%|██████████| 1/1 [00:00<00:00, 56.16it/s]


-20.1967910121662 64.67274475097656


Epoch 39: 100%|██████████| 1/1 [00:00<00:00, 48.54it/s]


-18.294574929804543 64.34037780761719


Epoch 40: 100%|██████████| 1/1 [00:00<00:00, 44.84it/s]


-18.45202848532042 64.7392349243164


Epoch 41: 100%|██████████| 1/1 [00:00<00:00, 42.74it/s]


-19.804325623581757 64.83012390136719


Epoch 42: 100%|██████████| 1/1 [00:00<00:00, 51.85it/s]


-18.60340671432751 64.28572845458984


Epoch 43: 100%|██████████| 1/1 [00:00<00:00, 55.00it/s]


-19.142018545051993 64.08544921875


Epoch 44: 100%|██████████| 1/1 [00:00<00:00, 50.10it/s]


-18.78630368772168 65.11734008789062


Epoch 45: 100%|██████████| 1/1 [00:00<00:00, 51.44it/s]


-18.388329644873 64.81556701660156


Epoch 46: 100%|██████████| 1/1 [00:00<00:00, 55.17it/s]


-17.734725996436474 64.4045639038086


Epoch 47: 100%|██████████| 1/1 [00:00<00:00, 55.63it/s]


-19.0674816800603 63.88050842285156


Epoch 48: 100%|██████████| 1/1 [00:00<00:00, 55.20it/s]


-18.350620250137013 64.65076446533203


Epoch 49: 100%|██████████| 1/1 [00:00<00:00, 54.05it/s]


-19.139153756222623 65.30514526367188


Epoch 50: 100%|██████████| 1/1 [00:00<00:00, 51.83it/s]


-18.529322300103342 65.0848617553711


In [43]:
from torch.nn import functional as F
from sklearn.metrics import r2_score

# from tqdm import tqdm
import tqdm.notebook as tq

opt = torch.optim.Adam(model.parameters(), lr=0.01)


for epoch in range(50):
    training_loss_list = []
    with tqdm(dataloader) as tq:
        model.train()
        r2_list = []
        tq.set_description(f"Epoch {epoch+1}")
        training_loss = 0
        target_type = "bond"
        for step, (batch_graph, batch_label) in enumerate(tq):
            # forward propagation by using all nodes and extracting the user embeddings
            batch_graph, batch_label = next(iter(dataloader))
            labels = batch_label[target_type]
            logits = model(batch_graph, batch_graph.ndata["feat"])[target_type]
            # print(logits.shape)
            # print(labels.shape)
            # compute loss
            loss = F.mse_loss(logits, labels)
            training_loss_list.append(loss.item())
            # loss_mae = F.l1_loss(logits, labels)
            # compute r2 score
            r2 = r2_score(logits.detach().numpy(), labels.detach().numpy())
            r2_list.append(r2)
            # Compute validation accuracy.  Omitted in this example.
            # backward propagation
            opt.zero_grad()
            loss.backward()
            opt.step()
            training_loss += loss.item()
            # tq.set_postfix({"Step": step, "MSE": loss.item()})

        r2_mean = np.mean(r2_list)
        loss = np.mean(training_loss_list)
        tq.set_postfix({"final_t_loss": training_loss, "R_2": r2_mean})
        print(r2_mean, loss)

        # tq.update()
        tq.close()

Epoch 1: 100%|██████████| 1/1 [00:00<00:00, 23.45it/s]


-3.7205302931984163 14.68019962310791


Epoch 2: 100%|██████████| 1/1 [00:00<00:00, 31.44it/s]


-3.9273690115274005 14.689446449279785


Epoch 3: 100%|██████████| 1/1 [00:00<00:00, 21.65it/s]


-3.718336389191752 14.703181266784668


Epoch 4: 100%|██████████| 1/1 [00:00<00:00, 21.07it/s]


-3.780627066867511 14.21032428741455


Epoch 5: 100%|██████████| 1/1 [00:00<00:00, 25.68it/s]


-3.7316928396393387 13.965476036071777


Epoch 6: 100%|██████████| 1/1 [00:00<00:00, 29.49it/s]


-3.457277128817072 13.535135269165039


Epoch 7: 100%|██████████| 1/1 [00:00<00:00, 30.38it/s]


-3.8366539567068187 14.601213455200195


Epoch 8: 100%|██████████| 1/1 [00:00<00:00, 26.72it/s]


-3.747202690314173 13.935641288757324


Epoch 9: 100%|██████████| 1/1 [00:00<00:00, 25.05it/s]


-3.9264098983236377 14.054888725280762


Epoch 10: 100%|██████████| 1/1 [00:00<00:00, 27.90it/s]


-3.7085638304547572 13.922420501708984


Epoch 11: 100%|██████████| 1/1 [00:00<00:00, 27.32it/s]


-3.656725223457279 13.761919975280762


Epoch 12: 100%|██████████| 1/1 [00:00<00:00, 30.60it/s]


-3.633013457811136 13.758620262145996


Epoch 13: 100%|██████████| 1/1 [00:00<00:00, 21.38it/s]


-3.657660576856241 13.473957061767578


Epoch 14: 100%|██████████| 1/1 [00:00<00:00, 23.97it/s]


-3.4632583343516097 12.75365161895752


Epoch 15: 100%|██████████| 1/1 [00:00<00:00, 29.13it/s]


-3.4518971826937594 12.83767032623291


Epoch 16: 100%|██████████| 1/1 [00:00<00:00, 30.63it/s]


-3.7413966232441367 13.785242080688477


Epoch 17: 100%|██████████| 1/1 [00:00<00:00, 29.79it/s]


-3.4143438067061944 12.265037536621094


Epoch 18: 100%|██████████| 1/1 [00:00<00:00, 28.99it/s]


-3.5086722726556574 12.938258171081543


Epoch 19: 100%|██████████| 1/1 [00:00<00:00, 26.56it/s]


-3.367915942005236 13.06378173828125


Epoch 20: 100%|██████████| 1/1 [00:00<00:00, 26.17it/s]


-3.4021242200252733 12.947981834411621


Epoch 21: 100%|██████████| 1/1 [00:00<00:00, 28.69it/s]


-3.5004966946170217 13.00985050201416


Epoch 22: 100%|██████████| 1/1 [00:00<00:00, 29.46it/s]


-3.384844404794242 13.28810977935791


Epoch 23: 100%|██████████| 1/1 [00:00<00:00, 29.80it/s]


-3.2244815688494857 12.474140167236328


Epoch 24: 100%|██████████| 1/1 [00:00<00:00, 29.52it/s]


-3.4690267312767773 12.989046096801758


Epoch 25: 100%|██████████| 1/1 [00:00<00:00, 26.23it/s]


-3.293178347443883 12.189672470092773


Epoch 26: 100%|██████████| 1/1 [00:00<00:00, 23.78it/s]


-2.944777341708965 12.390623092651367


Epoch 27: 100%|██████████| 1/1 [00:00<00:00, 27.04it/s]


-3.2695326133095257 13.350651741027832


Epoch 28: 100%|██████████| 1/1 [00:00<00:00, 28.84it/s]


-3.14580230252143 13.373291015625


Epoch 29: 100%|██████████| 1/1 [00:00<00:00, 28.49it/s]


-2.8243291473717504 11.924619674682617


Epoch 30: 100%|██████████| 1/1 [00:00<00:00, 30.56it/s]


-2.725302388567427 12.089298248291016


Epoch 31: 100%|██████████| 1/1 [00:00<00:00, 24.59it/s]


-2.9819645963794534 13.01147747039795


Epoch 32: 100%|██████████| 1/1 [00:00<00:00, 29.29it/s]


-2.6896620456346993 12.334114074707031


Epoch 33: 100%|██████████| 1/1 [00:00<00:00, 30.06it/s]


-2.7278466857037436 12.893320083618164


Epoch 34: 100%|██████████| 1/1 [00:00<00:00, 26.91it/s]


-2.616202893829377 12.038164138793945


Epoch 35: 100%|██████████| 1/1 [00:00<00:00, 26.45it/s]


-2.6583123846388705 12.17487621307373


Epoch 36: 100%|██████████| 1/1 [00:00<00:00, 28.45it/s]


-2.606168197946101 12.371857643127441


Epoch 37: 100%|██████████| 1/1 [00:00<00:00, 27.63it/s]


-2.4723748630432736 11.408025741577148


Epoch 38: 100%|██████████| 1/1 [00:00<00:00, 23.07it/s]


-2.5818821454983367 12.557168960571289


Epoch 39: 100%|██████████| 1/1 [00:00<00:00, 27.95it/s]


-2.654099973147433 13.04965591430664


Epoch 40: 100%|██████████| 1/1 [00:00<00:00, 30.16it/s]


-2.4009333073236205 12.38421630859375


Epoch 41: 100%|██████████| 1/1 [00:00<00:00, 28.58it/s]


-2.433594242284192 12.333194732666016


Epoch 42: 100%|██████████| 1/1 [00:00<00:00, 31.24it/s]


-2.296300125466566 11.27975845336914


Epoch 43: 100%|██████████| 1/1 [00:00<00:00, 29.14it/s]


-2.21173228676126 11.798844337463379


Epoch 44: 100%|██████████| 1/1 [00:00<00:00, 27.92it/s]


-2.338361850020438 12.338628768920898


Epoch 45: 100%|██████████| 1/1 [00:00<00:00, 25.97it/s]


-2.377102747756087 12.205547332763672


Epoch 46: 100%|██████████| 1/1 [00:00<00:00, 27.85it/s]


-2.3047446336976147 12.556431770324707


Epoch 47: 100%|██████████| 1/1 [00:00<00:00, 30.74it/s]


-2.3145441427097544 12.91015911102295


Epoch 48: 100%|██████████| 1/1 [00:00<00:00, 30.64it/s]


-2.2101164541132303 11.880921363830566


Epoch 49: 100%|██████████| 1/1 [00:00<00:00, 26.01it/s]


-2.021475935807809 11.253496170043945


Epoch 50: 100%|██████████| 1/1 [00:00<00:00, 24.18it/s]

-2.181490068655195 12.423982620239258





In [32]:
label_list = []
predictions_list = []

with tqdm(dataloader) as tq, torch.no_grad():
    for step, (batch_graph, batch_label) in enumerate(tq):
        batch_graph, batch_label = next(iter(dataloader))
        labels = batch_label[target_type]
        logits = model(batch_graph, batch_graph.ndata["feat"])[target_type]
        label_list.append(labels.cpu().numpy())
        predictions_list.append(logits.cpu().numpy())


cat_labels = np.concatenate(label_list)
cat_preds = np.concatenate(predictions_list)

100%|██████████| 1/1 [00:00<00:00, 28.51it/s]


In [16]:
print(cat_labels.shape)
print(cat_preds.shape)

(1632, 1)
(1632, 1)


In [33]:
r2 = r2_score(cat_labels[:, 0], cat_preds[:, 0])
print(r2)

0.8445300518290597


In [34]:
r2 = r2_score(cat_labels[:, 1], cat_preds[:, 1])
print(r2)

0.6478962999038074


In [35]:
r2 = r2_score(cat_labels[:, 2], cat_preds[:, 2])
print(r2)

0.43115008510261954


In [5]:
from qtaim_embed.core.datamodule import QTAIMNodeTaskDataModule
from qtaim_embed.models.node_level.base_gcn import GCNNodePred

dm = QTAIMNodeTaskDataModule(config=config)

In [6]:
import pytorch_lightning as pl

torch.set_float32_matmul_precision("high")
len_dict = dm.setup(stage="fit")
atom_input_size = len_dict["atom"]
bond_input_size = len_dict["bond"]
global_input_size = len_dict["global"]

... > running in debug mode
... > creating MoleculeWrapper objects


100%|██████████| 100/100 [00:00<00:00, 7746.14it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'N', 'H', 'C', 'O'}
selected atomic keys ['extra_feat_atom_esp_total']
selected bond keys ['extra_feat_bond_esp_total', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta', 'bond_length']
selected global keys []
... > Building graphs and featurizing


100%|██████████| 100/100 [00:00<00:00, 335.27it/s]


included in labels
{'atom': ['extra_feat_atom_esp_total'], 'bond': ['extra_feat_bond_esp_total', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_C', 'chemical_symbol_O'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['atom', 'bond'])
... > parsing labels and features in graphs


100%|██████████| 100/100 [00:00<00:00, 14270.22it/s]


original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['atom', 'bond'])
... > Log scaling features
... > Log scaling features complete
... > Scaling features
Standard deviation for feature 0 is 0.0, smaller than 0.001. You may want to exclude this feature.
... > Scaling features complete
... > feature mean(s): 
 {'atom': tensor([1.0335, 0.2958, 0.1931, 0.0230, 0.0399, 0.0681, 0.0430, 0.0191, 0.0456,
        0.3555, 0.2365, 0.0556]), 'bond': tensor([0.0000, 0.1958, 0.0242, 0.0442, 0.0701, 0.0459, 0.0208, 0.8152]), 'global': tensor([2.8185, 2.8375, 4.6976])}
... > feature std(s):  
 {'atom': tensor([0.3915, 0.4605, 0.3108, 0.1242, 0.1615, 0.2064, 0.1671, 0.1135, 0.1718,
        0.3465, 0.3286, 0.1882]), 'bond': tensor([0.0010, 0.3121, 0.1273, 0.1693, 0.2090, 0.1723, 0.1183, 0.0937]), 'global': tensor([0.1632, 0.1720, 0.0635])}
... > Log scaling targets
... > Log scaling targets complete
... > Scaling targets
... > Scaling targets complet

In [7]:
print(atom_input_size, bond_input_size, global_input_size)

12 8 3


In [8]:
model = GCNNodePred(
    atom_input_size=atom_input_size,
    bond_input_size=bond_input_size,
    global_input_size=global_input_size,
    target_dict={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": [
            "extra_feat_bond_esp_total",
            "extra_feat_bond_ellip_e_dens",
            "extra_feat_bond_eta",
        ],
    },
    activation="ReLU",
    conv_fn="ResidualBlock",
    resid_n_graph_convs=3,
    dropout=0.2,
    bias=True,
    batch_norm=True,
    n_conv_layers=8,
    lr_plateau_patience=10,
    lr=0.02,
    weight_decay=0.0001,
)

triggered normal outer layer
triggered normal outer layer
triggered normal outer layer
triggered normal outer layer
triggered normal outer layer
triggered normal outer layer
triggered output_layer args
triggered early stop condition!!!
target_dict {'atom': ['extra_feat_atom_esp_total'], 'bond': ['extra_feat_bond_esp_total', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']}
triggered separate intermediate layer
triggered separate intermediate layer
triggered separate outer layer
number of output dims 4


In [10]:
trainer_transfer = pl.Trainer(
    max_epochs=100,
    accelerator="gpu",
    devices=1,
    enable_progress_bar=True,
    gradient_clip_val=3.0,
    default_root_dir="./test/",
    precision="32",
    log_every_n_steps=10,
)

trainer_transfer.fit(model, dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name            | Type               | Params
--------------------------------------------------------
0  | conv_layers     | ModuleList         | 6.1 K 
1  | loss            | MultioutputWrapper | 0     
2  | train_r2        | MultioutputWrapper | 0     
3  | train_torch_l1  | MultioutputWrapper | 0     
4  | train_torch_mse | MultioutputWrapper | 0     
5  | val_r2          | MultioutputWrapper | 0     
6  | val_torch_l1    | MultioutputWrapper | 0     
7  | val_torch_mse   | MultioutputWrapper | 0     
8  | test_r2         | MultioutputWrapper | 0     
9  | test_torch_l1   | MultioutputWrapper | 0     
10 | test_torch_mse  | MultioutputWrapper | 0     
--------------------------------------------------------
6.1 K     Trainable params
0         Non-trainable params
6.1 K     Total par

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00012: reducing learning rate of group 0 to 1.0000e-02.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00023: reducing learning rate of group 0 to 5.0000e-03.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00034: reducing learning rate of group 0 to 2.5000e-03.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00045: reducing learning rate of group 0 to 1.2500e-03.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00056: reducing learning rate of group 0 to 6.2500e-04.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00067: reducing learning rate of group 0 to 3.1250e-04.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00078: reducing learning rate of group 0 to 1.5625e-04.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00089: reducing learning rate of group 0 to 7.8125e-05.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 00100: reducing learning rate of group 0 to 3.9063e-05.


In [14]:
dm.train_dataloader().dataset[0].ndata["labels"]["atom"].shape

torch.Size([22, 1])

In [15]:
# get example output from model
model.eval()
out = model(
    dm.train_dataloader().dataset[0], dm.train_dataloader().dataset[0].ndata["feat"]
)

RuntimeError: The size of tensor a (3) must match the size of tensor b (8) at non-singleton dimension 1

In [13]:
out["bond"].shape

torch.Size([22, 8])

In [12]:
out["atom"].shape

torch.Size([22, 12])

In [54]:
dm.train_dataloader().dataset[0].ndata["feat"]["atom"].shape

torch.Size([22, 12])