In [1]:
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import wandb
from pathlib import Path

from views_forecasts.extensions import *
from utils import fetch_data, transform_data, get_config_path, get_config_from_path, retrain_transformed_sweep, evaluate
from utils_map import plot_cm_map
# import importlib
# import utils  
# importlib.reload(utils)

import os
os.environ['WANDB_SILENT'] = 'true'

PARA_DICT = {
    'rf': ['transform', 'n_estimators', 'n_jobs', 'learning_rate', 'max_depth', 'min_child_weight', 'subsample', 'colsample_bytree'],
    'xgb': ['transform', 'n_estimators', 'n_jobs', 'learning_rate', 'max_depth', 'min_child_weight', 'subsample', 'colsample_bytree'],
    'gbm': ['transform', 'n_estimators', 'n_jobs', 'learning_rate', 'max_depth', 'min_samples_split', 'min_samples_leaf']
}

In [2]:
level = 'cm'
config_path = Path('./my_config')

In [3]:
transforms = ['raw', 'log', 'normalize', 'standardize']
Datasets_transformed = {}
para_transformed = {}
qslist, Datasets = fetch_data(level)
for t in transforms:
    Datasets_transformed[t], para_transformed[t] = transform_data(Datasets, t, by_group=True)

Fetching query sets
Fetching datasets
 .     

In [4]:
def train():
    run = wandb.init(config=common_config, project=wandb_config['project'], entity=wandb_config['entity'])
    wandb.config.update(model_config, allow_val_change=True)
    
    run_name = ''
    for para in sweep_paras:
        run_name += f'{para}_{run.config[para]}_'
    run_name = run_name.rstrip('_')
    wandb.run.name = run_name
    
    retrain_transformed_sweep(Datasets_transformed, sweep_paras)
    evaluate('calib', para_transformed, by_group=True, plot_map=True)
    run.finish()

In [5]:
common_config_path, wandb_config_path, model_config_path, sweep_config_path = get_config_path(config_path)
common_config = get_config_from_path(common_config_path, 'common')
wandb_config = get_config_from_path(wandb_config_path, 'wandb')

In [6]:
for sweep_file in sweep_config_path.iterdir():
    if sweep_file.is_file():
        model_file = model_config_path / sweep_file.name
        if not model_file.is_file():
            raise FileNotFoundError(f'The corresponding model configuration file {model_file} does not exist.')

        sweep_config = get_config_from_path(sweep_file, 'sweep')
        model_config = get_config_from_path(model_file, 'model')
    
        if sweep_file.stem.split('_')[-2] == 'hurdle':
            continue # Currently Hurdle models are not supported
        model = sweep_file.stem.split('_')[-1]
        sweep_paras = PARA_DICT[model]
        sweep_id = wandb.sweep(sweep_config, project=wandb_config['project'],
                               entity=wandb_config['entity'])
        wandb.agent(sweep_id, function=train)
        print(f'Finish sweeping over model {sweep_file.stem}')
    break

Create sweep with ID: raqaib7u
Sweep URL: https://wandb.ai/model-development-and-deployment/add_map/sweeps/raqaib7u


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01112789675567506, max=1.0)…

{'n_estimators': 100, 'n_jobs': 12, 'learning_rate': 0.05, 'max_depth': 12, 'min_child_weight': 12, 'subsample': 0.5, 'colsample_bytree': 0.5}
Training model fatalities003_nl_baseline_rf
Calibration partition (log)
 * == Performing a run: "fatalities003_nl_baseline_rf_calib_transform_log_n_estimators_100_n_jobs_12_learning_rate_0.05_max_depth_12_min_child_weight_12_subsample_0.5_colsample_bytree_0.5" == * 
Model object named "fatalities003_nl_baseline_rf_calib_transform_log_n_estimators_100_n_jobs_12_learning_rate_0.05_max_depth_12_min_child_weight_12_subsample_0.5_colsample_bytree_0.5" with equivalent metadata already exists.
Fetching "fatalities003_nl_baseline_rf_calib_transform_log_n_estimators_100_n_jobs_12_learning_rate_0.05_max_depth_12_min_child_weight_12_subsample_0.5_colsample_bytree_0.5" from storage
Getting predictions
pr_56_cm_fatalities003_nl_baseline_rf_calib_transform_log_n_estimators_100_n_jobs_12_learning_rate_0.05_max_depth_12_min_child_weight_12_subsample_0.5_colsamp

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011141128233349364, max=1.0…

{'n_estimators': 100, 'n_jobs': 12, 'learning_rate': 0.05, 'max_depth': 12, 'min_child_weight': 12, 'subsample': 0.5, 'colsample_bytree': 0.5}
Training model fatalities003_nl_baseline_rf
Calibration partition (normalize)
 * == Performing a run: "fatalities003_nl_baseline_rf_calib_transform_normalize_n_estimators_100_n_jobs_12_learning_rate_0.05_max_depth_12_min_child_weight_12_subsample_0.5_colsample_bytree_0.5" == * 
Model object named "fatalities003_nl_baseline_rf_calib_transform_normalize_n_estimators_100_n_jobs_12_learning_rate_0.05_max_depth_12_min_child_weight_12_subsample_0.5_colsample_bytree_0.5" with equivalent metadata already exists.
Fetching "fatalities003_nl_baseline_rf_calib_transform_normalize_n_estimators_100_n_jobs_12_learning_rate_0.05_max_depth_12_min_child_weight_12_subsample_0.5_colsample_bytree_0.5" from storage
Getting predictions
pr_56_cm_fatalities003_nl_baseline_rf_calib_transform_normalize_n_estimators_100_n_jobs_12_learning_rate_0.05_max_depth_12_min_child_w

In [11]:
stepcols = ['ged_sb_dep']
steps = [*range(1, 36 + 1, 1)]
for step in steps:
    stepcols.append('step_pred_' + str(step))
run_id = 'Fatalities003'
name = 'cm_fatalities003_nl_baseline_rf_calib_transform_raw_n_estimators_100_n_jobs_15_learning_rate_0.05_max_depth_12_min_child_weight_12_subsample_0.5_colsample_bytree_0.5'

In [12]:
df = pd.DataFrame.forecasts.read_store(run=run_id, name=name).replace([np.inf, -np.inf], 0)[stepcols]
df

pr_56_cm_fatalities003_nl_baseline_rf_calib_transform_raw_n_estimators_100_n_jobs_15_learning_rate_0.05_max_depth_12_min_child_weight_12_subsample_0.5_colsample_bytree_0.5.parquet


Unnamed: 0_level_0,Unnamed: 1_level_0,ged_sb_dep,step_pred_1,step_pred_2,step_pred_3,step_pred_4,step_pred_5,step_pred_6,step_pred_7,step_pred_8,step_pred_9,...,step_pred_27,step_pred_28,step_pred_29,step_pred_30,step_pred_31,step_pred_32,step_pred_33,step_pred_34,step_pred_35,step_pred_36
month_id,country_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
397,1,0.0,1.651095,1.689242,1.657279,0.913426,0.979597,1.065750,0.787519,0.919267,0.836542,...,0.698809,0.687637,0.650110,0.653180,0.696426,0.604014,0.565468,0.674215,0.576182,0.701868
397,2,0.0,3.204240,3.025700,2.939039,1.457489,1.581790,1.575650,1.085819,1.564273,1.311603,...,0.733626,0.694635,0.608645,0.702105,0.721626,0.617485,0.519229,0.568201,0.610041,0.778324
397,3,0.0,2.058623,1.834005,1.912669,1.009168,1.172435,1.024413,0.774664,1.049518,0.965083,...,0.619093,0.600686,0.616504,0.580385,0.594178,0.592638,0.520877,0.557706,0.587092,0.705077
397,4,0.0,0.555807,0.557101,0.612779,0.613710,0.634974,0.942594,0.608548,0.648162,0.608087,...,0.787829,0.899514,1.137710,0.772197,0.844141,0.973165,1.078962,1.244385,1.054108,0.920116
397,5,0.0,3.265088,3.040656,3.225754,1.547952,1.757818,1.631098,1.095765,1.619773,1.376115,...,0.743519,0.729198,0.627620,0.705535,0.750535,0.639610,0.527611,0.581906,0.630597,0.788689
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
444,242,0.0,0.952560,2.511681,0.899215,2.616912,1.150216,0.674771,0.804736,0.797980,0.836487,...,1.489074,1.951522,1.318421,0.960552,1.177219,0.836330,1.184410,0.849898,0.951530,1.018249
444,243,0.0,0.888121,0.780973,0.772871,0.718171,0.822518,0.779895,0.784185,0.843954,0.843327,...,0.963516,1.062782,0.719849,0.939175,0.949503,0.844645,0.999449,0.838120,1.003479,1.053536
444,244,0.0,0.514005,0.573902,0.520572,0.526588,0.546783,0.530654,0.552278,0.673362,0.531548,...,0.573374,0.681835,0.678342,0.662469,0.613124,0.594671,0.676165,0.609238,0.609334,0.627836
444,245,0.0,3.458892,5.573108,9.200547,5.286830,6.728735,7.581571,5.567369,6.420694,5.976310,...,4.409314,4.210153,3.821799,5.964530,5.077895,4.411976,5.496883,5.285884,3.480923,2.899676


In [9]:
df = np.exp(df) - 1
df

Unnamed: 0_level_0,Unnamed: 1_level_0,ged_sb_dep,step_pred_1,step_pred_2,step_pred_3,step_pred_4,step_pred_5,step_pred_6,step_pred_7,step_pred_8,step_pred_9,...,step_pred_27,step_pred_28,step_pred_29,step_pred_30,step_pred_31,step_pred_32,step_pred_33,step_pred_34,step_pred_35,step_pred_36
month_id,country_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
397,1,0.0,0.613933,0.616953,0.615754,0.614390,0.616320,0.616750,0.614932,0.620469,0.617564,...,0.617548,0.616562,0.615784,0.617495,0.619100,0.617310,0.613890,0.617663,0.614962,0.620047
397,2,0.0,0.628708,0.627454,0.626578,0.623116,0.627264,0.625510,0.623309,0.636764,0.630672,...,0.621717,0.615270,0.616696,0.617528,0.622300,0.619750,0.613505,0.619197,0.620475,0.622134
397,3,0.0,0.620750,0.618255,0.618701,0.618201,0.619136,0.617050,0.617961,0.624933,0.621965,...,0.616367,0.612105,0.616021,0.612847,0.615287,0.614563,0.611961,0.614759,0.615695,0.617963
397,4,0.0,0.623216,0.613123,0.613084,0.613172,0.614843,0.615930,0.614272,0.615191,0.614701,...,0.624097,0.625719,0.633003,0.623588,0.631281,0.633249,0.636581,0.634476,0.641765,0.637501
397,5,0.0,0.628965,0.627593,0.626561,0.623811,0.628646,0.624390,0.624241,0.636951,0.631354,...,0.622985,0.615514,0.616201,0.617878,0.622994,0.619938,0.613623,0.619270,0.621060,0.622389
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
444,242,0.0,0.627109,0.635177,0.629947,0.625829,0.630717,0.627193,0.628726,0.621645,0.630045,...,0.640423,0.655568,0.629049,0.635070,0.639029,0.633139,0.637537,0.638622,0.649304,0.653889
444,243,0.0,0.646165,0.636574,0.633543,0.626536,0.643574,0.631771,0.638303,0.638083,0.628973,...,0.641926,0.659257,0.626271,0.643949,0.646233,0.638871,0.646016,0.635770,0.646017,0.653868
444,244,0.0,0.614472,0.614351,0.612793,0.612464,0.612587,0.613666,0.612700,0.614742,0.613690,...,0.620062,0.620757,0.621573,0.620810,0.616627,0.620747,0.615645,0.619771,0.616820,0.616574
444,245,0.0,0.694161,0.798478,0.900139,0.834527,0.846576,0.877188,0.880774,0.885611,0.882882,...,0.781598,0.763405,0.765723,0.832082,0.844601,0.834185,0.850778,0.833401,0.790267,0.765193


In [13]:
df.describe()

Unnamed: 0,ged_sb_dep,step_pred_1,step_pred_2,step_pred_3,step_pred_4,step_pred_5,step_pred_6,step_pred_7,step_pred_8,step_pred_9,...,step_pred_27,step_pred_28,step_pred_29,step_pred_30,step_pred_31,step_pred_32,step_pred_33,step_pred_34,step_pred_35,step_pred_36
count,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,...,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0
mean,43.984293,2.763453,2.719436,2.763168,1.894036,1.928245,1.827149,1.621647,1.83769,1.649706,...,1.242454,1.237369,1.164332,1.202674,1.216925,1.16641,1.115325,1.138484,1.148242,1.224143
std,478.269169,3.672684,3.542995,3.772312,2.644334,2.367929,2.027407,1.869207,2.093655,1.758595,...,1.432982,1.447862,1.432327,1.397935,1.329776,1.451369,1.308437,1.410775,1.367916,1.252939
min,0.0,0.480216,0.480745,0.482173,0.490586,0.491545,0.485932,0.496379,0.491325,0.493941,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.748876,0.772803,0.729018,0.711627,0.79151,0.739163,0.752162,0.738416,0.753256,...,0.650071,0.647873,0.608584,0.632843,0.64985,0.600785,0.527729,0.576196,0.598334,0.699598
50%,0.0,2.634222,2.419861,2.716885,1.371664,1.39656,1.437981,1.095765,1.448651,1.228742,...,0.745319,0.72931,0.657638,0.716428,0.750603,0.643519,0.587826,0.602538,0.630868,0.793501
75%,0.0,3.265092,3.173449,3.235476,1.622635,1.757818,1.741029,1.335304,1.620372,1.421709,...,0.967016,0.993814,0.858438,0.96792,0.985358,0.851137,0.967359,0.89546,0.911281,1.062088
max,19000.0,46.281792,47.852463,49.803467,33.822624,28.254168,28.297398,17.863476,23.64448,18.066242,...,20.86784,22.295361,28.512068,24.516321,16.908724,20.12236,16.604708,29.344843,22.985306,15.077058


In [103]:
stepcols = ['ged_sb_dep']
steps = [*range(1, 36 + 1, 1)]
for step in steps:
    stepcols.append('step_pred_' + str(step))
run_id = 'Fatalities003'
name = 'cm_fatalities003_nl_topics_rf_calib_transform_raw_n_estimators_250_n_jobs_12_learning_rate_0.05_max_depth_12_min_child_weight_12_subsample_0.5_colsample_bytree_0.5'

In [112]:
df = pd.DataFrame.forecasts.read_store(run=run_id, name=name).replace([np.inf, -np.inf], 0)[stepcols]

pr_56_cm_fatalities003_nl_topics_rf_calib_transform_raw_n_estimators_250_n_jobs_12_learning_rate_0.05_max_depth_12_min_child_weight_12_subsample_0.5_colsample_bytree_0.5.parquet


In [113]:
df

Unnamed: 0_level_0,Unnamed: 1_level_0,ged_sb_dep,step_pred_1,step_pred_2,step_pred_3,step_pred_4,step_pred_5,step_pred_6,step_pred_7,step_pred_8,step_pred_9,...,step_pred_27,step_pred_28,step_pred_29,step_pred_30,step_pred_31,step_pred_32,step_pred_33,step_pred_34,step_pred_35,step_pred_36
month_id,country_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
397,1,0.0,1.888033,1.536575,2.049761,1.194238,1.080280,1.230281,1.012267,0.979016,1.124231,...,0.605951,0.616939,0.591293,0.696502,0.823638,0.557715,0.525643,0.541731,0.555350,0.671805
397,2,0.0,3.313491,2.442976,3.095902,1.791524,1.584468,1.561841,1.353563,1.191040,1.440145,...,0.726233,0.770196,0.647228,0.721609,0.706091,0.572721,0.510407,0.583434,0.579779,0.788106
397,3,0.0,1.746300,1.800268,1.744825,1.028135,1.015283,1.214884,0.834816,0.719306,0.796419,...,1.161782,1.024357,0.661682,0.779543,0.571230,0.505648,0.748100,0.514143,0.524082,0.612821
397,4,0.0,0.487359,0.500316,0.511034,0.510376,0.575729,0.510639,0.576125,0.510816,0.504026,...,0.665250,0.683974,0.645012,0.484309,1.083846,1.449021,1.139949,0.482186,0.479546,0.479713
397,5,0.0,3.312054,2.441470,3.093398,1.789339,1.576807,1.559876,1.353323,1.190878,1.439366,...,0.723306,0.757622,0.646705,0.720935,0.699926,0.572176,0.510458,0.581869,0.579295,0.787506
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
444,242,0.0,4.505231,5.732021,4.511056,11.456407,1.927837,1.705806,2.175097,7.634606,8.995864,...,9.225223,3.546402,3.188511,5.113416,3.803534,3.561557,3.248026,2.929720,2.788255,3.098599
444,243,0.0,0.564945,0.647136,0.593240,0.794530,0.965720,0.944105,0.807292,1.222221,1.019295,...,1.075368,1.527027,0.485269,0.516639,0.520443,0.820115,0.501266,0.590600,0.635062,1.199985
444,244,0.0,0.670239,1.084977,0.824759,0.710583,1.426395,1.099138,1.305220,1.311780,0.932088,...,2.428319,3.952767,0.813094,1.311723,0.843666,1.109788,2.646547,1.136765,1.150167,1.101795
444,245,0.0,5.004318,6.232212,8.742626,7.999347,5.930683,6.081397,5.129094,5.645001,6.278614,...,3.930102,2.930245,3.637518,4.189796,4.103630,3.408108,2.996001,3.379579,2.662481,2.405901


In [114]:
df.describe()

Unnamed: 0,ged_sb_dep,step_pred_1,step_pred_2,step_pred_3,step_pred_4,step_pred_5,step_pred_6,step_pred_7,step_pred_8,step_pred_9,...,step_pred_27,step_pred_28,step_pred_29,step_pred_30,step_pred_31,step_pred_32,step_pred_33,step_pred_34,step_pred_35,step_pred_36
count,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,...,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0,9168.0
mean,43.984293,3.064313,2.938524,3.024721,2.463671,2.344805,2.241646,2.080354,2.044865,2.113336,...,1.541634,1.638868,1.574594,1.621234,1.630235,1.54555,1.448616,1.550858,1.50365,1.700572
std,478.269169,4.013247,3.939282,3.012343,2.652986,2.608635,2.326226,2.094101,2.217072,2.18277,...,1.607832,1.771869,1.72824,1.659796,1.791969,1.802667,1.751915,1.778156,1.798335,2.024134
min,0.0,0.476367,0.476422,0.476632,0.476452,0.476629,0.476453,0.476443,0.47649,0.476624,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,1.141073,1.428083,1.226387,1.061084,1.03064,1.07464,0.995747,0.931747,0.949853,...,0.71953,0.752138,0.645578,0.720144,0.699926,0.572176,0.51293,0.581869,0.579295,0.768986
50%,0.0,2.967103,2.444391,3.093398,1.789339,1.576807,1.629165,1.357512,1.201662,1.439366,...,0.785038,0.801496,0.758299,0.81774,0.767103,0.69922,0.666647,0.680501,0.691359,0.833531
75%,0.0,3.312054,2.83725,3.34623,2.211347,2.121895,2.07509,1.941592,2.056416,2.010623,...,1.74657,1.930422,1.836017,1.928235,1.854694,1.79206,1.655463,1.842366,1.660575,1.698188
max,19000.0,57.567913,58.776665,34.229748,25.622738,30.41069,36.324074,19.378551,22.624584,20.225647,...,19.464668,20.670267,17.030479,17.005161,19.832689,28.565922,27.920639,24.754251,28.713552,22.110678


In [141]:
months = df.index.levels[0].tolist()

In [147]:
months = df.index.levels[0].tolist()
step_preds = [f'step_pred_{i}' for i in range(1, 37)]
wandb.init(project='test', entity='model-development-and-deployment')
for month in [399, 400]:
    for step in ['step_pred_1']:
        fig = plot_cm_map(df, month, step)
        
        wandb.log({f'month_{month}_step_{step}': wandb.Image(fig)})
        
        
wandb.finish()