In [1]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from qtaim_embed.models.utils import load_graph_level_model_from_config
from qtaim_embed.data.dataloader import DataLoaderMoleculeGraphTask
from qtaim_embed.core.dataset import HeteroGraphGraphLabelDataset
from qtaim_embed.models.utils import get_test_train_preds_as_df, test_and_predict_libe

In [2]:
qtaim_keys_high = {
    "atom": [
        "extra_feat_atom_Lagrangian_K", 
      "extra_feat_atom_Hamiltonian_K",
       "extra_feat_atom_e_density", "extra_feat_atom_lap_e_density",
       "extra_feat_atom_e_loc_func", "extra_feat_atom_ave_loc_ion_E",
       "extra_feat_atom_delta_g_promolecular", "extra_feat_atom_delta_g_hirsh",
       "extra_feat_atom_esp_nuc", "extra_feat_atom_esp_e",
       "extra_feat_atom_esp_total", "extra_feat_atom_grad_norm",
       "extra_feat_atom_lap_norm", "extra_feat_atom_eig_hess",
       "extra_feat_atom_det_hessian", "extra_feat_atom_ellip_e_dens",
       "extra_feat_atom_eta","extra_feat_atom_energy_density"
    ],
    "bond": [
        "bond_length",
       "extra_feat_bond_Lagrangian_K", 
       "extra_feat_bond_Hamiltonian_K", "extra_feat_bond_e_density",
       "extra_feat_bond_lap_e_density", "extra_feat_bond_e_loc_func",
       "extra_feat_bond_ave_loc_ion_E", "extra_feat_bond_delta_g_promolecular",
       "extra_feat_bond_delta_g_hirsh", "extra_feat_bond_esp_nuc",
       "extra_feat_bond_esp_e", "extra_feat_bond_esp_total",
       "extra_feat_bond_grad_norm", "extra_feat_bond_lap_norm",
       "extra_feat_bond_eig_hess", "extra_feat_bond_det_hessian",
       "extra_feat_bond_ellip_e_dens", "extra_feat_bond_eta",
       "extra_feat_bond_energy_density", "extra_feat_bond_lol"

    ],
    "global": ["corrected_E", "spin", "charge"],
}


libe_loc = "../../../../data/splits_1205/test_libe_qtaim_1205_labelled_corrected.pkl"
libe_loc_train = (
    "../../../../data/splits_1205/train_libe_qtaim_1205_labelled_corrected.pkl"
)

base_dict = {
    "atom": [],
    "bond": ["bond_length"],
    "global": ["corrected_E", "spin", "charge"],
}

dataset_dict = {}
keys_list = [
    base_dict,
    qtaim_keys_high,
]
for ind, feat_dict in enumerate(keys_list):
    if ind == 0:
        key = "base"
    elif ind == 1:
        key = "qtaim_full"

    dataset_dict[key + "_test"] = HeteroGraphGraphLabelDataset(
        file=libe_loc,
        allowed_ring_size=[3, 4, 5, 6, 7],
        allowed_charges=[-1, 0, 1],
        allowed_spins=[1, 2, 3],
        self_loop=True,
        extra_keys=feat_dict,
        target_list=["corrected_E"],
        extra_dataset_info={},
        debug=False,
        log_scale_features=True,
        log_scale_targets=False,
        standard_scale_features=True,
        standard_scale_targets=True,
    )
    dataset_dict[key + "_train"] = HeteroGraphGraphLabelDataset(
        file=libe_loc_train,
        allowed_ring_size=[3, 4, 5, 6, 7],
        allowed_charges=[-1, 0, 1],
        allowed_spins=[1, 2, 3],
        self_loop=True,
        extra_keys=feat_dict,
        target_list=["corrected_E"],
        extra_dataset_info={},
        debug=False,
        log_scale_features=True,
        log_scale_targets=False,
        standard_scale_features=True,
        standard_scale_targets=True,
    )


... > creating MoleculeWrapper objects


100%|██████████| 1716/1716 [00:00<00:00, 13984.60it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'C', 'F', 'O', 'N', 'S', 'H', 'P', 'Li'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['corrected_E', 'spin', 'charge']
... > Building graphs and featurizing


100%|██████████| 1716/1716 [00:02<00:00, 640.22it/s]


included in labels
{'global': ['corrected_E']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'ring_size_7', 'chemical_symbol_C', 'chemical_symbol_F', 'chemical_symbol_O', 'chemical_symbol_N', 'chemical_symbol_S', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Li'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'ring size_7', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight', 'charge one hot', 'charge one hot', 'charge one hot', 'spin one hot', 'spin one hot', 'spin one hot']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['global'])
... > parsing labels and features in graphs


100%|██████████| 1716/1716 [00:00<00:00, 30714.00it/s]


original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [0.82051049 0.12242161 0.07909777 0.00656382 0.05048244 0.01205826
 0.00859197 0.00140127 0.20550669 0.04406612 0.16778314 0.00151189
 0.00413005 0.21188614 0.01191076 0.0463524 ]
std [0.5301781  0.31105153 0.22038588 0.06713131 0.1801202  0.09062421
 0.07669211 0.0311339  0.31656498 0.16912269 0.29689599 0.03233694
 0.05334485 0.3193314  0.09007798 0.17314875]
mean [0.         0.09420027 0.00786088 0.05984692 0.01498344 0.01068385
 0.00182407 0.86012772]
std [0.         0.23753096 0.07339587 0.19468197 0.10080292 0.08538931
 0.03551088 0.34877978]
Standard deviation for feature 0 is 0.0, smaller than 0.001. You may want to exclude this feature.
mean [2.41164055 2.24761734 4.63867565 0.24397488 0.24478275 0.20438955
 0.29971749 0.29244671 0.10098298]
std [0.37903647 0.42249919 0.41

100%|██████████| 15441/15441 [00:16<00:00, 919.30it/s] 


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'C', 'F', 'O', 'N', 'S', 'H', 'P', 'Li'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['corrected_E', 'spin', 'charge']
... > Building graphs and featurizing


  2%|▏         | 364/15441 [00:18<09:37, 26.12it/s] 