In [1]:
import sys
sys.path.append('../')

In [2]:
import itertools
import pandas as pd
import numpy as np
from tqdm.auto import tqdm, trange
from pathlib import Path

from models.train import test as test_clas
from models.train_reg import test as test_reg
from models.utils import ContagionDataset, set_seed, ConfusionMatrix, save_pickle, load_pickle
from sklearn.metrics import matthews_corrcoef, mean_squared_error
from models.results import ResultCollection

Using backend: pytorch


In [3]:
seed = 4444

metric_filter_1 = 'val_mcc'
metric_filter_2 = 'test_mcc'

data_dir_name = 'data'
save_path_name = 'saved_'
networks = ['sym_network', 'europe_network']

target = 'additional_stress'
dict_sets_lengths = {
    '75':(0.5, 0.25, 0.25),
    '40':(0.3, 0.1, 0.6),
    '10':(0.07, 0.03, 0.9),
}
dict_test_type = {
    'reg': test_reg,
    'clas': test_clas,
}
dict_metric = {
    'reg': ('test_rmse_perc', False),
    'clas': ('test_mcc', True),
}


In [4]:
# col = ResultCollection()
# paths = list(itertools.chain.from_iterable([[k for k in Path(n).glob(f"models*/{save_path_name}*") if k.is_dir()] for n in networks]))
# network_paths = [[k for k in Path(n).glob(f"models*/{save_path_name}*") if k.is_dir()] for n in networks]

In [5]:
col_results = {}

for n, test_type in itertools.product(networks, dict_test_type.keys()):
    col = ResultCollection()
    metric = dict_metric[test_type]
    col_results[f"{n}_{test_type}"] = (col, metric)

    for p in [k for k in Path(n).glob(f"models_{test_type}*/{save_path_name}*") if k.is_dir()]:
        print(p)
        name = p.name.split(save_path_name)[1]
        _,_,sets_type = p.parent.name.split('_', 3)
        sets_lengths = dict_sets_lengths[sets_type]
        data_dir = p.parent.parent.joinpath(data_dir_name)
        
        set_seed(seed)
        dataset_val = ContagionDataset(
            raw_dir=data_dir,
            drop_edges=0,
            sets_lengths=sets_lengths,
            target = target,
            add_self_loop=True,
        )

        if 'base' in name:
            print("ahora")
            pass

        r = dict_test_type[test_type](
            dataset=dataset_val,
            save_path=str(p),
            n_runs=1,
            debug_mode=False,
            use_cpu=False,
            save=True,
            use_edge_weight=True,
            approach_default= 'base_n' if 'base' in name else 'scale'

            # todo check if base_n being used comparing with name
        )

        result = col.add(r[2], f"{sets_type}_{name}", sets_type, model=name)
        result.save_best(metric[0],p,metric[1])


sym_network\models_reg_10\saved_fnn


100%|██████████| 359/359 [00:19<00:00, 18.88it/s]


sym_network\models_reg_10\saved_gat


100%|██████████| 6979/6979 [06:06<00:00, 19.03it/s]


sym_network\models_reg_10\saved_gcn


100%|██████████| 1085/1085 [00:58<00:00, 18.53it/s]


sym_network\models_reg_10\saved_sage_base
ahora


100%|██████████| 3552/3552 [05:46<00:00, 10.24it/s]


sym_network\models_reg_10\saved_sage_dist


100%|██████████| 6266/6266 [12:14<00:00,  8.54it/s]


sym_network\models_reg_40\saved_fnn


100%|██████████| 516/516 [00:21<00:00, 24.22it/s]


sym_network\models_reg_40\saved_gat


100%|██████████| 4255/4255 [03:18<00:00, 21.38it/s]


sym_network\models_reg_40\saved_gcn


100%|██████████| 1050/1050 [00:52<00:00, 20.18it/s]


sym_network\models_reg_40\saved_sage_base
ahora


100%|██████████| 1937/1937 [06:22<00:00,  5.06it/s]


sym_network\models_reg_40\saved_sage_dist


100%|██████████| 2568/2568 [08:20<00:00,  5.13it/s]


sym_network\models_reg_40\saved_sage_scale


100%|██████████| 2538/2538 [08:19<00:00,  5.08it/s]


sym_network\models_reg_75\saved_fnn


100%|██████████| 525/525 [00:21<00:00, 23.98it/s]


sym_network\models_reg_75\saved_gat


100%|██████████| 5177/5177 [04:03<00:00, 21.22it/s]


sym_network\models_reg_75\saved_gcn


100%|██████████| 1101/1101 [00:55<00:00, 19.92it/s]


sym_network\models_reg_75\saved_sage_base
ahora


100%|██████████| 2342/2342 [07:42<00:00,  5.06it/s]


sym_network\models_reg_75\saved_sage_dist


100%|██████████| 2814/2814 [09:03<00:00,  5.17it/s]


sym_network\models_reg_75\saved_sage_scale


100%|██████████| 2763/2763 [09:05<00:00,  5.07it/s]


sym_network\models_clas_10\saved_fnn


100%|██████████| 745/745 [00:30<00:00, 24.23it/s]


sym_network\models_clas_10\saved_gat


100%|██████████| 18755/18755 [14:29<00:00, 21.58it/s]


sym_network\models_clas_10\saved_gcn


100%|██████████| 2226/2226 [01:47<00:00, 20.79it/s]


sym_network\models_clas_10\saved_sage


100%|██████████| 5723/5723 [14:05<00:00,  6.77it/s]


sym_network\models_clas_40\saved_fnn


100%|██████████| 1224/1224 [00:49<00:00, 24.64it/s]


sym_network\models_clas_40\saved_gat


100%|██████████| 11433/11433 [08:50<00:00, 21.53it/s]


sym_network\models_clas_40\saved_gcn


100%|██████████| 3453/3453 [02:56<00:00, 19.61it/s]


sym_network\models_clas_40\saved_sage


100%|██████████| 6212/6212 [20:16<00:00,  5.10it/s]


sym_network\models_clas_75\saved_fnn


100%|██████████| 1195/1195 [00:48<00:00, 24.59it/s]


sym_network\models_clas_75\saved_gat


100%|██████████| 11551/11551 [08:57<00:00, 21.50it/s]


sym_network\models_clas_75\saved_gcn


100%|██████████| 2514/2514 [02:03<00:00, 20.33it/s]


sym_network\models_clas_75\saved_sage


100%|██████████| 6562/6562 [21:16<00:00,  5.14it/s]


europe_network\models_clas_10\saved_fnn


100%|██████████| 1301/1301 [00:54<00:00, 24.05it/s]


europe_network\models_clas_10\saved_sage


100%|██████████| 99/99 [00:21<00:00,  4.71it/s]


europe_network\models_clas_40\saved_fnn


100%|██████████| 2479/2479 [01:47<00:00, 23.13it/s]


europe_network\models_clas_40\saved_sage


100%|██████████| 475/475 [01:54<00:00,  4.15it/s]


europe_network\models_clas_75\saved_fnn


100%|██████████| 2404/2404 [01:44<00:00, 23.07it/s]


europe_network\models_clas_75\saved_sage


100%|██████████| 654/654 [02:37<00:00,  4.15it/s]


In [6]:
# save_pickle(col_results, 'results.pickle')
# col_results = load_pickle('results.pickle')

In [7]:
col_results

{'sym_network_reg': (<models.results.ResultCollection at 0x2baecbdbfd0>,
  ('test_rmse_perc', False)),
 'sym_network_clas': (<models.results.ResultCollection at 0x2bad37c0d30>,
  ('test_mcc', True)),
 'europe_network_reg': (<models.results.ResultCollection at 0x2bc17fa8c40>,
  ('test_rmse_perc', False)),
 'europe_network_clas': (<models.results.ResultCollection at 0x2bce85207c0>,
  ('test_mcc', True))}

In [8]:
r = {k:v[0].df(v[1][0], v[1][1]) for k,v in col_results.items()}

In [9]:
r['sym_network_clas'].round(3)

Unnamed: 0_level_0,train_mcc,val_mcc,test_mcc,train_acc,val_acc,test_acc,train_rmse,val_rmse,test_rmse,train_mae,val_mae,test_mae,train_rmse_perc,val_rmse_perc,test_rmse_perc,train_mae_perc,val_mae_perc,test_mae_perc,group,model
uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
75_sage,0.922,0.787,0.802,0.941,0.84,0.851,0.273,0.41,0.482,0.064,0.163,0.173,0.088,0.103,0.12,0.069,0.079,0.086,75,sage
75_gcn,0.37,0.371,0.43,0.507,0.491,0.544,0.921,0.911,0.921,0.603,0.611,0.576,0.228,0.224,0.228,0.172,0.175,0.169,75,gcn
75_fnn,0.386,0.36,0.405,0.517,0.485,0.525,0.881,0.914,0.917,0.579,0.621,0.595,0.214,0.223,0.22,0.167,0.177,0.172,75,fnn
75_gat,0.456,0.469,0.372,0.591,0.6,0.528,0.855,0.818,0.947,0.507,0.483,0.603,0.213,0.199,0.234,0.151,0.142,0.172,75,gat
40_sage,1.0,0.814,0.764,1.0,0.86,0.822,0.0,0.374,0.481,0.0,0.14,0.193,0.072,0.11,0.112,0.062,0.081,0.084,40,sage
40_gcn,0.381,0.361,0.419,0.511,0.473,0.533,0.904,0.894,0.894,0.591,0.613,0.568,0.222,0.226,0.219,0.167,0.176,0.165,40,gcn
40_gat,0.477,0.372,0.407,0.604,0.52,0.548,0.856,1.117,0.956,0.502,0.7,0.592,0.215,0.282,0.232,0.154,0.194,0.167,40,gat
40_fnn,0.364,0.313,0.36,0.511,0.46,0.496,1.02,1.046,1.093,0.662,0.72,0.726,0.248,0.261,0.264,0.185,0.203,0.201,40,fnn
10_sage,0.795,0.609,0.603,0.848,0.711,0.701,0.488,0.65,0.643,0.181,0.333,0.335,0.12,0.154,0.153,0.086,0.107,0.112,10,sage
10_gcn,0.375,0.403,0.408,0.533,0.511,0.523,1.005,0.83,0.942,0.648,0.556,0.613,0.259,0.19,0.236,0.192,0.145,0.179,10,gcn


In [10]:
r['europe_network_clas'].round(3)

Unnamed: 0_level_0,train_mcc,val_mcc,test_mcc,train_acc,val_acc,test_acc,train_rmse,val_rmse,test_rmse,train_mae,val_mae,test_mae,train_rmse_perc,val_rmse_perc,test_rmse_perc,train_mae_perc,val_mae_perc,test_mae_perc,group,model
uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
75_sage,0.866,0.699,0.636,0.899,0.77,0.726,0.46,0.735,0.803,0.137,0.33,0.396,0.132,0.195,0.215,0.086,0.125,0.139,75,sage
75_fnn,0.569,0.418,0.494,0.669,0.551,0.612,0.823,0.976,0.944,0.442,0.615,0.548,0.211,0.25,0.242,0.144,0.178,0.169,75,fnn
40_sage,0.975,0.613,0.571,0.982,0.708,0.678,0.18,0.997,0.973,0.023,0.493,0.513,0.081,0.256,0.25,0.066,0.156,0.165,40,sage
40_fnn,0.528,0.556,0.482,0.644,0.667,0.607,0.869,0.898,0.967,0.487,0.486,0.565,0.219,0.241,0.243,0.152,0.162,0.168,40,fnn
10_sage,0.987,0.662,0.509,0.99,0.744,0.63,0.1,0.849,1.123,0.01,0.395,0.633,0.076,0.213,0.285,0.066,0.128,0.192,10,sage
10_fnn,0.648,0.531,0.443,0.733,0.651,0.573,0.724,0.876,0.982,0.347,0.488,0.602,0.183,0.207,0.249,0.125,0.143,0.177,10,fnn


In [11]:
r['sym_network_reg'].round(3)

Unnamed: 0_level_0,train_mcc,val_mcc,test_mcc,train_acc,val_acc,test_acc,train_rmse,val_rmse,test_rmse,train_mae,val_mae,test_mae,train_rmse_perc,val_rmse_perc,test_rmse_perc,train_mae_perc,val_mae_perc,test_mae_perc,group,model
uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
10_sage_dist,0.73,0.558,0.595,0.8,0.667,0.696,0.535,0.577,0.612,0.229,0.333,0.327,0.111,0.136,0.137,0.08,0.105,0.098,10,sage_dist
10_sage_base,0.434,0.467,0.394,0.552,0.622,0.531,0.669,0.667,0.736,0.448,0.4,0.493,0.139,0.159,0.161,0.122,0.127,0.135,10,sage_base
10_gat,0.461,0.386,0.349,0.6,0.533,0.509,0.775,0.816,0.916,0.467,0.533,0.602,0.172,0.198,0.215,0.133,0.149,0.164,10,gat
10_gcn,0.255,0.406,0.337,0.333,0.556,0.431,0.961,0.882,0.889,0.752,0.556,0.643,0.235,0.204,0.215,0.198,0.156,0.179,10,gcn
10_fnn,0.264,0.406,0.328,0.333,0.556,0.425,0.931,0.882,0.895,0.733,0.556,0.65,0.236,0.219,0.223,0.2,0.17,0.185,10,fnn
40_sage_dist,0.852,0.759,0.709,0.887,0.813,0.777,0.346,0.455,0.504,0.116,0.193,0.232,0.087,0.088,0.106,0.066,0.066,0.076,40,sage_dist
40_sage_scale,0.867,0.768,0.718,0.898,0.82,0.783,0.32,0.447,0.514,0.102,0.187,0.231,0.089,0.091,0.109,0.069,0.068,0.078,40,sage_scale
40_sage_base,0.575,0.596,0.526,0.649,0.667,0.621,0.598,0.577,0.637,0.353,0.333,0.388,0.118,0.125,0.126,0.095,0.096,0.099,40,sage_base
40_gat,0.394,0.363,0.387,0.536,0.513,0.536,0.776,0.856,0.798,0.509,0.56,0.519,0.174,0.214,0.183,0.134,0.155,0.14,40,gat
40_fnn,0.355,0.363,0.326,0.516,0.52,0.493,0.874,0.868,0.886,0.578,0.567,0.598,0.199,0.207,0.207,0.158,0.171,0.164,40,fnn
