In [1]:
import torch
from torchmetrics.wrappers import MultioutputWrapper
import torchmetrics

from qtaim_embed.models.utils import load_graph_level_model_from_config
from qtaim_embed.data.dataloader import DataLoaderMoleculeGraphTask
from qtaim_embed.core.dataset import HeteroGraphGraphLabelClassifierDataset


In [3]:
! ls ../../../../data/saved_models/1218/tox21/*

../../../../data/saved_models/1218/tox21/bl:
'model_lightning_epoch=130-val_loss=4.51.ckpt'
'model_lightning_epoch=162-val_loss=4.51.ckpt'
'model_lightning_epoch=171-val_loss=4.51.ckpt'
'model_lightning_epoch=192-val_loss=4.51.ckpt'
'model_lightning_epoch=192-val_loss=4.51-v1.ckpt'
'model_lightning_epoch=21-val_loss=4.51.ckpt'
'model_lightning_epoch=26-val_loss=4.51.ckpt'
'model_lightning_epoch=322-val_loss=4.51.ckpt'
'model_lightning_epoch=34-val_loss=4.51.ckpt'
'model_lightning_epoch=40-val_loss=4.51.ckpt'
'model_lightning_epoch=47-val_loss=4.50.ckpt'
'model_lightning_epoch=490-val_loss=4.51.ckpt'
'model_lightning_epoch=510-val_loss=4.50.ckpt'
'model_lightning_epoch=64-val_loss=4.51.ckpt'
'model_lightning_epoch=67-val_loss=4.51.ckpt'
'model_lightning_epoch=76-val_loss=4.51.ckpt'
'model_lightning_epoch=83-val_loss=4.51.ckpt'
'model_lightning_epoch=85-val_loss=4.51.ckpt'
'model_lightning_epoch=86-val_loss=4.51.ckpt'
'model_lightning_epoch=92-val_loss=4.50.ckpt'
'model_lightning_epoch=9

:::RESTORING MODEL FROM EXISTING FILE:::
readout in size 212
... > number of tasks: 12
... > number of tasks: 12
:::MODEL LOADED:::


In [18]:
bl_keys = {
    "atom": [],
    "bond": ["bond_length"],
    "global": [
        "NR-AR",
        "NR-AR-LBD",
        "NR-AhR",
        "NR-Aromatase",
        "NR-ER",
        "NR-ER-LBD",
        "NR-PPAR-gamma",
        "SR-ARE",
        "SR-ATAD5",
        "SR-HSE",
        "SR-MMP",
        "SR-p53",
    ],
}

base_dict = {
    "atom": [],
    "bond": [],
    "global": [
        "NR-AR",
        "NR-AR-LBD",
        "NR-AhR",
        "NR-Aromatase",
        "NR-ER",
        "NR-ER-LBD",
        "NR-PPAR-gamma",
        "SR-ARE",
        "SR-ATAD5",
        "SR-HSE",
        "SR-MMP",
        "SR-p53",
    ],
}

tox21_loc = "../../../../data/splits_1205/test_tox21_qtaim_1205_labelled.pkl"

dataset_bl = HeteroGraphGraphLabelClassifierDataset(
    file=tox21_loc,
    allowed_ring_size=[3, 4, 5, 6],
    allowed_charges=None,
    allowed_spins=None,
    self_loop=True,
    extra_keys=bl_keys,
    target_list=[
        "NR-AR",
        "NR-AR-LBD",
        "NR-AhR",
        "NR-Aromatase",
        "NR-ER",
        "NR-ER-LBD",
        "NR-PPAR-gamma",
        "SR-ARE",
        "SR-ATAD5",
        "SR-HSE",
        "SR-MMP",
        "SR-p53",
    ],
    extra_dataset_info={},
    debug=False,
    impute=False,
    element_set={
        "Fe",
        "H",
        "Cu",
        "Cr",
        "Ge",
        "Na",
        "P",
        "N",
        "C",
        "Br",
        "S",
        "V",
        "F",
        "Se",
        "B",
        "Cl",
        "Zn",
        "Ti",
        "O",
        "Si",
        "Ni",
        "Ca",
        "Al",
        "As",
    },
    log_scale_features=True,
    standard_scale_features=True,
)
dataset_bare = HeteroGraphGraphLabelClassifierDataset(
    file=tox21_loc,
    allowed_ring_size=[3, 4, 5, 6],
    allowed_charges=None,
    allowed_spins=None,
    self_loop=True,
    extra_keys=base_dict,
    target_list=[
        "NR-AR",
        "NR-AR-LBD",
        "NR-AhR",
        "NR-Aromatase",
        "NR-ER",
        "NR-ER-LBD",
        "NR-PPAR-gamma",
        "SR-ARE",
        "SR-ATAD5",
        "SR-HSE",
        "SR-MMP",
        "SR-p53",
    ],
    extra_dataset_info={},
    debug=False,
    impute=False,
    element_set={
        "Fe",
        "H",
        "Cu",
        "Cr",
        "Ge",
        "Na",
        "P",
        "N",
        "C",
        "Br",
        "S",
        "V",
        "F",
        "Se",
        "B",
        "Cl",
        "Zn",
        "Ti",
        "O",
        "Si",
        "Ni",
        "Ca",
        "Al",
        "As",
    },
    log_scale_features=True,
    standard_scale_features=True,
)

dataset_qtaim = HeteroGraphGraphLabelClassifierDataset(
    file=tox21_loc,
    allowed_ring_size=[3, 4, 5, 6],
    allowed_charges=None,
    allowed_spins=None,
    self_loop=True,
    extra_keys=base_dict,
    target_list=[
        "NR-AR",
        "NR-AR-LBD",
        "NR-AhR",
        "NR-Aromatase",
        "NR-ER",
        "NR-ER-LBD",
        "NR-PPAR-gamma",
        "SR-ARE",
        "SR-ATAD5",
        "SR-HSE",
        "SR-MMP",
        "SR-p53",
    ],
    extra_dataset_info={},
    debug=False,
    impute=False,
    element_set={
        "Fe",
        "H",
        "Cu",
        "Cr",
        "Ge",
        "Na",
        "P",
        "N",
        "C",
        "Br",
        "S",
        "V",
        "F",
        "Se",
        "B",
        "Cl",
        "Zn",
        "Ti",
        "O",
        "Si",
        "Ni",
        "Ca",
        "Al",
        "As",
    },
    log_scale_features=True,
    standard_scale_features=True,
)


# make a dataset for each task
dict_datasets = {}
dict_datasets["bare"] = {}
dict_datasets["bl"] = {}
dict_datasets["bare"]["test"] = dataset_bare
dict_datasets["bl"]["test"] = dataset_bl
dict_datasets["bare"]["single_list"] = []
dict_datasets["bl"]["single_list"] = []

for task in base_dict["global"]:
    base_dict_temp = base_dict.copy()
    base_dict_temp["global"] = [task]
    base_dict_bl_temp = bl_keys.copy()
    base_dict_bl_temp["global"] = [task]

    dict_datasets["bare"]["single_list"].append(
        HeteroGraphGraphLabelClassifierDataset(
            file=tox21_loc,
            allowed_ring_size=[3, 4, 5, 6],
            allowed_charges=None,
            allowed_spins=None,
            self_loop=True,
            extra_keys=base_dict_temp,
            target_list=[task],
            extra_dataset_info={},
            debug=False,
            impute=False,
            element_set={
                "Fe",
                "H",
                "Cu",
                "Cr",
                "Ge",
                "Na",
                "P",
                "N",
                "C",
                "Br",
                "S",
                "V",
                "F",
                "Se",
                "B",
                "Cl",
                "Zn",
                "Ti",
                "O",
                "Si",
                "Ni",
                "Ca",
                "Al",
                "As",
            },
            log_scale_features=True,
            standard_scale_features=True,
        )
    )

    dict_datasets["bl"]["single_list"].append(
        HeteroGraphGraphLabelClassifierDataset(
            file=tox21_loc,
            allowed_ring_size=[3, 4, 5, 6],
            allowed_charges=None,
            allowed_spins=None,
            self_loop=True,
            extra_keys=base_dict_bl_temp,
            target_list=[task],
            extra_dataset_info={},
            debug=False,
            impute=False,
            element_set={
                "Fe",
                "H",
                "Cu",
                "Cr",
                "Ge",
                "Na",
                "P",
                "N",
                "C",
                "Br",
                "S",
                "V",
                "F",
                "Se",
                "B",
                "Cl",
                "Zn",
                "Ti",
                "O",
                "Si",
                "Ni",
                "Ca",
                "Al",
                "As",
            },
            log_scale_features=True,
            standard_scale_features=True,
        )
    )


... > creating MoleculeWrapper objects


100%|██████████| 744/744 [00:00<00:00, 2741.09it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:03<00:00, 225.29it/s]


included in labels
{'global': ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight'

744it [00:00, 13065.33it/s]


... > number of categories for each label: 
...... > label  NR-AR :  2 with distribution:  [306, 8]
...... > label  NR-AR-LBD :  2 with distribution:  [308, 6]
...... > label  NR-AhR :  2 with distribution:  [294, 20]
...... > label  NR-Aromatase :  2 with distribution:  [306, 8]
...... > label  NR-ER :  2 with distribution:  [291, 23]
...... > label  NR-ER-LBD :  2 with distribution:  [306, 8]
...... > label  NR-PPAR-gamma :  2 with distribution:  [314, 0]
...... > label  SR-ARE :  2 with distribution:  [291, 23]
...... > label  SR-ATAD5 :  2 with distribution:  [314, 0]
...... > label  SR-HSE :  2 with distribution:  [309, 5]
...... > label  SR-MMP :  2 with distribution:  [296, 18]
...... > label  SR-p53 :  2 with distribution:  [311, 3]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  430
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.512262

100%|██████████| 744/744 [00:00<00:00, 11642.11it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 319.49it/s]


included in labels
{'global': ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loa

744it [00:00, 27272.64it/s]


... > number of categories for each label: 
...... > label  NR-AR :  2 with distribution:  [306, 8]
...... > label  NR-AR-LBD :  2 with distribution:  [308, 6]
...... > label  NR-AhR :  2 with distribution:  [294, 20]
...... > label  NR-Aromatase :  2 with distribution:  [306, 8]
...... > label  NR-ER :  2 with distribution:  [291, 23]
...... > label  NR-ER-LBD :  2 with distribution:  [306, 8]
...... > label  NR-PPAR-gamma :  2 with distribution:  [314, 0]
...... > label  SR-ARE :  2 with distribution:  [291, 23]
...... > label  SR-ATAD5 :  2 with distribution:  [314, 0]
...... > label  SR-HSE :  2 with distribution:  [309, 5]
...... > label  SR-MMP :  2 with distribution:  [296, 18]
...... > label  SR-p53 :  2 with distribution:  [311, 3]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  430
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.512262

100%|██████████| 744/744 [00:00<00:00, 12456.14it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['NR-AR']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 319.91it/s]


included in labels
{'global': ['NR-AR']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['globa

744it [00:00, 30084.09it/s]


... > number of categories for each label: 
...... > label  NR-AR :  2 with distribution:  [667, 25]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  52
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.83730508e-01 2.44589755e-01 3.77376419e-02 6.68415798e-04
 8.07669089e-03 1.88548956e-02 1.01376396e-02 5.29162507e-04
 5.16908217e-02 0.00000000e+00 0.00000000e+00 2.69315865e-02
 0.00000000e+00 2.59206076e-01 3.39666628e-01 4.53965729e-03
 0.00000000e+00 8.35519747e-05 0.00000000e+00 0.00000000e+00
 6.68415798e-04 2.78506582e-05 0.00000000e+00 8.35519747e-05
 3.62058557e-03 0.00000000e+00 0.00000000e+00 2.22805266e-04
 0.00000000e+00 0.00000000e+00 5.87648889e-03]
std [0.62816143 0.3965248  0.15726923 0.02151427 0.07438483 0.11275509
 0.0832112  0.01914439 0.18209175 0.         0.         0.13394866
 0.         0.33538064 0.34650476 0.05591102 0. 

100%|██████████| 744/744 [00:00<00:00, 6339.00it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-AR']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 310.96it/s]


included in labels
{'global': ['NR-AR']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  di

744it [00:00, 30397.95it/s]


... > number of categories for each label: 
...... > label  NR-AR :  2 with distribution:  [667, 25]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  52
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.83730508e-01 2.44589755e-01 3.77376419e-02 6.68415798e-04
 8.07669089e-03 1.88548956e-02 1.01376396e-02 5.29162507e-04
 5.16908217e-02 0.00000000e+00 0.00000000e+00 2.69315865e-02
 0.00000000e+00 2.59206076e-01 3.39666628e-01 4.53965729e-03
 0.00000000e+00 8.35519747e-05 0.00000000e+00 0.00000000e+00
 6.68415798e-04 2.78506582e-05 0.00000000e+00 8.35519747e-05
 3.62058557e-03 0.00000000e+00 0.00000000e+00 2.22805266e-04
 0.00000000e+00 0.00000000e+00 5.87648889e-03]
std [0.62816143 0.3965248  0.15726923 0.02151427 0.07438483 0.11275509
 0.0832112  0.01914439 0.18209175 0.         0.         0.13394866
 0.         0.33538064 0.34650476 0.05591102 0. 

100%|██████████| 744/744 [00:00<00:00, 12401.74it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['NR-AR-LBD']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 322.11it/s]


included in labels
{'global': ['NR-AR-LBD']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['g

744it [00:00, 30097.14it/s]


... > number of categories for each label: 
...... > label  NR-AR-LBD :  2 with distribution:  [631, 17]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  96
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.78572648e-01 2.41965878e-01 3.75142105e-02 6.37893459e-04
 7.47246623e-03 1.98050731e-02 9.59877776e-03 4.86014064e-04
 5.15478666e-02 0.00000000e+00 0.00000000e+00 2.71560358e-02
 0.00000000e+00 2.58286099e-01 3.41121121e-01 4.64750948e-03
 0.00000000e+00 9.11276369e-05 0.00000000e+00 0.00000000e+00
 7.29021096e-04 3.03758790e-05 0.00000000e+00 9.11276369e-05
 3.34134669e-03 0.00000000e+00 0.00000000e+00 2.43007032e-04
 0.00000000e+00 0.00000000e+00 5.37653058e-03]
std [0.62740766 0.39467953 0.1568297  0.02101778 0.07157989 0.11547982
 0.08100141 0.01834784 0.18186004 0.         0.         0.13448301
 0.         0.33513963 0.3465307  0.05656685

100%|██████████| 744/744 [00:00<00:00, 6391.15it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-AR-LBD']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:05<00:00, 140.34it/s]


included in labels
{'global': ['NR-AR-LBD']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:

744it [00:00, 29760.45it/s]


... > number of categories for each label: 
...... > label  NR-AR-LBD :  2 with distribution:  [631, 17]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  96
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.78572648e-01 2.41965878e-01 3.75142105e-02 6.37893459e-04
 7.47246623e-03 1.98050731e-02 9.59877776e-03 4.86014064e-04
 5.15478666e-02 0.00000000e+00 0.00000000e+00 2.71560358e-02
 0.00000000e+00 2.58286099e-01 3.41121121e-01 4.64750948e-03
 0.00000000e+00 9.11276369e-05 0.00000000e+00 0.00000000e+00
 7.29021096e-04 3.03758790e-05 0.00000000e+00 9.11276369e-05
 3.34134669e-03 0.00000000e+00 0.00000000e+00 2.43007032e-04
 0.00000000e+00 0.00000000e+00 5.37653058e-03]
std [0.62740766 0.39467953 0.1568297  0.02101778 0.07157989 0.11547982
 0.08100141 0.01834784 0.18186004 0.         0.         0.13448301
 0.         0.33513963 0.3465307  0.05656685

100%|██████████| 744/744 [00:00<00:00, 11787.76it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['NR-AhR']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 321.66it/s]


included in labels
{'global': ['NR-AhR']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['glob

744it [00:00, 30007.14it/s]


... > number of categories for each label: 
...... > label  NR-AhR :  2 with distribution:  [550, 72]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  122
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.79493605e-01 2.44615143e-01 3.80558819e-02 6.62664610e-04
 8.26752990e-03 1.89648300e-02 1.01608574e-02 5.36442780e-04
 5.17509505e-02 0.00000000e+00 0.00000000e+00 2.65065844e-02
 0.00000000e+00 2.57681867e-01 3.41966494e-01 4.67020773e-03
 0.00000000e+00 6.31109153e-05 0.00000000e+00 0.00000000e+00
 6.62664610e-04 3.15554576e-05 0.00000000e+00 9.46663729e-05
 3.02932393e-03 0.00000000e+00 0.00000000e+00 2.20888203e-04
 0.00000000e+00 0.00000000e+00 5.93242604e-03]
std [0.62830949 0.397145   0.15789261 0.0214216  0.07524801 0.11307411
 0.08330502 0.01927553 0.18218909 0.         0.         0.13292993
 0.         0.33497987 0.34654297 0.05670388 0

100%|██████████| 744/744 [00:00<00:00, 6040.69it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-AhR']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 314.17it/s]


included in labels
{'global': ['NR-AhR']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  d

744it [00:00, 30636.40it/s]


... > number of categories for each label: 
...... > label  NR-AhR :  2 with distribution:  [550, 72]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  122
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.79493605e-01 2.44615143e-01 3.80558819e-02 6.62664610e-04
 8.26752990e-03 1.89648300e-02 1.01608574e-02 5.36442780e-04
 5.17509505e-02 0.00000000e+00 0.00000000e+00 2.65065844e-02
 0.00000000e+00 2.57681867e-01 3.41966494e-01 4.67020773e-03
 0.00000000e+00 6.31109153e-05 0.00000000e+00 0.00000000e+00
 6.62664610e-04 3.15554576e-05 0.00000000e+00 9.46663729e-05
 3.02932393e-03 0.00000000e+00 0.00000000e+00 2.20888203e-04
 0.00000000e+00 0.00000000e+00 5.93242604e-03]
std [0.62830949 0.397145   0.15789261 0.0214216  0.07524801 0.11307411
 0.08330502 0.01927553 0.18218909 0.         0.         0.13292993
 0.         0.33497987 0.34654297 0.05670388 0

100%|██████████| 744/744 [00:00<00:00, 12478.85it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['NR-Aromatase']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 325.09it/s]


included in labels
{'global': ['NR-Aromatase']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(

744it [00:00, 30316.15it/s]


... > number of categories for each label: 
...... > label  NR-Aromatase :  2 with distribution:  [554, 28]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  162
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.76823959e-01 2.43039199e-01 3.92610193e-02 6.11330750e-04
 7.67559720e-03 2.08531712e-02 1.01209202e-02 4.75479473e-04
 5.27102958e-02 0.00000000e+00 0.00000000e+00 2.71023299e-02
 0.00000000e+00 2.56215510e-01 3.42922588e-01 4.14346397e-03
 0.00000000e+00 3.39628195e-05 0.00000000e+00 0.00000000e+00
 6.79256389e-04 3.39628195e-05 0.00000000e+00 1.01888458e-04
 2.92080247e-03 0.00000000e+00 0.00000000e+00 2.03776917e-04
 0.00000000e+00 0.00000000e+00 5.60386521e-03]
std [0.62783794 0.39587277 0.16022558 0.02057592 0.07253553 0.11840381
 0.08314358 0.01814804 0.18373246 0.         0.         0.13435538
 0.         0.33458731 0.34655436 0.0534

100%|██████████| 744/744 [00:00<00:00, 6531.28it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-Aromatase']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 312.88it/s]


included in labels
{'global': ['NR-Aromatase']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include nam

744it [00:00, 30447.78it/s]


... > number of categories for each label: 
...... > label  NR-Aromatase :  2 with distribution:  [554, 28]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  162
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.76823959e-01 2.43039199e-01 3.92610193e-02 6.11330750e-04
 7.67559720e-03 2.08531712e-02 1.01209202e-02 4.75479473e-04
 5.27102958e-02 0.00000000e+00 0.00000000e+00 2.71023299e-02
 0.00000000e+00 2.56215510e-01 3.42922588e-01 4.14346397e-03
 0.00000000e+00 3.39628195e-05 0.00000000e+00 0.00000000e+00
 6.79256389e-04 3.39628195e-05 0.00000000e+00 1.01888458e-04
 2.92080247e-03 0.00000000e+00 0.00000000e+00 2.03776917e-04
 0.00000000e+00 0.00000000e+00 5.60386521e-03]
std [0.62783794 0.39587277 0.16022558 0.02057592 0.07253553 0.11840381
 0.08314358 0.01814804 0.18373246 0.         0.         0.13435538
 0.         0.33458731 0.34655436 0.0534

100%|██████████| 744/744 [00:00<00:00, 11745.26it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['NR-ER']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 318.92it/s]


included in labels
{'global': ['NR-ER']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['globa

744it [00:00, 30163.77it/s]


... > number of categories for each label: 
...... > label  NR-ER :  2 with distribution:  [541, 64]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  139
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.74287592e-01 2.39476831e-01 3.78055820e-02 7.06332047e-04
 7.60147822e-03 2.00800111e-02 9.41776063e-03 5.38157750e-04
 5.19994926e-02 0.00000000e+00 0.00000000e+00 2.65379041e-02
 0.00000000e+00 2.57912102e-01 3.41259283e-01 4.67524546e-03
 0.00000000e+00 6.72697188e-05 0.00000000e+00 0.00000000e+00
 6.39062329e-04 3.36348594e-05 0.00000000e+00 1.00904578e-04
 3.46439052e-03 0.00000000e+00 0.00000000e+00 2.35444016e-04
 0.00000000e+00 0.00000000e+00 5.68429124e-03]
std [0.62712493 0.39373202 0.15740257 0.02211545 0.07218837 0.11625488
 0.08024463 0.01930629 0.18259067 0.         0.         0.13300531
 0.         0.33504088 0.34653284 0.05673425 0.

100%|██████████| 744/744 [00:00<00:00, 6216.40it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-ER']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 311.01it/s]


included in labels
{'global': ['NR-ER']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  di

744it [00:00, 30383.15it/s]


... > number of categories for each label: 
...... > label  NR-ER :  2 with distribution:  [541, 64]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  139
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.74287592e-01 2.39476831e-01 3.78055820e-02 7.06332047e-04
 7.60147822e-03 2.00800111e-02 9.41776063e-03 5.38157750e-04
 5.19994926e-02 0.00000000e+00 0.00000000e+00 2.65379041e-02
 0.00000000e+00 2.57912102e-01 3.41259283e-01 4.67524546e-03
 0.00000000e+00 6.72697188e-05 0.00000000e+00 0.00000000e+00
 6.39062329e-04 3.36348594e-05 0.00000000e+00 1.00904578e-04
 3.46439052e-03 0.00000000e+00 0.00000000e+00 2.35444016e-04
 0.00000000e+00 0.00000000e+00 5.68429124e-03]
std [0.62712493 0.39373202 0.15740257 0.02211545 0.07218837 0.11625488
 0.08024463 0.01930629 0.18259067 0.         0.         0.13300531
 0.         0.33504088 0.34653284 0.05673425 0.

100%|██████████| 744/744 [00:00<00:00, 12275.14it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['NR-ER-LBD']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 321.62it/s]


included in labels
{'global': ['NR-ER-LBD']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['g

744it [00:00, 30894.21it/s]


... > number of categories for each label: 
...... > label  NR-ER-LBD :  2 with distribution:  [651, 29]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  64
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.81433372e-01 2.43156756e-01 3.71236347e-02 6.04338239e-04
 7.77006308e-03 1.93100456e-02 9.43918774e-03 5.75560228e-04
 5.15989744e-02 0.00000000e+00 0.00000000e+00 2.70225527e-02
 0.00000000e+00 2.59174771e-01 3.40184873e-01 4.57570381e-03
 0.00000000e+00 8.63340342e-05 0.00000000e+00 0.00000000e+00
 5.75560228e-04 2.87780114e-05 0.00000000e+00 8.63340342e-05
 3.13680324e-03 0.00000000e+00 0.00000000e+00 2.30224091e-04
 0.00000000e+00 0.00000000e+00 5.87071432e-03]
std [0.6273452  0.3956909  0.15605761 0.02045801 0.0729755  0.11406939
 0.0803346  0.01996539 0.18194293 0.         0.         0.13416552
 0.         0.33537248 0.3465147  0.05613109

100%|██████████| 744/744 [00:00<00:00, 6603.88it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-ER-LBD']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:06<00:00, 120.35it/s]


included in labels
{'global': ['NR-ER-LBD']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:

744it [00:00, 30627.38it/s]


... > number of categories for each label: 
...... > label  NR-ER-LBD :  2 with distribution:  [651, 29]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  64
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.81433372e-01 2.43156756e-01 3.71236347e-02 6.04338239e-04
 7.77006308e-03 1.93100456e-02 9.43918774e-03 5.75560228e-04
 5.15989744e-02 0.00000000e+00 0.00000000e+00 2.70225527e-02
 0.00000000e+00 2.59174771e-01 3.40184873e-01 4.57570381e-03
 0.00000000e+00 8.63340342e-05 0.00000000e+00 0.00000000e+00
 5.75560228e-04 2.87780114e-05 0.00000000e+00 8.63340342e-05
 3.13680324e-03 0.00000000e+00 0.00000000e+00 2.30224091e-04
 0.00000000e+00 0.00000000e+00 5.87071432e-03]
std [0.6273452  0.3956909  0.15605761 0.02045801 0.0729755  0.11406939
 0.0803346  0.01996539 0.18194293 0.         0.         0.13416552
 0.         0.33537248 0.3465147  0.05613109

100%|██████████| 744/744 [00:00<00:00, 11635.90it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['NR-PPAR-gamma']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 318.16it/s]


included in labels
{'global': ['NR-PPAR-gamma']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys

744it [00:00, 29769.54it/s]


... > number of categories for each label: 
...... > label  NR-PPAR-gamma :  2 with distribution:  [605, 15]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  124
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.75704674e-01 2.40185455e-01 3.86452166e-02 5.92011828e-04
 7.82771195e-03 2.06217454e-02 9.60374744e-03 6.24901374e-04
 5.29192796e-02 0.00000000e+00 0.00000000e+00 2.70680964e-02
 0.00000000e+00 2.56933134e-01 3.40900145e-01 4.67031554e-03
 0.00000000e+00 6.57790920e-05 0.00000000e+00 0.00000000e+00
 7.56459559e-04 3.28895460e-05 0.00000000e+00 9.86686381e-05
 3.09161733e-03 0.00000000e+00 0.00000000e+00 2.63116368e-04
 0.00000000e+00 0.00000000e+00 5.72278101e-03]
std [0.62676045 0.39354562 0.1590389  0.02024848 0.07324263 0.11776523
 0.08102209 0.02080284 0.18406629 0.         0.         0.13427395
 0.         0.33478029 0.34652715 0.056

100%|██████████| 744/744 [00:00<00:00, 6641.77it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['NR-PPAR-gamma']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 313.54it/s]


included in labels
{'global': ['NR-PPAR-gamma']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include na

744it [00:00, 30736.88it/s]


... > number of categories for each label: 
...... > label  NR-PPAR-gamma :  2 with distribution:  [605, 15]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  124
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.75704674e-01 2.40185455e-01 3.86452166e-02 5.92011828e-04
 7.82771195e-03 2.06217454e-02 9.60374744e-03 6.24901374e-04
 5.29192796e-02 0.00000000e+00 0.00000000e+00 2.70680964e-02
 0.00000000e+00 2.56933134e-01 3.40900145e-01 4.67031554e-03
 0.00000000e+00 6.57790920e-05 0.00000000e+00 0.00000000e+00
 7.56459559e-04 3.28895460e-05 0.00000000e+00 9.86686381e-05
 3.09161733e-03 0.00000000e+00 0.00000000e+00 2.63116368e-04
 0.00000000e+00 0.00000000e+00 5.72278101e-03]
std [0.62676045 0.39354562 0.1590389  0.02024848 0.07324263 0.11776523
 0.08102209 0.02080284 0.18406629 0.         0.         0.13427395
 0.         0.33478029 0.34652715 0.056

100%|██████████| 744/744 [00:00<00:00, 13756.36it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['SR-ARE']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 322.07it/s]


included in labels
{'global': ['SR-ARE']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['glob

744it [00:00, 30299.96it/s]


... > number of categories for each label: 
...... > label  SR-ARE :  2 with distribution:  [454, 83]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  207
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.71552337e-01 2.34248621e-01 3.86703165e-02 7.29628613e-04
 8.10698459e-03 1.99837170e-02 9.84998628e-03 6.48558767e-04
 5.25737951e-02 0.00000000e+00 0.00000000e+00 2.75232127e-02
 0.00000000e+00 2.57518366e-01 3.38709816e-01 4.62098122e-03
 0.00000000e+00 1.21604769e-04 0.00000000e+00 0.00000000e+00
 6.89093690e-04 4.05349230e-05 4.05349230e-05 8.10698459e-05
 2.99958430e-03 0.00000000e+00 0.00000000e+00 2.83744461e-04
 0.00000000e+00 0.00000000e+00 7.29628613e-03]
std [0.62420862 0.38910538 0.15908748 0.02247682 0.07452255 0.11598409
 0.08203943 0.02119259 0.18351396 0.         0.         0.1353518
 0.         0.33493644 0.34648436 0.05640626 0.

100%|██████████| 744/744 [00:00<00:00, 6761.99it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['SR-ARE']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 312.91it/s]


included in labels
{'global': ['SR-ARE']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  d

744it [00:00, 30729.01it/s]


... > number of categories for each label: 
...... > label  SR-ARE :  2 with distribution:  [454, 83]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  207
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.71552337e-01 2.34248621e-01 3.86703165e-02 7.29628613e-04
 8.10698459e-03 1.99837170e-02 9.84998628e-03 6.48558767e-04
 5.25737951e-02 0.00000000e+00 0.00000000e+00 2.75232127e-02
 0.00000000e+00 2.57518366e-01 3.38709816e-01 4.62098122e-03
 0.00000000e+00 1.21604769e-04 0.00000000e+00 0.00000000e+00
 6.89093690e-04 4.05349230e-05 4.05349230e-05 8.10698459e-05
 2.99958430e-03 0.00000000e+00 0.00000000e+00 2.83744461e-04
 0.00000000e+00 0.00000000e+00 7.29628613e-03]
std [0.62420862 0.38910538 0.15908748 0.02247682 0.07452255 0.11598409
 0.08203943 0.02119259 0.18351396 0.         0.         0.1353518
 0.         0.33493644 0.34648436 0.05640626 0.

100%|██████████| 744/744 [00:00<00:00, 13609.79it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['SR-ATAD5']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:06<00:00, 114.70it/s]


included in labels
{'global': ['SR-ATAD5']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['gl

744it [00:00, 30889.93it/s]


... > number of categories for each label: 
...... > label  SR-ATAD5 :  2 with distribution:  [669, 15]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  60
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.80454831e-01 2.42565052e-01 3.83360570e-02 6.94913421e-04
 8.16523269e-03 1.94575758e-02 1.00183351e-02 5.50139791e-04
 5.22053707e-02 0.00000000e+00 0.00000000e+00 2.67831214e-02
 0.00000000e+00 2.58710476e-01 3.39841617e-01 4.74857504e-03
 0.00000000e+00 2.89547259e-05 0.00000000e+00 0.00000000e+00
 6.65958695e-04 2.89547259e-05 0.00000000e+00 8.68641776e-05
 3.12711039e-03 0.00000000e+00 0.00000000e+00 2.31637807e-04
 0.00000000e+00 0.00000000e+00 6.13840188e-03]
std [0.62746471 0.39557266 0.15843887 0.02193614 0.07478661 0.11449178
 0.08272735 0.01951987 0.1829224  0.         0.         0.13359382
 0.         0.33525114 0.3465082  0.0571744  

100%|██████████| 744/744 [00:00<00:00, 6284.85it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['SR-ATAD5']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 312.98it/s]


included in labels
{'global': ['SR-ATAD5']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names: 

744it [00:00, 31175.07it/s]


... > number of categories for each label: 
...... > label  SR-ATAD5 :  2 with distribution:  [669, 15]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  60
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.80454831e-01 2.42565052e-01 3.83360570e-02 6.94913421e-04
 8.16523269e-03 1.94575758e-02 1.00183351e-02 5.50139791e-04
 5.22053707e-02 0.00000000e+00 0.00000000e+00 2.67831214e-02
 0.00000000e+00 2.58710476e-01 3.39841617e-01 4.74857504e-03
 0.00000000e+00 2.89547259e-05 0.00000000e+00 0.00000000e+00
 6.65958695e-04 2.89547259e-05 0.00000000e+00 8.68641776e-05
 3.12711039e-03 0.00000000e+00 0.00000000e+00 2.31637807e-04
 0.00000000e+00 0.00000000e+00 6.13840188e-03]
std [0.62746471 0.39557266 0.15843887 0.02193614 0.07478661 0.11449178
 0.08272735 0.01951987 0.1829224  0.         0.         0.13359382
 0.         0.33525114 0.3465082  0.0571744  

100%|██████████| 744/744 [00:00<00:00, 13868.18it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['SR-HSE']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 325.05it/s]


included in labels
{'global': ['SR-HSE']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['glob

744it [00:00, 31282.58it/s]


... > number of categories for each label: 
...... > label  SR-HSE :  2 with distribution:  [565, 31]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  148
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.73826656e-01 2.34350583e-01 3.74271833e-02 7.44290578e-04
 7.51379060e-03 2.03439425e-02 8.82515971e-03 6.37963352e-04
 5.08244137e-02 0.00000000e+00 0.00000000e+00 2.77159634e-02
 0.00000000e+00 2.59721969e-01 3.37588941e-01 4.74928274e-03
 0.00000000e+00 1.06327225e-04 0.00000000e+00 0.00000000e+00
 7.44290578e-04 3.54424085e-05 3.54424085e-05 1.06327225e-04
 3.22525917e-03 0.00000000e+00 0.00000000e+00 2.48096859e-04
 0.00000000e+00 0.00000000e+00 7.40746337e-03]
std [0.62441394 0.38906064 0.15665808 0.0227013  0.07177538 0.11699346
 0.07771262 0.02101893 0.18068115 0.         0.         0.13580526
 0.         0.3355146  0.34645711 0.05717863 0

100%|██████████| 744/744 [00:00<00:00, 6696.77it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['SR-HSE']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 314.80it/s]


included in labels
{'global': ['SR-HSE']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  d

744it [00:00, 30921.15it/s]


... > number of categories for each label: 
...... > label  SR-HSE :  2 with distribution:  [565, 31]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  148
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.73826656e-01 2.34350583e-01 3.74271833e-02 7.44290578e-04
 7.51379060e-03 2.03439425e-02 8.82515971e-03 6.37963352e-04
 5.08244137e-02 0.00000000e+00 0.00000000e+00 2.77159634e-02
 0.00000000e+00 2.59721969e-01 3.37588941e-01 4.74928274e-03
 0.00000000e+00 1.06327225e-04 0.00000000e+00 0.00000000e+00
 7.44290578e-04 3.54424085e-05 3.54424085e-05 1.06327225e-04
 3.22525917e-03 0.00000000e+00 0.00000000e+00 2.48096859e-04
 0.00000000e+00 0.00000000e+00 7.40746337e-03]
std [0.62441394 0.38906064 0.15665808 0.0227013  0.07177538 0.11699346
 0.07771262 0.02101893 0.18068115 0.         0.         0.13580526
 0.         0.3355146  0.34645711 0.05717863 0

100%|██████████| 744/744 [00:00<00:00, 13848.05it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['SR-MMP']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 323.79it/s]


included in labels
{'global': ['SR-MMP']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['glob

744it [00:00, 31163.24it/s]


... > number of categories for each label: 
...... > label  SR-MMP :  2 with distribution:  [488, 90]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  166
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.76939430e-01 2.42993602e-01 3.77558991e-02 6.21780588e-04
 7.80680072e-03 1.88261345e-02 1.05011833e-02 6.21780588e-04
 5.34040439e-02 0.00000000e+00 0.00000000e+00 2.63220449e-02
 0.00000000e+00 2.56139059e-01 3.42013867e-01 4.49063758e-03
 0.00000000e+00 1.03630098e-04 0.00000000e+00 0.00000000e+00
 5.87237222e-04 3.45433660e-05 0.00000000e+00 1.03630098e-04
 3.17798967e-03 0.00000000e+00 0.00000000e+00 1.72716830e-04
 0.00000000e+00 0.00000000e+00 5.97600232e-03]
std [0.62751994 0.39526672 0.15730508 0.02075088 0.07314585 0.11267146
 0.08466753 0.02075088 0.18483742 0.         0.         0.13248472
 0.         0.33456666 0.34654359 0.05561031 0

100%|██████████| 744/744 [00:00<00:00, 6809.07it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['SR-MMP']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 314.24it/s]


included in labels
{'global': ['SR-MMP']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  d

744it [00:00, 31352.04it/s]


... > number of categories for each label: 
...... > label  SR-MMP :  2 with distribution:  [488, 90]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  166
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.76939430e-01 2.42993602e-01 3.77558991e-02 6.21780588e-04
 7.80680072e-03 1.88261345e-02 1.05011833e-02 6.21780588e-04
 5.34040439e-02 0.00000000e+00 0.00000000e+00 2.63220449e-02
 0.00000000e+00 2.56139059e-01 3.42013867e-01 4.49063758e-03
 0.00000000e+00 1.03630098e-04 0.00000000e+00 0.00000000e+00
 5.87237222e-04 3.45433660e-05 0.00000000e+00 1.03630098e-04
 3.17798967e-03 0.00000000e+00 0.00000000e+00 1.72716830e-04
 0.00000000e+00 0.00000000e+00 5.97600232e-03]
std [0.62751994 0.39526672 0.15730508 0.02075088 0.07314585 0.11267146
 0.08466753 0.02075088 0.18483742 0.         0.         0.13248472
 0.         0.33456666 0.34654359 0.05561031 0

100%|██████████| 744/744 [00:00<00:00, 12402.23it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys []
selected global keys ['SR-p53']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 316.03it/s]


included in labels
{'global': ['SR-p53']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  dict_keys(['glob

744it [00:00, 30674.95it/s]


... > number of categories for each label: 
...... > label  SR-p53 :  2 with distribution:  [623, 38]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  83
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.79577000e-01 2.42537877e-01 3.91367396e-02 7.16462052e-04
 8.53783945e-03 1.97325590e-02 1.01498791e-02 5.97051710e-04
 5.32271599e-02 0.00000000e+00 0.00000000e+00 2.68076218e-02
 0.00000000e+00 2.57568108e-01 3.40409032e-01 4.44803524e-03
 0.00000000e+00 8.95577565e-05 0.00000000e+00 0.00000000e+00
 6.26904295e-04 2.98525855e-05 0.00000000e+00 8.95577565e-05
 3.01511113e-03 0.00000000e+00 0.00000000e+00 2.38820684e-04
 0.00000000e+00 0.00000000e+00 6.00036968e-03]
std [0.62790027 0.39544352 0.15998699 0.02227331 0.07645315 0.11527443
 0.08326068 0.02033441 0.18455656 0.         0.         0.13365246
 0.         0.33494966 0.34651876 0.05534761 0.

100%|██████████| 744/744 [00:00<00:00, 6404.48it/s]


... > bond_feats_error_count:  0
... > atom_feats_error_count:  0
element set {'Br', 'O', 'Cu', 'V', 'N', 'Ti', 'C', 'H', 'S', 'Zn', 'Se', 'Na', 'Ge', 'P', 'B', 'Al', 'As', 'F', 'Ca', 'Fe', 'Si', 'Cr', 'Ni', 'Cl'}
selected atomic keys []
selected bond keys ['bond_length']
selected global keys ['SR-p53']
... > Building graphs and featurizing


100%|██████████| 744/744 [00:02<00:00, 313.94it/s]


included in labels
{'global': ['SR-p53']}
included in graph features
{'atom': ['total_degree', 'total_H', 'is_in_ring', 'ring_size_3', 'ring_size_4', 'ring_size_5', 'ring_size_6', 'chemical_symbol_Br', 'chemical_symbol_O', 'chemical_symbol_Cu', 'chemical_symbol_V', 'chemical_symbol_N', 'chemical_symbol_Ti', 'chemical_symbol_C', 'chemical_symbol_H', 'chemical_symbol_S', 'chemical_symbol_Zn', 'chemical_symbol_Se', 'chemical_symbol_Na', 'chemical_symbol_Ge', 'chemical_symbol_P', 'chemical_symbol_B', 'chemical_symbol_Al', 'chemical_symbol_As', 'chemical_symbol_F', 'chemical_symbol_Ca', 'chemical_symbol_Fe', 'chemical_symbol_Si', 'chemical_symbol_Cr', 'chemical_symbol_Ni', 'chemical_symbol_Cl'], 'bond': ['metal bond', 'ring inclusion', 'ring size_3', 'ring size_4', 'ring size_5', 'ring size_6', 'bond_length'], 'global': ['num atoms', 'num bonds', 'molecule weight']}
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys([])
include names:  d

744it [00:00, 30876.78it/s]


... > number of categories for each label: 
...... > label  SR-p53 :  2 with distribution:  [623, 38]
original loader node types: dict_keys(['atom', 'bond', 'global'])
original loader label types: dict_keys(['global'])
number of graphs filtered:  83
... > Log scaling features
... > Log scaling features complete
... > Scaling features
mean [7.79577000e-01 2.42537877e-01 3.91367396e-02 7.16462052e-04
 8.53783945e-03 1.97325590e-02 1.01498791e-02 5.97051710e-04
 5.32271599e-02 0.00000000e+00 0.00000000e+00 2.68076218e-02
 0.00000000e+00 2.57568108e-01 3.40409032e-01 4.44803524e-03
 0.00000000e+00 8.95577565e-05 0.00000000e+00 0.00000000e+00
 6.26904295e-04 2.98525855e-05 0.00000000e+00 8.95577565e-05
 3.01511113e-03 0.00000000e+00 0.00000000e+00 2.38820684e-04
 0.00000000e+00 0.00000000e+00 6.00036968e-03]
std [0.62790027 0.39544352 0.15998699 0.02227331 0.07645315 0.11527443
 0.08326068 0.02033441 0.18455656 0.         0.         0.13365246
 0.         0.33494966 0.34651876 0.05534761 0.

In [19]:
# print(len(dataset))
# [print(len(i)) for i in dataset_list]


In [3]:
! ls ../../../../data/saved_models/1218/tox21/qtaim

# All at once testing


In [20]:
model_config = {
    "model": {
        "restore": True,
        "restore_path": "../../../../data/saved_models/1218/tox21/no_bl/model_lightning_epoch=236-val_loss=4.50.ckpt",
    }
}

model = load_graph_level_model_from_config(model_config["model"])
model.cpu()

data_loader = DataLoaderMoleculeGraphTask(
    dict_datasets["bl"]["test"],
    batch_size=len(dict_datasets["bl"]["test"].graphs),
    shuffle=False,
)


def evaluate_manually(model, batch_graph, batch_label):
    """
    Evaluate a set of data manually
    Takes
        feats: dict, dictionary of batched features
        scaler_list: list, list of scalers
    """

    labels = batch_label["global"]
    labels_one_hot = torch.argmax(labels, axis=2)
    logits = model.forward(batch_graph, batch_graph.ndata["feat"])
    logits_one_hot = torch.argmax(logits, axis=-1)
    # print(logits_one_hot.shape)
    # print(labels_one_hot.shape)
    print("task: ", model.hparams.ntasks)

    if model.hparams.ntasks > 1:
        # create a dict of softmax layers
        test_auroc = torchmetrics.classification.MultilabelAUROC(
            num_labels=model.hparams.ntasks
        )
        test_acc = torchmetrics.classification.MultilabelAccuracy(
            num_labels=model.hparams.ntasks
        )
        test_f1 = torchmetrics.classification.MultilabelF1Score(
            num_labels=model.hparams.ntasks, average="micro"
        )

        test_acc.update(logits, labels)
        test_f1.update(logits, labels)
        test_auroc.update(logits, labels)

        # compute accuracy manually for each task outside of torchmetrics
        acc_manual = []
        for i in range(model.hparams.ntasks):
            acc_manual.append(
                torch.sum(logits_one_hot[:, i] == labels_one_hot[:, i])
                / len(labels_one_hot[:, i])
            )
        acc_manual = torch.stack(acc_manual)

        auroc_manual = []
        for i in range(model.hparams.ntasks):
            auroc_manual.append(
                torchmetrics.functional.auroc(
                    logits[:, i], labels[:, i], num_classes=2, task="binary"
                )
            )

    else:
        labels_one_hot = labels_one_hot.reshape(-1)

        test_auroc = torchmetrics.classification.AUROC(num_labels=1, task="binary")
        test_f1 = torchmetrics.F1Score(num_classes=2, task="binary")
        test_acc = torchmetrics.Accuracy(num_classes=2, task="binary")

        test_f1.update(logits_one_hot, labels_one_hot)
        test_auroc.update(logits_one_hot, labels_one_hot)
        test_acc.update(logits_one_hot, labels_one_hot)

    auroc = test_auroc.compute()
    acc = test_acc.compute()
    f1 = test_f1.compute()

    return acc, acc_manual, auroc, auroc_manual, f1


batch_graph, batched_labels = next(iter(data_loader))
acc, acc_manual, auroc, auroc_manual, f1 = evaluate_manually(
    model, batch_graph, batched_labels
)
print("-" * 50)
print("acc: {:.4f}\t auroc: {:.4f}\t f1: {:.4f}".format(acc, auroc, f1))
print("acc manual: ", acc_manual)
print("auroc manual: ", auroc_manual)


:::RESTORING MODEL FROM EXISTING FILE:::
readout in size 300
... > number of tasks: 12
... > number of tasks: 12
:::MODEL LOADED:::


RuntimeError: mat1 and mat2 shapes cannot be multiplied (7700x7 and 6x100)

In [17]:
# distro for training set, no impute
print(2608 / (2608 + 51))
print(2631 / (2631 + 28))
print(2527 / (2527 + 132))
print(2612 / (2612 + 47))
print(2451 / (2451 + 208))
print(2603 / (2603 + 56))
print(2634 / (2634 + 25))
print(2496 / (2496 + 163))
print(2648 / (2648 + 11))
print(2616 / (2616 + 43))
print(2539 / (2539 + 120))
print(2635 / (2635 + 24))


0.9808198570891312
0.9894697254606996
0.9503572771718691
0.9823241820233171
0.9217751034223393
0.978939450921399
0.9905979691613388
0.9386987589319293
0.9958631064309891
0.9838285069575028
0.9548702519744264
0.9909740503948853


In [18]:
# Individual task performance


tensor(0.9567)

In [16]:
from sklearn.metrics import f1_score


def manual_eval_separate_tasks(model, dataset_list):
    n_tasks = len(dataset_list)

    temp_data_loader = []
    for i in dataset_list:
        temp_data_loader.append(
            DataLoaderMoleculeGraphTask(i, batch_size=len(i.graphs), shuffle=False)
        )
    task_ind = 0
    statistics_dict = {"f1": [], "acc": [], "auroc": []}
    for task_ind in range(n_tasks):
        batch_graph, batch_label = next(iter(temp_data_loader[task_ind]))

        labels = batch_label["global"]
        labels_one_hot = torch.argmax(labels, axis=2)
        logits = model.forward(batch_graph, batch_graph.ndata["feat"])
        logits_one_hot = torch.argmax(logits, axis=-1)

        # take logits from first task only
        labels = labels.reshape(-1)
        logits = logits[:, task_ind].reshape(-1)
        logits_one_hot = logits_one_hot[:, task_ind].reshape(-1)
        labels_one_hot = labels_one_hot.reshape(-1)
        # compute acc, auroc, f1 manually
        acc_manual = torch.sum(logits_one_hot == labels_one_hot) / len(labels_one_hot)
        auroc_manual = torchmetrics.functional.auroc(
            logits, labels, num_classes=2, task="binary"
        )
        f1 = f1_score(labels_one_hot, logits_one_hot, pos_label=0)
        # print("-" * 50)
        # print(
        #    "acc: {:.4f}\t auroc: {:.4f}\t f1: {:.4f}".format(
        #        acc_manual, auroc_manual, f1
        #    )
        # )
        statistics_dict["f1"].append(f1)
        statistics_dict["acc"].append(acc_manual.numpy())
        statistics_dict["auroc"].append(auroc_manual.numpy())
    return statistics_dict


In [21]:
import os
import numpy as np

# model_root = "../../../../data/saved_models/1212/tox21/"
# model_list = os.listdir(model_root)

model_root = "../../../../data/saved_models/1218/tox21/no_bl/"
model_root_qtaim = "../../../../data/saved_models/1218/tox21/bl/"
models_no_qtaim = os.listdir(model_root)
models_qtaim = os.listdir(model_root_qtaim)
model_list = models_no_qtaim + models_qtaim

for ind, model_path in enumerate(model_list):
    if ind < len(models_no_qtaim):
        print("bare")
        model_loc = model_root + model_path
        dict_key = "bare"
    else:
        print("bl")
        dict_key = "bl"
        model_loc = model_root_qtaim + model_path
    model_config = {
        "model": {
            "restore": True,
            "restore_path": model_loc,
        }
    }
    model = load_graph_level_model_from_config(model_config["model"])
    model.cpu()
    print(model_path)
    stats_dict = manual_eval_separate_tasks(
        model, dataset_list=dict_datasets[dict_key]["single_list"]
    )
    print(np.array(stats_dict["f1"]).mean())


bare
:::RESTORING MODEL FROM EXISTING FILE:::
readout in size 412
... > number of tasks: 12
... > number of tasks: 12
:::MODEL LOADED:::
model_lightning_epoch=484-val_loss=4.48.ckpt
0.9609143275244217
bare
:::RESTORING MODEL FROM EXISTING FILE:::
readout in size 300
... > number of tasks: 12
... > number of tasks: 12
:::MODEL LOADED:::
model_lightning_epoch=64-val_loss=4.50.ckpt
0.9580763806365323
bare
:::RESTORING MODEL FROM EXISTING FILE:::
readout in size 212
... > number of tasks: 12
... > number of tasks: 12
:::MODEL LOADED:::
model_lightning_epoch=58-val_loss=4.50.ckpt
0.9592234120145146
bare
:::RESTORING MODEL FROM EXISTING FILE:::
readout in size 212
... > number of tasks: 12
... > number of tasks: 12
:::MODEL LOADED:::
model_lightning_epoch=660-val_loss=4.49.ckpt
0.9620616097775211
bare
:::RESTORING MODEL FROM EXISTING FILE:::
readout in size 300
... > number of tasks: 12
... > number of tasks: 12
:::MODEL LOADED:::
model_lightning_epoch=236-val_loss=4.50.ckpt
0.95879175224755