In [1]:
import os
import numpy as np

import torch
from torchmetrics.wrappers import MultioutputWrapper
import torchmetrics

from qtaim_embed.models.utils import load_graph_level_model_from_config
from qtaim_embed.data.dataloader import DataLoaderMoleculeGraphTask
from qtaim_embed.core.dataset import HeteroGraphGraphLabelClassifierDataset


In [59]:
tox21_test_loc = "../../../../data/splits_1205/test_tox21_qtaim_1205_labelled.pkl"
tox21_train_loc = "../../../../data/splits_1205/train_tox21_qtaim_1205_labelled.pkl"

bl_keys = {
    "atom": [],
    "bond": ["bond_length"],
    "global": [
        "NR-AR",
        "NR-AR-LBD",
        "NR-AhR",
        "NR-Aromatase",
        "NR-ER",
        "NR-ER-LBD",
        "NR-PPAR-gamma",
        "SR-ARE",
        "SR-ATAD5",
        "SR-HSE",
        "SR-MMP",
        "SR-p53",
    ],
}

qtaim_dict = {
    "atom": [
        "extra_feat_atom_Hamiltonian_K",
        "extra_feat_atom_e_density",
        "extra_feat_atom_lap_e_density",
        "extra_feat_atom_esp_total",
        "extra_feat_atom_det_hessian",
        "extra_feat_atom_eta",
        "extra_feat_atom_energy_density"
    ],
    "bond": [
        "extra_feat_bond_Lagrangian_K",
        "bond_length",
        "extra_feat_bond_Hamiltonian_K",
        "extra_feat_bond_e_density",
        "extra_feat_bond_lap_e_density",
        "extra_feat_bond_e_loc_func",
        "extra_feat_bond_esp_e",
        "extra_feat_bond_esp_total",
        "extra_feat_bond_grad_norm",
        "extra_feat_bond_lap_norm",
        "extra_feat_bond_ellip_e_dens",
        "extra_feat_bond_eta",
    ],
    "global": [
        "NR-AR",
        "NR-AR-LBD",
        "NR-AhR",
        "NR-Aromatase",
        "NR-ER",
        "NR-ER-LBD",
        "NR-PPAR-gamma",
        "SR-ARE",
        "SR-ATAD5",
        "SR-HSE",
        "SR-MMP",
        "SR-p53",
    ],
}



dataset_bl_train = HeteroGraphGraphLabelClassifierDataset(
    file=tox21_train_loc,
    allowed_ring_size=[3, 4, 5, 6],
    allowed_charges=None,
    allowed_spins=None,
    self_loop=True,
    extra_keys=bl_keys,
    target_list=[
        "NR-AR",
        "NR-AR-LBD",
        "NR-AhR",
        "NR-Aromatase",
        "NR-ER",
        "NR-ER-LBD",
        "NR-PPAR-gamma",
        "SR-ARE",
        "SR-ATAD5",
        "SR-HSE",
        "SR-MMP",
        "SR-p53",
    ],
    extra_dataset_info={},
    debug=False,
    impute=False,
    element_set={
        "Fe",
        "H",
        "Cu",
        "Cr",
        "Ge",
        "Na",
        "P",
        "N",
        "C",
        "Br",
        "S",
        "V",
        "F",
        "Se",
        "B",
        "Cl",
        "Zn",
        "Ti",
        "O",
        "Si",
        "Ni",
        "Ca",
        "Al",
        "As",
    },
    log_scale_features=False,
    standard_scale_features=True,
)

dataset_bl_test = HeteroGraphGraphLabelClassifierDataset(
    file=tox21_test_loc,
    allowed_ring_size=[3, 4, 5, 6],
    allowed_charges=None,
    allowed_spins=None,
    self_loop=True,
    extra_keys=bl_keys,
    target_list=[
        "NR-AR",
        "NR-AR-LBD",
        "NR-AhR",
        "NR-Aromatase",
        "NR-ER",
        "NR-ER-LBD",
        "NR-PPAR-gamma",
        "SR-ARE",
        "SR-ATAD5",
        "SR-HSE",
        "SR-MMP",
        "SR-p53",
    ],
    extra_dataset_info={},
    debug=False,
    impute=False,
    element_set={
        "Fe",
        "H",
        "Cu",
        "Cr",
        "Ge",
        "Na",
        "P",
        "N",
        "C",
        "Br",
        "S",
        "V",
        "F",
        "Se",
        "B",
        "Cl",
        "Zn",
        "Ti",
        "O",
        "Si",
        "Ni",
        "Ca",
        "Al",
        "As",
    },
    log_scale_features=False,
    standard_scale_features=True,
)

dataset_train_qtaim = HeteroGraphGraphLabelClassifierDataset(
    file=tox21_train_loc,
    allowed_ring_size=[3, 4, 5, 6],
    allowed_charges=None,
    allowed_spins=None,
    self_loop=True,
    extra_keys=qtaim_dict,
    target_list=[
        "NR-AR",
        "NR-AR-LBD",
        "NR-AhR",
        "NR-Aromatase",
        "NR-ER",
        "NR-ER-LBD",
        "NR-PPAR-gamma",
        "SR-ARE",
        "SR-ATAD5",
        "SR-HSE",
        "SR-MMP",
        "SR-p53",
    ],
    extra_dataset_info={},
    debug=False,
    impute=False,
    element_set={
        "Fe",
        "H",
        "Cu",
        "Cr",
        "Ge",
        "Na",
        "P",
        "N",
        "C",
        "Br",
        "S",
        "V",
        "F",
        "Se",
        "B",
        "Cl",
        "Zn",
        "Ti",
        "O",
        "Si",
        "Ni",
        "Ca",
        "Al",
        "As",
    },
    log_scale_features=False,
    standard_scale_features=True,
)

dataset_test_qtaim = HeteroGraphGraphLabelClassifierDataset(
    file=tox21_test_loc,
    allowed_ring_size=[3, 4, 5, 6],
    allowed_charges=None,
    allowed_spins=None,
    self_loop=True,
    extra_keys=qtaim_dict,
    target_list=[
        "NR-AR",
        "NR-AR-LBD",
        "NR-AhR",
        "NR-Aromatase",
        "NR-ER",
        "NR-ER-LBD",
        "NR-PPAR-gamma",
        "SR-ARE",
        "SR-ATAD5",
        "SR-HSE",
        "SR-MMP",
        "SR-p53",
    ],
    extra_dataset_info={},
    debug=False,
    impute=False,
    element_set={
        "Fe",
        "H",
        "Cu",
        "Cr",
        "Ge",
        "Na",
        "P",
        "N",
        "C",
        "Br",
        "S",
        "V",
        "F",
        "Se",
        "B",
        "Cl",
        "Zn",
        "Ti",
        "O",
        "Si",
        "Ni",
        "Ca",
        "Al",
        "As",
    },
    log_scale_features=True,
    standard_scale_features=True,
)


# make a dataset for each task
dict_datasets = {}
dict_datasets["qtaim"] = {}
dict_datasets["bl"] = {}
dict_datasets["qtaim"]["test"] = dataset_test_qtaim
dict_datasets["bl"]["test"] = dataset_bl_test
dict_datasets["qtaim"]["single_list"] = []
dict_datasets["bl"]["single_list"] = []

for task in qtaim_dict["global"]:
    qtaim_dict_temp = qtaim_dict.copy()
    qtaim_dict_temp["global"] = [task]
    base_dict_bl_temp = bl_keys.copy()
    base_dict_bl_temp["global"] = [task]

    dict_datasets["qtaim"]["single_list"].append(
        HeteroGraphGraphLabelClassifierDataset(
            file=tox21_test_loc,
            allowed_ring_size=[3, 4, 5, 6],
            allowed_charges=None,
            allowed_spins=None,
            self_loop=True,
            extra_keys=qtaim_dict_temp,
            target_list=[task],
            extra_dataset_info={},
            debug=False,
            impute=False,
            element_set={
                "Fe",
                "H",
                "Cu",
                "Cr",
                "Ge",
                "Na",
                "P",
                "N",
                "C",
                "Br",
                "S",
                "V",
                "F",
                "Se",
                "B",
                "Cl",
                "Zn",
                "Ti",
                "O",
                "Si",
                "Ni",
                "Ca",
                "Al",
                "As",
            },
            log_scale_features=False,
            standard_scale_features=True,
        )
    )

    dict_datasets["bl"]["single_list"].append(
        HeteroGraphGraphLabelClassifierDataset(
            file=tox21_test_loc,
            allowed_ring_size=[3, 4, 5, 6],
            allowed_charges=None,
            allowed_spins=None,
            self_loop=True,
            extra_keys=base_dict_bl_temp,
            target_list=[task],
            extra_dataset_info={},
            debug=False,
            impute=False,
            element_set={
                "Fe",
                "H",
                "Cu",
                "Cr",
                "Ge",
                "Na",
                "P",
                "N",
                "C",
                "Br",
                "S",
                "V",
                "F",
                "Se",
                "B",
                "Cl",
                "Zn",
                "Ti",
                "O",
                "Si",
                "Ni",
                "Ca",
                "Al",
                "As",
            },
            log_scale_features=False,
            standard_scale_features=True,
        )
    )


... > creating MoleculeWrapper objects



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 6692/6692 [00:01<00:00, 6598.59it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 6692/6692 [00:31<00:00, 211.48it/s]


included in labels
{'global': ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight'


[A
[A
6692it [00:00, 28692.78it/s]


... > number of categories for each label: 
...... > label  NR-AR :  2 with distribution:  [2608, 51]
...... > label  NR-AR-LBD :  2 with distribution:  [2631, 28]
...... > label  NR-AhR :  2 with distribution:  [2527, 132]
...... > label  NR-Aromatase :  2 with distribution:  [2612, 47]
...... > label  NR-ER :  2 with distribution:  [2451, 208]
...... > label  NR-ER-LBD :  2 with distribution:  [2603, 56]
...... > label  NR-PPAR-gamma :  2 with distribution:  [2634, 25]
...... > label  SR-ARE :  2 with distribution:  [2496, 163]
...... > label  SR-ATAD5 :  2 with distribution:  [2648, 11]
...... > label  SR-HSE :  2 with distribution:  [2616, 43]
...... > label  SR-MMP :  2 with distribution:  [2539, 120]
...... > label  SR-p53 :  2 with distribution:  [2635, 24]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  4033
... > Scaling features
mean [1.56201885e+00 3.80881765e-01 5.58098750e-02 9


[A
100%|██████████| 744/744 [00:00<00:00, 5899.22it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 307.73it/s]


included in labels
{'global': ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight'


744it [00:00, 28204.14it/s]


... > number of categories for each label: 
...... > label  NR-AR :  2 with distribution:  [306, 8]
...... > label  NR-AR-LBD :  2 with distribution:  [308, 6]
...... > label  NR-AhR :  2 with distribution:  [294, 20]
...... > label  NR-Aromatase :  2 with distribution:  [306, 8]
...... > label  NR-ER :  2 with distribution:  [291, 23]
...... > label  NR-ER-LBD :  2 with distribution:  [306, 8]
...... > label  NR-PPAR-gamma :  2 with distribution:  [314, 0]
...... > label  SR-ARE :  2 with distribution:  [291, 23]
...... > label  SR-ATAD5 :  2 with distribution:  [314, 0]
...... > label  SR-HSE :  2 with distribution:  [309, 5]
...... > label  SR-MMP :  2 with distribution:  [296, 18]
...... > label  SR-p53 :  2 with distribution:  [311, 3]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  430
... > Scaling features
mean [1.57018674e+00 3.71538957e-01 5.53766903e-02 1.28783001e-03
 9.87336338


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 6692/6692 [00:03<00:00, 2030.39it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']
... > Building graphs and featuri


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 6692/6692 [00:28<00:00, 234.62it/s]


included in labels
{'global': ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_


[A
[A
6692it [00:00, 27262.01it/s]


... > number of categories for each label: 
...... > label  NR-AR :  2 with distribution:  [2608, 51]
...... > label  NR-AR-LBD :  2 with distribution:  [2631, 28]
...... > label  NR-AhR :  2 with distribution:  [2527, 132]
...... > label  NR-Aromatase :  2 with distribution:  [2612, 47]
...... > label  NR-ER :  2 with distribution:  [2451, 208]
...... > label  NR-ER-LBD :  2 with distribution:  [2603, 56]
...... > label  NR-PPAR-gamma :  2 with distribution:  [2634, 25]
...... > label  SR-ARE :  2 with distribution:  [2496, 163]
...... > label  SR-ATAD5 :  2 with distribution:  [2648, 11]
...... > label  SR-HSE :  2 with distribution:  [2616, 43]
...... > label  SR-MMP :  2 with distribution:  [2539, 120]
...... > label  SR-p53 :  2 with distribution:  [2635, 24]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  4033
... > Scaling features
mean [ 1.56201885e+00  3.80881765e-01  5.58098750e-0


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 1948.04it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']
... > Building graphs and featuri


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 299.49it/s]


included in labels
{'global': ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_


744it [00:00, 27943.50it/s]


... > number of categories for each label: 
...... > label  NR-AR :  2 with distribution:  [306, 8]
...... > label  NR-AR-LBD :  2 with distribution:  [308, 6]
...... > label  NR-AhR :  2 with distribution:  [294, 20]
...... > label  NR-Aromatase :  2 with distribution:  [306, 8]
...... > label  NR-ER :  2 with distribution:  [291, 23]
...... > label  NR-ER-LBD :  2 with distribution:  [306, 8]
...... > label  NR-PPAR-gamma :  2 with distribution:  [314, 0]
...... > label  SR-ARE :  2 with distribution:  [291, 23]
...... > label  SR-ATAD5 :  2 with distribution:  [314, 0]
...... > label  SR-HSE :  2 with distribution:  [309, 5]
...... > label  SR-MMP :  2 with distribution:  [296, 18]
...... > label  SR-p53 :  2 with distribution:  [311, 3]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  430
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [ 7.51226


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 2069.97it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['NR-AR']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 304.68it/s]


included in labels
{'global': ['NR-AR']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ri


744it [00:00, 30730.53it/s]


... > number of categories for each label: 
...... > label  NR-AR :  2 with distribution:  [667, 25]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  52
... > Scaling features
mean [ 1.65871103e+00  3.99349084e-01  5.44439087e-02  9.64320154e-04
  1.16522019e-02  2.72018644e-02  1.46255223e-02  7.63420122e-04
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  4.01800064e-05  0.00000000e+00  8.47798136e-03
  0.00000000e+00  5.22340084e-03  0.00000000e+00  1.20540019e-04
  3.88540662e-02  4.90035358e-01  9.64320154e-04  1.20540019e-04
  3.21440051e-04  7.45740919e-02  3.73955320e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  6.54934105e-03  1.12563992e+06
 -1.12563992e+06 -4.49864208e+06  4.56780596e+08 -1.62552374e+24
 -1.08668119e+00 -1.12563992e+06]
std [1.67965557e+00 6.99619606e-01 2.26891537e-01 3.10385283e-02
 1.07314622e-01 1.62671211e-01 1.200483


[A
100%|██████████| 744/744 [00:00<00:00, 6288.31it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-AR']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 312.23it/s]


included in labels
{'global': ['NR-AR']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  di


744it [00:00, 30320.86it/s]


... > number of categories for each label: 
...... > label  NR-AR :  2 with distribution:  [667, 25]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  52
... > Scaling features
mean [1.65871103e+00 3.99349084e-01 5.44439087e-02 9.64320154e-04
 1.16522019e-02 2.72018644e-02 1.46255223e-02 7.63420122e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 4.01800064e-05 0.00000000e+00 8.47798136e-03
 0.00000000e+00 5.22340084e-03 0.00000000e+00 1.20540019e-04
 3.88540662e-02 4.90035358e-01 9.64320154e-04 1.20540019e-04
 3.21440051e-04 7.45740919e-02 3.73955320e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.54934105e-03]
std [1.67965557 0.69961961 0.22689154 0.03103853 0.10731462 0.16267121
 0.12004839 0.02761951 0.         0.         0.         0.
 0.         0.00633864 0.         0.09168481 0.         0.0720841
 0.         0.01097841 0.19324706 0.4999007  0.031038


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 2009.16it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['NR-AR-LBD']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 305.05it/s]


included in labels
{'global': ['NR-AR-LBD']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5',


744it [00:00, 31432.25it/s]


... > number of categories for each label: 
...... > label  NR-AR-LBD :  2 with distribution:  [631, 17]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  96
... > Scaling features
mean [ 1.64485736e+00  3.94627284e-01  5.41215654e-02  9.20285727e-04
  1.07804899e-02  2.85726807e-02  1.38481090e-02  7.01170078e-04
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  4.38231298e-05  0.00000000e+00  7.75669398e-03
  0.00000000e+00  4.82054428e-03  0.00000000e+00  1.31469390e-04
  3.91778781e-02  4.92133748e-01  1.05175512e-03  1.31469390e-04
  3.50585039e-04  7.43678514e-02  3.72628073e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  6.70493887e-03  1.07658723e+06
 -1.07658723e+06 -4.30261260e+06  4.97986751e+08 -1.53266112e+24
 -1.08662543e+00 -1.07658723e+06]
std [1.67330334e+00 6.94975844e-01 2.26257423e-01 3.03222493e-02
 1.03267957e-01 1.66602169e-01 1.16


[A
100%|██████████| 744/744 [00:00<00:00, 6191.70it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-AR-LBD']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 308.00it/s]


included in labels
{'global': ['NR-AR-LBD']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:


744it [00:00, 31383.25it/s]


... > number of categories for each label: 
...... > label  NR-AR-LBD :  2 with distribution:  [631, 17]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  96
... > Scaling features
mean [1.64485736e+00 3.94627284e-01 5.41215654e-02 9.20285727e-04
 1.07804899e-02 2.85726807e-02 1.38481090e-02 7.01170078e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 4.38231298e-05 0.00000000e+00 7.75669398e-03
 0.00000000e+00 4.82054428e-03 0.00000000e+00 1.31469390e-04
 3.91778781e-02 4.92133748e-01 1.05175512e-03 1.31469390e-04
 3.50585039e-04 7.43678514e-02 3.72628073e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.70493887e-03]
std [1.67330334 0.69497584 0.22625742 0.03032225 0.10326796 0.16660217
 0.11686034 0.02647033 0.         0.         0.         0.
 0.         0.00661976 0.         0.08772986 0.         0.06926259
 0.         0.01146526 0.19401797 0.49993812 0.0


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 2058.52it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['NR-AhR']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 304.66it/s]


included in labels
{'global': ['NR-AhR']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'r


744it [00:00, 31265.97it/s]


... > number of categories for each label: 
...... > label  NR-AhR :  2 with distribution:  [550, 72]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  122
... > Scaling features
mean [ 1.64841118e+00  3.99890740e-01  5.49030320e-02  9.56022945e-04
  1.19275244e-02  2.73604662e-02  1.46590185e-02  7.73923336e-04
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  4.55249021e-05  0.00000000e+00  8.55868160e-03
  0.00000000e+00  4.37039060e-03  0.00000000e+00  1.36574706e-04
  3.82409178e-02  4.93353364e-01  9.56022945e-04  9.10498042e-05
  3.18674315e-04  7.46608395e-02  3.71756351e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  6.73768551e-03  1.12513558e+06
 -1.12513558e+06 -4.49662513e+06  5.17256923e+08 -1.62203530e+24
 -1.08717504e+00 -1.12513558e+06]
std [1.67643065e+00 7.01884138e-01 2.27790889e-01 3.09048372e-02
 1.08559931e-01 1.63131453e-01 1.2018


[A
100%|██████████| 744/744 [00:00<00:00, 6360.26it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-AhR']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 313.12it/s]


included in labels
{'global': ['NR-AhR']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  d


744it [00:00, 30805.46it/s]


... > number of categories for each label: 
...... > label  NR-AhR :  2 with distribution:  [550, 72]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  122
... > Scaling features
mean [1.64841118e+00 3.99890740e-01 5.49030320e-02 9.56022945e-04
 1.19275244e-02 2.73604662e-02 1.46590185e-02 7.73923336e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 4.55249021e-05 0.00000000e+00 8.55868160e-03
 0.00000000e+00 4.37039060e-03 0.00000000e+00 1.36574706e-04
 3.82409178e-02 4.93353364e-01 9.56022945e-04 9.10498042e-05
 3.18674315e-04 7.46608395e-02 3.71756351e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.73768551e-03]
std [1.67643065 0.70188414 0.22779089 0.03090484 0.10855993 0.16313145
 0.12018374 0.02780871 0.         0.         0.         0.
 0.         0.00674706 0.         0.0921164  0.         0.06596431
 0.         0.01168572 0.19177734 0.49995582 0.030


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 2054.21it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['NR-Aromatase']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 306.42it/s]


included in labels
{'global': ['NR-Aromatase']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_


744it [00:00, 31221.86it/s]


... > number of categories for each label: 
...... > label  NR-Aromatase :  2 with distribution:  [554, 28]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  162
... > Scaling features
mean [ 1.64148170e+00  3.96981724e-01  5.66416777e-02  8.81963839e-04
  1.10735460e-02  3.00847665e-02  1.46014013e-02  6.85971875e-04
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  4.89979911e-05  0.00000000e+00  8.08466853e-03
  0.00000000e+00  4.21382723e-03  0.00000000e+00  1.46993973e-04
  3.91003969e-02  4.94732716e-01  9.79959822e-04  4.89979911e-05
  2.93987946e-04  7.60448822e-02  3.69640845e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  5.97775491e-03  1.00422848e+06
 -1.00422848e+06 -4.01343817e+06  5.56460017e+08 -1.41976815e+24
 -1.08752957e+00 -1.00422848e+06]
std [1.67561011e+00 6.99343261e-01 2.31156653e-01 2.96847769e-02
 1.04646656e-01 1.70820588e-01 


[A
100%|██████████| 744/744 [00:00<00:00, 6289.26it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-Aromatase']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 309.53it/s]


included in labels
{'global': ['NR-Aromatase']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include nam


744it [00:00, 28645.04it/s]


... > number of categories for each label: 
...... > label  NR-Aromatase :  2 with distribution:  [554, 28]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  162
... > Scaling features
mean [1.64148170e+00 3.96981724e-01 5.66416777e-02 8.81963839e-04
 1.10735460e-02 3.00847665e-02 1.46014013e-02 6.85971875e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 4.89979911e-05 0.00000000e+00 8.08466853e-03
 0.00000000e+00 4.21382723e-03 0.00000000e+00 1.46993973e-04
 3.91003969e-02 4.94732716e-01 9.79959822e-04 4.89979911e-05
 2.93987946e-04 7.60448822e-02 3.69640845e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 5.97775491e-03]
std [1.67561011 0.69934326 0.23115665 0.02968478 0.10464666 0.17082059
 0.11995083 0.02618208 0.         0.         0.         0.
 0.         0.00699969 0.         0.08955058 0.         0.06477709
 0.         0.01212322 0.19383384 0.49997225


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 2005.28it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['NR-ER']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 304.77it/s]


included in labels
{'global': ['NR-ER']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ri


744it [00:00, 31520.83it/s]


... > number of categories for each label: 
...... > label  NR-ER :  2 with distribution:  [541, 64]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  139
... > Scaling features
mean [ 1.63431677e+00  3.90867624e-01  5.45419255e-02  1.01902174e-03
  1.09666149e-02  2.89693323e-02  1.35869565e-02  7.76397516e-04
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  4.85248447e-05  0.00000000e+00  8.20069876e-03
  0.00000000e+00  4.99805901e-03  0.00000000e+00  1.45574534e-04
  3.82861025e-02  4.92333075e-01  9.21972050e-04  9.70496894e-05
  3.39673913e-04  7.50194099e-02  3.72088509e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  6.74495342e-03  1.13368263e+06
 -1.13368263e+06 -4.53078384e+06  5.51160538e+08 -1.64024502e+24
 -1.08644824e+00 -1.13368263e+06]
std [1.67195380e+00 6.94105104e-01 2.27083914e-01 3.19058511e-02
 1.04145803e-01 1.67720333e-01 1.15768


[A
100%|██████████| 744/744 [00:00<00:00, 6310.22it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-ER']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 312.91it/s]


included in labels
{'global': ['NR-ER']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  di


744it [00:00, 31810.34it/s]


... > number of categories for each label: 
...... > label  NR-ER :  2 with distribution:  [541, 64]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  139
... > Scaling features
mean [1.63431677e+00 3.90867624e-01 5.45419255e-02 1.01902174e-03
 1.09666149e-02 2.89693323e-02 1.35869565e-02 7.76397516e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 4.85248447e-05 0.00000000e+00 8.20069876e-03
 0.00000000e+00 4.99805901e-03 0.00000000e+00 1.45574534e-04
 3.82861025e-02 4.92333075e-01 9.21972050e-04 9.70496894e-05
 3.39673913e-04 7.50194099e-02 3.72088509e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.74495342e-03]
std [1.6719538  0.6941051  0.22708391 0.03190585 0.1041458  0.16772033
 0.11576852 0.02785309 0.         0.         0.         0.
 0.         0.00696581 0.         0.09018563 0.         0.07052006
 0.         0.01206455 0.1918861  0.49994121 0.0303


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 2030.64it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['NR-ER-LBD']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 304.30it/s]


included in labels
{'global': ['NR-ER-LBD']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5',


744it [00:00, 31589.75it/s]


... > number of categories for each label: 
...... > label  NR-ER-LBD :  2 with distribution:  [651, 29]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  64
... > Scaling features
mean [ 1.65199701e+00  3.96952587e-01  5.35580835e-02  8.71875778e-04
  1.12098314e-02  2.78585070e-02  1.36178693e-02  8.30357884e-04
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  4.15178942e-05  0.00000000e+00  8.46965042e-03
  0.00000000e+00  4.52545047e-03  0.00000000e+00  1.24553683e-04
  3.89853027e-02  4.90783027e-01  8.30357884e-04  1.24553683e-04
  3.32143154e-04  7.44415843e-02  3.73910155e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  6.60134518e-03  1.19207188e+06
 -1.19207188e+06 -4.76412458e+06  4.71954610e+08 -1.75551312e+24
 -1.08679175e+00 -1.19207188e+06]
std [1.67689209e+00 6.98265854e-01 2.25143544e-01 2.95146677e-02
 1.05281390e-01 1.64567344e-01 1.15


[A
100%|██████████| 744/744 [00:00<00:00, 6453.30it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-ER-LBD']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 312.89it/s]


included in labels
{'global': ['NR-ER-LBD']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:


744it [00:00, 31170.71it/s]


... > number of categories for each label: 
...... > label  NR-ER-LBD :  2 with distribution:  [651, 29]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  64
... > Scaling features
mean [1.65199701e+00 3.96952587e-01 5.35580835e-02 8.71875778e-04
 1.12098314e-02 2.78585070e-02 1.36178693e-02 8.30357884e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 4.15178942e-05 0.00000000e+00 8.46965042e-03
 0.00000000e+00 4.52545047e-03 0.00000000e+00 1.24553683e-04
 3.89853027e-02 4.90783027e-01 8.30357884e-04 1.24553683e-04
 3.32143154e-04 7.44415843e-02 3.73910155e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.60134518e-03]
std [1.67689209 0.69826585 0.22514354 0.02951467 0.10528139 0.16456734
 0.11589833 0.02880396 0.         0.         0.         0.
 0.         0.0064433  0.         0.09164014 0.         0.06711908
 0.         0.01115967 0.19355994 0.49991504 0.0


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 2068.27it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['NR-PPAR-gamma']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 313.93it/s]


included in labels
{'global': ['NR-PPAR-gamma']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size


744it [00:00, 32394.84it/s]


... > number of categories for each label: 
...... > label  NR-PPAR-gamma :  2 with distribution:  [605, 15]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  124
... > Scaling features
mean [ 1.63701068e+00  3.91553974e-01  5.57532622e-02  8.54092527e-04
  1.12930012e-02  2.97508897e-02  1.38552788e-02  9.01542112e-04
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  4.74495848e-05  0.00000000e+00  8.25622776e-03
  0.00000000e+00  4.46026097e-03  0.00000000e+00  1.42348754e-04
  3.90510083e-02  4.91814947e-01  1.09134045e-03  9.48991696e-05
  3.79596679e-04  7.63463820e-02  3.70676157e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  6.73784104e-03  1.24689782e+06
 -1.24689782e+06 -4.98322905e+06  5.39119377e+08 -1.86401673e+24
 -1.08702145e+00 -1.24689782e+06]
std [1.67081048e+00 6.92577033e-01 2.29444625e-01 2.92123784e-02
 1.05666784e-01 1.69899306e-01


[A
100%|██████████| 744/744 [00:00<00:00, 6477.60it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-PPAR-gamma']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 322.14it/s]


included in labels
{'global': ['NR-PPAR-gamma']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include na


744it [00:00, 31863.61it/s]


... > number of categories for each label: 
...... > label  NR-PPAR-gamma :  2 with distribution:  [605, 15]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  124
... > Scaling features
mean [1.63701068e+00 3.91553974e-01 5.57532622e-02 8.54092527e-04
 1.12930012e-02 2.97508897e-02 1.38552788e-02 9.01542112e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 4.74495848e-05 0.00000000e+00 8.25622776e-03
 0.00000000e+00 4.46026097e-03 0.00000000e+00 1.42348754e-04
 3.90510083e-02 4.91814947e-01 1.09134045e-03 9.48991696e-05
 3.79596679e-04 7.63463820e-02 3.70676157e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.73784104e-03]
std [1.67081048 0.69257703 0.22944462 0.02921238 0.10566678 0.16989931
 0.11689016 0.03001215 0.         0.         0.         0.
 0.         0.0068882  0.         0.09048791 0.         0.06663608
 0.         0.01193015 0.19371636 0.499933 


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 2062.43it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['SR-ARE']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 310.20it/s]


included in labels
{'global': ['SR-ARE']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'r


744it [00:00, 30591.35it/s]


... > number of categories for each label: 
...... > label  SR-ARE :  2 with distribution:  [454, 83]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  207
... > Scaling features
mean [ 1.62192982e+00  3.80818713e-01  5.57894737e-02  1.05263158e-03
  1.16959064e-02  2.88304094e-02  1.42105263e-02  9.35672515e-04
  0.00000000e+00  0.00000000e+00  5.84795322e-05  0.00000000e+00
  0.00000000e+00  5.84795322e-05  0.00000000e+00  1.05263158e-02
  0.00000000e+00  4.32748538e-03  0.00000000e+00  1.16959064e-04
  3.97076023e-02  4.88654971e-01  9.94152047e-04  1.75438596e-04
  4.09356725e-04  7.58479532e-02  3.71520468e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  6.66666667e-03  1.35519275e+06
 -1.35519275e+06 -5.41601424e+06  6.63887763e+08 -2.00194796e+24
 -1.08638986e+00 -1.35519275e+06]
std [1.65339941e+00 6.82106667e-01 2.29514724e-01 3.24272038e-02
 1.07513312e-01 1.67329665e-01 1.1835


[A
100%|██████████| 744/744 [00:00<00:00, 6101.70it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['SR-ARE']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 319.45it/s]


included in labels
{'global': ['SR-ARE']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  d


744it [00:00, 31348.89it/s]


... > number of categories for each label: 
...... > label  SR-ARE :  2 with distribution:  [454, 83]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  207
... > Scaling features
mean [1.62192982e+00 3.80818713e-01 5.57894737e-02 1.05263158e-03
 1.16959064e-02 2.88304094e-02 1.42105263e-02 9.35672515e-04
 0.00000000e+00 0.00000000e+00 5.84795322e-05 0.00000000e+00
 0.00000000e+00 5.84795322e-05 0.00000000e+00 1.05263158e-02
 0.00000000e+00 4.32748538e-03 0.00000000e+00 1.16959064e-04
 3.97076023e-02 4.88654971e-01 9.94152047e-04 1.75438596e-04
 4.09356725e-04 7.58479532e-02 3.71520468e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.66666667e-03]
std [1.65339941 0.68210667 0.22951472 0.0324272  0.10751331 0.16732967
 0.11835788 0.03057445 0.         0.         0.00764697 0.
 0.         0.00764697 0.         0.10205642 0.         0.06564113
 0.         0.01081413 0.19527137 0.49987127 0.031


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 2036.84it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['SR-ATAD5']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 310.24it/s]


included in labels
{'global': ['SR-ATAD5']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 


744it [00:00, 32605.71it/s]


... > number of categories for each label: 
...... > label  SR-ATAD5 :  2 with distribution:  [669, 15]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  60
... > Scaling features
mean [ 1.64973474e+00  3.96131835e-01  5.53072392e-02  1.00254814e-03
  1.17799407e-02  2.80713480e-02  1.44534024e-02  7.93683947e-04
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  4.17728393e-05  0.00000000e+00  8.85584193e-03
  0.00000000e+00  4.51146664e-03  0.00000000e+00  1.25318518e-04
  3.86398764e-02  4.90287815e-01  9.60775304e-04  4.17728393e-05
  3.34182714e-04  7.53164293e-02  3.73240319e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  6.85074565e-03  1.10246306e+06
 -1.10246306e+06 -4.40601718e+06  2.29515691e+06 -1.59211374e+24
 -1.08693802e+00 -1.10246306e+06]
std [1.67642303e+00 6.98278388e-01 2.28578977e-01 3.16471648e-02
 1.07894271e-01 1.65176716e-01 1.193


[A
100%|██████████| 744/744 [00:00<00:00, 6388.30it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['SR-ATAD5']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 320.82it/s]


included in labels
{'global': ['SR-ATAD5']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names: 


744it [00:00, 32939.21it/s]


... > number of categories for each label: 
...... > label  SR-ATAD5 :  2 with distribution:  [669, 15]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  60
... > Scaling features
mean [1.64973474e+00 3.96131835e-01 5.53072392e-02 1.00254814e-03
 1.17799407e-02 2.80713480e-02 1.44534024e-02 7.93683947e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 4.17728393e-05 0.00000000e+00 8.85584193e-03
 0.00000000e+00 4.51146664e-03 0.00000000e+00 1.25318518e-04
 3.86398764e-02 4.90287815e-01 9.60775304e-04 4.17728393e-05
 3.34182714e-04 7.53164293e-02 3.73240319e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.85074565e-03]
std [1.67642303 0.69827839 0.22857898 0.03164716 0.10789427 0.16517672
 0.11935033 0.02816121 0.         0.         0.         0.
 0.         0.00646306 0.         0.09368786 0.         0.06701577
 0.         0.01119387 0.19273515 0.49990566 0.03


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 2062.38it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['SR-HSE']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 312.52it/s]


included in labels
{'global': ['SR-HSE']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'r


744it [00:00, 31210.62it/s]


... > number of categories for each label: 
...... > label  SR-HSE :  2 with distribution:  [565, 31]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  148
... > Scaling features
mean [ 1.62790817e+00  3.80886639e-01  5.39960117e-02  1.07378432e-03
  1.08401084e-02  2.93501048e-02  1.27320141e-02  9.20386562e-04
  0.00000000e+00  0.00000000e+00  5.11325868e-05  0.00000000e+00
  0.00000000e+00  5.11325868e-05  0.00000000e+00  1.06867106e-02
  0.00000000e+00  4.65306540e-03  0.00000000e+00  1.53397760e-04
  3.99856829e-02  4.87037889e-01  1.07378432e-03  1.53397760e-04
  3.57928108e-04  7.33241295e-02  3.74699596e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  6.85176663e-03  1.35256767e+06
 -1.35256767e+06 -5.40552551e+06  5.80813548e+08 -1.97518176e+24
 -1.08551791e+00 -1.35256767e+06]
std [1.65611459e+00 6.81525447e-01 2.26009828e-01 3.27510505e-02
 1.03549990e-01 1.68785889e-01 1.1211


[A
100%|██████████| 744/744 [00:00<00:00, 6237.29it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['SR-HSE']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 317.62it/s]


included in labels
{'global': ['SR-HSE']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  d


744it [00:00, 30295.54it/s]


... > number of categories for each label: 
...... > label  SR-HSE :  2 with distribution:  [565, 31]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  148
... > Scaling features
mean [1.62790817e+00 3.80886639e-01 5.39960117e-02 1.07378432e-03
 1.08401084e-02 2.93501048e-02 1.27320141e-02 9.20386562e-04
 0.00000000e+00 0.00000000e+00 5.11325868e-05 0.00000000e+00
 0.00000000e+00 5.11325868e-05 0.00000000e+00 1.06867106e-02
 0.00000000e+00 4.65306540e-03 0.00000000e+00 1.53397760e-04
 3.99856829e-02 4.87037889e-01 1.07378432e-03 1.53397760e-04
 3.57928108e-04 7.33241295e-02 3.74699596e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.85176663e-03]
std [1.65611459 0.68152545 0.22600983 0.03275105 0.10354999 0.16878589
 0.11211561 0.03032391 0.         0.         0.00715052 0.
 0.         0.00715052 0.         0.10282269 0.         0.0680545
 0.         0.01238443 0.19592557 0.49983196 0.0327


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 2025.63it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['SR-MMP']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 302.89it/s]


included in labels
{'global': ['SR-MMP']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'r


744it [00:00, 30406.83it/s]


... > number of categories for each label: 
...... > label  SR-MMP :  2 with distribution:  [488, 90]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  166
... > Scaling features
mean [ 1.64173228e+00  3.96391907e-01  5.44702482e-02  8.97039769e-04
  1.12628327e-02  2.71603708e-02  1.51500050e-02  8.97039769e-04
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  4.98355427e-05  0.00000000e+00  8.62154889e-03
  0.00000000e+00  4.58486993e-03  0.00000000e+00  1.49506628e-04
  3.79746835e-02  4.93421708e-01  8.47204226e-04  1.49506628e-04
  2.49177714e-04  7.70457490e-02  3.69530549e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  6.47862055e-03  1.29088861e+06
 -1.29088861e+06 -5.15903164e+06  5.66030968e+08 -1.92506953e+24
 -1.08778143e+00 -1.29088861e+06]
std [1.67808758e+00 6.96744533e-01 2.26943253e-01 2.99371857e-02
 1.05527159e-01 1.62550561e-01 1.2214


[A
100%|██████████| 744/744 [00:00<00:00, 5798.24it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['SR-MMP']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 308.96it/s]


included in labels
{'global': ['SR-MMP']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  d


744it [00:00, 30429.96it/s]


... > number of categories for each label: 
...... > label  SR-MMP :  2 with distribution:  [488, 90]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  166
... > Scaling features
mean [1.64173228e+00 3.96391907e-01 5.44702482e-02 8.97039769e-04
 1.12628327e-02 2.71603708e-02 1.51500050e-02 8.97039769e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 4.98355427e-05 0.00000000e+00 8.62154889e-03
 0.00000000e+00 4.58486993e-03 0.00000000e+00 1.49506628e-04
 3.79746835e-02 4.93421708e-01 8.47204226e-04 1.49506628e-04
 2.49177714e-04 7.70457490e-02 3.69530549e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.47862055e-03]
std [1.67808758 0.69674453 0.22694325 0.02993719 0.10552716 0.16255056
 0.12214943 0.02993719 0.         0.         0.         0.
 0.         0.00705925 0.         0.09245116 0.         0.06755626
 0.         0.01222638 0.19113505 0.49995672 0.029


[A
[A
[A
100%|██████████| 744/744 [00:00<00:00, 1989.77it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys ['extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density']
selected bond keys ['extra_feat_bond_Lagrangian_K', 'bond_length', 'extra_feat_bond_Hamiltonian_K', 'extra_feat_bond_e_density', 'extra_feat_bond_lap_e_density', 'extra_feat_bond_e_loc_func', 'extra_feat_bond_esp_e', 'extra_feat_bond_esp_total', 'extra_feat_bond_grad_norm', 'extra_feat_bond_lap_norm', 'extra_feat_bond_ellip_e_dens', 'extra_feat_bond_eta']
selected global keys ['SR-p53']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 303.93it/s]


included in labels
{'global': ['SR-p53']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S', 'extra_feat_atom_Hamiltonian_K', 'extra_feat_atom_e_density', 'extra_feat_atom_lap_e_density', 'extra_feat_atom_esp_total', 'extra_feat_atom_det_hessian', 'extra_feat_atom_eta', 'extra_feat_atom_energy_density'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'r


744it [00:00, 29815.90it/s]


... > number of categories for each label: 
...... > label  SR-p53 :  2 with distribution:  [623, 38]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  83
... > Scaling features
mean [ 1.64869288e+00  3.95968819e-01  5.64623799e-02  1.03363625e-03
  1.23174986e-02  2.84680649e-02  1.46431802e-02  8.61363538e-04
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  4.30681769e-05  0.00000000e+00  8.65670356e-03
  0.00000000e+00  4.34988587e-03  0.00000000e+00  1.29204531e-04
  3.86752229e-02  4.91106421e-01  9.04431715e-04  1.29204531e-04
  3.44545415e-04  7.67905595e-02  3.71592231e-01  0.00000000e+00
  0.00000000e+00  0.00000000e+00  6.41715836e-03  1.22909428e+06
 -1.22909428e+06 -4.91207969e+06  4.89522714e+08 -1.82106231e+24
 -1.08701791e+00 -1.22909428e+06]
std [1.67994978e+00 6.97539333e-01 2.30812434e-01 3.21335937e-02
 1.10298585e-01 1.66305845e-01 1.20119


[A
100%|██████████| 744/744 [00:00<00:00, 6110.16it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'Ca', 'Na', 'Al', 'V', 'Ti', 'B', 'Cr', 'Cl', 'Cu', 'F', 'Fe', 'As', 'N', 'H', 'P', 'Se', 'Si', 'O', 'C', 'Ni', 'Zn', 'Ge', 'S'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['SR-p53']
... > Building graphs and featurizing



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 744/744 [00:02<00:00, 310.47it/s]


included in labels
{'global': ['SR-p53']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_Ca', 'chemical_symbol_Na', 'chemical_symbol_Al', 'chemical_symbol_V', 'chemical_symbol_Ti', 'chemical_symbol_B', 'chemical_symbol_Cr', 'chemical_symbol_Cl', 'chemical_symbol_Cu', 'chemical_symbol_F', 'chemical_symbol_Fe', 'chemical_symbol_As', 'chemical_symbol_N', 'chemical_symbol_H', 'chemical_symbol_P', 'chemical_symbol_Se', 'chemical_symbol_Si', 'chemical_symbol_O', 'chemical_symbol_C', 'chemical_symbol_Ni', 'chemical_symbol_Zn', 'chemical_symbol_Ge', 'chemical_symbol_S'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  d


744it [00:00, 30173.10it/s]


... > number of categories for each label: 
...... > label  SR-p53 :  2 with distribution:  [623, 38]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  83
... > Scaling features
mean [1.64869288e+00 3.95968819e-01 5.64623799e-02 1.03363625e-03
 1.23174986e-02 2.84680649e-02 1.46431802e-02 8.61363538e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 4.30681769e-05 0.00000000e+00 8.65670356e-03
 0.00000000e+00 4.34988587e-03 0.00000000e+00 1.29204531e-04
 3.86752229e-02 4.91106421e-01 9.04431715e-04 1.29204531e-04
 3.44545415e-04 7.67905595e-02 3.71592231e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.41715836e-03]
std [1.67994978 0.69753933 0.23081243 0.03213359 0.11029858 0.16630585
 0.12011976 0.02933635 0.         0.         0.         0.
 0.         0.00656249 0.         0.09263782 0.         0.06581006
 0.         0.01136608 0.19281973 0.4999209  0.0300

In [60]:
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    EarlyStopping,
    ModelCheckpoint,
)
import pytorch_lightning as pl

dataloader_qtaim_train = DataLoaderMoleculeGraphTask(
    dataset=dataset_train_qtaim,
    batch_size=1024,
    shuffle=True
)

dataloader_qtaim_test = DataLoaderMoleculeGraphTask(
    dataset=dataset_test_qtaim,
    batch_size=1024, 
    shuffle=False
)


dataloader_bl_train = DataLoaderMoleculeGraphTask(
    dataset=dataset_bl_train,
    batch_size=1024,
    shuffle=True
)

dataloader_bl_test = DataLoaderMoleculeGraphTask(
    dataset=dataset_bl_test,
    batch_size=1024,
    shuffle=False
)


qtaim_model_bl_dict = {
    "atom_feature_size": 38,
    "bond_feature_size": 18,
    "global_feature_size": 3,
    "conv_fn": "ResidualBlock",
    "target_dict": {"global": ['NR-AR',
  'NR-AR-LBD',
  'NR-AhR',
  'NR-Aromatase',
  'NR-ER',
  'NR-ER-LBD',
  'NR-PPAR-gamma',
  'SR-ARE',
  'SR-ATAD5',
  'SR-HSE',
  'SR-MMP',
  'SR-p53']},
    "dropout": 0.2,
    "batch_norm_tf": True,
    "activation": "ReLU",
    "bias": True,
    "norm": "both",
    "aggregate": "sum",
    "n_conv_layers": 4,
    "lr": 0.001,
    "weight_decay": 1e-05,
    "lr_plateau_patience": 10,
    "lr_scale_factor": 0.1,
    "scheduler_name": "reduce_on_plateau",
    "loss_fn": "mse",
    "resid_n_graph_convs": 4,
    "embedding_size": 100,
    "fc_layer_size": [512, 512, 512],
    "shape_fc": "flat",
    "fc_dropout": 0.2,
    "fc_batch_norm": True,
    "n_fc_layers": 2,
    "global_pooling_fn": "MeanPoolingThenCat",
    "ntypes_pool": ["atom", "bond", "global"],
    "ntypes_pool_direct_cat": ["global"],
    "lstm_iters": 15,
    "lstm_layers": 3,
    "num_heads": 1,
    "feat_drop": 0.2,
    "attn_drop": 0.1,
    "residual": True,
    "ntasks": 12,
    "num_heads_gat": 1,
    "dropout_feat_gat": 0.2,
    "dropout_attn_gat": 0.1,
    "hidden_size_gat": 100,
    "residual_gat": True,
    "shape_fc": "flat",
    "classifier": False,
    "batch_norm": False,
    "pooling_ntypes": ["atom", "bond", "global"],
    "pooling_ntypes_direct": ["global"],
    "fc_hidden_size_1": 1024,
    "fc_num_layers": 2,
    "restore": False,
    "classifier": True,
}

non_qtaim_model_bl_dict = {
    "atom_feature_size": 31,
    "bond_feature_size": 7,
    "global_feature_size": 3,
    "conv_fn": "GraphConvDropoutBatch",
    "target_dict": {"global": 
        ['NR-AR',
        'NR-AR-LBD',
        'NR-AhR',
        'NR-Aromatase',
        'NR-ER',
        'NR-ER-LBD',
        'NR-PPAR-gamma',
        'SR-ARE',
        'SR-ATAD5',
        'SR-HSE',
        'SR-MMP',
        'SR-p53']},
    "dropout": 0.2,
    "batch_norm_tf": True,
    "activation": "ReLU",
    "bias": True,
    "norm": "both",
    "fc_num_layers": 3,
    "aggregate": "sum",
    "n_conv_layers": 5,
    "lr": 0.001,
    "weight_decay": 1e-05,
    "lr_plateau_patience": 10,
    "lr_scale_factor": 0.1,
    "scheduler_name": "reduce_on_plateau",
    "loss_fn": "mse",
    "resid_n_graph_convs": 2,
    "embedding_size": 50,
    "fc_layer_size": [1024, 512, 256],
    "fc_dropout": 0.1,
    "fc_batch_norm": True,
    "n_fc_layers": 3,
    "global_pooling_fn": "MeanPoolingThenCat",
    "ntypes_pool": ["atom", "bond", "global"],
    "ntypes_pool_direct_cat": ["global"],
    "lstm_iters": 9,
    "lstm_layers": 2,
    "num_heads": 3,
    "feat_drop": 0.1,
    "attn_drop": 0.1,
    "residual": False,
    "hidden_size": 10,
    "ntasks": 12,
    "shape_fc": "cone",
    "num_heads_gat": 1,
    "dropout_feat_gat": 0.1,
    "dropout_attn_gat": 0.1,
    "hidden_size_gat": 10,
    "residual_gat": False,
    "batch_norm": True,
    "pooling_ntypes": ["atom", "bond", "global"],
    "pooling_ntypes_direct": ["global"],
    "fc_hidden_size_1": 1024,
    "restore": False,
    "classifier": True,
}

%load_ext autoreload
%autoreload 2
model_temp_qtaim = load_graph_level_model_from_config(qtaim_model_bl_dict)
model_temp_noqtaim = load_graph_level_model_from_config(non_qtaim_model_bl_dict)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
:::CLASSIFIER MODEL:::
... > number of tasks: 12
:::CLASSIFIER MODEL:::
... > number of tasks: 12


In [61]:
early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=100,
    verbose=False,
    mode="min",
)
lr_monitor = LearningRateMonitor(logging_interval="step")

trainer = pl.Trainer(
    max_epochs=100,
    accelerator="gpu",
    devices=[0],
    num_nodes=1,
    gradient_clip_val=100.0,
    accumulate_grad_batches=2,
    enable_progress_bar=True,
    callbacks=[
        early_stopping_callback,
        lr_monitor,
    ],
    enable_checkpointing=True,
    strategy="auto",
    default_root_dir="./qtaim/",
    precision="bf16-mixed",
    num_sanity_val_steps=0
)

trainer.fit(model_temp_qtaim, dataloader_qtaim_train, val_dataloaders=dataloader_qtaim_test)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name        | Type               | Params
----------------------------------------------------
0  | activation  | ReLU               | 0     
1  | embedding   | UnifySize          | 5.9 K 
2  | conv_layers | ModuleList         | 336 K 
3  | readout     | MeanPoolingThenCat | 0     
4  | fc_layers   | ModuleList         | 1.3 M 
5  | train_auroc | MultilabelAUROC    | 0     
6  | train_acc   | MultilabelAccuracy | 0     
7  | train_f1    | MultilabelF1Score  | 0     
8  | val_auroc   | MultilabelAUROC    | 0     
9  | val_acc     | MultilabelAccuracy | 0     
10 | val_f1      | MultilabelF1Score  | 0     
11 | test_auroc  | MultilabelAUROC    | 0     
12 | test_acc    | MultilabelAccuracy | 0     
13 | test_f1     | MultilabelF1Score  | 0     
14 | loss        | ModuleList         | 0     
----------------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.557     Total 

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00034: reducing learning rate of group 0 to 1.0000e-04.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00045: reducing learning rate of group 0 to 1.0000e-05.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00056: reducing learning rate of group 0 to 1.0000e-06.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00067: reducing learning rate of group 0 to 1.0000e-07.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00078: reducing learning rate of group 0 to 1.0000e-08.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [62]:
from sklearn.metrics import f1_score


def manual_eval_separate_tasks(model, dataset_list):
    n_tasks = len(dataset_list)

    temp_data_loader = []
    for i in dataset_list:
        temp_data_loader.append(
            DataLoaderMoleculeGraphTask(i, batch_size=len(i.graphs), shuffle=False)
        )
    task_ind = 0
    statistics_dict = {"f1": [], "acc": [], "auroc": []}
    for task_ind in range(n_tasks):
        batch_graph, batch_label = next(iter(temp_data_loader[task_ind]))

        labels = batch_label["global"]
        labels_one_hot = torch.argmax(labels, axis=2)
        logits = model.forward(batch_graph, batch_graph.ndata["feat"])
        logits_one_hot = torch.argmax(logits, axis=-1)

        # take logits from first task only
        labels = labels.reshape(-1)
        logits = logits[:, task_ind].reshape(-1)
        logits_one_hot = logits_one_hot[:, task_ind].reshape(-1)
        labels_one_hot = labels_one_hot.reshape(-1)
        # compute acc, auroc, f1 manually
        acc_manual = torch.sum(logits_one_hot == labels_one_hot) / len(labels_one_hot)
        auroc_manual = torchmetrics.functional.auroc(
            logits, labels, num_classes=2, task="binary"
        )
        f1 = f1_score(labels_one_hot, logits_one_hot, pos_label=0)
        # print("-" * 50)
        # print(
        #    "acc: {:.4f}\t auroc: {:.4f}\t f1: {:.4f}".format(
        #        acc_manual, auroc_manual, f1
        #    )
        # )
        statistics_dict["f1"].append(f1)
        statistics_dict["acc"].append(acc_manual.numpy())
        statistics_dict["auroc"].append(auroc_manual.numpy())
    return statistics_dict

In [63]:


model_temp_qtaim.cpu()

stats_dict = manual_eval_separate_tasks(
    model_temp_qtaim, dataset_list=dict_datasets["qtaim"]["single_list"]
)
print(np.array(stats_dict["f1"]).mean(), np.array(stats_dict["auroc"]).mean())


In [None]:
early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=50,
    verbose=False,
    mode="min",
)
lr_monitor = LearningRateMonitor(logging_interval="step")

trainer = pl.Trainer(
    max_epochs=100,
    accelerator="gpu",
    devices=[0],
    num_nodes=1,
    gradient_clip_val=100.0,
    accumulate_grad_batches=2,
    enable_progress_bar=True,
    callbacks=[
        early_stopping_callback,
        lr_monitor,
    ],
    enable_checkpointing=True,
    strategy="auto",
    default_root_dir="./no_qtaim/",
    precision="bf16-mixed",
    num_sanity_val_steps=0
)

trainer.fit(model_temp_noqtaim, dataloader_bl_train, val_dataloaders=dataloader_bl_test)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name        | Type               | Params
----------------------------------------------------
0  | activation  | ReLU               | 0     
1  | embedding   | UnifySize          | 2.0 K 
2  | conv_layers | ModuleList         | 119 K 
3  | readout     | MeanPoolingThenCat | 0     
4  | fc_layers   | ModuleList         | 820 K 
5  | train_auroc | MultilabelAUROC    | 0     
6  | train_acc   | MultilabelAccuracy | 0     
7  | train_f1    | MultilabelF1Score  | 0     
8  | val_auroc   | MultilabelAUROC    | 0     
9  | val_acc     | MultilabelAccuracy | 0     
10 | val_f1      | MultilabelF1Score  | 0     
11 | test_auroc  | MultilabelAUROC    | 0     
12 | test_acc    | MultilabelAccuracy | 0     
13 | test_f1     | MultilabelF1Score  | 0    

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00080: reducing learning rate of group 0 to 1.0000e-04.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00091: reducing learning rate of group 0 to 1.0000e-05.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [None]:


model_temp_noqtaim.cpu()

stats_dict = manual_eval_separate_tasks(
    model_temp_noqtaim, dataset_list=dict_datasets["bl"]["single_list"]
)
print(np.array(stats_dict["f1"]).mean(), np.array(stats_dict["auroc"]).mean())

0.9629272824038767 0.9429536
