In [1]:
import sys
sys.path.append('../')

In [2]:
import itertools
import pandas as pd
import numpy as np
from tqdm.auto import tqdm, trange
from pathlib import Path

from models.train import test as test_clas
from models.train_reg import test as test_reg
from models.utils import ContagionDataset, set_seed, ConfusionMatrix, save_pickle, load_pickle
from sklearn.metrics import matthews_corrcoef, mean_squared_error
from models.results import ResultCollection

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
seed = 4444

# metric_filter_1 = 'val_mcc'
# metric_filter_2 = 'test_mcc'

data_dir_name = 'data'
save_path_name = 'saved_'
networks = ['sym_network', 'europe_network']

target = 'additional_stress'
dict_sets_lengths = {
    '75':(0.6, 0.15, 0.25),
    '40':(0.32, 0.08, 0.6),
    '10':(0.08, 0.02, 0.9),
}
dict_test_type = {
    'reg': test_reg,
    'clas': test_clas,
}
dict_metric = {
    'reg': ('test_rmse_perc', False),
    'clas': ('test_mcc', True),
}


In [4]:
# col = ResultCollection()
# paths = list(itertools.chain.from_iterable([[k for k in Path(n).glob(f"models*/{save_path_name}*") if k.is_dir()] for n in networks]))
# network_paths = [[k for k in Path(n).glob(f"models*/{save_path_name}*") if k.is_dir()] for n in networks]

In [5]:
col_results = {}

for n, test_type in itertools.product(networks, dict_test_type.keys()):
    col = ResultCollection()
    metric = dict_metric[test_type]
    col_results[f"{n}_{test_type}"] = (col, metric)

    for p in [k for k in Path(n).glob(f"models_{test_type}*/{save_path_name}*") if k.is_dir()]:
        print(p)
        name = p.name.split(save_path_name)[1]
        _,_,sets_type = p.parent.name.split('_', 3)
        sets_lengths = dict_sets_lengths[sets_type]
        data_dir = p.parent.parent.joinpath(data_dir_name)
        
        set_seed(seed)
        dataset_val = ContagionDataset(
            raw_dir=data_dir,
            drop_edges=0,
            sets_lengths=sets_lengths,
            target = target,
            add_self_loop=True,
        )

        if 'base' in name:
            print("ahora")
            pass

        def approach(s):
            if 'base' in s:
                return 'base_n'
            elif 'scale' in s:
                return 'scale'
            else:
                return 'dist'

        r = dict_test_type[test_type](
            dataset=dataset_val,
            save_path=str(p),
            n_runs=1,
            debug_mode=False,
            use_cpu=False,
            save=True,
            use_edge_weight=True,
            approach_default= approach(name),
        )

        result = col.add(r[2], f"{sets_type}_{name}", sets_type, model=name)
        result.save_best(metric[0],p,metric[1])


sym_network\models_reg_10\saved_fnn


100%|██████████| 747/747 [00:27<00:00, 26.87it/s]


sym_network\models_reg_10\saved_gat


100%|██████████| 7433/7433 [05:46<00:00, 21.46it/s]


sym_network\models_reg_10\saved_gcn


100%|██████████| 1066/1066 [00:38<00:00, 27.78it/s]


sym_network\models_reg_10\saved_sage_base
ahora


100%|██████████| 5341/5341 [06:55<00:00, 12.87it/s]


sym_network\models_reg_10\saved_sage_dist


100%|██████████| 2922/2922 [04:55<00:00,  9.89it/s]


sym_network\models_reg_10\saved_sage_scale


100%|██████████| 5751/5751 [08:27<00:00, 11.33it/s]


sym_network\models_reg_40\saved_fnn


100%|██████████| 551/551 [00:31<00:00, 17.42it/s]


sym_network\models_reg_40\saved_gat


100%|██████████| 4471/4471 [05:12<00:00, 14.29it/s]


sym_network\models_reg_40\saved_gcn


100%|██████████| 1191/1191 [01:18<00:00, 15.20it/s]


sym_network\models_reg_40\saved_sage_base
ahora


100%|██████████| 2774/2774 [09:44<00:00,  4.75it/s]


sym_network\models_reg_40\saved_sage_dist


100%|██████████| 4026/4026 [14:16<00:00,  4.70it/s]


sym_network\models_reg_40\saved_sage_scale


100%|██████████| 3768/3768 [13:18<00:00,  4.72it/s]


sym_network\models_reg_75\saved_fnn


100%|██████████| 513/513 [00:29<00:00, 17.24it/s]


sym_network\models_reg_75\saved_gat


100%|██████████| 4959/4959 [05:14<00:00, 15.77it/s]


sym_network\models_reg_75\saved_gcn


100%|██████████| 1228/1228 [01:21<00:00, 15.15it/s]


sym_network\models_reg_75\saved_sage_base
ahora


100%|██████████| 2974/2974 [10:31<00:00,  4.71it/s]


sym_network\models_reg_75\saved_sage_dist


100%|██████████| 4128/4128 [14:37<00:00,  4.71it/s]


sym_network\models_reg_75\saved_sage_scale


100%|██████████| 4155/4155 [14:36<00:00,  4.74it/s]


sym_network\models_clas_10\saved_fnn


100%|██████████| 692/692 [00:36<00:00, 18.93it/s]


sym_network\models_clas_10\saved_gat


100%|██████████| 16260/16260 [16:52<00:00, 16.06it/s]


sym_network\models_clas_10\saved_gcn


100%|██████████| 2020/2020 [02:14<00:00, 15.05it/s]


sym_network\models_clas_10\saved_sage


100%|██████████| 5784/5784 [16:16<00:00,  5.92it/s]


sym_network\models_clas_40\saved_fnn


100%|██████████| 994/994 [00:52<00:00, 18.78it/s]


sym_network\models_clas_40\saved_gat


100%|██████████| 9929/9929 [10:09<00:00, 16.28it/s]


sym_network\models_clas_40\saved_gcn


100%|██████████| 3092/3092 [03:31<00:00, 14.62it/s]


sym_network\models_clas_40\saved_sage


100%|██████████| 6925/6925 [24:18<00:00,  4.75it/s]


sym_network\models_clas_75\saved_fnn


100%|██████████| 991/991 [00:50<00:00, 19.72it/s]


sym_network\models_clas_75\saved_gat


100%|██████████| 10641/10641 [11:02<00:00, 16.06it/s]


sym_network\models_clas_75\saved_gcn


100%|██████████| 2294/2294 [02:31<00:00, 15.16it/s]


sym_network\models_clas_75\saved_sage


100%|██████████| 6957/6957 [23:58<00:00,  4.84it/s] 


europe_network\models_clas_10\saved_fnn


100%|██████████| 1033/1033 [00:52<00:00, 19.55it/s]


europe_network\models_clas_10\saved_gat


100%|██████████| 3288/3288 [03:15<00:00, 16.80it/s]


europe_network\models_clas_10\saved_gcn


100%|██████████| 1775/1775 [01:50<00:00, 16.03it/s]


europe_network\models_clas_10\saved_sage


100%|██████████| 17/17 [00:04<00:00,  3.50it/s]


europe_network\models_clas_40\saved_fnn


100%|██████████| 2011/2011 [01:46<00:00, 18.80it/s]


europe_network\models_clas_40\saved_gat


100%|██████████| 9542/9542 [09:30<00:00, 16.72it/s]


europe_network\models_clas_40\saved_gcn


100%|██████████| 2417/2417 [02:46<00:00, 14.52it/s]


europe_network\models_clas_40\saved_sage


100%|██████████| 400/400 [02:10<00:00,  3.08it/s]


europe_network\models_clas_75\saved_fnn


100%|██████████| 1972/1972 [01:42<00:00, 19.33it/s]


europe_network\models_clas_75\saved_gat


100%|██████████| 10357/10357 [10:27<00:00, 16.51it/s]


europe_network\models_clas_75\saved_gcn


100%|██████████| 1832/1832 [02:00<00:00, 15.20it/s]


europe_network\models_clas_75\saved_sage


100%|██████████| 423/423 [02:18<00:00,  3.05it/s]


In [6]:
save_pickle(col_results, 'results.pickle')
# col_results = load_pickle('results.pickle')

In [7]:
col_results

{'sym_network_reg': (<models.results.ResultCollection at 0x2d9af7498b0>,
  ('test_rmse_perc', False)),
 'sym_network_clas': (<models.results.ResultCollection at 0x2d9e5779100>,
  ('test_mcc', True)),
 'europe_network_reg': (<models.results.ResultCollection at 0x2dbb6542bb0>,
  ('test_rmse_perc', False)),
 'europe_network_clas': (<models.results.ResultCollection at 0x2dc7d6a4310>,
  ('test_mcc', True))}

In [8]:
r = {k:v[0].df(v[1][0], v[1][1]) for k,v in col_results.items()}

In [9]:
r['sym_network_clas'].round(3).to_excel('sym_clas.xlsx')
r['sym_network_clas'].round(3)

Unnamed: 0_level_0,train_mcc,val_mcc,test_mcc,train_acc,val_acc,test_acc,train_rmse,val_rmse,test_rmse,train_mae,val_mae,test_mae,train_rmse_perc,val_rmse_perc,test_rmse_perc,train_mae_perc,val_mae_perc,test_mae_perc,group,model
uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
75_sage,0.957,0.83,0.802,0.968,0.871,0.851,0.18,0.377,0.45,0.032,0.133,0.165,0.077,0.103,0.11,0.064,0.077,0.082,75,sage
75_gcn,0.567,0.503,0.589,0.673,0.622,0.688,0.664,0.745,0.687,0.361,0.431,0.36,0.163,0.176,0.161,0.118,0.128,0.115,75,gcn
75_fnn,0.373,0.37,0.409,0.494,0.48,0.52,0.896,0.894,0.919,0.602,0.613,0.6,0.219,0.218,0.22,0.172,0.174,0.171,75,fnn
75_gat,0.419,0.446,0.375,0.561,0.578,0.528,0.91,0.94,0.981,0.556,0.556,0.621,0.228,0.224,0.243,0.163,0.156,0.177,75,gat
40_sage,0.9,0.713,0.717,0.925,0.783,0.788,0.285,0.492,0.533,0.077,0.225,0.233,0.089,0.127,0.129,0.07,0.092,0.093,40,sage
40_gcn,0.59,0.548,0.541,0.692,0.658,0.656,0.649,0.665,0.693,0.346,0.375,0.387,0.16,0.167,0.17,0.116,0.126,0.125,40,gcn
40_fnn,0.36,0.339,0.375,0.51,0.5,0.523,0.926,0.922,0.904,0.612,0.617,0.589,0.225,0.231,0.224,0.173,0.175,0.171,40,fnn
40_gat,0.467,0.356,0.37,0.596,0.508,0.519,0.885,1.194,0.995,0.525,0.758,0.637,0.216,0.3,0.242,0.155,0.208,0.178,40,gat
10_sage,0.659,0.676,0.599,0.742,0.767,0.696,0.725,0.483,0.706,0.342,0.233,0.363,0.169,0.115,0.17,0.113,0.084,0.119,10,sage
10_gcn,0.627,0.446,0.494,0.717,0.6,0.619,0.677,0.632,0.719,0.342,0.4,0.425,0.158,0.153,0.175,0.114,0.122,0.131,10,gcn


In [10]:
r['europe_network_clas'].round(3).to_excel('eur_clas.xlsx')
r['europe_network_clas'].round(3)

Unnamed: 0_level_0,train_mcc,val_mcc,test_mcc,train_acc,val_acc,test_acc,train_rmse,val_rmse,test_rmse,train_mae,val_mae,test_mae,train_rmse_perc,val_rmse_perc,test_rmse_perc,train_mae_perc,val_mae_perc,test_mae_perc,group,model
uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
75_sage,0.91,0.69,0.629,0.932,0.769,0.721,0.383,0.736,0.849,0.091,0.329,0.417,0.115,0.194,0.219,0.078,0.121,0.141,75,sage
75_fnn,0.56,0.451,0.489,0.665,0.569,0.61,0.849,0.94,0.975,0.46,0.579,0.569,0.217,0.237,0.248,0.149,0.165,0.174,75,fnn
75_gat,0.424,0.443,0.459,0.551,0.569,0.577,1.42,1.442,1.349,0.888,0.894,0.815,0.364,0.376,0.357,0.253,0.259,0.245,75,gat
75_gcn,0.269,0.339,0.254,0.403,0.491,0.395,1.49,1.445,1.513,1.052,0.94,1.08,0.368,0.363,0.382,0.283,0.263,0.296,75,gcn
40_sage,0.974,0.613,0.578,0.981,0.713,0.683,0.14,0.928,0.957,0.019,0.461,0.507,0.075,0.236,0.248,0.063,0.15,0.164,40,sage
40_fnn,0.524,0.576,0.494,0.643,0.687,0.616,0.87,0.786,0.928,0.489,0.409,0.537,0.224,0.213,0.236,0.154,0.142,0.164,40,fnn
40_gat,0.466,0.49,0.41,0.589,0.617,0.541,1.217,1.244,1.318,0.727,0.713,0.829,0.311,0.311,0.331,0.211,0.206,0.234,40,gat
40_gcn,0.278,0.232,0.276,0.346,0.287,0.361,1.164,1.146,1.198,0.887,0.913,0.904,0.29,0.289,0.301,0.235,0.242,0.242,40,gcn
10_fnn,0.651,0.605,0.463,0.739,0.714,0.59,0.734,0.707,0.958,0.348,0.357,0.574,0.187,0.163,0.243,0.128,0.115,0.171,10,fnn
10_sage,0.977,0.534,0.46,0.983,0.643,0.594,0.132,1.035,1.216,0.017,0.571,0.719,0.078,0.259,0.305,0.067,0.17,0.208,10,sage


In [11]:
r['sym_network_reg'].round(3).to_excel('sym_reg.xlsx')
r['sym_network_reg'].round(3)

Unnamed: 0_level_0,train_mcc,val_mcc,test_mcc,train_acc,val_acc,test_acc,train_rmse,val_rmse,test_rmse,train_mae,val_mae,test_mae,train_rmse_perc,val_rmse_perc,test_rmse_perc,train_mae_perc,val_mae_perc,test_mae_perc,group,model
uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
10_sage_dist,0.802,0.669,0.654,0.85,0.767,0.739,0.418,0.483,0.554,0.158,0.233,0.276,0.084,0.106,0.121,0.068,0.079,0.087,10,sage_dist
10_sage_scale,0.769,0.575,0.576,0.825,0.7,0.68,0.474,0.548,0.62,0.192,0.3,0.341,0.109,0.144,0.136,0.089,0.119,0.104,10,sage_scale
10_sage_base,0.521,0.443,0.486,0.592,0.6,0.582,0.658,0.632,0.702,0.417,0.4,0.441,0.124,0.12,0.15,0.096,0.095,0.109,10,sage_base
10_gcn,0.6,0.444,0.491,0.692,0.6,0.607,0.599,0.816,0.723,0.325,0.467,0.434,0.151,0.206,0.175,0.119,0.16,0.132,10,gcn
10_fnn,0.158,0.377,0.231,0.333,0.533,0.4,0.917,0.753,0.906,0.725,0.5,0.67,0.212,0.174,0.21,0.182,0.144,0.177,10,fnn
10_gat,0.292,0.184,0.259,0.458,0.367,0.432,0.801,0.856,0.919,0.575,0.667,0.657,0.183,0.209,0.22,0.148,0.173,0.176,10,gat
40_sage_dist,0.942,0.789,0.728,0.956,0.842,0.796,0.209,0.398,0.469,0.044,0.158,0.209,0.056,0.078,0.086,0.045,0.057,0.062,40,sage_dist
40_sage_base,0.77,0.696,0.701,0.823,0.767,0.774,0.421,0.483,0.491,0.177,0.233,0.23,0.098,0.102,0.104,0.079,0.079,0.079,40,sage_base
40_sage_scale,0.889,0.773,0.752,0.915,0.825,0.808,0.292,0.418,0.463,0.085,0.175,0.199,0.093,0.098,0.108,0.074,0.077,0.08,40,sage_scale
40_gcn,0.519,0.402,0.502,0.629,0.542,0.621,0.668,0.847,0.7,0.396,0.533,0.414,0.164,0.223,0.166,0.129,0.163,0.128,40,gcn


In [12]:
col_results = load_pickle('results.pickle')

In [15]:
col_results['sym_network_clas'][0]

<models.results.ResultCollection at 0x2de49a935b0>