In [1]:
import sys
sys.path.append('../')

In [2]:
import itertools
import pandas as pd
import numpy as np
from tqdm.auto import tqdm, trange
from pathlib import Path

from models.train import test as test_clas
from models.train_reg import test as test_reg
from models.utils import ContagionDataset, set_seed, ConfusionMatrix, save_pickle, load_pickle
from sklearn.metrics import matthews_corrcoef, mean_squared_error
from models.results import ResultCollection

Using backend: pytorch


In [3]:
seed = 4444

metric_filter_1 = 'val_mcc'
metric_filter_2 = 'test_mcc'

data_dir_name = 'data'
save_path_name = 'saved_'
networks = ['sym_network', 'europe_network']

target = 'additional_stress'
dict_sets_lengths = {
    '75':(0.5, 0.25, 0.25),
    '40':(0.3, 0.1, 0.6),
    '10':(0.07, 0.03, 0.9),
}
dict_test_type = {
    'reg': test_reg,
    'clas': test_clas,
}
dict_metric = {
    'reg': ('test_rmse_perc', False),
    'clas': ('test_mcc', True),
}


Get all models that will be tested

In [4]:
# col = ResultCollection()
# paths = list(itertools.chain.from_iterable([[k for k in Path(n).glob(f"models*/{save_path_name}*") if k.is_dir()] for n in networks]))
# network_paths = [[k for k in Path(n).glob(f"models*/{save_path_name}*") if k.is_dir()] for n in networks]

# list(Path(networks[0]).glob(f"models_{list(dict_test_type.keys())[0]}*/{save_path_name}*"))

In [5]:
col_results = {}

for n, test_type in itertools.product(networks, dict_test_type.keys()):
    col = ResultCollection()
    metric = dict_metric[test_type]
    col_results[f"{n}_{test_type}"] = (col, metric)

    for p in [k for k in Path(n).glob(f"models_{test_type}*/{save_path_name}*") if k.is_dir()]:
        print(p)
        name = p.name.split(save_path_name)[1]
        _,_,sets_type = p.parent.name.split('_', 3)
        sets_lengths = dict_sets_lengths[sets_type]
        data_dir = p.parent.parent.joinpath(data_dir_name)
        
        set_seed(seed)
        dataset_val = ContagionDataset(
            raw_dir=data_dir,
            drop_edges=0,
            sets_lengths=sets_lengths,
            target = target,
            add_self_loop=True,
        )

        if 'base' in name:
            print("ahora")
            pass

        r = dict_test_type[test_type](
            dataset=dataset_val,
            save_path=str(p),
            n_runs=1,
            debug_mode=False,
            use_cpu=False,
            save=True,
            use_edge_weight=True,
            approach_default= 'base_n' if 'base' in name else 'scale'

            # todo check if base_n being used comparing with name
        )

        result = col.add(r[2], f"{sets_type}_{name}", sets_type, model=name)
        result.save_best(metric[0],p,metric[1])


sym_network\models_reg_10\saved_fnn


100%|██████████| 266/266 [00:13<00:00, 19.04it/s]


sym_network\models_reg_10\saved_gat


100%|██████████| 6338/6338 [05:31<00:00, 19.14it/s]


sym_network\models_reg_10\saved_gcn


100%|██████████| 794/794 [00:50<00:00, 15.73it/s]


sym_network\models_reg_10\saved_sage


100%|██████████| 1388/1388 [04:22<00:00,  5.28it/s]


sym_network\models_reg_10\saved_sage_base
ahora


100%|██████████| 2774/2774 [05:27<00:00,  8.47it/s]


sym_network\models_reg_40\saved_fnn


100%|██████████| 362/362 [00:15<00:00, 24.01it/s]


sym_network\models_reg_40\saved_gat


100%|██████████| 3170/3170 [02:39<00:00, 19.83it/s]


sym_network\models_reg_40\saved_gcn


100%|██████████| 794/794 [00:43<00:00, 18.41it/s]


sym_network\models_reg_40\saved_sage


100%|██████████| 1190/1190 [04:42<00:00,  4.22it/s]


sym_network\models_reg_40\saved_sage_base
ahora


100%|██████████| 1190/1190 [05:12<00:00,  3.81it/s]


sym_network\models_reg_75\saved_fnn


100%|██████████| 356/356 [00:14<00:00, 24.32it/s]


sym_network\models_reg_75\saved_gat


100%|██████████| 3169/3169 [02:56<00:00, 17.92it/s]


sym_network\models_reg_75\saved_gcn


100%|██████████| 793/793 [00:47<00:00, 16.63it/s]


sym_network\models_reg_75\saved_sage


100%|██████████| 1189/1189 [04:38<00:00,  4.26it/s]


sym_network\models_reg_75\saved_sage_base
ahora


100%|██████████| 1189/1189 [05:17<00:00,  3.74it/s]


sym_network\models_clas_10\saved_fnn


100%|██████████| 745/745 [00:41<00:00, 18.13it/s]


sym_network\models_clas_10\saved_gat


100%|██████████| 18755/18755 [15:38<00:00, 19.97it/s]


sym_network\models_clas_10\saved_gcn


100%|██████████| 2226/2226 [01:54<00:00, 19.43it/s]


sym_network\models_clas_10\saved_sage


100%|██████████| 5723/5723 [16:54<00:00,  5.64it/s]


sym_network\models_clas_40\saved_fnn


100%|██████████| 1224/1224 [00:51<00:00, 23.76it/s]


sym_network\models_clas_40\saved_gat


100%|██████████| 11433/11433 [09:36<00:00, 19.84it/s]


sym_network\models_clas_40\saved_gcn


100%|██████████| 3453/3453 [03:01<00:00, 19.07it/s]


sym_network\models_clas_40\saved_sage


100%|██████████| 6212/6212 [25:37<00:00,  4.04it/s]  


sym_network\models_clas_75\saved_fnn


100%|██████████| 1195/1195 [01:06<00:00, 17.95it/s]


sym_network\models_clas_75\saved_gat


100%|██████████| 11551/11551 [10:45<00:00, 17.88it/s] 


sym_network\models_clas_75\saved_gcn


100%|██████████| 2514/2514 [02:12<00:00, 18.98it/s]


sym_network\models_clas_75\saved_sage


100%|██████████| 6562/6562 [26:39<00:00,  4.10it/s] 


europe_network\models_clas_10\saved_fnn


100%|██████████| 1301/1301 [01:04<00:00, 20.31it/s]


europe_network\models_clas_10\saved_sage


100%|██████████| 99/99 [00:29<00:00,  3.34it/s]


europe_network\models_clas_40\saved_fnn


100%|██████████| 2479/2479 [01:52<00:00, 22.08it/s]


europe_network\models_clas_40\saved_sage


100%|██████████| 475/475 [02:04<00:00,  3.81it/s]


europe_network\models_clas_75\saved_fnn


100%|██████████| 2404/2404 [02:20<00:00, 17.05it/s]


europe_network\models_clas_75\saved_sage


100%|██████████| 654/654 [02:56<00:00,  3.71it/s]


In [6]:
# save_pickle(col_results, 'results.pickle')
# col_results = load_pickle('results.pickle')

In [7]:
col_results

{'sym_network_reg': (<models.results.ResultCollection at 0x20d2e499a30>,
  ('test_rmse_perc', False)),
 'sym_network_clas': (<models.results.ResultCollection at 0x20d2e3088b0>,
  ('test_mcc', True)),
 'europe_network_reg': (<models.results.ResultCollection at 0x20e55e9a370>,
  ('test_rmse_perc', False)),
 'europe_network_clas': (<models.results.ResultCollection at 0x20f4a18b580>,
  ('test_mcc', True))}

In [8]:
r = {k:v[0].df(v[1][0], v[1][1]) for k,v in col_results.items()}
r

{'sym_network_reg':               train_mcc   val_mcc  test_mcc  train_acc   val_acc  test_acc  \
 uid                                                                          
 10_sage        0.634618  0.521291  0.581344   0.723809  0.644444  0.684444   
 10_sage_base   0.244881  0.266535  0.370856   0.390476  0.488889  0.503704   
 10_gat         0.442096  0.344674  0.355414   0.561905  0.511111  0.505185   
 10_fnn         0.312659  0.374865  0.332307   0.428571  0.555555  0.480000   
 10_gcn         0.308340  0.381849  0.326334   0.380952  0.555555  0.438519   
 40_sage        0.835089  0.733372  0.712704   0.871111  0.793333  0.776667   
 40_sage_base   0.707345  0.646037  0.613525   0.764444  0.726667  0.697778   
 40_gat         0.518122  0.432167  0.444630   0.633333  0.566667  0.580000   
 40_fnn         0.360827  0.368414  0.325004   0.515556  0.520000  0.490000   
 40_gcn         0.315727  0.329497  0.312130   0.435556  0.426667  0.436667   
 75_sage        0.801031  0.77617

In [9]:
r['sym_network_clas']

Unnamed: 0_level_0,train_mcc,val_mcc,test_mcc,train_acc,val_acc,test_acc,train_rmse,val_rmse,test_rmse,train_mae,val_mae,test_mae,train_rmse_perc,val_rmse_perc,test_rmse_perc,train_mae_perc,val_mae_perc,test_mae_perc,group,model
uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
75_sage,0.922016,0.786613,0.802167,0.941333,0.84,0.850667,0.273252,0.409878,0.481664,0.064,0.162667,0.173333,0.087515,0.102572,0.12033,0.068813,0.079427,0.085693,75,sage
75_gcn,0.370189,0.370565,0.430011,0.506667,0.490667,0.544,0.920869,0.910677,0.920869,0.602667,0.610667,0.576,0.228003,0.223579,0.228121,0.171693,0.175053,0.16924,75,gcn
75_fnn,0.385798,0.360453,0.405166,0.517333,0.485333,0.525333,0.880909,0.913601,0.916515,0.578667,0.621333,0.594667,0.214458,0.223132,0.220422,0.166867,0.177,0.171667,75,fnn
75_gat,0.455612,0.469401,0.371722,0.590667,0.6,0.528,0.85479,0.818128,0.946573,0.506667,0.482667,0.602667,0.212804,0.198866,0.23432,0.15104,0.141987,0.17212,75,gat
40_sage,1.0,0.814434,0.764037,1.0,0.86,0.822222,0.0,0.374166,0.48074,0.0,0.14,0.193333,0.071691,0.109938,0.112366,0.061533,0.0812,0.083578,40,sage
40_gcn,0.381281,0.360564,0.418741,0.511111,0.473333,0.533333,0.904311,0.894427,0.893806,0.591111,0.613333,0.567778,0.221574,0.226097,0.219237,0.166689,0.175733,0.164889,40,gcn
40_gat,0.476819,0.372075,0.407336,0.604444,0.52,0.547778,0.856349,1.116542,0.956266,0.502222,0.7,0.592222,0.214827,0.282347,0.232304,0.1536,0.194333,0.167211,40,gat
40_fnn,0.364046,0.313141,0.360033,0.511111,0.46,0.495556,1.019804,1.045626,1.092906,0.662222,0.72,0.725556,0.247736,0.260934,0.264152,0.185356,0.203333,0.201178,40,fnn
10_sage,0.795001,0.608728,0.603387,0.847619,0.711111,0.701481,0.48795,0.649786,0.64291,0.180952,0.333333,0.334815,0.120183,0.15369,0.15326,0.085571,0.107444,0.1116,10,sage
10_gcn,0.374599,0.403404,0.407791,0.533333,0.511111,0.522963,1.004751,0.829993,0.942416,0.647619,0.555556,0.612593,0.258689,0.190492,0.235553,0.192238,0.144556,0.17857,10,gcn


In [10]:
r['sym_network_reg']

Unnamed: 0_level_0,train_mcc,val_mcc,test_mcc,train_acc,val_acc,test_acc,train_rmse,val_rmse,test_rmse,train_mae,val_mae,test_mae,train_rmse_perc,val_rmse_perc,test_rmse_perc,train_mae_perc,val_mae_perc,test_mae_perc,group,model
uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
10_sage,0.634618,0.521291,0.581344,0.723809,0.644444,0.684444,0.601585,0.649786,0.643486,0.304762,0.377778,0.347407,0.134155,0.150404,0.143559,0.104882,0.111082,0.104955,10,sage
10_sage_base,0.244881,0.266535,0.370856,0.390476,0.488889,0.503704,0.798809,0.760117,0.73686,0.619048,0.533333,0.511852,0.177122,0.165596,0.165149,0.152154,0.133731,0.133605,10,sage_base
10_gat,0.442096,0.344674,0.355414,0.561905,0.511111,0.505185,0.762202,0.869227,0.843274,0.485714,0.577778,0.564444,0.165999,0.187621,0.196543,0.134906,0.139174,0.152113,10,gat
10_fnn,0.312659,0.374865,0.332307,0.428571,0.555555,0.48,0.87831,0.802773,0.883595,0.638095,0.511111,0.604444,0.210453,0.185058,0.211368,0.17903,0.139817,0.169835,10,fnn
10_gcn,0.30834,0.381849,0.326334,0.380952,0.555555,0.438519,0.905012,0.802773,0.859371,0.685714,0.511111,0.62,0.23138,0.200166,0.217017,0.199017,0.16091,0.18167,10,gcn
40_sage,0.835089,0.733372,0.712704,0.871111,0.793333,0.776667,0.368179,0.454606,0.512076,0.131111,0.206667,0.235556,0.095655,0.097373,0.112529,0.075609,0.076762,0.082301,40,sage
40_sage_base,0.707345,0.646037,0.613525,0.764444,0.726667,0.697778,0.485341,0.522813,0.561743,0.235556,0.273333,0.306667,0.105357,0.113051,0.119162,0.084245,0.091583,0.092553,40,sage_base
40_gat,0.518122,0.432167,0.44463,0.633333,0.566667,0.58,0.707107,0.83666,0.767391,0.411111,0.513333,0.473333,0.159566,0.205826,0.173368,0.11821,0.140881,0.130058,40,gat
40_fnn,0.360827,0.368414,0.325004,0.515556,0.52,0.49,0.85505,0.856349,0.864741,0.566667,0.56,0.587778,0.197678,0.203618,0.205201,0.158469,0.167623,0.164736,40,fnn
40_gcn,0.315727,0.329497,0.31213,0.435556,0.426667,0.436667,0.85894,0.890693,0.848528,0.622222,0.646667,0.615556,0.211265,0.219865,0.206994,0.175307,0.190456,0.173165,40,gcn
