In [1]:
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from tqdm import tqdm

## Dataset

In [2]:
PATH_TO_DATA_FOLDER = "../Dataset time-series/data/Yemen/"

In [3]:
# Load the dataset.
df = pd.read_csv(PATH_TO_DATA_FOLDER + "all_train.csv", header = [0, 1], index_col = 0)
df.index.name = "Datetime"
df.index = pd.to_datetime(df.index)
freq = "D"
df.index.freq = freq

In [4]:
df

AdminStrata,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,...,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz
Indicator,1 Month Anomaly (%) Rainfall,3 Months Anomaly (%) Rainfall,Cereals and tubers,Exchange rate (USD/LCU),FCS,Fatality,Lat,Lon,NDVI Anomaly,Population,...,Exchange rate (USD/LCU),FCS,Fatality,Lat,Lon,NDVI Anomaly,Population,Rainfall (mm),Ramadan,rCSI
Datetime,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
2018-08-22,273.127727,330.007392,0.328827,0.370033,31.809805,10.0,13.704878,46.158142,3051.954690,615154.0,...,0.370157,39.043078,171.0,13.416517,43.778161,3234.545997,3065034.0,41.6606,0.0,50.919038
2018-08-23,272.395637,327.547664,0.330076,0.372509,32.374660,10.0,13.704878,46.158142,3048.945680,615154.0,...,0.372918,39.371670,147.0,13.416517,43.778161,3229.722645,3065034.0,42.6437,0.0,50.415852
2018-08-24,271.663548,325.087935,0.331326,0.374986,33.772110,10.0,13.704878,46.158142,3045.883031,615154.0,...,0.375680,36.662083,145.0,13.416517,43.778161,3224.674241,3065034.0,43.6268,0.0,49.391298
2018-08-25,270.931459,322.628207,0.332575,0.377463,34.533738,10.0,13.704878,46.158142,3042.766744,615154.0,...,0.378441,37.205170,156.0,13.416517,43.778161,3219.400787,3065034.0,44.6099,0.0,50.302392
2018-08-26,270.199369,320.168478,0.333825,0.379939,32.327892,10.0,13.704878,46.158142,3039.596818,615154.0,...,0.381202,37.025723,164.0,13.416517,43.778161,3213.902282,3065034.0,45.5930,0.0,50.293046
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2020-05-27,483.374576,503.090907,0.498504,0.545873,36.877190,110.0,13.704878,46.158142,3854.300870,615154.0,...,0.542802,28.549600,128.0,13.416517,43.778161,3868.085212,3065034.0,51.6261,27.0,47.568770
2020-05-28,486.728971,506.709718,0.498504,0.545873,36.669190,131.0,13.704878,46.158142,3857.571292,615154.0,...,0.542802,27.438180,129.0,13.416517,43.778161,3866.480959,3065034.0,51.2398,26.0,47.568770
2020-05-29,490.083367,510.328529,0.498504,0.545873,37.472700,131.0,13.704878,46.158142,3860.690117,615154.0,...,0.542802,28.970260,116.0,13.416517,43.778161,3864.253975,3065034.0,50.8535,25.0,46.352410
2020-05-30,493.437762,513.947340,0.498504,0.545873,37.743430,134.0,13.704878,46.158142,3863.657346,615154.0,...,0.542802,28.549380,105.0,13.416517,43.778161,3861.404262,3065034.0,50.4672,24.0,45.931540


In [5]:
PROVINCES = df.columns.get_level_values(0).unique()
PROVINCES

Index(['Abyan', 'Aden', 'Al Bayda', 'Al Dhale'e', 'Al Hudaydah', 'Al Jawf',
       'Al Maharah', 'Al Mahwit', 'Amanat Al Asimah', 'Amran', 'Dhamar',
       'Hajjah', 'Ibb', 'Lahj', 'Marib', 'Raymah', 'Sa'ada', 'Sana'a',
       'Shabwah', 'Taizz'],
      dtype='object', name='AdminStrata')

In [6]:
PREDICTORS = df.columns.get_level_values(1).unique()
PREDICTORS

Index(['1 Month Anomaly (%) Rainfall', '3 Months Anomaly (%) Rainfall',
       'Cereals and tubers', 'Exchange rate (USD/LCU)', 'FCS', 'Fatality',
       'Lat', 'Lon', 'NDVI Anomaly', 'Population', 'Rainfall (mm)', 'Ramadan',
       'rCSI'],
      dtype='object', name='Indicator')

In [7]:
from ste import STE

# Lag importance

In [8]:
# Select only the endogenous indicators (FCS indicator).
df_fcs = df.xs("FCS", axis = 1, level = 1, drop_level = False)
df_fcs.head()

AdminStrata,Abyan,Aden,Al Bayda,Al Dhale'e,Al Hudaydah,Al Jawf,Al Maharah,Al Mahwit,Amanat Al Asimah,Amran,Dhamar,Hajjah,Ibb,Lahj,Marib,Raymah,Sa'ada,Sana'a,Shabwah,Taizz
Indicator,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS
Datetime,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2
2018-08-22,31.809805,19.412643,49.318569,38.415584,17.902243,35.067187,18.083183,25.745639,24.075777,36.360875,30.204461,32.363844,41.695435,50.265111,48.752228,51.354731,29.250284,30.901922,33.048813,39.043078
2018-08-23,32.37466,21.149675,51.034483,37.643521,18.287813,32.744186,11.840689,25.433362,25.171302,35.571979,29.365854,32.252939,40.600909,49.943757,48.817673,52.44785,29.93026,31.738683,31.034483,39.37167
2018-08-24,33.77211,25.506867,49.224466,39.629049,19.172334,36.061269,14.88498,26.628819,23.225682,38.183475,30.981888,33.020252,38.79717,48.953202,47.548161,51.786465,31.395349,32.759045,25.936048,36.662083
2018-08-25,34.533738,24.338942,48.076367,40.495283,21.687916,40.400411,16.566265,25.19826,23.574934,39.430147,33.146592,33.111702,38.377358,46.274738,46.073439,51.113811,30.455291,33.138993,28.135259,37.20517
2018-08-26,32.327892,24.294671,48.911223,40.991926,21.371394,35.632689,20.193152,25.808842,24.204882,45.047022,29.52953,32.469268,37.981147,44.569574,41.661668,48.916896,25.828248,33.035913,28.347996,37.025723


In [9]:
lags = 20
def history_length_Y(serie):
    adminstrata = serie.name[0]
    y = list()
    for k in range(lags):
        result = STE.entropy_rate(serie, m = 3, k = k+1)
        y.append(result)
        
    return y  

tqdm.pandas()
df_results_Y = df_fcs.progress_apply(history_length_Y)
df_results_Y.index = df_results_Y.index  + 1
df_results_Y.head()

100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:15<00:00,  1.32it/s]


AdminStrata,Abyan,Aden,Al Bayda,Al Dhale'e,Al Hudaydah,Al Jawf,Al Maharah,Al Mahwit,Amanat Al Asimah,Amran,Dhamar,Hajjah,Ibb,Lahj,Marib,Raymah,Sa'ada,Sana'a,Shabwah,Taizz
Indicator,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS,FCS
1,1.447442,1.471874,1.480932,1.438884,1.41099,1.466584,1.477994,1.469904,1.423165,1.44646,1.482986,1.438975,1.440279,1.427457,1.473366,1.433772,1.458303,1.471596,1.478427,1.482024
2,1.418601,1.430676,1.461778,1.425798,1.383435,1.430214,1.449241,1.442238,1.407213,1.400857,1.463248,1.422195,1.419345,1.394567,1.439141,1.405206,1.43219,1.443888,1.446682,1.443498
3,1.312922,1.340119,1.36108,1.35034,1.296553,1.33741,1.366129,1.366367,1.307153,1.308935,1.403087,1.305609,1.306903,1.32782,1.336615,1.281924,1.337758,1.33517,1.363872,1.334513
4,1.095396,1.141668,1.137752,1.093005,1.092344,1.097571,1.095603,1.138362,1.060439,1.063219,1.14582,1.06404,1.087776,1.108813,1.11278,1.04567,1.10977,1.111445,1.130532,1.086956
5,0.760458,0.765101,0.730201,0.73377,0.82673,0.723279,0.756863,0.756622,0.714431,0.745536,0.718986,0.754394,0.782317,0.770436,0.735356,0.757025,0.749989,0.721619,0.72137,0.724947


In [10]:
import seaborn as sns
from ipywidgets import interact, widgets, fixed

def plot_df(name, df):
    group = df[name]

    # Set default trace colors with colorway.
    colorway = sns.color_palette("hls", 8).as_hex()
    layout = go.Layout(colorway = colorway)

    fig = go.Figure(layout = layout)

    for column in group.columns:
        fig.add_trace(go.Scatter(x = group.index, y = group[column], name = column, mode = "lines", 
                                 showlegend = True, line = dict(width = 1.5)))

    # Edit the layout.
    fig.update_layout(title = dict(text = name, y = 0.9, x = 0.5))
    fig.update_layout(xaxis_title = dict(text = "Lags"))

    fig.show()

In [11]:
# Create figure.
w = widgets.ToggleButtons(options = df_results_Y.columns.levels[0], description = "Adminstrata:", 
                          disabled = False)
p = interact(plot_df, name = w, df = fixed(df_results_Y))

interactive(children=(ToggleButtons(description='Adminstrata:', options=('Abyan', 'Aden', 'Al Bayda', "Al Dhal…

In [12]:
# Select only the exogenous indicators.
# Delete static features.
df_no_fcs = df.drop(columns = ["FCS", "Lat", "Lon", "Population"], axis = 1, level = 1)
df_no_fcs.head()

AdminStrata,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Aden,...,Shabwah,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz
Indicator,1 Month Anomaly (%) Rainfall,3 Months Anomaly (%) Rainfall,Cereals and tubers,Exchange rate (USD/LCU),Fatality,NDVI Anomaly,Rainfall (mm),Ramadan,rCSI,1 Month Anomaly (%) Rainfall,...,rCSI,1 Month Anomaly (%) Rainfall,3 Months Anomaly (%) Rainfall,Cereals and tubers,Exchange rate (USD/LCU),Fatality,NDVI Anomaly,Rainfall (mm),Ramadan,rCSI
Datetime,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
2018-08-22,273.127727,330.007392,0.328827,0.370033,10.0,3051.95469,12.9544,0.0,37.42188,316.119821,...,34.3022,297.556947,434.194043,0.349354,0.370157,171.0,3234.545997,41.6606,0.0,50.919038
2018-08-23,272.395637,327.547664,0.330076,0.372509,10.0,3048.94568,13.0949,0.0,37.346207,315.989157,...,36.742499,300.172125,429.447759,0.349575,0.372918,147.0,3229.722645,42.6437,0.0,50.415852
2018-08-24,271.663548,325.087935,0.331326,0.374986,10.0,3045.883031,13.2354,0.0,37.999408,315.858492,...,35.857748,302.787303,424.701475,0.349796,0.37568,145.0,3224.674241,43.6268,0.0,49.391298
2018-08-25,270.931459,322.628207,0.332575,0.377463,10.0,3042.766744,13.3759,0.0,36.068234,315.727828,...,34.658434,305.402481,419.955191,0.350017,0.378441,156.0,3219.400787,44.6099,0.0,50.302392
2018-08-26,270.199369,320.168478,0.333825,0.379939,10.0,3039.596818,13.5164,0.0,38.218464,315.597164,...,35.91308,308.017659,415.208908,0.350238,0.381202,164.0,3213.902282,45.593,0.0,50.293046


In [13]:
lags = 180
def history_length_X(serie):
    adminstrata = serie.name[0]
    y = list()
    for k in range(lags):
        result = STE.calc_ste(serie, df_fcs[adminstrata]["FCS"], m = 3, kx = k+1, ky = 1)
        y.append(result)        
    return y  

tqdm.pandas()
df_results_X = df_no_fcs.progress_apply(history_length_X)
df_results_X.index = df_results_X.index  + 1
df_results_X.head()

100%|████████████████████████████████████████████████████████████████████████████████| 180/180 [56:35<00:00, 18.86s/it]


AdminStrata,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Aden,...,Shabwah,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz
Indicator,1 Month Anomaly (%) Rainfall,3 Months Anomaly (%) Rainfall,Cereals and tubers,Exchange rate (USD/LCU),Fatality,NDVI Anomaly,Rainfall (mm),Ramadan,rCSI,1 Month Anomaly (%) Rainfall,...,rCSI,1 Month Anomaly (%) Rainfall,3 Months Anomaly (%) Rainfall,Cereals and tubers,Exchange rate (USD/LCU),Fatality,NDVI Anomaly,Rainfall (mm),Ramadan,rCSI
1,0.042594,0.037121,0.039075,0.020509,0.059726,0.023402,0.044402,0.019067,0.066331,0.036163,...,0.072714,0.048549,0.053239,0.035383,0.031334,0.047913,0.0361,0.047341,0.030115,0.088939
2,0.068722,0.057437,0.053301,0.036206,0.104562,0.0392,0.098261,0.029227,0.256069,0.056524,...,0.260138,0.084165,0.073576,0.057196,0.047391,0.222504,0.060663,0.089768,0.039824,0.279563
3,0.101525,0.079855,0.076327,0.047228,0.201309,0.053074,0.138407,0.033421,0.566574,0.076103,...,0.585155,0.118331,0.111752,0.079579,0.062529,0.494737,0.085909,0.129731,0.048326,0.581052
4,0.124783,0.108614,0.093525,0.05578,0.270242,0.065578,0.191104,0.040155,0.885807,0.099596,...,0.959168,0.155767,0.138686,0.104706,0.077963,0.810982,0.11213,0.169541,0.05638,0.943797
5,0.149979,0.124409,0.121509,0.07027,0.324171,0.076856,0.229756,0.049489,1.16406,0.124861,...,1.192155,0.194448,0.164771,0.126879,0.091129,1.056742,0.133972,0.21453,0.063942,1.178371


In [14]:
# Create figure.
w = widgets.ToggleButtons(options = df_results_X.columns.levels[0], description = "Adminstrata:", 
                          disabled = False)
p = interact(plot_df, name = w, df = fixed(df_results_X))

interactive(children=(ToggleButtons(description='Adminstrata:', options=('Abyan', 'Aden', 'Al Bayda', "Al Dhal…

In [15]:
def quantile_25(x):
    return x.quantile(q = 0.25)

def quantile_75(x):
    return x.quantile(q = 0.75)

In [16]:
# Create figure.
fig = go.Figure()

def plot_quantiles(group):
    statistics_group = group.agg([np.mean, quantile_25, quantile_75], axis = "columns")
    
    # Random color.
    color = list(np.random.choice(range(256), size = 3))
    color_mean = "rgb" + str(tuple(color))
    color.append(0.3)
    color_quantile = "rgba" + str(tuple(color))

    # Plot mean.
    fig.add_trace(go.Scatter(x = statistics_group.index, y = statistics_group["mean"], mode = "lines", name = group.name, legendgroup = group.name, 
                             line = dict(width = 1.5, color = color_mean)))
    # Plot quantiles.
    fig.add_trace(go.Scatter(x = statistics_group.index, y = statistics_group["quantile_25"], legendgroup = group.name, showlegend = False, fill = None, mode = "lines", fillcolor = color_quantile, line = dict(width = .2, color = color_quantile)))
    fig.add_trace(go.Scatter(x = statistics_group.index, y = statistics_group["quantile_75"], legendgroup = group.name, showlegend = False, fill = "tonexty", mode = "lines", fillcolor = color_quantile, line = dict(width = .2, color = color_quantile)))

    fig.update_layout(xaxis_title = dict(text = "Lags", font = dict(size = 15)), 
                      yaxis_title = dict(text = "Symbolic Transfer Entropy", font = dict(size = 15)))
    
df_results_X.groupby(axis = 1, level = 1).apply(plot_quantiles);
fig.show()

# Feature importance

In [17]:
df_feature_importance = df.copy()
# Delete static features.
df_feature_importance = df_feature_importance.drop(columns = ["Lat", "Lon", "Population"], axis = 1, level = 1)
df_feature_importance.head()

AdminStrata,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,Abyan,...,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz,Taizz
Indicator,1 Month Anomaly (%) Rainfall,3 Months Anomaly (%) Rainfall,Cereals and tubers,Exchange rate (USD/LCU),FCS,Fatality,NDVI Anomaly,Rainfall (mm),Ramadan,rCSI,...,1 Month Anomaly (%) Rainfall,3 Months Anomaly (%) Rainfall,Cereals and tubers,Exchange rate (USD/LCU),FCS,Fatality,NDVI Anomaly,Rainfall (mm),Ramadan,rCSI
Datetime,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
2018-08-22,273.127727,330.007392,0.328827,0.370033,31.809805,10.0,3051.95469,12.9544,0.0,37.42188,...,297.556947,434.194043,0.349354,0.370157,39.043078,171.0,3234.545997,41.6606,0.0,50.919038
2018-08-23,272.395637,327.547664,0.330076,0.372509,32.37466,10.0,3048.94568,13.0949,0.0,37.346207,...,300.172125,429.447759,0.349575,0.372918,39.37167,147.0,3229.722645,42.6437,0.0,50.415852
2018-08-24,271.663548,325.087935,0.331326,0.374986,33.77211,10.0,3045.883031,13.2354,0.0,37.999408,...,302.787303,424.701475,0.349796,0.37568,36.662083,145.0,3224.674241,43.6268,0.0,49.391298
2018-08-25,270.931459,322.628207,0.332575,0.377463,34.533738,10.0,3042.766744,13.3759,0.0,36.068234,...,305.402481,419.955191,0.350017,0.378441,37.20517,156.0,3219.400787,44.6099,0.0,50.302392
2018-08-26,270.199369,320.168478,0.333825,0.379939,32.327892,10.0,3039.596818,13.5164,0.0,38.218464,...,308.017659,415.208908,0.350238,0.381202,37.025723,164.0,3213.902282,45.593,0.0,50.293046


In [18]:
rows_ste = list()
rows_ste_rank = list()

def feature_importance(group):
    adminstrata = group.name
    T = STE.compute_T(group, m = 3, h = 1, kx = 1, ky = 1)
    # Select row of the FCS.
    row = T.loc[adminstrata].loc["FCS"].drop((adminstrata, "FCS")).droplevel(axis = 0, level = 0)
    row.name = adminstrata + " - FCS"
    rows_ste.append(row)
    rows_ste_rank.append(row.rank(method = "max", ascending = False))

tqdm.pandas()
df_feature_importance.groupby(level = 0, axis = 1).progress_apply(feature_importance);
T_fcs = pd.concat(rows_ste, axis = 1).transpose()
T_fcs_rank = pd.concat(rows_ste_rank, axis = 1).transpose()

100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:24<00:00,  1.25s/it]


In [19]:
fig = go.Figure(data = go.Heatmap(z = T_fcs, x = T_fcs.columns, y = T_fcs.index, xgap = 2, ygap = 2, hoverinfo = "x+y+z", colorbar = {"title": "STE"},
                                  colorscale = "Reds", reversescale = False, hovertemplate = "<br><b>Indicator</b>: %{x}<br>" + 
                                  "<b>Indicator</b>: %{y}" + "<br><b>STE</b>: %{z:.3f}<br>", hoverlabel = dict(namelength = 0)))
fig.update_layout(width = 500, height = 600, yaxis = dict(autorange = "reversed", tickfont = dict(size = 10)), 
                  xaxis = dict(tickfont = dict(size = 10)))
fig.update_layout(title = dict(text = "T - Influence different indicators for FCS"))
fig.show()

In [20]:
fig = go.Figure(data = [go.Box(y = T_fcs_rank[col], name = col) for i, col in enumerate(T_fcs_rank)])
fig.update_layout(yaxis = dict(autorange = "reversed"))
fig.update_layout(yaxis_title = dict(text = "Ranking position importance", font = dict(size = 15)))
fig.show()