In [1]:
import numpy as np
import pandas as pd
from copy import deepcopy
import networkx as nx

import torch
import dgl
from tqdm import tqdm
from qtaim_embed.utils.grapher import get_grapher
from qtaim_embed.data.molwrapper import mol_wrappers_from_df
from qtaim_embed.utils.tests import get_data
from qtaim_embed.core.dataset import HeteroGraphNodeLabelDataset
from qtaim_embed.data.dataloader import DataLoaderMoleculeNodeTask
from qtaim_embed.models.node_level.base_gnn import GCNNodePred

In [2]:
df = pd.read_pickle("/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim.pkl")

train_dataset = HeteroGraphNodeLabelDataset(
    #file="/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_full.pkl",
    file="/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim.pkl",
    allowed_ring_size=[3, 4, 5, 6, 7],
    allowed_charges=None,
    self_loop=True,
    extra_keys={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": [
            "extra_feat_bond_esp_total",
            'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta',
            "bond_length",
        ],
        "global": [],
    },
    target_dict={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": ["extra_feat_bond_esp_total", 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta',],
    },
    extra_dataset_info={},
    debug=False,
    log_scale_targets=False,
    standard_scale_targets=True,
)


... > creating MoleculeWrapper objects


100%|██████████| 21786/21786 [00:06<00:00, 3557.16it/s]


element set {'N', 'F', 'O', 'H', 'C'}
selected atomic keys ['extra_feat_atom_esp_total']
selected bond keys ['extra_feat_bond_esp_total', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta', 'bond_length']
selected global keys []
... > Building graphs and featurizing


100%|██████████| 21785/21785 [02:26<00:00, 148.83it/s]


included in labels
{'atom': ['extra_feat_atom_esp_total'], 'bond': ['extra_feat_bond_esp_total', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta'], 'global': []}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_N', 'chemical_symbol_F', 'chemical_symbol_O', 'chemical_symbol_H', 'chemical_symbol_C'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
... > parsing labels and features in graphs


100%|██████████| 21785/21785 [00:09<00:00, 2412.01it/s]


... > Scaling features
Standard deviation for feature 0 is 0.0, smaller than 0.001. You may want to exclude this feature.
... > Scaling features complete
... > mean: 
 {'atom': tensor([2.0525e+00, 5.2508e-01, 2.5562e-01, 5.0526e-02, 5.3274e-02, 9.2300e-02,
        5.0783e-02, 8.7325e-03, 5.9473e-02, 1.4606e-03, 7.8644e-02, 5.1693e-01,
        3.4349e-01]), 'bond': tensor([0.0000, 0.2584, 0.0523, 0.0571, 0.0956, 0.0543, 0.0101, 1.2687]), 'global': tensor([ 16.0904,  16.5128, 108.8643])}
... > std:  
 {'atom': tensor([1.2716, 0.8735, 0.4362, 0.2190, 0.2246, 0.2894, 0.2196, 0.0930, 0.2365,
        0.0382, 0.2692, 0.4997, 0.4749]), 'bond': tensor([0.0010, 0.4377, 0.2226, 0.2320, 0.2941, 0.2267, 0.1002, 0.2146]), 'global': tensor([2.9101, 3.1555, 8.0615])}
... > Scaling targets
... > Scaling targets complete
... > mean: 
 {'atom': tensor([2260655.2500]), 'bond': tensor([0.9732, 0.0857, 1.4807])}
... > std:  
 {'atom': tensor([63750496.]), 'bond': tensor([0.4235, 0.2129, 0.4297])}
... > load

In [8]:

len_dict = train_dataset.featuze_size()
atom_input_size = len_dict["atom"]
bond_input_size = len_dict["bond"]
global_input_size = len_dict["global"]

model = GCNNodePred(
    atom_input_size=atom_input_size,
    bond_input_size=bond_input_size,
    global_input_size=global_input_size,
    n_conv_layers=2,
    target_dict={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": ["extra_feat_bond_esp_total", 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta',],
    },
    activation="ReLU",
    dropout=0.2,
    lr_plateau_patience=10,
    lr=0.01
)

number of output dims 4


In [9]:
dataloader = DataLoaderMoleculeNodeTask(train_dataset, batch_size=100, shuffle=True)
batch_graph, batch_label = next(iter(dataloader))

In [10]:

"""
feats = batch_graph.ndata["feat"]
for layer in model.conv_layers:
    feats = layer(batch_graph, feats)


for key in ["atom", "bond", "global"]:
    feats[key] = _split_batched_output(batch_graph, feats[key], key)


def get_targets(self, targets_feats):
    targets = {}
    for k, v in self.target_dict.items():
        # if v is None or [] skip 
        if not (v is None or len(v) == 0):
            targets[k] = targets_feats[k]
    #[print(i) for i in list(targets.values())]
    # concat dict of tensors into one tensor
    list(targets.values())
    targets = torch.cat(list(targets.values()), dim=1)
    return targets    

targets = get_targets(model, feats)
"""

'\nfeats = batch_graph.ndata["feat"]\nfor layer in model.conv_layers:\n    feats = layer(batch_graph, feats)\n\n\nfor key in ["atom", "bond", "global"]:\n    feats[key] = _split_batched_output(batch_graph, feats[key], key)\n\n\ndef get_targets(self, targets_feats):\n    targets = {}\n    for k, v in self.target_dict.items():\n        # if v is None or [] skip \n        if not (v is None or len(v) == 0):\n            targets[k] = targets_feats[k]\n    #[print(i) for i in list(targets.values())]\n    # concat dict of tensors into one tensor\n    list(targets.values())\n    targets = torch.cat(list(targets.values()), dim=1)\n    return targets    \n\ntargets = get_targets(model, feats)\n'

In [11]:
feats = model.forward(batch_graph, batch_graph.ndata["feat"])

In [12]:
feats

{'atom': tensor([[ 4.3852],
         [ 5.0082],
         [ 9.4321],
         ...,
         [-0.6390],
         [-1.1658],
         [-0.7524]], grad_fn=<SumBackward1>),
 'bond': tensor([[ 0.2367,  0.2887,  0.4309],
         [-1.1728, -0.0552, -1.5208],
         [-0.9402, -1.5262, -1.6328],
         ...,
         [ 1.1535, -1.5262, -1.6328],
         [-0.4272,  3.1284,  1.2104],
         [ 0.0921, -0.7483, -1.6328]], grad_fn=<SumBackward1>),
 'global': tensor([[-1.8912e+00, -2.7998e-01,  2.4384e+00],
         [-4.6431e-01,  3.7792e+00, -5.4131e-01],
         [ 2.3665e-01, -1.1738e+00,  1.9950e+00],
         [ 2.9629e+00, -1.4526e+00,  3.5168e+00],
         [ 1.8827e+00, -1.5585e+00,  6.7825e-01],
         [-5.7771e-01, -1.5585e+00,  1.4813e+00],
         [ 3.5461e+00,  1.3334e+00,  2.7724e+00],
         [-2.8486e-01,  6.9395e-01, -1.1819e+00],
         [-2.3902e-01, -1.2880e+00, -1.8501e+00],
         [ 2.9359e+00,  7.3711e-01,  3.5059e+00],
         [-1.3242e+00, -1.5585e+00, -1.8501e+0

In [13]:
"""
optimizer, lr_scheduler = model.configure_optimizers()
optimizer = optimizer[0]
lr_scheduler = lr_scheduler[0]
"""

'\noptimizer, lr_scheduler = model.configure_optimizers()\noptimizer = optimizer[0]\nlr_scheduler = lr_scheduler[0]\n'

In [14]:
import pytorch_lightning as pl

trainer_transfer = pl.Trainer(
    max_epochs=100,
    accelerator="gpu",
    devices=1,
    enable_progress_bar=True,
    gradient_clip_val=3.0,
    default_root_dir="./test/",
    precision=32,
)

trainer_transfer.fit(model, dataloader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


  rank_zero_warn(
You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name            | Type               | Params
--------------------------------------------------------
0  | conv_layers     | ModuleList         | 1.0 K 
1  | loss            | MultioutputWrapper | 0     
2  | train_r2        | MultioutputWrapper | 0     
3  | train_torch_l1  | MultioutputWrapper | 0     
4  | train_torch_mse | MultioutputWrapper | 0     
5  | val_r2          | MultioutputWrapper | 0     
6  | val_torch_l1    | MultioutputWrapper | 0     
7  | val_torch_mse   | MultioutputWrapper | 0     
8  | test_r2         | MultioutputWrapper | 0   

Epoch 0:   0%|          | 0/218 [00:00<?, ?it/s] 

In [None]:
#opt = torch.optim.Adam(model.parameters(), lr=0.01)
from sklearn.metrics import r2_score
# import F 
import torch.nn.functional as F
for epoch in range(50):
    training_loss_list = []
    with tqdm(dataloader) as tq:
        model.train()
        r2_list = []
        tq.set_description(f"Epoch {epoch+1}")
        training_loss = 0
        target_type = "bond" 

        for step, (batch_graph, batch_label) in enumerate(tq):
            # forward propagation by using all nodes and extracting the user embeddings
            batch_graph, batch_label = next(iter(dataloader))
            labels = batch_label[target_type]
            logits = model(batch_graph, batch_graph.ndata["feat"])[target_type]
            
            logits_list = []
            labels_list = []
            max_nodes = -1
            for target_type in ["bond", "atom"]:
                
                #print(logits.shape)
                #print(labels.shape)
                if max_nodes < logits.shape[0]:
                    max_nodes = logits.shape[0]
                logits_list.append(logits)
                labels_list.append(labels)

                # compute loss
            # zero pad logits and labels to max_nodes size
            logits_list = [F.pad(i, (0, 0, 0, max_nodes - i.shape[0])) for i in logits_list]
            labels_list = [F.pad(i, (0, 0, 0, max_nodes - i.shape[0])) for i in labels_list]
            logits = torch.cat(logits_list, dim=1)
            labels = torch.cat(labels_list, dim=1)

            loss = F.mse_loss(logits, labels)
            training_loss_list.append(loss.item())
            # loss_mae = F.l1_loss(logits, labels)
            # compute r2 score
            r2 = r2_score(logits.detach().numpy(), labels.detach().numpy())
            r2_list.append(r2)
            #print(r2)
            # Compute validation accuracy.  Omitted in this example.
            # backward propagation

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            training_loss += loss.item()
            tq.set_postfix({"Step": step, "MSE": loss.item()})

        # show mean training loss
        #tq.set_postfix({"final_t_loss": training_loss, "R_2": r2})
        r2_mean = np.mean(r2_list)
        loss = np.mean(training_loss_list)
        #tq.set_postfix({"final_t_loss": training_loss, "R_2": r2_mean})
        #print(r2_mean, loss)

        # tq.update()
        tq.close()

Epoch 1: 100%|██████████| 1/1 [00:00<00:00, 24.94it/s, Step=0, MSE=14.5]
Epoch 2: 100%|██████████| 1/1 [00:00<00:00, 27.28it/s, Step=0, MSE=14.2]
Epoch 3: 100%|██████████| 1/1 [00:00<00:00, 28.43it/s, Step=0, MSE=14.6]
Epoch 4: 100%|██████████| 1/1 [00:00<00:00, 24.78it/s, Step=0, MSE=14.1]
Epoch 5: 100%|██████████| 1/1 [00:00<00:00, 25.01it/s, Step=0, MSE=13.9]
Epoch 6: 100%|██████████| 1/1 [00:00<00:00, 24.79it/s, Step=0, MSE=14.4]
Epoch 7: 100%|██████████| 1/1 [00:00<00:00, 24.11it/s, Step=0, MSE=14]
Epoch 8: 100%|██████████| 1/1 [00:00<00:00, 16.64it/s, Step=0, MSE=14.2]
Epoch 9: 100%|██████████| 1/1 [00:00<00:00, 28.94it/s, Step=0, MSE=14.8]
Epoch 10: 100%|██████████| 1/1 [00:00<00:00, 28.79it/s, Step=0, MSE=13.7]
Epoch 11: 100%|██████████| 1/1 [00:00<00:00, 25.83it/s, Step=0, MSE=13.9]
Epoch 12: 100%|██████████| 1/1 [00:00<00:00, 27.60it/s, Step=0, MSE=14.5]
Epoch 13: 100%|██████████| 1/1 [00:00<00:00, 29.53it/s, Step=0, MSE=14.4]
Epoch 14: 100%|██████████| 1/1 [00:00<00:00, 28.1