In [1]:
import pandas as pd

import torch
import dgl
from tqdm import tqdm

from qtaim_embed.core.dataset import HeteroGraphNodeLabelDataset
from qtaim_embed.data.dataloader import DataLoaderMoleculeNodeTask
from qtaim_embed.models.node_level.base_gcn import GCNNodePred

In [15]:
# df = pd.read_pickle(
#    "/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim_labelled.pkl"
# )

dataset = HeteroGraphNodeLabelDataset(
    # file="/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_full.pkl",
    file="/home/santiagovargas/dev/qtaim_embed/data/splits_1205/train_qm8_qtaim_1205_labelled.pkl",
    allowed_ring_size=[3, 4, 5, 6, 7],
    allowed_charges=None,
    self_loop=True,
    extra_keys={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": [
            "bond_length",
            "extra_feat_bond_esp_total",
           "extra_feat_bond_ellip_e_dens",
           "extra_feat_bond_eta",
        ],
        #"global": ["E1-CC2", "E2-CC2"],
    },
    target_dict={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": [
           "extra_feat_bond_esp_total",
           "extra_feat_bond_ellip_e_dens",
           "extra_feat_bond_eta",
        ],
        #"global": ["E1-CC2", "E2-CC2"],
    },
    extra_dataset_info={},
    debug=False,
    log_scale_features=False,
    log_scale_targets=True,
    standard_scale_features=True,
    standard_scale_targets=True,
    
)

bond = [
    "extra_feat_bond_esp_total",
    "extra_feat_bond_ellip_e_dens",
    "extra_feat_bond_eta",
]

... > creating MoleculeWrapper objects


100%|██████████| 19607/19607 [00:02<00:00, 8688.51it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'N', 'C', 'O', 'F', 'H'}
selected atomic keys ['extra_feat_atom_esp_total']
selected bond keys ['bond_length', 'extra_feat_bond_esp_total', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys []
... > Building graphs and featurizing


100%|██████████| 19607/19607 [00:33<00:00, 582.16it/s]


included in labels
{'atom': ['extra_feat_atom_esp_total'], 'bond': ['extra_feat_bond_esp_total', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_N', 'chemical_symbol_C', 'chemical_symbol_O', 'chemical_symbol_F', 'chemical_symbol_H'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['atom', 'bond'])
... > parsing labels and features in graphs


100%|██████████| 19607/19607 [00:01<00:00, 14412.44it/s]


original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['atom', 'bond'])
... > Scaling features
mean [1.40693510e+00 3.21409331e-01 5.23967372e-02 9.02539660e-03
 2.37487086e-02 1.49895105e-02 3.90107557e-03 7.32045862e-04
 5.92925458e-02 3.43693948e-01 7.86172890e-02 1.39120404e-03
 5.17005013e-01]
std [1.5807023  0.64451751 0.22282576 0.0945724  0.15226525 0.1215106
 0.06233664 0.02704644 0.23617142 0.47494044 0.2691405  0.03727289
 0.49971075]
mean [0.00000000e+00 7.05700608e-02 1.21337385e-02 3.20213561e-02
 2.01889451e-02 5.44936000e-03 9.80375514e-04 1.55522426e+00]
std [0.         0.25610531 0.10948293 0.17605678 0.14064619 0.07361837
 0.0312956  0.7714357 ]
Standard deviation for feature 0 is 0.0, smaller than 0.001. You may want to exclude this feature.
mean [ 16.09394604  12.01734075 108.859255  ]
std [2.90372676 1.74248899 8.05426511]
... > Scaling features complete
... > feature mean(s): 
 {'atom': tensor([1.4069e+00, 3.2141

In [16]:
# split dataset into train and val
from torch.utils.data import random_split

generator = torch.Generator().manual_seed(42)
train_size = int(0.5 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(
    dataset, [train_size, val_size], generator=generator
)


train_dataloader = DataLoaderMoleculeNodeTask(
    train_dataset, batch_size=100, shuffle=True, num_workers=4, 
)

batch_graph, batch_label = next(iter(train_dataloader))
val_dataloader = DataLoaderMoleculeNodeTask(
    val_dataset, batch_size=100, shuffle=True, num_workers=4
)

In [17]:
print(batch_label["atom"].shape)
print(batch_label["bond"].shape)

torch.Size([1588, 1])
torch.Size([1194, 3])


In [18]:
len_dict = dataset.feature_size()
atom_input_size = len_dict["atom"]
bond_input_size = len_dict["bond"]
global_input_size = len_dict["global"]

model = GCNNodePred(
    atom_input_size=atom_input_size,
    bond_input_size=bond_input_size,
    global_input_size=global_input_size,
    target_dict={
        "atom": ["extra_feat_atom_esp_total"],
        "bond": [
           "extra_feat_bond_esp_total",
           "extra_feat_bond_ellip_e_dens",
           "extra_feat_bond_eta",
        ],
    },
    activation="ReLU",
    conv_fn="ResidualBlock",
    resid_n_graph_convs=3,
    hidden_size_gat=64,
    num_heads_gat=3,
    dropout_feat_gat=0.2,
    dropout_attn_gat=0.2,
    dropout=0.2,
    bias=True,
    batch_norm=True,
    n_conv_layers=6,
    lr_plateau_patience=10,
    lr=0.02,
    embedding_size=64,
    weight_decay=0.0001,
)

triggered output_layer args
number of output dims 4


In [19]:
batch_graph, batch_label = next(iter(train_dataloader))
feats = model.forward(batch_graph, batch_graph.ndata["feat"])

In [20]:
"""
optimizer, lr_scheduler = model.configure_optimizers()
optimizer = optimizer[0]
lr_scheduler = lr_scheduler[0]
"""

'\noptimizer, lr_scheduler = model.configure_optimizers()\noptimizer = optimizer[0]\nlr_scheduler = lr_scheduler[0]\n'

In [21]:
import pytorch_lightning as pl

torch.set_float32_matmul_precision("high")

trainer_transfer = pl.Trainer(
    max_epochs=100,
    accelerator="gpu",
    devices=1,
    enable_progress_bar=True,
    gradient_clip_val=30.0,
    default_root_dir="./test/",
    precision="bf16",
    log_every_n_steps=10,
)

trainer_transfer.fit(model, train_dataloader, val_dataloader)

  rank_zero_warn(
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name            | Type               | Params
--------------------------------------------------------
0  | embedding       | UnifySize          | 1.5 K 
1  | conv_layers     | ModuleList         | 193 K 
2  | loss            | MultioutputWrapper | 0     
3  | train_r2        | MultioutputWrapper | 0     
4  | train_torch_l1  | MultioutputWrapper | 0     
5  | train_torch_mse | MultioutputWrapper | 0     
6  | val_r2          | MultioutputWrapper | 0     
7  | val_torch_l1    | MultioutputWrapper | 0     
8  | val_torch_mse   | MultioutputWrapper | 0     
9  | test_r2         | MultioutputWrapper | 0     
10 | test_torch_l1   | MultioutputWrapper | 0     
11 | test_torch_mse  | MultioutputWrapper | 0     
------------------

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00012: reducing learning rate of group 0 to 1.0000e-02.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00023: reducing learning rate of group 0 to 5.0000e-03.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00034: reducing learning rate of group 0 to 2.5000e-03.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00045: reducing learning rate of group 0 to 1.2500e-03.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00056: reducing learning rate of group 0 to 6.2500e-04.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00067: reducing learning rate of group 0 to 3.1250e-04.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00078: reducing learning rate of group 0 to 1.5625e-04.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00089: reducing learning rate of group 0 to 7.8125e-05.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 00100: reducing learning rate of group 0 to 3.9063e-05.


In [22]:
val_dataloader = DataLoaderMoleculeNodeTask(
    val_dataset, batch_size=100, shuffle=True, num_workers=4
)


#batch_graph, batch_label = next(iter(train_dataloader))
#feats = model.forward(batch_graph, batch_graph.ndata["feat"])

batch_graph, batch_label = next(iter(val_dataloader))
scaler_list = dataset.label_scalers
model.cpu()
#model.evaluate_manually(batch_graph, batch_label, scaler_list)

GCNNodePred(
  (embedding): UnifySize(
    (linears): ModuleDict(
      (atom): Linear(in_features=13, out_features=64, bias=False)
      (bond): Linear(in_features=8, out_features=64, bias=False)
      (global): Linear(in_features=3, out_features=64, bias=False)
    )
  )
  (conv_layers): ModuleList(
    (0): ResidualBlock(
      (layers): ModuleList(
        (0-2): 3 x HeteroGraphConv(
          (mods): ModuleDict(
            (a2b): GraphConvDropoutBatch(
              (graph_conv): GraphConv(
                in=64, out=64, normalization=both
                (_activation): ReLU()
              )
              (dropout): Dropout(p=0.2, inplace=False)
              (batch_norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (b2a): GraphConvDropoutBatch(
              (graph_conv): GraphConv(
                in=64, out=64, normalization=both
                (_activation): ReLU()
              )
              (dropout): Dropout

In [23]:
from copy import deepcopy
import torchmetrics
from torchmetrics.wrappers import MultioutputWrapper


def evaluate_manually(model, batch_graph, batched_label, scaler_list):
    """
    Evaluate a set of data manually
    Takes
        feats: dict, dictionary of batched features
        scaler_list: list, list of scalers
    """
    

    #batch_graph, batch_label = batch
    preds = model.forward(batch_graph, batch_graph.ndata["feat"])
    for k, v in preds.items():
        preds[k] = v.cpu().detach()
        
    preds_unscaled = deepcopy(preds)
    labels_unscaled = deepcopy(batched_label)
    
    for scaler in scaler_list:
        labels_unscaled = scaler.inverse_feats(labels_unscaled)
        preds_unscaled = scaler.inverse_feats(preds_unscaled)

    r2_eval = MultioutputWrapper(
        torchmetrics.R2Score(), num_outputs=model.hparams.output_dims
    )
    mae_eval = MultioutputWrapper(
        torchmetrics.MeanAbsoluteError(), num_outputs=model.hparams.output_dims
    )
    mse_eval = MultioutputWrapper(
        torchmetrics.MeanSquaredError(squared=False),
        num_outputs=model.hparams.output_dims,
    )

    r2_dict = {}
    mae_dict = {}
    #max_nodes = -1
    for target_type, target_list in model.hparams.target_dict.items():
        if target_list != [None] and len(target_list) > 0:
                        
            r2_eval = MultioutputWrapper(
                torchmetrics.R2Score(), num_outputs=len(model.hparams.target_dict[target_type])
            )
            mae_eval = MultioutputWrapper(
                torchmetrics.MeanAbsoluteError(), num_outputs=len(model.hparams.target_dict[target_type])
            )
            logits = preds[target_type]
            labels = batched_label[target_type]

            r2_eval.update(logits, labels)
            mae_eval.update(logits, labels)
            
            r2_val = r2_eval.compute()
            mae_val = mae_eval.compute()

            r2_dict[target_type] = r2_val
            mae_dict[target_type] = mae_val
    
    return r2_dict, mae_dict
            

    
r2, mae = evaluate_manually(model, batch_graph, batch_label, scaler_list)

In [24]:
print(r2)

{'atom': tensor([0.7550]), 'bond': tensor([0.7377, 0.5981, 0.5113])}


In [25]:
print(mae)

{'atom': tensor([0.3381]), 'bond': tensor([0.3045, 0.3326, 0.4291])}


In [26]:
preds = model.forward(batch_graph, batch_graph.ndata["feat"])
for k, v in preds.items():
    preds[k] = v.cpu().detach()
print(preds["bond"].shape)
print(batch_label["bond"].shape)

torch.Size([1204, 3])
torch.Size([1204, 3])


In [27]:
batch_label.keys()
print(preds["atom"].shape)
print(batch_label["atom"].shape)

torch.Size([1628, 1])
torch.Size([1628, 1])


In [28]:
df = pd.read_pickle(
    "/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim.pkl"
)

FileNotFoundError: [Errno 2] No such file or directory: '/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim.pkl'

In [None]:
df.names

In [None]:
root_xyz = "../../../data/xyz_qm8/xyz/"
csv = "../../../data/xyz_qm8/qm8.sdf.csv"
# open csv file
df_labels = pd.read_csv(csv)
# print(df.head())
data_added = [[] for i in range(len(df))]
col_names_labels = list(df_labels.columns)
for i in tqdm(range(len(df))):
    name = df.names[i]
    # get second line of xyz file
    with open(root_xyz + name, "r") as f:
        line = f.readlines()[1]
    id = float(line.split()[1])
    row_hit = df_labels[df_labels["gdb9_index"] == id]

    # convert to list
    row_hit = row_hit.values.tolist()[0]
    # append to data_added
    data_added[i] = row_hit

# add to df with column names with prefix "extra_feat_global_" and replace . with _ and - with _
for i in range(len(col_names_labels)):
    df[
        "extra_feat_global_" + col_names_labels[i].replace(".", "_").replace("-", "_")
    ] = [row[i] for row in data_added]

In [None]:
df.columns

In [None]:
df.to_pickle(
    "/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim_labelled.pkl"
)

In [4]:
import pandas as pd

df = pd.read_pickle(
    "/home/santiagovargas/dev/qtaim_embed/data/xyz_qm8/molecules_qtaim_labelled.pkl"
)
df = df.head(100)
df.to_pickle("/home/santiagovargas/dev/qtaim_embed/tests/data/labelled_data.pkl")