In [1]:
import sys
sys.path.append('../')

In [2]:
import itertools
import pandas as pd
import numpy as np
from tqdm.auto import tqdm, trange
from pathlib import Path

from models.train import test as test_clas
from models.train_reg import test as test_reg
from models.utils import ContagionDataset, set_seed, ConfusionMatrix, save_pickle, load_pickle
from sklearn.metrics import matthews_corrcoef, mean_squared_error
from models.results import ResultCollection

Using backend: pytorch


In [3]:
seed = 4444

metric_filter_1 = 'val_mcc'
metric_filter_2 = 'test_mcc'

data_dir_name = 'data'
save_path_name = 'saved_'
networks = ['sym_network', 'europe_network']

target = 'additional_stress'
dict_sets_lengths = {
    '75':(0.5, 0.25, 0.25),
    '40':(0.3, 0.1, 0.6),
    '10':(0.07, 0.03, 0.9),
}
dict_test_type = {
    'clas': test_clas,
    'reg': test_reg,
}


Get all models that will be tested

In [4]:
# col = ResultCollection()
# paths = list(itertools.chain.from_iterable([[k for k in Path(n).glob(f"models*/{save_path_name}*") if k.is_dir()] for n in networks]))
network_paths = [[k for k in Path(n).glob(f"models*/{save_path_name}*") if k.is_dir()] for n in networks]

In [5]:
col_results = {}

for n,paths in zip(networks, network_paths):
    col = ResultCollection()
    col_results[n] = col

    for p in paths:
        print(p)
        name = p.name.split(save_path_name)[1]
        _,test_type,sets_type = p.parent.name.split('_', 3)
        sets_lengths = dict_sets_lengths[sets_type]
        data_dir = p.parent.parent.joinpath(data_dir_name)
        
        set_seed(seed)
        dataset_val = ContagionDataset(
            raw_dir=data_dir,
            drop_edges=0,
            sets_lengths=sets_lengths,
            target = target,
            add_self_loop=True,
        )

        r = dict_test_type[test_type](
            dataset=dataset_val,
            save_path=str(p),
            n_runs=1,
            debug_mode=False,
            use_cpu=False,
            save=True,
            use_edge_weight=True,
        )

        result = col.add(r[2], f"{p.parent.name}_{name}")
        result.save_best('test_mcc',p,True)


sym_network\models_clas_10\saved_fnn


100%|██████████| 743/743 [00:30<00:00, 24.33it/s]


sym_network\models_clas_10\saved_gat


100%|██████████| 18753/18753 [14:51<00:00, 21.04it/s]


sym_network\models_clas_10\saved_gcn


100%|██████████| 2224/2224 [01:47<00:00, 20.70it/s]


sym_network\models_clas_10\saved_sage


100%|██████████| 5721/5721 [14:32<00:00,  6.56it/s]


sym_network\models_clas_40\saved_fnn


100%|██████████| 1222/1222 [00:48<00:00, 25.09it/s]


sym_network\models_clas_40\saved_gat


100%|██████████| 11431/11431 [08:40<00:00, 21.94it/s]


sym_network\models_clas_40\saved_gcn


100%|██████████| 3451/3451 [02:52<00:00, 19.99it/s]


sym_network\models_clas_40\saved_sage


100%|██████████| 6210/6210 [20:51<00:00,  4.96it/s]


sym_network\models_clas_75\saved_fnn


100%|██████████| 1193/1193 [00:50<00:00, 23.86it/s]


sym_network\models_clas_75\saved_gat


100%|██████████| 11549/11549 [08:50<00:00, 21.76it/s]


sym_network\models_clas_75\saved_gcn


100%|██████████| 2512/2512 [02:12<00:00, 18.97it/s]


sym_network\models_clas_75\saved_sage


100%|██████████| 6560/6560 [22:19<00:00,  4.90it/s]


sym_network\models_reg_10\saved_fnn


100%|██████████| 264/264 [00:10<00:00, 24.78it/s]


sym_network\models_reg_10\saved_gat


100%|██████████| 6336/6336 [05:13<00:00, 20.21it/s]


sym_network\models_reg_10\saved_gcn


100%|██████████| 792/792 [00:42<00:00, 18.71it/s]


sym_network\models_reg_10\saved_sage


100%|██████████| 1386/1386 [03:49<00:00,  6.05it/s]


sym_network\models_reg_10\saved_sage_base


100%|██████████| 2772/2772 [04:45<00:00,  9.69it/s]


sym_network\models_reg_40\saved_fnn


100%|██████████| 360/360 [00:15<00:00, 22.81it/s]


sym_network\models_reg_40\saved_gat


100%|██████████| 3168/3168 [02:37<00:00, 20.08it/s]


sym_network\models_reg_40\saved_gcn


100%|██████████| 792/792 [00:43<00:00, 18.41it/s]


sym_network\models_reg_40\saved_sage


100%|██████████| 1188/1188 [04:15<00:00,  4.65it/s]


sym_network\models_reg_40\saved_sage_base


100%|██████████| 1188/1188 [04:07<00:00,  4.80it/s]


sym_network\models_reg_75\saved_fnn


100%|██████████| 355/355 [00:16<00:00, 21.23it/s]


sym_network\models_reg_75\saved_gat


100%|██████████| 3168/3168 [02:43<00:00, 19.38it/s]


sym_network\models_reg_75\saved_gcn


100%|██████████| 792/792 [00:42<00:00, 18.45it/s]


sym_network\models_reg_75\saved_sage


100%|██████████| 1188/1188 [03:52<00:00,  5.12it/s]


sym_network\models_reg_75\saved_sage_base


100%|██████████| 1188/1188 [03:35<00:00,  5.50it/s]


europe_network\models_clas_10\saved_fnn


100%|██████████| 1299/1299 [00:48<00:00, 26.70it/s]


europe_network\models_clas_10\saved_sage


100%|██████████| 97/97 [00:20<00:00,  4.64it/s]


europe_network\models_clas_40\saved_fnn


100%|██████████| 2477/2477 [01:34<00:00, 26.18it/s]


europe_network\models_clas_40\saved_sage


100%|██████████| 473/473 [01:55<00:00,  4.08it/s]


europe_network\models_clas_75\saved_fnn


100%|██████████| 2402/2402 [01:31<00:00, 26.19it/s]


europe_network\models_clas_75\saved_sage


100%|██████████| 652/652 [02:39<00:00,  4.10it/s]


In [6]:
# col_results = load_pickle('results.pickle')

In [7]:
r = {k:v.df('test_mcc', True) for k,v in col_results.items()}
r

{'sym_network':                          train_acc   val_acc  train_mcc   val_mcc  test_mcc  \
 name                                                                          
 models_reg_75_sage_base   0.736000  0.674667   0.653794  0.576053  0.606178   
 models_reg_75_sage        0.944000  0.853333   0.927167  0.810044  0.788355   
 models_reg_75_gcn         0.438667  0.413333   0.345090  0.341689  0.323545   
 models_reg_75_gat         0.650667  0.666667   0.534243  0.555925  0.482151   
 models_reg_75_fnn         0.516000  0.480000   0.386519  0.352674  0.409551   
 models_reg_40_sage_base   0.764444  0.726667   0.707345  0.646037  0.613525   
 models_reg_40_sage        0.875556  0.786667   0.840620  0.724312  0.719139   
 models_reg_40_gcn         0.437778  0.420000   0.347737  0.346852  0.338712   
 models_reg_40_gat         0.588889  0.546667   0.451973  0.396770  0.453376   
 models_reg_40_fnn         0.508889  0.440000   0.382851  0.309779  0.370818   
 models_reg_10_sage_base 

In [13]:
s = r['sym_network']
s[[k for k in s.columns if 'test' in k]]

Unnamed: 0_level_0,test_mcc,test_rmse,test_acc,test_mae,test_rmse_perc,test_mae_perc
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
models_reg_75_sage_base,0.606178,0.623966,0.696,0.330667,0.124594,0.09133
models_reg_75_sage,0.788355,0.478888,0.837333,0.181333,0.117416,0.081585
models_reg_75_gcn,0.323545,0.92376,0.426667,0.666667,0.219243,0.181033
models_reg_75_gat,0.482151,0.719259,0.610667,0.432,0.162221,0.119653
models_reg_75_fnn,0.409551,0.91214,0.525333,0.592,0.244187,0.19388
models_reg_40_sage_base,0.613525,0.561743,0.697778,0.306667,0.119162,0.092553
models_reg_40_sage,0.719139,0.507718,0.781111,0.231111,0.112933,0.082773
models_reg_40_gcn,0.338712,0.883805,0.432222,0.638889,0.216487,0.179924
models_reg_40_gat,0.453376,0.785281,0.59,0.476667,0.183264,0.13601
models_reg_40_fnn,0.370818,0.879394,0.498889,0.591111,0.223169,0.17866
