
# ViEWS 3 ensembles
## Fatalities project, cm level
This notebook evaluates constituent models and explores some very simple ensemble algorithms

## Importing modules

In [1]:
# Basics
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cbook as cbook
# sklearn
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.ensemble import AdaBoostRegressor
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
# Views 3
from viewser.operations import fetch
from viewser import Queryset, Column
import views_runs
from views_partitioning import data_partitioner, legacy
from stepshift import views
import views_dataviz
from views_runs import storage, ModelMetadata
from views_runs.storage import store, retrieve, list, fetch_metadata
from views_forecasts.extensions import *

# Mapper
import geopandas as gpd

from views_dataviz.map import mapper, utils
from views_dataviz import color
from views_dataviz.map.presets import ViewsMap

import sqlalchemy as sa
from ingester3.config import source_db_path

import sqlalchemy as sa
from ingester3.config import source_db_path

# Other packages
import pickle as pkl

#Parallelization
from joblib import Parallel, delayed, cpu_count
from functools import partial
from genetic2 import *

from pathlib import Path

Refreshing


In [2]:
# Common parameters:

run_id = 'Fat_devel_v7'
FutureStart = 505
RunGeneticAlgo = False

steps = [*range(1, 36+1, 1)] # Which steps to train and predict for

#steps = [1,2,3,4,5,6,7,8,9,10,11,12,15,18,21,24] # Which steps to train and predict for
#fi_steps = [1,3,6,12,36] # Which steps to present feature importances for
#steps = [1,12,24,36]
fi_steps = [1,3,6,12,36]
#steps = [1,6,36]
#fi_steps = [1,6,36]

# Specifying partitions

calib_partitioner_dict = {"train":(121,396),"predict":(397,444)}
test_partitioner_dict = {"train":(121,444),"predict":(445,492)}
future_partitioner_dict = {"train":(121,492),"predict":(493,504)}
calib_partitioner =  views_runs.DataPartitioner({"calib":calib_partitioner_dict})
test_partitioner =  views_runs.DataPartitioner({"test":test_partitioner_dict})
future_partitioner =  views_runs.DataPartitioner({"future":future_partitioner_dict})

Mydropbox = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS'
overleafpath = '/Users/havardhegre/Dropbox (ViEWS)/Apps/Overleaf/ViEWS predicting fatalities/Tables/'


In [3]:
# Use 3 decimal places in output display
pd.set_option("display.precision", 3)

# Don't wrap repr(DataFrame) across additional lines
pd.set_option("display.expand_frame_repr", False)

In [5]:
# Built on script from Geoff Hurdock: https://geoffruddock.com/building-a-hurdle-regression-estimator-in-scikit-learn/

from typing import Optional, Union
import numpy as np
import pandas as pd

from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.base import BaseEstimator
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
from lightgbm import LGBMClassifier, LGBMRegressor


class HurdleRegression(BaseEstimator):
    """ Regression model which handles excessive zeros by fitting a two-part model and combining predictions:
            1) binary classifier
            2) continuous regression
    Implementeted as a valid sklearn estimator, so it can be used in pipelines and GridSearch objects.
    Args:
        clf_name: currently supports either 'logistic' or 'LGBMClassifier'
        reg_name: currently supports either 'linear' or 'LGBMRegressor'
        clf_params: dict of parameters to pass to classifier sub-model when initialized
        reg_params: dict of parameters to pass to regression sub-model when initialized
    """

    def __init__(self,
                 clf_name: str = 'logistic',
                 reg_name: str = 'linear',
                 clf_params: Optional[dict] = None,
                 reg_params: Optional[dict] = None):

        self.clf_name = clf_name
        self.reg_name = reg_name
        self.clf_params = clf_params
        self.reg_params = reg_params
        self.clf_fi = []
        self.reg_fi = []

    @staticmethod
    def _resolve_estimator(func_name: str):
        """ Lookup table for supported estimators.
        This is necessary because sklearn estimator default arguments
        must pass equality test, and instantiated sub-estimators are not equal. """

        funcs = {'linear': LinearRegression(),
                 'logistic': LogisticRegression(solver='liblinear'),
                 'LGBMRegressor': LGBMRegressor(n_estimators=100),
                 'LGBMClassifier': LGBMClassifier(n_estimators=100),
                 'RFRegressor': XGBRFRegressor(n_estimators=300,n_jobs=nj),
                 'RFClassifier': XGBRFClassifier(n_estimators=300,n_jobs=nj),
                 'GBMRegressor': GradientBoostingRegressor(n_estimators=200),
                 'GBMClassifier': GradientBoostingClassifier(n_estimators=200),
                 'XGBRegressor': XGBRegressor(n_estimators=200,tree_method='hist',n_jobs=nj),
                 'XGBClassifier': XGBClassifier(n_estimators=200,tree_method='hist',n_jobs=nj,use_label_encoder=False),
                 'HGBRegressor': HistGradientBoostingRegressor(max_iter=200),
                 'HGBClassifier': HistGradientBoostingClassifier(max_iter=200),
                }

        return funcs[func_name]

    def fit(self,
            X: Union[np.ndarray, pd.DataFrame],
            y: Union[np.ndarray, pd.Series]):
        X, y = check_X_y(X, y, dtype=None,
                         accept_sparse=False,
                         accept_large_sparse=False,
                         force_all_finite='allow-nan')

        if X.shape[1] < 2:
            raise ValueError('Cannot fit model when n_features = 1')

        self.clf_ = self._resolve_estimator(self.clf_name)
        if self.clf_params:
            self.clf_.set_params(**self.clf_params)
        self.clf_.fit(X, y > 0)
        self.clf_fi = self.clf_.feature_importances_

        self.reg_ = self._resolve_estimator(self.reg_name)
        if self.reg_params:
            self.reg_.set_params(**self.reg_params)
        self.reg_.fit(X[y > 0], y[y > 0])
        self.reg_fi = self.reg_.feature_importances_

        self.is_fitted_ = True
        return self


#    def predict(self, X: Union[np.ndarray, pd.DataFrame]):
    def predict_bck(self, X: Union[np.ndarray, pd.DataFrame]):
        """ Predict combined response using binary classification outcome """
        X = check_array(X, accept_sparse=False, accept_large_sparse=False)
        check_is_fitted(self, 'is_fitted_')
        return self.clf_.predict(X) * self.reg_.predict(X)

    def predict(self, X: Union[np.ndarray, pd.DataFrame]):
#    def predict_expected_value(self, X: Union[np.ndarray, pd.DataFrame]):
        """ Predict combined response using probabilistic classification outcome """
        X = check_array(X, accept_sparse=False, accept_large_sparse=False)
        check_is_fitted(self, 'is_fitted_')
        return self.clf_.predict_proba(X)[:, 1] * self.reg_.predict(X)


def manual_test():
    """ Validate estimator using sklearn's provided utility and ensure it can fit and predict on fake dataset. """
    check_estimator(HurdleRegression)
    from sklearn.datasets import make_regression
    X, y = make_regression()
    reg = HurdleRegression()
    reg.fit(X, y)
    reg.predict(X)


#if __name__ == '__main__':
#    manual_test()

##Â Retrieve models and predictions

In [6]:
#import pickle as pkl
#path = '/Users/havardhegre/temp/'
RetrieveAll = True

if RetrieveAll:
    localpath = '/Users/havardhegre/temp/'
    picklename = localpath + 'ModelList_cm_' + run_id + '.p'
else:
    path = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/Predictions/'
    picklename = path + 'ModelList_db_cm_' + run_id + '.p'
EnsembleList = pkl.load( open (picklename, "rb") )

In [7]:
i=0
for model in EnsembleList:
    print(i,model['modelname'])
    i=i+1

0 fat_baseline_rf
1 fat_conflicthistory_srf
2 fat_conflicthistory_rf
3 fat_conflicthistory_gbm
4 fat_conflicthistory_hurdle_rf
5 fat_conflicthistory_hurdle_xgb
6 fat_conflicthistory_hurdle_lgb
7 fat_conflicthistory_histgbm
8 fat_conflicthistory_xgb
9 fat_conflicthistory_lgb
10 fat_conflicthistory_long_rf
11 fat_conflicthistory_long_xgb
12 fat_vdem_rf
13 fat_vdem_xgb
14 fat_vdem_hurdle_xgb
15 fat_wdi_rf
16 fat_wdi_xgb
17 fat_wdi_hurdle_xgb
18 fat_topics_rf
19 fat_topics_histgbm
20 fat_topics_xgb
21 fat_topics_hurdle_xgb
22 fat_prs_rf
23 fat_prs_xgb
24 fat_prs_hurdle_xgb
25 fat_broad_rf
26 fat_broad_xgb
27 fat_broad_hurdle_xgb
28 fat_greatest_hits_rf
29 fat_greatest_hits_xgb
30 fat_greatest_hits_hurdle_xgb
31 fat_greatest_hits_lgb
32 fat_hh20_srf
33 fat_hh20_rf
34 fat_hh20_gbm
35 fat_hh20_hurdle_rf
36 fat_hh20_hurdle_xgb
37 fat_hh20_hurdle_lgb
38 fat_hh20_histgbm
39 fat_hh20_xgb
40 fat_hh20_lgb
41 fat_all_pca3_xgb
42 fat_all_pca3_lgb
43 fat_topics_pca3_xgb
44 fat_topics_pca3_lgb
45 fat_v

In [8]:
# Calibrate to conform to mean and standard deviation

# Calibration function
def mean_sd_calibrated(y_true_calpart,y_pred_calpart,y_pred_test,shift, threshold=0):
    ''' 
    Calibrates predictions. Expects the input columns from calibration partition to be without infinity values
    '''
    expand = y_true_calpart.loc[y_true_calpart>=threshold].std() / y_pred_calpart.loc[y_pred_calpart>=threshold].std()
    shiftsize = 0
    expanded = y_pred_test.copy()
    expanded.loc[expanded>=threshold] = expanded * expand
    if shift==True:
        shiftsize = y_true_calpart.loc[y_true_calpart>=threshold].mean() - y_pred_calpart.loc[y_pred_calpart>=threshold].mean()
        shifted = expanded
        shifted.loc[shifted>=threshold] = shifted + shiftsize
        calibrated_pred = shifted 
    if shift==False:
        calibrated_pred = expanded       
#    print('Calibration --', 'threshold:',threshold,'Shift:',shiftsize,'Expand:',expand)
    return (calibrated_pred,expand,shiftsize)
    

In [None]:
!conda list | grep views-forecasts

In [None]:
#ViewsMetadata().mine().with_name('baseline').fetch()
run_id

In [9]:
# Retrieving the predictions
# The EnsembleList contains the predictions organized by model
i=0
RetrieveFuture = False
stepcols = ['ln_ged_sb_dep']
for step in steps:
    stepcols.append('step_pred_' + str(step))
level = 'cm'
for model in EnsembleList:
    print(i, model['modelname'])
    stored_modelname_calib = level + '_' + model['modelname'] + '_calib'
    stored_modelname_test = level + '_' + model['modelname'] + '_test'
    stored_modelname_future = level +  '_' + model['modelname'] + '_f' + str(FutureStart)
    model['predictions_calib_df'] = pd.DataFrame.forecasts.read_store(stored_modelname_calib, run=run_id)[stepcols]
    model['predictions_calib_df'].replace([np.inf, -np.inf], 0, inplace=True)
    model['predictions_test_df'] = pd.DataFrame.forecasts.read_store(stored_modelname_test, run=run_id)[stepcols]
    model['predictions_test_df'].replace([np.inf, -np.inf], 0, inplace=True)
    if RetrieveFuture:
        model['predictions_future_df'] = pd.DataFrame.forecasts.read_store(stored_modelname_future, run=run_id)
        model['predictions_future_df'].replace([np.inf, -np.inf], 0, inplace=True)
    i = i + 1

fat_baseline_rf
pr_33_cm_fat_baseline_rf_calib.parquet
pr_33_cm_fat_baseline_rf_test.parquet
fat_conflicthistory_srf
pr_33_cm_fat_conflicthistory_srf_calib.parquet
pr_33_cm_fat_conflicthistory_srf_test.parquet
fat_conflicthistory_rf
pr_33_cm_fat_conflicthistory_rf_calib.parquet
pr_33_cm_fat_conflicthistory_rf_test.parquet
fat_conflicthistory_gbm
pr_33_cm_fat_conflicthistory_gbm_calib.parquet
pr_33_cm_fat_conflicthistory_gbm_test.parquet
fat_conflicthistory_hurdle_rf
pr_33_cm_fat_conflicthistory_hurdle_rf_calib.parquet
pr_33_cm_fat_conflicthistory_hurdle_rf_test.parquet
fat_conflicthistory_hurdle_xgb
pr_33_cm_fat_conflicthistory_hurdle_xgb_calib.parquet
pr_33_cm_fat_conflicthistory_hurdle_xgb_test.parquet
fat_conflicthistory_hurdle_lgb
pr_33_cm_fat_conflicthistory_hurdle_lgb_calib.parquet
pr_33_cm_fat_conflicthistory_hurdle_lgb_test.parquet
fat_conflicthistory_histgbm
pr_33_cm_fat_conflicthistory_histgbm_calib.parquet
pr_33_cm_fat_conflicthistory_histgbm_test.parquet
fat_conflicthistory

In [None]:
#ViewsMetadata().mine().with_name('Markov').fetch()

In [10]:
# Prediction target
# In this particular ensemble available in model 0
stored_modelname_calib = level + '_' + EnsembleList[0]['modelname'] + '_calib'
stored_modelname_test = level + '_' + EnsembleList[0]['modelname'] + '_test'
target = {
        'y_calib':  pd.DataFrame.forecasts.read_store(stored_modelname_calib, run=run_id)['ln_ged_sb_dep'],
        'y_test':  pd.DataFrame.forecasts.read_store(stored_modelname_test, run=run_id)['ln_ged_sb_dep']
    }

pr_33_cm_fat_baseline_rf_calib.parquet
pr_33_cm_fat_baseline_rf_test.parquet


# To do with calibration:
Coordinate with Jim to incorporate his improvements, and to work the retrieveal of the gam object into his version

In [11]:
# GAM-based calibration function
def gam_calibrated(y_true_calpart,y_pred_calpart,y_pred_test,n_splines):
    ''' 
    Calibrates predictions using GAM.
    Expects the input columns from calibration partition to be without infinity values
    '''
    from pygam import LogisticGAM, LinearGAM, s, te
    gam = LinearGAM(s(0, constraints='monotonic_inc',n_splines = n_splines)).fit(y_pred_calpart, y_true_calpart)

    calibrated_pred = gam.predict(y_pred_test)
#    gam_summary = gam.summary()
    return (calibrated_pred, gam)
    

In [12]:
# Calibration
IncludeFuture = False
for model in EnsembleList:   
    model['calib_df_cal_expand'] = model['predictions_calib_df'].copy()
    model['test_df_cal_expand'] = model['predictions_test_df'].copy()
    if IncludeFuture:
        model['future_df_cal_expand'] = model['predictions_future_df'].copy()
    model['calib_df_calibrated'] = model['predictions_calib_df'].copy()
    model['test_df_calibrated'] = model['predictions_test_df'].copy()
    if IncludeFuture:
        model['future_df_calibrated'] = model['predictions_future_df'].copy()
    print(model['modelname'])
    model['calibration_gams'] = [] # Will hold calibration GAM objects, one for each step
    for col in stepcols[1:]:
        thisstep = int(col[10:])
        thismonth = FutureStart + thisstep
        calibration_gam_dict = {
            'Step': thisstep,
            'GAM': []
        }
        # Remove from model dfs rows where [col] has infinite values (due to the 2011 split of Sudan)
        df_calib = model['predictions_calib_df'][~np.isinf(model['predictions_calib_df'][col])].fillna(0)
        df_test = model['predictions_test_df'][~np.isinf(model['predictions_test_df'][col])].fillna(0)
        if IncludeFuture:
            df_future = model['predictions_future_df'][~np.isinf(model['predictions_future_df']['step_combined'])].fillna(0)
        
        (model['calib_df_cal_expand'][col],model['expanded'],model['shiftsize']) = mean_sd_calibrated(
            y_true_calpart = df_calib['ln_ged_sb_dep'], 
            y_pred_calpart = df_calib[col], 
            y_pred_test = df_calib[col], 
            shift=False, 
            threshold = 0
        )
        (model['test_df_cal_expand'][col],model['expanded'],model['shiftsize']) = mean_sd_calibrated(
            y_true_calpart = df_calib['ln_ged_sb_dep'], 
            y_pred_calpart = df_calib[col], 
            y_pred_test = df_test[col], 
            shift=False, 
            threshold = 0
        )
        if IncludeFuture:
            (model['future_df_cal_expand'].loc[thismonth]['step_combined'], model['expanded'],model['shiftsize']) = mean_sd_calibrated(
                y_true_calpart = df_calib['ln_ged_sb_dep'], 
                y_pred_calpart = df_calib[col], 
                y_pred_test = df_future.loc[thismonth]['step_combined'], 
                shift=False, 
                threshold = 0
            )
        if model['modelname'] == 'fat_hh20_Markov_glm' or model['modelname'] == 'fat_hh20_Markov_rf':
            model['calib_df_calibrated'][col] = model['calib_df_cal_expand'][col]
            model['test_df_calibrated'][col] = model['test_df_cal_expand'][col]
        else:
            (model['calib_df_calibrated'][col], calibration_gam_dict['calibration_GAM']) = gam_calibrated(
                    y_true_calpart = df_calib['ln_ged_sb_dep'], 
                    y_pred_calpart = df_calib[col], 
                    y_pred_test = df_calib[col], 
                    n_splines = 15
            )
            #print(model['calibration_gam'].summary())
            (model['test_df_calibrated'][col], gam) = gam_calibrated(
                    y_true_calpart = df_calib['ln_ged_sb_dep'], 
                    y_pred_calpart = df_calib[col], 
                    y_pred_test = df_test[col], 
                    n_splines = 15
            )
            if IncludeFuture:
                (model['future_df_calibrated'].loc[thismonth]['step_combined'], gam) = gam_calibrated(
                        y_true_calpart = df_calib['ln_ged_sb_dep'], 
                        y_pred_calpart = df_calib[col], 
                        y_pred_test = df_future.loc[thismonth]['step_combined'], 
                        n_splines = 15
                )
        model['calibration_gams'].append(calibration_gam_dict)
                

fat_baseline_rf
fat_conflicthistory_srf
fat_conflicthistory_rf
fat_conflicthistory_gbm
fat_conflicthistory_hurdle_rf
fat_conflicthistory_hurdle_xgb
fat_conflicthistory_hurdle_lgb
fat_conflicthistory_histgbm
fat_conflicthistory_xgb
fat_conflicthistory_lgb
fat_conflicthistory_long_rf
fat_conflicthistory_long_xgb
fat_vdem_rf
fat_vdem_xgb
fat_vdem_hurdle_xgb
fat_wdi_rf
fat_wdi_xgb
fat_wdi_hurdle_xgb
fat_topics_rf
fat_topics_histgbm
fat_topics_xgb
fat_topics_hurdle_xgb
fat_prs_rf
fat_prs_xgb
fat_prs_hurdle_xgb
fat_broad_rf
fat_broad_xgb
fat_broad_hurdle_xgb
fat_greatest_hits_rf
fat_greatest_hits_xgb
fat_greatest_hits_hurdle_xgb
fat_greatest_hits_lgb
fat_hh20_srf
fat_hh20_rf
fat_hh20_gbm
fat_hh20_hurdle_rf
fat_hh20_hurdle_xgb
fat_hh20_hurdle_lgb
fat_hh20_histgbm
fat_hh20_xgb
fat_hh20_lgb
fat_all_pca3_xgb
fat_all_pca3_lgb
fat_topics_pca3_xgb
fat_topics_pca3_lgb
fat_vdem_pca3_lgb
fat_wdi_pca3_lgb
fat_hh20_Markov_glm
fat_hh20_Markov_rf


In [13]:
EnsembleList[46]['calibration_gams']

[{'Step': 1,
  'GAM': [],
  'calibration_GAM': LinearGAM(callbacks=[Deviance(), Diffs()], fit_intercept=True, 
     max_iter=100, scale=None, terms=s(0) + intercept, tol=0.0001, 
     verbose=False)},
 {'Step': 2,
  'GAM': [],
  'calibration_GAM': LinearGAM(callbacks=[Deviance(), Diffs()], fit_intercept=True, 
     max_iter=100, scale=None, terms=s(0) + intercept, tol=0.0001, 
     verbose=False)},
 {'Step': 3,
  'GAM': [],
  'calibration_GAM': LinearGAM(callbacks=[Deviance(), Diffs()], fit_intercept=True, 
     max_iter=100, scale=None, terms=s(0) + intercept, tol=0.0001, 
     verbose=False)},
 {'Step': 4,
  'GAM': [],
  'calibration_GAM': LinearGAM(callbacks=[Deviance(), Diffs()], fit_intercept=True, 
     max_iter=100, scale=None, terms=s(0) + intercept, tol=0.0001, 
     verbose=False)},
 {'Step': 5,
  'GAM': [],
  'calibration_GAM': LinearGAM(callbacks=[Deviance(), Diffs()], fit_intercept=True, 
     max_iter=100, scale=None, terms=s(0) + intercept, tol=0.0001, 
     verbose=Fals

In [None]:
print(model['predictions_future_df'].loc[510]['step_combined'].describe(), 
model['future_df_calibrated'].loc[510]['step_combined'].describe())

In [None]:
for col in stepcols[1:]:
    print(int(col[10:]))

In [None]:
# Illustrating calibration
model = EnsembleList[5]
print(model['modelname'])
col = 'step_pred_1'
period = 'test'

print(model[f'{period}_df_calibrated'][col].describe())
print(model[f'predictions_{period}_df'][col].describe())

plt.scatter(model[f'predictions_{period}_df'][col],model[f'{period}_df_calibrated'][col])
#plt.show()

overleafpath = '/Users/havardhegre/Dropbox (ViEWS)/Apps/Overleaf/ViEWS predicting fatalities/Figures/PredictionPlots/'
filename = overleafpath + 'Calibration_example_' + model['modelname'] + '.png'
plt.savefig(filename, dpi=300)
#overleafpath = '~/Dropbox (ViEWS)/Apps/Overleaf/ViEWS predicting fatalities/Figures/PredictionPlots/'
#filename = overleafpath + 'Calibration_example' + model['modelname'] + '.png'
#plt.savefig(filename, dpi=300)

In [14]:
# Save EnsembleList with all content locally and on dropbox
localpath = '/Users/havardhegre/temp/'
dbpath = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/Predictions/'
picklename = localpath + 'EnsembleList_cm_' + run_id + '.p'
pkl.dump(EnsembleList, open(picklename, "wb" ) )
picklename = dbpath + 'EnsembleList_db_cm_' + run_id + '.p'
pkl.dump(EnsembleList, open(picklename, "wb" ) )

In [None]:
percentiles = [.75,.9,.95,.99]
print(model['modelname'], 
      model['predictions_calib_df']['step_pred_1'].describe(percentiles = percentiles),
      model['calib_df_calibrated']['step_pred_1'].describe(percentiles = percentiles),
      model['predictions_test_df']['ln_ged_sb_dep'].describe(percentiles = percentiles),
      model['predictions_test_df']['step_pred_1'].describe(percentiles = percentiles),
      model['test_df_calibrated']['step_pred_1'].describe(percentiles = percentiles))

In [None]:
EnsembleList[12]

In [None]:
cols_to_inspect = ['step_pred_1','step_pred_6','step_pred_12','step_pred_36']
EnsembleList[12]['predictions_calib_df'][cols_to_inspect].describe()

In [None]:
# Create unweighted average ensemble
# The gam calibrated is basis currently
 
IncludeFuture = False
 # Model: Ensemble, GED logged dependent variable,
if IncludeFuture:      
    ensemble = {
        'modelname': 'ensemble_unweighted',
        'depvar': "ln_ged_sb_dep",
        'loggeddepvar': True,
        'predictions_file_calib': "",
        'predictions_file_test': "",
        'calib_df_calibrated':  EnsembleList[0]['calib_df_calibrated'].copy(),
        'test_df_calibrated':   EnsembleList[0]['test_df_calibrated'].copy(),
        'future_df_calibrated': EnsembleList[0]['future_df_calibrated'].copy(),
        'calib_df_cal_expand':  EnsembleList[0]['calib_df_cal_expand'].copy(),
        'test_df_cal_expand':   EnsembleList[0]['test_df_cal_expand'].copy(),
        'future_df_cal_expand': EnsembleList[0]['future_df_cal_expand'].copy(),
    }
else:      
    ensemble = {
        'modelname': 'ensemble_unweighted',
        'depvar': "ln_ged_sb_dep",
        'loggeddepvar': True,
        'predictions_file_calib': "",
        'predictions_file_test': "",
        'calib_df_calibrated':  EnsembleList[0]['calib_df_calibrated'].copy(),
        'test_df_calibrated':   EnsembleList[0]['test_df_calibrated'].copy(),
#        'future_df_calibrated': EnsembleList[0]['future_df_calibrated'].copy(),
        'calib_df_cal_expand':  EnsembleList[0]['calib_df_cal_expand'].copy(),
        'test_df_cal_expand':   EnsembleList[0]['test_df_cal_expand'].copy(),
#        'future_df_cal_expand': EnsembleList[0]['future_df_cal_expand'].copy(),
        'calibration_gam': []
    }

n_models = 1
for model in EnsembleList:
    ensemble['calib_df_calibrated'] = ensemble['calib_df_calibrated'].add(model['calib_df_calibrated'])
    ensemble['test_df_calibrated'] = ensemble['test_df_calibrated'].add(model['test_df_calibrated'])
    if IncludeFuture:
        ensemble['future_df_calibrated'] = ensemble['future_df_calibrated'].add(model['future_df_calibrated'])
    ensemble['calib_df_cal_expand'] = ensemble['calib_df_cal_expand'].add(model['calib_df_cal_expand'])
    ensemble['test_df_cal_expand'] = ensemble['test_df_cal_expand'].add(model['test_df_cal_expand'])
    if IncludeFuture:
        ensemble['future_df_cal_expand'] = ensemble['future_df_cal_expand'].add(model['future_df_cal_expand'])
    n_models = n_models + 1
#n_models = 1
    
ensemble['calib_df_calibrated'] = ensemble['calib_df_calibrated'].divide(n_models)
ensemble['test_df_calibrated'] = ensemble['test_df_calibrated'].divide(n_models)
if IncludeFuture:
    ensemble['future_df_calibrated'] = ensemble['future_df_calibrated'].divide(n_models)
ensemble['calib_df_cal_expand'] = ensemble['calib_df_cal_expand'].divide(n_models)
ensemble['test_df_cal_expand'] = ensemble['test_df_cal_expand'].divide(n_models)
if IncludeFuture:
    ensemble['future_df_cal_expand'] = ensemble['future_df_cal_expand'].divide(n_models)

EnsembleList.append(ensemble)

# Save ensemble predictions

predstore_calib = level +  '_' + ensemble['modelname'] + '_calib'
ensemble['calib_df_calibrated'].forecasts.set_run(run_id)
ensemble['calib_df_calibrated'].forecasts.to_store(name=predstore_calib, overwrite = True)
predstore_test = level +  '_' + ensemble['modelname'] + '_test'
ensemble['test_df_calibrated'].forecasts.set_run(run_id)
ensemble['test_df_calibrated'].forecasts.to_store(name=predstore_test, overwrite = True)
if IncludeFuture:
    predstore_future = level +  '_' + ensemble['modelname'] + '_future'
    ensemble['future_df_calibrated'].forecasts.set_run(run_id)
    ensemble['future_df_calibrated'].forecasts.to_store(name=predstore_future, overwrite = True)


In [None]:
EnsembleList[-10]['future_df_calibrated'].xs(78, level=1).plot(kind='line')

In [None]:
CreateAblatedEnsemble = False
if CreateAblatedEnsemble:
    # Create ablated unweighted average ensemble
    AblatedEnsembleList = EnsembleList.copy()
    AblatedEnsembleList.pop(0)
    AblatedEnsembleList.pop(0)
    AblatedEnsembleList.pop(19)
    AblatedEnsembleList.pop(23)
    AblatedEnsembleList.pop(23)
    AblatedEnsembleList.pop(25)
    AblatedEnsembleList.pop(-1)

    m = 0
    for model in AblatedEnsembleList:
        print(m, model['modelname'])
        m = m + 1

In [None]:
 # Model: Ensemble, ablated
if CreateAblatedEnsemble:
    ensemble = {
        'modelname': 'ensemble_unw_ablated',
        'run': '',
        'depvar': "ln_ged_sb_dep",
        'data': '',
        'loggeddepvar': True,
        'predictions_file_calib': "",
        'predictions_file_test': "",
        'fi': "",
        'trained': False,
        'calib_df_calibrated': EnsembleList[0]['calib_df_calibrated'].copy(),
        'test_df_calibrated': EnsembleList[0]['test_df_calibrated'].copy()
    }

    n_models = 1
    for model in AblatedEnsembleList:
        ensemble['calib_df_calibrated'] = ensemble['calib_df_calibrated'].add(model['calib_df_calibrated'])
        ensemble['test_df_calibrated'] = ensemble['test_df_calibrated'].add(model['test_df_calibrated'])
        n_models = n_models + 1
    #n_models = 1

    ensemble['calib_df_calibrated'] = ensemble['calib_df_calibrated'].divide(n_models)
    ensemble['test_df_calibrated'] = ensemble['test_df_calibrated'].divide(n_models)
    EnsembleList.append(ensemble)
    # Save ensemble predictions
    path = Mydropbox + '/Projects/PredictingFatalities/Predictions/'
    filename_test = path + 'EnsembleEqualWeights_ablated' + '_test.csv'
    ensemble['test_df_calibrated'].to_csv(filename_test)

    ensemble['test_df_calibrated'].describe()

# Create two 'mask' dfs to emulate 'onset' categories
Based on predictions from the baseline model: 'Onsets' are defined as country years where the baseline model assigns a low expected count of deaths.

In [None]:
onset_mask_test = EnsembleList[0]['predictions_test_df'].copy()
onset_mask_calib = EnsembleList[0]['predictions_calib_df'].copy()
cut_bins = [0,0.001,0.1,1,10]
for df in [onset_mask_test, onset_mask_calib]:
    for col in df.columns:
        df[col] = pd.cut(df[col], cut_bins,right=False,labels=('grp0','grp1','grp2','grp3'))
df.describe()

# Estimating ensembles

In [None]:
target['y_calib'].head()

In [None]:
# Checking missingness
N=51
df = EnsembleList[0]['predictions_test_df']
#df = pd.DataFrame(target['y_test'])
for col in df.iloc[: , :N].columns:
    print(col,len(df[col]), 'missing:', df[col].isnull().sum(), 'infinity:', np.isinf(df).values.sum())



In [None]:
# Compute ablation MSE, calibration partition

from numpy import array
from numpy.linalg import norm

def ensemble_predictions(yhats, weights):
    # make predictions
    yhats = np.array(yhats)
    # weighted sum across ensemble members
    result = np.dot(weights,yhats)
    return result

def evaluate_ensemble(yhats, weights, test_y):
    ensemble_y = ensemble_predictions(yhats,weights)
    return mean_squared_error(ensemble_y, test_y)

# normalize a vector to have unit norm
def normalize(weights):
    # calculate l1 vector norm
    result = norm(weights, 1)
    # check for a vector of all zeros
    if result == 0.0:
        return weights
    # return normalized vector (unit norm)
    return weights / result


ensemble_mses = [] # List to hold unweighted ensemble mses 

# Count models, set up lists
number_of_models = 0
mlist = []
for model in EnsembleList[0:-1]:
    number_of_models = number_of_models + 1
    model['Ablation_MSE']=[0] * (len(steps)+1)
    mlist.append(model['modelname'])
print('Models:',number_of_models)        

# Compute unweighted ensemble mses
for col in stepcols:
#    print(col)
    yhats = []
    weights = []
    for model in EnsembleList[0:-1]:
        df_calib = model['calib_df_calibrated'][~np.isinf(model['calib_df_calibrated'][col])].fillna(0)
        yhats.append(df_calib[col])
        weights.append(1/number_of_models)
    emse = evaluate_ensemble(yhats, weights, df_calib['ln_ged_sb_dep'])
    ensemble_mses.append(emse)

#print('Unweighted ensemble MSEs:',ensemble_mses)

# Compute ablation scores
colno = 0
for col in stepcols:
    print('Step',col)
    weights = []
    for model in EnsembleList[0:-1]: # Assuming the ablated ensemble exists!
        model['calib_df_calibrated'] = model['calib_df_calibrated'].fillna(0)
#        print('Model to compute ablation MSE for',model['modelname'])
        yhats = []
        weights = []
        for abl_model in EnsembleList[0:-1]:
            abl_model['calib_df_calibrated'] = abl_model['calib_df_calibrated'].fillna(0) # Not sure what is best to do with NAs
            y = model['calib_df_calibrated']['ln_ged_sb_dep'][~np.isinf(model['calib_df_calibrated'][col])]
            if model['modelname'] != abl_model['modelname']:
#                print('Model in ablated ensemble', abl_model['modelname'])
                df_calib = abl_model['calib_df_calibrated'][~np.isinf(abl_model['calib_df_calibrated'][col])]
                yhats.append(df_calib[col])
                weights.append(1/(number_of_models-1))
        ablated_mse = evaluate_ensemble(yhats, weights, y)
        Ablation_MSE = ensemble_mses[colno] - ablated_mse
        
#        print(model['modelname'], 'ablated_mse:', ablated_mse, 'ensemble mse:',ensemble_mses[colno],'Ablation:' ,Ablation_MSE)
        model['Ablation_MSE'][colno] = Ablation_MSE
    colno = colno + 1
    

In [None]:
from statistics import mean
mean(EnsembleList[0]['Ablation_MSE'][1:])

In [None]:
# Go through the ablation MSEfor pre-screening of ensemble inclusion
# model['Include'] set to False if contribution is not positive in any of four step segments

from statistics import mean, stdev
m = 0
for model in EnsembleList[:-1]:
    model['Include'] = False
    m_all = mean(model['Ablation_MSE'][1:36])
    m_1 = mean(model['Ablation_MSE'][1:6])
    m_2 = mean(model['Ablation_MSE'][7:12])
    m_3 = mean(model['Ablation_MSE'][13:24])
    m_4 = mean(model['Ablation_MSE'][25:36])
    for pmean in [m_all, m_1, m_2, m_3, m_4]:
        if pmean < 0:
            model['Include'] = True
    
    print(m, model['Include'], model['modelname'], ', aMSE steps all, 1-6, 7-12, 13-24, 25-36', 
          f'{m_all:.4f}',f'{m_1:.4f}',f'{m_2:.4f}',f'{m_3:.4f}',f'{m_4:.4f}',)
    m = m + 1

In [None]:
# Constructing dfs to hold the predictions
# A list of dictionaries organizing predictions and information as one step per entry,
# including a dataframe for each step with one column per prediction model
StepEnsembles = []
for col in stepcols[1:]:
    Step_prediction = {
        'step_pred': col,
        'df_calib': pd.DataFrame(target['y_calib']),
        'df_test': pd.DataFrame(target['y_test']),
        'ensembles_calib': pd.DataFrame(target['y_calib']),
        'ensembles_test': pd.DataFrame(target['y_test'])
    }
    for model in EnsembleList:
        modelname = model['modelname']
        Step_prediction['df_calib'][modelname] = model['calib_df_calibrated'][col]
        Step_prediction['df_test'][modelname] = model['test_df_calibrated'][col]
    StepEnsembles.append(Step_prediction)

# Calculating unweighted average ensembles
i = 0
for col in stepcols[1:]:
    # Unweighted average
    StepEnsembles[i]['ensembles_test']['unweighted_average'] = StepEnsembles[i]['df_test'].drop('ln_ged_sb_dep', axis=1).mean(axis=1)
    StepEnsembles[i]['ensembles_calib']['unweighted_average'] = StepEnsembles[i]['df_calib'].drop('ln_ged_sb_dep', axis=1).mean(axis=1)
    i = i + 1

    

# Genetic algorithm

In [None]:
# Inspect the import
model = EnsembleList[0]
model.keys()
len(EnsembleList)
print(RunGeneticAlgo)

In [None]:
from genetic2 import *

def make_run_from_step (
    step,
    e_set,
    df_name = 'calib_df_calibrated',
    target = 'ln_ged_sb_dep',
    population_count = 100,
    initial_population = None,
    base_genes = np.array([0,1]),
    number_of_generations = 500
):
    """
    step : step you want as an int,
    ensemble_set : structure of the EnsembleList type,
    target = Y in prediction,
    df_name = name of the df in the ensemble set you want.
    """
    
    df_step = f'step_pred_{step}'
    
    try: 
        del aggregate_df
    except NameError:
        pass 
    
    for i_ens in EnsembleList:
        try:
            #Join the step from the model into the ensemble df if it exists.
            aggregate_df = aggregate_df.join(i_ens[df_name][[df_step]], rsuffix=f'_{i_ens["modelname"]}')
        except NameError:
            #If the ensemble df does not exist create it and include the target.
            aggregate_df = i_ens[df_name][[target,df_step]].copy()
            aggregate_df = aggregate_df.rename(columns = {df_step : f'{df_step}_{i_ens["modelname"]}'})
    
    aggregate_df = aggregate_df.dropna()
    aggregate_df = aggregate_df[aggregate_df.columns[~aggregate_df.columns.str.contains('ensemble')]]
    
    X = aggregate_df.copy(); del X[target]
    Y = aggregate_df[target]
    
    inst_mse = partial(weighted_mse_score, Y, X, mean_squared_error)
    if initial_population is None:
        population =  init_population_sum(population_count,base_genes,X.shape[1],0.5,3)
    else: 
        population = initial_population
    
    from genetic2 import temp_file_name
    import os
    Path('./exploration_pickle/').mkdir(parents=True, exist_ok=True) 
    pd.DataFrame({'step':[step], 'memoization_id':[temp_file_name]}).to_csv(f'exploration_pickle/id_{temp_file_name}.csv', index=False)
    
    generation = genetic_algorithm(population, 
                                   inst_mse, 
                                   base_genes, 
                                   f_thres=None, 
                                   ngen=number_of_generations, 
                                   pmut=0.2)
    return {'step':step, 'memoization_id':temp_file_name, 'generation':generation}
    

In [None]:
EnsembleRun = []
m = 0
for model in EnsembleList[:-1]:
    if model['Include'] == True:
        EnsembleRun.append(model)
        print(m,model['modelname'])
        m = m + 1

In [None]:
super_walrus_genes = np.array([0, 0.001, 0.002, 0.003, 0.005, 0.007, 0.010, 0.015, 0.020, 0.025, 0.030, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.12, 0.14, 0.16, 0.18, 0.20, 0.25])
steps_to_optimize = [1,2,3,4,6,9,12,15,18,24,30,36]

In [None]:
filled_function = partial(make_run_from_step, 
    e_set = EnsembleRun,
    df_name = 'calib_df_calibrated',
    target = 'ln_ged_sb_dep',
    population_count = 100,
    initial_population = None,
    base_genes = super_walrus_genes,
    number_of_generations = 500
)



In [None]:
cpus = cpu_count()-4 if cpu_count()>2 else 1
cpus - len(steps_to_optimize)

In [None]:
if RunGeneticAlgo:
    generations = Parallel(n_jobs=cpus)(delayed(filled_function)(i) for i in steps_to_optimize)

In [None]:
with open('exploration_pickle/full_gen.pickle', 'wb') as handle:
    pkl.dump(generations, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
#The output contains a list of dictionaries. Each dictionary contains three elements:
# 1. the step that was optimized.
# 2. the memoization ID.
# 3. A generation object, containing a tuple: (a list of organisms sorted by score descending; scores)
generations[10]

In [None]:
# Load the trained weights from file so that it doesn't die.
with open('exploration_pickle/full_gen.pickle', 'rb') as f:
    generations = pickle.load(f)

In [None]:
# Print the memoization id's so that you can explore the training process in the visualizer
for i in generations:
    print (i['step'], i['memoization_id'])

In [None]:
# Fetch the best organism.
for gen in generations:
    print ('\nStep: ',gen['step'],'\n','*'*24,'\n')
    print (gen['generation'][0])
    #The best is always the top organism. You can get the top 20 by slicing gen['generation'][0:20] and so on

In [None]:
# Results:

GeneticAlgoResult = [
{'Org': [0.2, 0.0, 0.0, 0.18, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.05, 0.0, 0.02, 0.0, 0.09, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25],'Fitness': 3.735356478485664,'Step':36},
{'Org': [0.0, 0.0, 0.02, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.001, 0.0, 0.003, 0.0, 0.03, 0.0, 0.003, 0.25, 0.04, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25],'Fitness':4.317382478784119,'Step':30},
{'Org': [0.0, 0.0, 0.001, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.0, 0.2, 0.0, 0.001, 0.015, 0.0, 0.0, 0.0, 0.0, 0.0, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18],'Fitness':4.9754311051495375,'Step':24},
{'Org': [0.0, 0.0, 0.09, 0.09, 0.14, 0.0, 0.16, 0.0, 0.0, 0.001, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.005, 0.0, 0.01, 0.0, 0.0, 0.0, 0.18, 0.0, 0.0, 0.0, 0.0, 0.0, 0.05, 0.0, 0.0, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.001, 0.0, 0.0, 0.0, 0.05, 0.0, 0.0, 0.0, 0.025, 0.0],'Fitness':6.379993395323842,'Step':18},
{'Org': [0.03, 0.0, 0.1, 0.2, 0.0, 0.0, 0.16, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.015, 0.0, 0.05, 0.002, 0.01, 0.1, 0.01, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.007, 0.0, 0.0, 0.09, 0.0],'Fitness':7.45380689233033,'Step':15},
{'Org': [0.003, 0.0, 0.16, 0.16, 0.003, 0.0, 0.03, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.05, 0.0, 0.001, 0.007, 0.001, 0.001, 0.0, 0.0, 0.003, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.002, 0.05, 0.0, 0.04, 0.0, 0.0, 0.0, 0.007, 0.16, 0.0, 0.0, 0.08, 0.01],'Fitness':9.288744127966044,'Step':12},
{'Org': [0.1, 0.001, 0.16, 0.04, 0.025, 0.0, 0.18, 0.0, 0.0, 0.14, 0.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.005, 0.0, 0.03, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03, 0.0, 0.0, 0.002, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.015, 0.0, 0.0, 0.02, 0.0, 0.0, 0.025, 0.12],'Fitness':10.003185652932892,'Step':9},
{'Org': [0.08, 0.08, 0.2, 0.05, 0.0, 0.0, 0.0, 0.005, 0.0, 0.1, 0.18, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.025, 0.001, 0.03, 0.001, 0.0, 0.0, 0.0, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.14],'Fitness':12.062857149691187,'Step':6},
{'Org': [0.0, 0.0, 0.14, 0.18, 0.0, 0.0, 0.003, 0.06, 0.0, 0.01, 0.005, 0.0, 0.0, 0.002, 0.0, 0.0, 0.0, 0.0, 0.09, 0.0, 0.12, 0.0, 0.001, 0.0, 0.0, 0.06, 0.05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003, 0.0, 0.0, 0.05, 0.0, 0.0, 0.14, 0.08],'Fitness':14.844532043805717,'Step':4},
{'Org': [0.0, 0.0, 0.003, 0.01, 0.0, 0.0, 0.0, 0.2, 0.0, 0.005, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.001, 0.001, 0.1, 0.0, 0.0, 0.25, 0.06, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14, 0.09],'Fitness':16.607509430086104,'Step':3},
{'Org': [0.01, 0.0, 0.06, 0.003, 0.0, 0.0, 0.18, 0.18, 0.0, 0.0, 0.007, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04, 0.0, 0.0, 0.0, 0.2, 0.18, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.007, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.025, 0.0, 0.0, 0.1, 0.01],'Fitness':23.24984770618241,'Step':2},
{'Org': [0.025, 0.0, 0.001, 0.015, 0.12, 0.0, 0.07, 0.2, 0.0, 0.1, 0.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08, 0.0, 0.003, 0.002, 0.0, 0.007, 0.07, 0.005, 0.002, 0.001, 0.0, 0.0, 0.0, 0.007, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.015, 0.0, 0.001, 0.1, 0.0, 0.0, 0.07, 0.0],'Fitness':38.53772692394461,'Step':1},
]


In [None]:
len(GeneticAlgoResult[0]['Org'])

In [None]:
# list some weights
modelno = 48
for line in GeneticAlgoResult: 
    print('Weight model', modelno, 'step', line['Step'], ':', line['Org'][modelno])

# Assignment of the genetic weights 

In [None]:
# Reading from GeneticAlgoResult:
w_step = [None] * 37
for line in GeneticAlgoResult:
    w_step[line['Step']] = line['Org']
w_step[1]
print(sum(w_step[1]))


In [None]:
# Linear interpolation of weights:
print(steps_to_optimize)
WeightMatrix = [None] * 37
modelnames = []
for model in EnsembleList[:-1]: 
    modelnames.append(model['modelname'])
for step in steps:
    if step in steps_to_optimize:
#        print(step, 'is optimized')
        WeightMatrix[step] = w_step[step]
    else:
        WeightMatrix[step] = np.nan * len(w_step[1])
   



In [None]:
WeightMatrix

In [None]:


for model in EnsembleList[:-1]:
    if model['Include'] == True:
        EnsembleRun.append(model)
        print(m,model['modelname'])
        m = m + 1

In [None]:
StepAssigner = [1,2,3,4,4,6,6,9,9,9,12,12,12,15,15,15,18,18,18,18,18,24,24,24,24,24,24,30,30,30,30,30,30,36,36,36]
WeightMatrix = [None] * 37
modelnames = []
for model in EnsembleList[:-1]: 
    modelnames.append(model['modelname'])

for step in steps:
#    print('Step',step,'assigned',StepAssigner[step-1])
    WeightMatrix[step] = w_step[StepAssigner[step-1]]
wmt = np.array(WeightMatrix[1:]).T
weights_df = pd.DataFrame(wmt,columns=stepcols[1:],index=modelnames)
weights_df



In [None]:
# Interpolated weights
i_weights_df = weights_df.copy()
for step in steps:
    col = 'step_pred_' + str(step)
    if step == 5:
        prestepcol = 'step_pred_' + str(step-1)
        
        poststepcol = 'step_pred_' + str(step+1)
        i_weights_df[col] = (i_weights_df[prestepcol] + i_weights_df[poststepcol]) / 2
    if step == 7 or step == 10 or step == 13 or step == 16:
        prestepcol = 'step_pred_' + str(step-1)
        poststepcol = 'step_pred_' + str(step+2)
        i_weights_df[col] = ((i_weights_df[prestepcol]*2) + (i_weights_df[poststepcol]*1)) / 3
    if step == 8 or step == 11 or step == 14 or step == 17:
        prestepcol = 'step_pred_' + str(step-2)
        poststepcol = 'step_pred_' + str(step+1)
        i_weights_df[col] = ((i_weights_df[prestepcol]*1) + (i_weights_df[poststepcol]*2)) / 3
    if step == 19 or step == 25 or step == 31:
        prestepcol = 'step_pred_' + str(step-1)
        poststepcol = 'step_pred_' + str(step+5)
        i_weights_df[col] = ((i_weights_df[prestepcol]*5) + (i_weights_df[poststepcol]*1)) / 6
    if step == 20 or step == 26 or step == 32:
        prestepcol = 'step_pred_' + str(step-2)
        poststepcol = 'step_pred_' + str(step+3)
        i_weights_df[col] = ((i_weights_df[prestepcol]*4) + (i_weights_df[poststepcol]*2)) / 6
    if step == 21 or step == 27 or step == 33:
        prestepcol = 'step_pred_' + str(step-3)
        poststepcol = 'step_pred_' + str(step+3)
        i_weights_df[col] = ((i_weights_df[prestepcol]*3) + (i_weights_df[poststepcol]*3)) / 6
    if step == 22 or step == 28 or step == 34:
        prestepcol = 'step_pred_' + str(step-4)
        poststepcol = 'step_pred_' + str(step+2)
        i_weights_df[col] = ((i_weights_df[prestepcol]*2) + (i_weights_df[poststepcol]*4)) / 6
    if step == 23 or step == 29 or step == 35:
        prestepcol = 'step_pred_' + str(step-5)
        poststepcol = 'step_pred_' + str(step+1)
        i_weights_df[col] = ((i_weights_df[prestepcol]*1) + (i_weights_df[poststepcol]*5)) / 6
        
print(steps_to_optimize)
i_weights_df

In [None]:
# Export weights 
i_weights_df.to_csv('GeneticWeights.csv')

In [None]:
import seaborn as sns
palette = 'vlag'
palette = sns.color_palette('BrBG',n_colors=50)
palette = sns.cubehelix_palette(start=2, rot=0, dark=0, light=1, n_colors=100)
overleafpath = '/Users/havardhegre/Dropbox (ViEWS)/Apps/Overleaf/ViEWS predicting fatalities/Figures/Pred_Eval/'

fig, ax =plt.subplots(1,figsize=(16,11))
ax = sns.heatmap(i_weights_df, xticklabels=2, linewidths=.5, cmap=palette,square=True)
filename = overleafpath + 'genetic_weights.png'
plt.savefig(filename, dpi=300)

In [None]:

# Calculating weighted average ensembles
# Based on the weights_df dataframe filled with Mihai's weights above

def ensemble_predictions(yhats, weights):
    # make predictions
    yhats = np.array(yhats)
    # weighted sum across ensemble members
    result = np.dot(weights,yhats)
    return result

# normalize a vector to have unit norm
def normalize(weights):
    # calculate l1 vector norm
    result = norm(weights, 1)
    # check for a vector of all zeros
    if result == 0.0:
        return weights
    # return normalized vector (unit norm)
    return weights / result

i = 0
for col in stepcols[1:]:
    # Unweighted average
    df_calib = StepEnsembles[i]['df_calib'].drop('ln_ged_sb_dep', axis=1)
    df_test = StepEnsembles[i]['df_test'].drop('ln_ged_sb_dep', axis=1)
    StepEnsembles[i]['ensembles_calib']['weighted_average'] = (df_calib*i_weights_df[col]).sum(axis=1)
    StepEnsembles[i]['ensembles_test']['weighted_average'] =  (df_test*i_weights_df[col]).sum(axis=1)
    i = i + 1

In [None]:
# Save the weights dfs
dflist = [
    (i_weights_df,'i_weights_df'), 
]

path = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/MSEs/'
for df in dflist:
    filename = path + df[1] + '.csv'
    df[0].to_csv(filename)
    

In [None]:
# Calculating some ensembles
from sklearn import datasets, linear_model
if False:

    i = 0
    for col in stepcols[1:]:    # Linear regression
        regr = linear_model.LinearRegression()
        df = StepEnsembles[i]['df_calib'][~np.isinf(StepEnsembles[i]['df_calib'])].fillna(0)
        regr.fit(df.drop('ln_ged_sb_dep', axis=1), df['ln_ged_sb_dep'])
    #    regr.fit(StepEnsembles[i]['df_calib'].drop('lndepvar', axis=1), StepEnsembles[i]['df_calib']['lndepvar'])
        StepEnsembles[i]['ensembles_calib']['linear_regression'] = regr.predict(StepEnsembles[i]['df_calib'].drop('ln_ged_sb_dep', axis=1))
        StepEnsembles[i]['ensembles_test']['linear_regression'] = regr.predict(StepEnsembles[i]['df_test'].drop('ln_ged_sb_dep', axis=1))
        # Random forest regression
        rf = RandomForestRegressor()
        rf.fit(StepEnsembles[i]['df_calib'].drop('ln_ged_sb_dep', axis=1), StepEnsembles[i]['df_calib']['ln_ged_sb_dep'])
        StepEnsembles[i]['ensembles_calib']['rf_regression'] = rf.predict(StepEnsembles[i]['df_calib'].drop('ln_ged_sb_dep', axis=1))
        StepEnsembles[i]['ensembles_test']['rf_regression'] = rf.predict(StepEnsembles[i]['df_test'].drop('ln_ged_sb_dep', axis=1))
        i = i + 1

    ensemble_models = ['unweighted_average','linear_regression','rf_regression']

In [None]:
# Reshape the ensemble predictions

if False: 
    linear = {
        'modelname': 'ensemble_linear',
        'algorithm': '',
        'depvar': "ln_ged_sb_dep",
        'calib_df_calibrated': EnsembleList[0]['calib_df_calibrated'].copy(),
        'test_df_calibrated': EnsembleList[0]['test_df_calibrated'].copy()
    }
    rf = {
        'modelname': 'ensemble_rf',
        'algorithm': '',
        'depvar': "ln_ged_sb_dep",
        'calib_df_calibrated': EnsembleList[0]['calib_df_calibrated'].copy(),
        'test_df_calibrated': EnsembleList[0]['test_df_calibrated'].copy()
    }
genetic = {
        'modelname': 'ensemble_genetic',
        'algorithm': '',
        'depvar': "ln_ged_sb_dep",
        'calib_df_calibrated': EnsembleList[0]['calib_df_calibrated'].copy(),
        'test_df_calibrated': EnsembleList[0]['test_df_calibrated'].copy(),
        'calibration_gam': []
    }    

for step in StepEnsembles:
    colname = step['step_pred']
    print(colname)
#    linear['calib_df_calibrated'][colname] = step['ensembles_calib']['linear_regression']
#    linear['test_df_calibrated'][colname] = step['ensembles_test']['linear_regression']
#    rf['calib_df_calibrated'][colname] = step['ensembles_calib']['rf_regression']
#    rf['test_df_calibrated'][colname] = step['ensembles_test']['rf_regression']
    genetic['calib_df_calibrated'][colname] = step['ensembles_calib']['weighted_average']
    genetic['test_df_calibrated'][colname] = step['ensembles_test']['weighted_average']
   
#EnsembleList.append(linear)
#EnsembleList.append(rf)
# Adding placeholder data for the old calibration method for the weighted ensemble

genetic['calib_df_cal_expand'] = genetic['calib_df_calibrated']
genetic['test_df_cal_expand'] = genetic['test_df_calibrated']
EnsembleList.append(genetic)

In [None]:
# Save ensemble predictions
IncludeFuture = False
predstore_calib = level +  '_' + genetic['modelname'] + '_calib'
genetic['calib_df_calibrated'].forecasts.set_run(run_id)
genetic['calib_df_calibrated'].forecasts.to_store(name=predstore_calib, overwrite = True)
predstore_test = level +  '_' + genetic['modelname'] + '_test'
genetic['test_df_calibrated'].forecasts.set_run(run_id)
genetic['test_df_calibrated'].forecasts.to_store(name=predstore_test, overwrite = True)
if IncludeFuture:
    predstore_future = level +  '_' + genetic['modelname'] + '_future'
    genetic['future_df_calibrated'].forecasts.set_run(run_id)
    genetic['future_df_calibrated'].forecasts.to_store(name=predstore_future, overwrite = True)

In [None]:
# Check they are there
ViewsMetadata().mine().with_name('genetic').fetch()

In [None]:
# Saving revised model list, stripped down in dropbox: 
dbpath = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/Predictions/'
db_EnsembleList = []
for model in EnsembleList:
    print(model['modelname'])
    if 'Markov' in model['modelname']:
        db_dict = {
            'modelname':       model['modelname'],
            'algorithm':       '',
            'depvar':          model['depvar'],
            'queryset':        model['queryset'],
            'calibration_gam': []
        }
    elif 'ensemble' in model['modelname']:
        db_dict = {
            'modelname':       model['modelname'],
            'algorithm':       '',
            'depvar':          model['depvar'],
            'queryset':        '',
            'calibration_gam': []
        }
    else:
        db_dict = {
            'modelname':       model['modelname'],
            'algorithm':       str(model['algorithm']),
            'depvar':          model['depvar'],
            'queryset':        model['queryset'],
            'calibration_gam': model['calibration_gam']
        }
    db_EnsembleList.append(db_dict)
picklename = dbpath + 'ModelList_calibrated_db_cm_' + run_id + '.p'
#pkl.dump(ModelList, open(picklename, "wb" ) )
pkl.dump(db_EnsembleList, open(picklename, "wb" ) )

In [None]:
EnsembleList[1]['calibration_gam']

# Correlation of predictions

In [None]:
# Calculate 
meancorr_df = pd.DataFrame(StepEnsembles[0]['df_calib'].corr().mean(axis=1))
for step in [2,5,11,23,35]:
    colname = 'step_' + str(step+1)
    meancorr_df[colname] = StepEnsembles[step]['df_calib'].corr().mean(axis=1)
meancorr_df['average'] = meancorr_df.mean(axis=1)
meancorr_df
# Save the corr dfs
dflist = [
    (meancorr_df,'meancorr_df'), 
]

path = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/MSEs/'
for df in dflist:
    filename = path + df[1] + '.csv'
    df[0].to_csv(filename)
    

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
level = 'cm'

cols_to_see = ['lndepvar', 'fat_hh20_xgb', 'fat_hh20_hurdle_xgb','fat_hh20_rf','fat_hh20_xgbrf']

sns.set_context("notebook")
sns.set() # Setting seaborn as default style even if use only matplotlib
path = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/PredictionPlots/'
overleafpath = '/Users/havardhegre/Dropbox (ViEWS)/Apps/Overleaf/ViEWS predicting fatalities/Figures/'
cm = "YlGnBu" 
#cm = 'mako'
#cm = 'rocket'
fig, ((ax1, ax3), (ax6,ax12), (ax24,ax36)) = plt.subplots(3, 2, figsize=(30,30))
for subplot in [(ax1,0),(ax3,2),(ax6,5),(ax12,11),(ax24,23),(ax36,35)]:
    subplot[0].set_box_aspect(1)
    sns.heatmap(StepEnsembles[subplot[1]]['df_calib'].corr(), ax=subplot[0], cmap=cm) 
    subplot_title=('step ' + str(subplot[1]+1))
    subplot[0].set_title = subplot_title

plt.tight_layout()
plt.show()
figname = overleafpath + 'Correlations/PredictionCorrelations_calib_' + level + '.png'
fig.savefig(figname, dpi=300)

In [None]:
corrdf = StepEnsembles[0]['df_test'].corr()
corrdf


In [None]:
meancorr_df.head()

# Create sc predictions and country prediction tables

In [None]:
EnsembleList[0]['test_df_calibrated'].describe()

In [None]:
# Construct a set of step-combined series starting from a series of start months
# -- to see how predictions react to events at different calendar times
for model in EnsembleList:
    print(model['modelname'])
    df = model['test_df_calibrated']
    last_in_training = 444
    last_in_test = 492
    # t_range specifies the duration of the step-combined series to construct
    t_range = range(0, 48)
    step_range = range(1,36)
    model['sc_df'] = pd.DataFrame(df['ln_ged_sb_dep'])
    for month in t_range:
        # Create a column to hold predictions starting from a given "last observed" month
        col = 'sc_' + str(last_in_training + month)
#        print(col)
        model['sc_df'][col] = np.NaN
        for step in step_range:
            if (last_in_training + month+step) <= last_in_test: # To avoid generating series beyond last month in partition
                predcol = 'step_pred_' + str(step) # The column in the in-df that contains predictions for this step
#                print('For month', last_in_training + month + step, 'use', predcol, last_in_training + step + month)
                model['sc_df'][col].loc[[last_in_training + month + step], :] = df[predcol].loc[last_in_training + step + month,:].values
    model['sc_df_smooth']=model['sc_df'].rolling(3,center=True).mean().groupby(level=1)   
    # Sub-list of predictions by country 

    model['CountryList'] = []
    countries = model['test_df_calibrated'].index.unique(level='country_id').tolist()
    for cnt in range(250):
        cntdict = {
        'country_id': cnt,
        'country_name': ''
        }
        if cnt in countries:
            cntdict['country_id'] = cnt
            cntdict['test_df_calibrated'] = model['test_df_calibrated'].xs(cnt, level='country_id').copy()
            cntdict['sc_df'] = model['sc_df'].xs(cnt, level='country_id')
            cntdict['sc_df_smooth'] = cntdict['sc_df'].copy()
            predcols = cntdict['sc_df_smooth'].columns[1:]
            cntdict['sc_df_smooth'][predcols] = cntdict['sc_df'][predcols].rolling(3,center=True).mean()
            cntdict['sc_df_smooth'] = cntdict['sc_df_smooth'].fillna(cntdict['sc_df'])
        model['CountryList'].append(cntdict)
    

In [None]:
print(last_in_training, step, month)

In [None]:
step = 1
month = 1
col = 'sc_445'
df[predcol].loc[last_in_training + step + month,:]

In [None]:
model['sc_df'][col].loc[[last_in_training + month + step], :] = df[predcol].loc[last_in_training + step + month,:].values

In [None]:
model['sc_df'][col].loc[[last_in_training + month + step], :]

In [None]:
EnsembleList[0]['test_df_calibrated'].head()

In [None]:
cnt = 246
EnsembleList[0]['sc_df'].xs(cnt, level='country_id').head(20)

In [None]:
# Select models from this data structure:
modelname = 'baseline_rf'
cid = 50
step_to_inspect = 'step_pred_1'

thismodel = [mod for mod in EnsembleList if mod['modelname'] == modelname][0]
print(thismodel['test_df_calibrated'][step_to_inspect].describe())

# Select country predictions for model:

thiscm = [cnt for cnt in thismodel['CountryList'] if cnt['country_id'] == cid][0]
print(thiscm['test_df_calibrated'][step_to_inspect].describe())

thiscm2 = thismodel['CountryList'][cid]
print(thiscm2['test_df_calibrated'][step_to_inspect].describe())

In [None]:
# Inspect calibration dfs for the models in the ensemble:
cols = ['ln_ged_sb_dep','step_pred_1','step_pred_12','step_pred_24','step_pred_36']
pti = [.75,.90,.95,.99]
for m in EnsembleList:
    print(m['modelname'])
    print('Before calibration')
    print(m['predictions_calib_df'][cols].describe(percentiles=pti))
    print(m['predictions_test_df'][cols].describe(percentiles=pti))
#    print('Expanded parameter:',m['expanded'],'shiftsize parameter:',m['shiftsize'])
    print('After calibration, calibration partition (top), test partition (bottom)')
    print(m['calib_df_calibrated'][cols].describe(percentiles=pti))
    print(m['test_df_calibrated'][cols].describe(percentiles=pti))

## Identify escalation and de-escalation periods based on sc for lndepvar

In [None]:
print(EnsembleList[-1]['modelname'])
sc_df = EnsembleList[-1]['sc_df']

# Rolling means
sc_df['rolldep_3'] =  sc_df.groupby(level=1)['ln_ged_sb_dep'].transform(lambda x: x.rolling(3, 3).mean())
sc_df['rolldep_6'] =  sc_df.groupby(level=1)['ln_ged_sb_dep'].transform(lambda x: x.rolling(6, 6).mean())
sc_df['rolldep_10'] =  sc_df.groupby(level=1)['ln_ged_sb_dep'].transform(lambda x: x.rolling(10, 10).mean())
sc_df['rolldep_3_lag'] =  sc_df.groupby(level=1)['rolldep_3'].shift(1)
sc_df['rolldep_10_lag'] =  sc_df.groupby(level=1)['rolldep_10'].shift(1)

# Escalation: 
# If log monthly fatalities is escalation_threshold = 2.398 (appr. log(10)) higher than last month's 10-month mean
# or rolling 3-month mean is escalation_threshold = 2.398 higher than last month's 10-month mean and monthly > 0
escalation_threshold = 1
sc_df['escalation'] = np.nan
sc_df['escalation'].loc[457:] = 0
sc_df['escalation'][sc_df['ln_ged_sb_dep'] > sc_df['rolldep_3_lag'] + escalation_threshold] = 1
sc_df['escalation'][sc_df['rolldep_3'] > sc_df['rolldep_10_lag'] + (escalation_threshold)] = 1

cols_to_inspect = ['ln_ged_sb_dep','rolldep_3','rolldep_6','rolldep_10','escalation']
# 57: Ethiopia
# 47 Burkina Faso
# 50 Mali
# 79 Nigeria
cnt= 79
sc_df[cols_to_inspect].xs(cnt, level='country_id')

In [None]:
cols_to_inspect = ['lndepvar','escalation']
sc_df[cols_to_inspect].xs(cnt, level='country_id').plot()

## Evaluation of constituent models

In [None]:
model['calib_df_calibrated'].describe()

In [None]:
# Evaluation of constituent models
calculate_grpMSEs = False

MSE_calib_all = []
MSE_calib_zeros = []
MSE_calib_nonzeros = []
if calculate_grpMSEs:
    MSE_calib_grp0 = []
    MSE_calib_grp1 = []
    MSE_calib_grp2 = []
    MSE_calib_grp3 = []
MSE_test_all = []
MSE_test_zeros = []
MSE_test_nonzeros = []
MSE_test_exp_all = []
MSE_test_exp_zeros = []
MSE_test_exp_nonzeros = []

for model in EnsembleList:
    calib_all_line = [model['modelname']]
    calib_zeros_line = [model['modelname']]
    calib_nonzeros_line = [model['modelname']]
    test_all_line = [model['modelname']]
    test_zeros_line = [model['modelname']]
    test_nonzeros_line = [model['modelname']]
    test_exp_all_line = [model['modelname']]
    test_exp_zeros_line = [model['modelname']]
    test_exp_nonzeros_line = [model['modelname']]
    print(model['modelname'])
    model['mse_calib'] = []
    model['mse_calib_zeros'] = []
    model['mse_calib_nonzeros'] = []
    model['mse_test'] = []
    model['mse_test_zeros'] = []
    model['mse_test_nonzeros'] = []
    model['mse_test_exp'] = []
    model['mse_test_exp_zeros'] = []
    model['mse_test_exp_nonzeros'] = []
    if calculate_grpMSEs:
        calib_grp0_line = [model['modelname']]
        calib_grp1_line = [model['modelname']]
        calib_grp2_line = [model['modelname']]
        calib_grp3_line = [model['modelname']]
        model['mse_calib_grp0'] = []
        model['mse_calib_grp1'] = []
        model['mse_calib_grp2'] = []
        model['mse_calib_grp3'] = []
    for cnt in model['CountryList']:
        if cnt['country_id'] in countries:
            cnt['mse'] = []
    for col in stepcols[1:]:
        # Remove from evaluation rows where [col] has infinite values (due to the 2011 split of Sudan)
        df_calib = model['calib_df_calibrated'][~np.isinf(model['calib_df_calibrated'][col])]
        df_test = model['test_df_calibrated'][~np.isinf(model['test_df_calibrated'][col])]
        df_test_exp = model['test_df_cal_expand'][~np.isinf(model['test_df_cal_expand'][col])]

        mse_calib = mean_squared_error(df_calib[col], df_calib['ln_ged_sb_dep'])
        model['mse_calib'].append(mse_calib)
        calib_all_line.append(mse_calib)
        
        mse_calib_zeros = mean_squared_error(df_calib[col].loc[df_calib['ln_ged_sb_dep'] == 0], df_calib['ln_ged_sb_dep'].loc[df_calib['ln_ged_sb_dep'] == 0])
        model['mse_calib_zeros'].append(mse_calib_zeros)
        calib_zeros_line.append(mse_calib_zeros)
        
        mse_calib_nonzeros = mean_squared_error(df_calib[col].loc[df_calib['ln_ged_sb_dep'] > 0], df_calib['ln_ged_sb_dep'].loc[df_calib['ln_ged_sb_dep'] > 0])
        model['mse_calib_nonzeros'].append(mse_calib_nonzeros)
        calib_nonzeros_line.append(mse_calib_nonzeros)
        
        
        if calculate_grpMSEs:
            # MSE for groups of cases based on baseline model predictions:
            # Group 0
            df_calib_grp0 = df_calib[onset_mask_calib[col]=='grp0']
            mse_calib_grp0 = mean_squared_error(df_calib_grp0[col], df_calib_grp0['ln_ged_sb_dep'])
            model['mse_calib_grp0'].append(mse_calib_grp0)
            calib_grp0_line.append(mse_calib_grp0)

            # Group 1
            df_calib_grp1 = df_calib[onset_mask_calib[col]=='grp1']
            mse_calib_grp1 = mean_squared_error(df_calib_grp1[col], df_calib_grp1['ln_ged_sb_dep'])
            model['mse_calib_grp1'].append(mse_calib_grp1)
            calib_grp1_line.append(mse_calib_grp1)

            # Group 2
            df_calib_grp2 = df_calib[onset_mask_calib[col]=='grp2']
            mse_calib_grp2 = mean_squared_error(df_calib_grp2[col], df_calib_grp2['ln_ged_sb_dep'])
            model['mse_calib_grp2'].append(mse_calib_grp2)
            calib_grp2_line.append(mse_calib_grp2)

            # Group 3
            df_calib_grp3 = df_calib[onset_mask_calib[col]=='grp3']
            mse_calib_grp3 = mean_squared_error(df_calib_grp3[col], df_calib_grp3['ln_ged_sb_dep'])
            model['mse_calib_grp3'].append(mse_calib_grp3)
            calib_grp3_line.append(mse_calib_grp3)
        
        
#        mse_test = mean_squared_error(model['predictions_test_df'][col], model['predictions_test_df']['ln_ged_sb_dep'])
        mse_test = mean_squared_error(df_test[col], target['y_test'])
        model['mse_test'].append(mse_test)
        test_all_line.append(mse_test)
        
        mse_zeros = mean_squared_error(df_test[col].loc[df_test['ln_ged_sb_dep'] == 0], df_test['ln_ged_sb_dep'].loc[df_test['ln_ged_sb_dep'] == 0])
        model['mse_test_zeros'].append(mse_zeros)
        test_zeros_line.append(mse_zeros)
        
        mse_nonzeros = mean_squared_error(df_test[col].loc[df_test['ln_ged_sb_dep'] > 0], df_test['ln_ged_sb_dep'].loc[df_test['ln_ged_sb_dep'] > 0])
        model['mse_test_nonzeros'].append(mse_nonzeros)
        test_nonzeros_line.append(mse_nonzeros)

        mse_test_exp = mean_squared_error(df_test_exp[col], target['y_test'])
        model['mse_test_exp'].append(mse_test_exp)
        test_exp_all_line.append(mse_test_exp)
        
        mse_exp_zeros = mean_squared_error(df_test_exp[col].loc[df_test_exp['ln_ged_sb_dep'] == 0], df_test_exp['ln_ged_sb_dep'].loc[df_test_exp['ln_ged_sb_dep'] == 0])
        model['mse_test_exp_zeros'].append(mse_exp_zeros)
        test_exp_zeros_line.append(mse_exp_zeros)
        
        mse_exp_nonzeros = mean_squared_error(df_test_exp[col].loc[df_test_exp['ln_ged_sb_dep'] > 0], df_test_exp['ln_ged_sb_dep'].loc[df_test_exp['ln_ged_sb_dep'] > 0])
        model['mse_test_exp_nonzeros'].append(mse_exp_nonzeros)
        test_exp_nonzeros_line.append(mse_exp_nonzeros)


        countries = model['test_df_calibrated'].index.unique(level='country_id').tolist()
        for cnt in model['CountryList']:
            if cnt['country_id'] in countries:
                df_test = cnt['test_df_calibrated'][~np.isinf(cnt['test_df_calibrated'][col])]            
                cnt_mse = mean_squared_error(df_test[col], df_test['ln_ged_sb_dep'])
                cnt['mse'].append(cnt_mse)
        
    MSE_calib_all.append(calib_all_line)
    MSE_calib_zeros.append(calib_zeros_line)
    MSE_calib_nonzeros.append(calib_nonzeros_line)
    if calculate_grpMSEs:
        MSE_calib_grp0.append(calib_grp0_line)
        MSE_calib_grp1.append(calib_grp1_line)
        MSE_calib_grp2.append(calib_grp2_line)
        MSE_calib_grp3.append(calib_grp3_line)
    MSE_test_all.append(test_all_line)
    MSE_test_zeros.append(test_zeros_line)
    MSE_test_nonzeros.append(test_nonzeros_line)
    MSE_test_exp_all.append(test_exp_all_line)
    MSE_test_exp_zeros.append(test_exp_zeros_line)
    MSE_test_exp_nonzeros.append(test_exp_nonzeros_line)
    
MSE_calib_all_df = pd.DataFrame(MSE_calib_all, columns=stepcols) 
MSE_calib_all_df.set_index('ln_ged_sb_dep', inplace=True)
MSE_calib_zeros_df = pd.DataFrame(MSE_calib_zeros, columns=stepcols) 
MSE_calib_zeros_df.set_index('ln_ged_sb_dep', inplace=True)
MSE_calib_nonzeros_df = pd.DataFrame(MSE_calib_nonzeros, columns=stepcols) 
MSE_calib_nonzeros_df.set_index('ln_ged_sb_dep', inplace=True)
if calculate_grpMSEs:
    MSE_calib_grp0_df = pd.DataFrame(MSE_calib_grp0, columns=stepcols) 
    MSE_calib_grp1_df = pd.DataFrame(MSE_calib_grp1, columns=stepcols) 
    MSE_calib_grp2_df = pd.DataFrame(MSE_calib_grp2, columns=stepcols) 
    MSE_calib_grp3_df = pd.DataFrame(MSE_calib_grp3, columns=stepcols) 
MSE_test_all_df = pd.DataFrame(MSE_test_all, columns=stepcols)  
MSE_test_zeros_df = pd.DataFrame(MSE_test_zeros, columns=stepcols)  
MSE_test_nonzeros_df = pd.DataFrame(MSE_test_nonzeros, columns=stepcols)  
MSE_test_exp_all_df = pd.DataFrame(MSE_test_exp_all, columns=stepcols)  
MSE_test_exp_zeros_df = pd.DataFrame(MSE_test_exp_zeros, columns=stepcols)  
MSE_test_exp_nonzeros_df = pd.DataFrame(MSE_test_exp_nonzeros, columns=stepcols)  

print('All models done')

In [None]:
# Save the MSE dfs
dflist = [
    (MSE_calib_all_df,'MSE_calib_all_df'),
    (MSE_calib_zeros_df,'MSE_calib_zeros_df'),
    (MSE_calib_nonzeros_df,'MSE_calib_nonzeros_df'),
    (MSE_test_all_df,'MSE_test_all_df'),
    (MSE_test_zeros_df,'MSE_test_zeros_df'),
    (MSE_test_nonzeros_df,'MSE_test_nonzeros_df'),   
    (MSE_test_exp_all_df,'MSE_test_exp_all_df'),
    (MSE_test_exp_zeros_df,'MSE_test_exp_zeros_df'),
    (MSE_test_exp_nonzeros_df,'MSE_test_exp_nonzeros_df')   
]

path = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/MSEs/'
for df in dflist:
    filename = path + df[1] + '.csv'
    df[0].to_csv(filename)
    

## List global MSEs

In [None]:
for model in EnsembleList:
    print(model['modelname'])
    print('MSE calibration partition:', model['mse_calib'])
    print('MSE test partition:', model['mse_test'])
    print('MSE test partition, zeros:', model['mse_test_zeros'])
    print('MSE test partition, non-zeros:', model['mse_test_nonzeros'])

# Plotting performance as heatmaps

## Ablation MSE

In [None]:
# Heatmap of ablation MSEs
# df to hold ablation MSEs
import seaborn as sns
abl_df = pd.DataFrame(0.0, index=np.arange(len(EnsembleList[0]['Ablation_MSE'])), columns=mlist)    
for model in EnsembleList[0:-2]: # Assuming the ablated ensemble exists
    abl_df[model['modelname']] = model['Ablation_MSE']
abl_df = abl_df[1:].T
abl_df.columns=stepcols[1:]

#palette = 'Spectral'
#plt.figure()
palette = 'vlag'
fig, ax =plt.subplots(1,figsize=(16,11))
ax = sns.heatmap(abl_df, center=0, xticklabels=2, linewidths=.5, cmap=palette,square=True)

overleafpath = '/Users/havardhegre/Dropbox (ViEWS)/Apps/Overleaf/ViEWS predicting fatalities/Figures/Pred_Eval/'
filename = overleafpath + 'Ablation_MSEs.png'
plt.savefig(filename, dpi=300)


## Dataframes with country-specific MSEs

In [None]:
# Making a dataframe for each model, plotting with Matshow
Countries_to_plot = [41,42,43,47,48,49,50,52,53,54,55,56,60,62,67,69,70,124,155,156,157,162]
Namelist = ['Cote dIvoire','Ghana','Liberia','47','48','49','50',
            '52','53','54','55','56','60','62','67',
            '69','70','124','155','156','157','162']


for model in EnsembleList:
    listdata = []
#    countries = list(EnsembleList[0]['test_df_calibrated'].index.unique(level='country_id'))
    for cnt in Countries_to_plot:
        row = [cnt] + model['CountryList'][cnt]['mse']
        listdata.append(row)
    colnames = ['Country'] + stepcols[1:]
    model['CC_MSEs'] = pd.DataFrame(listdata,columns=colnames) 
    
    plt.matshow(model['CC_MSEs'][stepcols[1:]])
    cb = plt.colorbar()
    cb.ax.tick_params(labelsize=14)
    plt.xticks(steps)
#    plt.yticks(Namelist)
#    ax.set_yticks(Countries_to_plot)
#    ax.set_yticklabels(NameList)

    plt.show()
    
    


In [None]:
print(EnsembleList[24]['modelname'])

In [None]:
# Plotting prediction vs actuals
#ModelSelection=[0,2,3,6,8,9,13,17:24,26:27,31,35, 37:38]


plx = 2
ply = 2
fig, axs = plt.subplots(plx, ply, sharey=True,sharex=True,figsize=(13,12))
#to do log scales, use numpy package for the scale value to plot in xticks
log_scale_value = np.array([np.log1p(0), np.log1p(1), np.log1p(10), np.log1p(100), np.log1p(1000), np.log1p(10000)])
log_scale_naming = ['0','1', '10', '100', '1000','10000']

month = [487,488,489,490,491,492]
step = 13
predvar = 'step_pred_' + str(step)
size = 20 

fpx = 0
fpy = 0
model_data = []
for model in [EnsembleList[47],EnsembleList[47],EnsembleList[47],EnsembleList[47]]:
    print(model['modelname'], fpx, fpy)
    print('Prediction mean: ', model['test_df_calibrated'][predvar].mean())
    axs[fpx,fpy].scatter( model['test_df_calibrated']['ln_ged_sb_dep'].loc[month], model['test_df_calibrated'][predvar].loc[month],s=size, alpha=0.5)
    axs[fpx,fpy].set_ylabel(model['modelname'], fontsize=10)
    axs[fpx,fpy].set_xlabel('Actually observed', fontsize=10)
    plt.xticks(log_scale_value, log_scale_naming, rotation=30)
    plt.yticks(log_scale_value, log_scale_naming, rotation=30)
    axs[fpx,fpy].grid(True)
#    axs[fpx,fpy].
#    axs[fpx,fpy].
    if fpx==plx-1:
        fpx = 0
        fpy = fpy + 1
    else:
        fpx = fpx + 1

fig.tight_layout()

#plt.show()

overleafpath = '/Users/havardhegre/Dropbox (ViEWS)/Apps/Overleaf/ViEWS predicting fatalities/Figures/PredictionPlots/'
filename = overleafpath + 'PredictionVsActuals_cm_s' + str(step) + '.png'
plt.savefig(filename, dpi=300)

In [None]:
model['test_df_calibrated']

In [None]:
# Identify the outliers
df = model['test_df_calibrated']
cols = ['ln_ged_sb_dep','step_pred_2']
print(df[cols].loc[df['ln_ged_sb_dep']>np.log1p(1000)].loc[month].head(20))
# 57: Ethiopia
# 120: Somalia
# 126: Azerbaijan (Nagorno-Karabakh)
# 133: Afghanistan
# 167: DRC
# 220: Mali

for i in [np.log1p(10), np.log1p(25), np.log1p(100), np.log1p(1000)]:
    print('exp(', str(i),')', np.rint(np.expm1(i)))

In [None]:
# Mapping the predictions
engine = sa.create_engine(source_db_path)
gdf = gpd.GeoDataFrame.from_postgis(
    "SELECT id as country_id, in_africa, in_me, geom FROM prod.country", 
    engine, 
    geom_col='geom'
)


month = [480]
step = 6
predvar = 'step_pred_' + str(step)

for model in EnsembleList:
    print(model['modelname'])
    monthdata = model['test_df_calibrated'].loc[month]
    data = monthdata.join(gdf.set_index("country_id"))
    gdf = gpd.GeoDataFrame(data, geometry="geom")

# From Malika
sb_457 = mapper.Mapper(
    width=10,
    height=10,
    frame_on=True,
    title="State based fatalities, Jan 2018",
    bbox=[-18.5, 64.0, -35.5, 43.0], 
).add_layer(
    gdf=gdf.loc[457],
    cmap="viridis",
    edgecolor="black",
    linewidth=0.5,
    column="ln_ged_sb"
).add_colorbar(
cmap = 'viridis', 
vmin=gdf.loc[457]['ln_ged_sb'].min(), 
vmax=gdf.loc[457]['ln_ged_sb'].max()
)

sb_457.cbar.set_ticks(scale)
sb_457.cbar.set_ticklabels(labels)


In [None]:
model['sc_df'].head()
for month_id, cnt_df in model['sc_df'].groupby(level=1):
    country_id = cnt_df.index.values[0][1]
    print(country_id)
    model['sc_df_smooth']=model['sc_df'].rolling(3,center=True).mean()

In [None]:
    # Make a smoothed version
    model['sc_df_smooth']=model['sc_df'].rolling(3,center=True).mean().groupby(level=1)
#    model['sc_df_smooth'] = model['sc_df'].copy()
#    for col in model['sc_df_smooth'].columns[1:]:
#        model['sc_df_smooth'].col

In [None]:
CountryData = pd.DataFrame([
    [1,'Guyana',211982.004988,'Georgetown',-58.2,6.8,0,110,1966,5,26,2016,6,30,110,1966,5,26,2016,6,30,'Guyana',328,'GY','GUY'
],[2,'Suriname',145952.274029,'Paramaribo',-55.2,5.833333,1,115,1975,11,25,2016,6,30,115,1975,11,25,2016,6,30,'Suriname',740,'SR','SUR'
],[3,'Trinidad and Tobago',5041.72895193,'Port-of-Spain',-61.5,10.65,2,52,1962,8,31,2016,6,30,52,1962,8,31,2016,6,30,'Trinidad and Tobago',780,'TT','TTO'
],[4,'Venezuela',916782.217193,'Caracas',-66.9,10.5,3,101,1946,1,1,2016,6,30,101,1946,1,1,2016,6,30,'Venezuela',862,'VE','VEN'
],[5,'Samoa',2955.21236649,'Apia',-172,-13.8,4,990,1976,12,15,2016,6,30,990,1962,1,1,2016,6,30,'Samoa',882,'WS','WSM'
],[6,'Tonga',464.74733142,'Nukualofa',-175,-21.1,5,955,1999,9,14,2016,6,30,972,1970,6,4,2016,6,30,'Tonga',776,'TO','TON'
],[7,'Argentina',2787442.09772,'Buenos Aires',-58.7,-34.6,6,160,1946,1,1,2016,6,30,160,1946,1,1,2016,6,30,'Argentina',32,'AR','ARG'
],[8,'Bolivia',1092697.43562,'La Paz',-68.2,-16.5,7,145,1946,1,1,2016,6,30,145,1946,1,1,2016,6,30,'Bolivia',68,'BO','BOL'
],[9,'Brazil',8523619.57145,'Brasilia',-47.9,-15.8,8,140,1960,4,21,2016,6,30,140,1960,4,21,2016,6,30,'Brazil',76,'BR','BRA'
],[10,'Chile',745808.493557,'Santiago',-70.7,-33.5,9,155,1946,1,1,2016,6,30,155,1946,1,1,2016,6,30,'Chile',152,'CL','CHL'
],[11,'Ecuador',257026.690766,'Quito',-78.5,-0.2166667,10,130,1946,1,1,2016,6,30,130,1946,1,1,2016,6,30,'Ecuador',218,'EC','ECU'
],[12,'Paraguay',400653.400582,'Asuncion',-57.7,-25.3,11,150,1946,1,1,2016,6,30,150,1946,1,1,2016,6,30,'Paraguay',600,'Py','PRY'
],[13,'Peru',1299028.07763,'Lima',-77,-12.1,12,135,1946,1,1,2016,6,30,135,1946,1,1,2016,6,30,'Peru',604,'Pe','PER'
],[14,'Uruguay',178357.24144,'Montevideo',-56.2,-34.9,13,165,1946,1,1,2016,6,30,165,1946,1,1,2016,6,30,'Uruguay',858,'Uy','URY'
],[15,'Canada',9526178.12487,'Ottawa',-75.7,45.42083,14,20,1946,1,1,1948,6,30,20,1946,1,1,1948,6,30,'',0,'',''
],[27,'Belize',22208.8585502,'Belmopan',-88.8,17.25,26,80,1981,9,21,2016,6,30,80,1981,9,21,2016,6,30,'Belize',84,'BZ','BLZ'
],[28,'Colombia',1142725.71315,'Bogota',-74.1,4.6,27,100,1946,1,1,2016,6,30,100,1946,1,1,2016,6,30,'colombia',170,'co','cOL'
],[16,'Guatemala',109653.931118,'Guatemala City',-90.5,14.62111,15,90,1946,1,1,2016,6,30,90,1946,1,1,2016,6,30,'Guatemala',320,'GT','GTM'
],[17,'Mexico',1965660.69014,'Mexico City',-99.1,19.43417,16,70,1946,1,1,2016,6,30,70,1946,1,1,2016,6,30,'Mexico',484,'MX','MEX'
],[18,'Barbados',448.945911949,'Bridgetown',-59.6,13.1,17,53,1966,11,30,2016,6,30,53,1966,11,30,2016,6,30,'Barbados',52,'Bb','BRB'
],[19,'Dominica',774.491216087,'Roseau',-61.4,15.3,18,54,1978,11,3,2016,6,30,54,1978,11,3,2016,6,30,'Dominica',212,'DM','DMA'
],[20,'Grenada',348.318551633,'Saint Georges',-61.8,12.05,19,55,1974,2,7,2016,6,30,55,1974,2,7,2016,6,30,'Grenada',308,'GD','GRD'
],[21,'St. Lucia',639.059498815,'castries',-61,14,20,56,1979,2,22,2016,6,30,56,1979,2,22,2016,6,30,'Saint Lucia',662,'LC','LCA'
],[22,'St. Vincent and the Grenadines',344.300049243,'Kingstown',-61.2,13.16667,21,57,1979,10,27,2016,6,30,57,1979,10,27,2016,6,30,'Saint Vincent and the Grenadines',670,'VC','VCT'
],[36,'St. Kitts and Nevis',196.699148006,'Basseterre',-62.7,17.3,35,60,1983,9,19,2016,6,30,60,1983,9,19,2016,6,30,'Saint Kitts and Nevis',659,'Kn','KNA'
],[37,'Iceland',102510.041259,'Reykjavik',-21.9,64.15,36,395,1946,1,1,2016,6,30,395,1946,1,1,2016,6,30,'Iceland',352,'Is','ISL'
],[38,'Ireland',69507.08947,'Dublin',-6.248889,53.33306,37,205,1946,1,1,2016,6,30,205,1946,1,1,2016,6,30,'Ireland',372,'Ie','IRL'
],[39,'United Kingdom',243845.075909,'London',-0.1166667,51.5,38,200,1946,1,1,2016,6,30,200,1946,1,1,2016,6,30,'United Kingdom',826,'Gb','GBR'
],[23,'Dominican Republic',48625.6990635,'Santo Domingo',-69.9,18.46667,22,42,1946,1,1,2016,6,30,42,1946,1,1,2016,6,30,'Dominican Republic',214,'Do','DOM'
],[140,'Brunei',5785.18039944,'Bandar Seri Begawan',115,4.883333,141,835,1984,1,1,2016,6,30,835,1984,1,1,2016,6,30,'Brunei Darussalam',96,'Bn','BRN'
],[141,'China',9424689.74644,'Beijing',116,39.92889,142,710,1946,1,1,1949,12,7,710,1946,1,1,1949,12,7,'',0,'',''
],[142,'Japan',371862.299662,'Tokyo',140,35.685,143,740,1952,4,28,2016,6,30,740,1946,1,1,2016,6,30,'Japan',392,'JP','JPN'
],[24,'HaitI',27315.4591425,'Port-au-Prince',-72.3,18.5392,23,41,1946,1,1,2016,6,30,41,1946,1,1,2016,6,30,'HaitI',332,'Ht','HTI'
],[25,'Jamaica',11097.4197969,'Kingston',-76.8,18,24,51,1962,8,6,2016,6,30,51,1962,8,6,2016,6,30,'Jamaica',388,'JM','JAM'
],[26,'Bahamas',12192.0404721,'Nassau',-77.3,25.08333,25,31,1973,7,10,2016,6,30,31,1973,7,10,2016,6,30,'Bahamas',44,'Bs','BHS'
],[29,'Costa Rica',51402.6253926,'San Jose',-84.1,9.933333,28,94,1946,1,1,2016,6,30,94,1946,1,1,2016,6,30,'Costa Rica',188,'cr','cRI'
],[30,'Cuba',109740.925034,'Havana',-82.4,23.13194,29,40,1946,1,1,2016,6,30,40,1946,1,1,2016,6,30,'cuba',192,'cu','cUB'
],[31,'El Salvador',20692.8390883,'San Salvador',-89.2,13.70861,30,92,1946,1,1,2016,6,30,92,1946,1,1,2016,6,30,'El Salvador',222,'SV','SLV'
],[153,'Micronesia',519.642298193,'Palikir',158,6.916667,154,987,1991,9,17,2016,6,30,987,1986,11,3,2016,6,30,"Micronesia', Federated States of",583,'FM','FSM'
],[32,'Honduras',112871.255492,'Tegucigalpa',-87.2,14.1,31,91,1946,1,1,2016,6,30,91,1946,1,1,2016,6,30,'Honduras',340,'Hn','HND'
],[155,'BurundI',27367.749959,'Bujumbura',29.36,-3.376111,156,516,1962,7,1,2016,6,30,516,1962,7,1,2016,6,30,'BurundI',108,'BI','BDI'
],[33,'Nicaragua',128883.087518,'Managua',-86.3,12.15083,32,93,1946,1,1,2016,6,30,93,1946,1,1,2016,6,30,'Nicaragua',558,'NI','NIC'
],[40,'Cape Verde',4054.30324666,'Praia',-23.5,14.91667,39,402,1975,7,5,2016,6,30,402,1975,7,5,2016,6,30,'Cape Verde',132,'cV','cPV'
],[41,'Cote dIvoire',323412.228545,'Yamoussoukro',-5.283333,6.816667,40,437,1960,8,7,2016,6,30,437,1960,8,7,2016,6,30,'Cote dIvoire',384,'cI','cIV'
],[34,'Panama',74612.5675615,'Panama City',-79.5,8.966667,33,95,1946,1,1,2016,6,30,95,1946,1,1,2016,6,30,'Panama',591,'Pa','PAN'
],[35,'Antigua and Barbuda',539.862118904,'St. Johns',-61.9,17.11667,34,58,1981,11,1,2016,6,30,58,1981,11,1,2016,6,30,'Antigua and Barbuda',28,'Ag','ATG'
],[42,'Ghana',240585.519658,'Accra',-0.2166667,5.55,41,452,1957,3,6,2016,6,30,452,1957,3,6,2016,6,30,'Ghana',288,'GH','GHA'
],[156,'Rwanda',25305.6244393,'KigalI',30.06056,-1.953611,158,517,1962,7,1,2016,6,30,517,1962,1,7,2016,6,30,'Rwanda',646,'Rw','RWA'
],[157,'Zambia',756428.205434,'Lusaka',28.28333,-15.4,160,551,1964,10,24,2016,6,30,551,1964,10,24,2016,6,30,'Zambia',894,'ZM','ZMB'
],[43,'Liberia',96634.5174739,'Monrovia',-10.8,6.310555,42,450,1946,1,1,2016,6,30,450,1946,1,1,2016,6,30,'Liberia',430,'Lr','LBR'
],[53,'Sierra Leone',72952.3426516,'Freetown',-13.2,8.49,52,451,1961,4,27,2016,6,30,451,1961,4,27,2016,6,30,'Sierra Leone',694,'Sl','SLE'
],[44,'Morocco',401444.80109,'Rabat',-6.83,34.02,43,600,1956,3,2,1958,3,31,600,1956,3,2,1958,3,31,'',0,'',''
],[45,'Portugal',92026.5855825,'Lisbon',-9.133333,38.71667,44,235,1946,1,1,2016,6,30,235,1946,1,1,2016,6,30,'Portugal',620,'Pt','PRT'
],[46,'Spain',506733.603989,'Madrid',-3.683333,40.4,45,230,1946,1,1,2016,6,30,230,1946,1,1,2016,6,30,'Spain',724,'Es','ESP'
],[47,'Burkina Faso',274009.474052,'Ouagadougou',-1.524722,12.37028,46,439,1960,8,5,2016,6,30,439,1960,8,5,2016,6,30,'Burkina Faso',854,'BF','BFA'
],[76,'Equatorial Guinea',27102.4666914,'Malabo',8.783333,3.75,76,411,1968,10,12,2016,6,30,411,1968,10,12,2016,6,30,'Equatorial Guinea',226,'GQ','GNQ'
],[48,'Guinea',246594.677066,'conakry',-13.7,9.509167,47,438,1958,10,2,2016,6,30,438,1958,10,2,2016,6,30,'Guinea',324,'Gn','GIN'
],[49,'Guinea-Bissau',33462.1154205,'Bissau',-15.6,11.85,48,404,1974,9,10,2016,6,30,404,1974,9,10,2016,6,30,'Guinea-Bissau',624,'Gw','GNB'
],[50,'MalI',1259194.50123,'Bamako',-8,12.65,49,432,1960,8,20,2016,6,30,432,1960,9,22,2016,6,30,'MalI',466,'Ml','MLI'
],[77,'KiribatI',425.447053045,'Tarawa',173,1.316667,77,946,1999,9,14,2016,6,30,970,1979,7,12,2016,6,30,'KiribatI',296,'KI','KIR'
],[51,'Mauritania',1043783.42747,'Nouakchott',-16,18.11944,50,435,1960,11,28,1976,3,31,435,1960,11,28,1976,3,31,'Mauritania',478,'Mr','MRT'
],[52,'Senegal',197172.573237,'Dakar',-17.5,14.70889,51,433,1960,8,20,2016,6,30,433,1960,4,4,2016,6,30,'Senegal',686,'Sn','SEN'
],[54,'The Gambia',10787.1404565,'Banjul',-16.6,13.45306,53,420,1965,2,18,2016,6,30,420,1965,2,18,2016,6,30,'Gambia',270,'GM','GMB'
],[55,'DjiboutI',21567.9097045,'DjiboutI',43.14806,11.595,54,522,1977,6,27,2016,6,30,522,1977,6,27,2016,6,30,'DjiboutI',262,'DJ','DJI'
],[82,'Albania',28680.7144592,'Tirane',19.81889,41.3275,82,339,1946,1,1,2016,6,30,339,1946,1,1,2016,6,30,'Albania',8,'Al','ALB'
],[56,'Eritrea',121605.399367,'Asmara',38.93333,15.33333,55,531,1993,5,23,2016,6,30,531,1993,5,24,2016,6,30,'Eritrea',232,'Er','ERI'
],[57,'Ethiopia',1134773.29692,'Addis Ababa',38.7,9.033333,56,530,1993,5,23,2016,6,30,530,1993,5,24,2016,6,30,'Ethiopia',231,'Et','ETH'
],[58,'Mongolia',1562323.42683,'Ulan Bator',107,47.91667,57,712,1946,1,1,2016,6,30,712,1946,1,1,2016,6,30,'Mongolia',496,'Mn','MNG'
],[59,'Sudan',2501656.1935,'Khartoum',32.53417,15.58806,58,625,1956,1,1,2011,7,8,625,1956,1,1,2011,7,8,'Sudan',736,'Sd','SDN'
],[60,'Iraq',437462.661411,'Baghdad',44.39389,33.33861,60,645,1946,1,1,2016,6,30,645,1946,1,1,2016,6,30,'Iraq',368,'Iq','IRQ'
],[118,'Maldives',33.8538943201,'Male',73.5,4.166667,119,781,1965,7,26,2016,6,30,781,1965,5,26,2016,6,30,'Maldives',462,'MV','MDV'
],[61,'Israel',20783.9644258,'Jerusalem',35.22361,31.77917,61,666,1948,5,14,1967,5,31,666,1948,5,14,1967,5,31,'',0,'',''
],[62,'Jordan',89491.0684856,'Amman',35.93333,31.95,62,663,1946,3,22,2016,6,30,663,1946,5,25,2016,6,30,'Jordan',400,'Jo','JOR'
],[63,'Kazakhstan',2721207.07916,'Astana',71.42778,51.18111,63,705,1991,12,26,2016,6,30,705,1991,12,16,2016,6,30,'Kazakhstan',398,'KZ','KAZ'
],[64,'Norway',319586.663464,'Oslo',10.75,59.91667,64,385,1946,1,1,2016,6,30,385,1946,1,1,2016,6,30,'Norway',578,'No','NOR'
],[65,'Russia',16827198.0013,'Moscow',37.61555,55.75222,65,365,1991,12,26,2016,6,30,365,1991,12,21,2016,6,30,'Russian Federation',643,'Ru','RUS'
],[66,'Sweden',444337.592324,'Stockholm',18.05,59.33333,66,380,1946,1,1,2016,6,30,380,1946,1,1,2016,6,30,'Sweden',752,'Se','SWE'
],[67,'Algeria',2326148.44131,'Algiers',3.050556,36.76305,67,615,1962,7,5,2016,6,30,615,1962,7,5,2016,6,30,'Algeria',12,'DZ','DZA'
],[68,'Andorra',507.451028558,'Andorra la Vella',1.516667,42.5,68,232,1993,7,28,2016,6,30,232,1946,1,1,2016,6,30,'Andorra',20,'Ad','AND'
],[69,'cameroon',467816.059699,'Yaounde',11.51667,3.866667,69,471,1961,10,1,2016,6,30,471,1961,10,1,2016,6,30,'cameroon',120,'cm','cMR'
],[70,'Central African Republic',622696.58145,'BanguI',18.58333,4.366667,70,482,1960,8,13,2016,6,30,482,1960,8,13,2016,6,30,'Central African Republic',140,'cF','cAF'
],[71,'Libya',1623932.59444,'TripolI',13.18,32.8925,71,620,1951,12,24,1972,12,31,620,1951,12,24,1972,12,31,'',0,'',''
],[72,'Monaco',9.27583711959,'Monaco',7.416667,43.73333,72,221,1993,5,28,2016,6,30,221,1946,1,1,2016,6,30,'Monaco',492,'MC','MCO'
],[73,'Tunisia',155771.465695,'Tunis',10.17972,36.80278,73,616,1956,3,20,2016,6,30,616,1956,1,1,2016,6,30,'Tunisia',788,'Tn','TUN'
],[74,'Benin',116922.187423,'Porto-Novo',2.616667,6.483333,74,434,1960,8,1,2016,6,30,434,1960,8,1,2016,6,30,'Benin',204,'BJ','BEN'
],[81,'Togo',57485.4484335,'Lome',1.6,7.1,81,461,1960,4,27,2016,6,30,461,1960,4,27,2016,6,30,'Togo',768,'Tg','TGO'
],[75,'chad',1279118.91278,'NDjamena',15.21667,11.71667,75,483,1960,8,11,1972,12,31,483,1960,8,11,1972,12,31,'',0,'',
],[78,'Niger',1188522.5248,'Niamey',2.116667,13.51667,78,436,1960,10,3,2016,6,30,436,1960,8,3,2016,6,30,'Niger',562,'Ne','NER'
],[79,'Nigeria',914291.337152,'Abuja',7.533333,9.083333,79,475,1961,6,1,2016,6,30,475,1961,6,1,2016,6,30,'Nigeria',566,'Ng','NGA'
],[80,'Sao Tome and Principe',1149.77537193,'Sao Tome',6.681389,0.336111,80,403,1975,7,12,2016,6,30,403,1975,7,12,2016,6,30,'Sao Tome and Principe',678,'St','STP'
],[83,'Bosnia and Herzegovina',51537.9748266,'Sarajevo',18.38333,43.85,83,346,1992,4,7,2016,6,30,346,1992,4,27,2016,6,30,'Bosnia and Herzegovina',70,'Ba','BIH'
],[84,'croatia',55889.1151742,'Zagreb',16,45.8,84,344,1992,1,15,2016,6,30,344,1992,4,27,2016,6,30,'croatia',191,'Hr','HRV'
],[85,'Italy',300242.620682,'Rome',12.48333,41.9,85,325,1946,1,1,2016,6,30,325,1946,1,1,2016,6,30,'Italy',380,'It','ITA'
],[86,'Macedonia',25483.1932119,'Skopje',21.43333,42,86,343,1993,4,8,2016,6,30,343,1991,11,20,2016,6,30,"Macedonia', the former Yugoslav Republic of",807,'Mk','MKD'
],[87,'Malta',295.019435295,'Valletta',14.51472,35.89972,87,338,1964,9,21,2016,6,30,338,1964,9,21,2016,6,30,'Malta',470,'Mt','MLT'
],[88,'San Marino',59.8811750005,'San Marino',12.45,43.93333,88,331,1992,3,2,2016,6,30,331,1946,1,1,2016,6,30,'San Marino',674,'Sm','SMR'
],[160,'Lesotho',30621.4287343,'Maseru',27.48333,-29.3,163,570,1966,10,4,2016,6,30,570,1966,10,4,2016,6,30,'Lesotho',426,'Ls','LSO'
],[89,'Bulgaria',111083.068328,'Sofia',23.31667,42.68333,90,355,1946,1,1,2016,6,30,355,1946,1,1,2016,6,30,'Bulgaria',100,'Bg','BGR'
],[90,'cyprus',9157.70860517,'Nicosia',33.36666,35.16667,91,352,1960,8,16,2016,6,30,352,1960,8,16,2016,6,30,'cyprus',196,'cy','cYP'
],[91,'Egypt',1002473.97844,'cairo',31.25,30.05,92,651,1946,1,1,1958,1,31,651,1946,1,1,1967,5,31,'',0,'',''
],[170,'Namibia',828714.948494,'Windhoek',17.08361,-22.6,173,565,1990,3,21,2016,6,30,565,1990,3,21,2016,6,30,'Namibia',516,'Na','NAM'
],[171,'New Zealand',268943.318871,'Wellington',175,-41.3,174,920,1946,1,1,2016,6,30,920,1946,1,1,2016,6,30,'New Zealand',554,'Nz','NZL'
],[92,'Georgia',70004.6046417,'TbilisI',44.79083,41.725,93,372,1991,12,26,2016,6,30,372,1991,9,6,2016,6,30,'Georgia',268,'Ge','GEO'
],[93,'Greece',130247.978372,'Athens',23.73333,37.98333,94,350,1946,1,1,2016,6,30,350,1946,1,1,2016,6,30,'Greece',300,'Gr','GRC'
],[94,'Lebanon',10240.2335568,'Beirut',35.50972,33.87194,95,660,1946,3,10,2016,6,30,660,1946,1,1,2016,6,30,'Lebanon',422,'Lb','LBN'
],[172,'Madagascar',596098.591457,'Antananarivo',47.51667,-18.9,175,580,1960,6,20,2016,6,30,580,1960,6,26,2016,6,30,'Madagascar',450,'Mg','MDG'
],[173,'Mauritius',2155.89713857,'Port Louis',57.49889,-20.2,176,590,1968,3,12,2016,6,30,590,1968,3,12,2016,6,30,'Mauritius',480,'Mu','MUS'
],[95,'Syria',188436.319945,'Damascus',36.3,33.5,96,652,1946,4,17,1958,2,1,652,1946,1,1,1967,5,31,'',0,'',''
]
    
])
CountryData.head()


In [None]:
,[96,'Turkey',781079.220751,'Ankara',32.86444,39.92722,97,640,1946,1,1,2016,6,30,640,1946,1,1,2016,6,30,'Turkey',792,'Tr','TUR
],[125,'Armenia',29702.4992641,'Yerevan',44.51361,40.18111,126,371,1991,12,26,2016,6,30,371,1991,12,21,2016,6,30,'Armenia',51,'Am','ARM
],[97,'Austria',83896.0779033,'Vienna',16.36667,48.2,98,305,1955,7,27,2016,6,30,305,1946,1,1,2016,6,30,'Austria',40,'At','AUT
],[128,'Iran',1626149.96028,'Tehran',51.42445,35.67194,129,630,1946,1,1,2016,6,30,630,1946,1,1,2016,6,30,"Iran', Islamic Republic of",364,'Ir','IRN
],[174,'Seychelles',380.172711433,'Victoria',55.45,-4.616667,177,591,1976,6,29,2016,6,30,591,1976,6,29,2016,6,30,'Seychelles',690,'SC','SYC
],[175,'Indonesia',1478049.02833,'Jakarta',107,-6.174444,178,850,1949,12,27,1963,4,30,850,1946,1,1,1963,4,30,'',0,'',''
],[98,'Czech Republic',78668.4117434,'Prague',14.46667,50.08333,99,316,1993,1,1,2016,6,30,316,1993,1,1,2016,6,30,'Czech Republic',203,'cz','cZE
],[99,'Denmark',42604.6743577,'copenhagen',12.58333,55.66667,100,390,1946,1,1,2016,6,30,390,1946,1,1,2016,6,30,'Denmark',208,'Dk','DNK
],[129,'Kuwait',16798.1110648,'Kuwait',47.97833,29.36972,130,690,1961,6,19,2016,6,30,690,1961,6,19,2016,6,30,'Kuwait',414,'Kw','KWT
],[100,'Hungary',92949.1325297,'Budapest',19.08333,47.5,101,310,1946,1,1,2016,6,30,310,1946,1,1,2016,6,30,'Hungary',348,'Hu','HUN
],[101,'Poland',311161.237356,'Warsaw',21,52.25,102,290,1946,1,1,2016,6,30,290,1946,1,1,2016,6,30,'Poland',616,'Pl','POL
],[130,'Qatar',11143.6634192,'Doha',51.53333,25.28667,131,694,1971,9,3,2016,6,30,694,1971,9,3,2016,6,30,'Qatar',634,'Qa','QAT
],[131,'Saudi Arabia',1933234.38469,'Riyadh',46.77278,24.64083,132,670,2000,6,12,2016,6,30,670,2000,6,12,2016,6,30,'Saudi Arabia',682,'Sa','SAU
],[102,'Slovakia',48884.7200983,'Bratislava',17.11667,48.15,103,317,1993,1,1,2016,6,30,317,1993,1,1,2016,6,30,'Slovakia',703,'Sk','SVK
],[103,'Slovenia',20415.6125511,'Ljubljana',14.51444,46.05528,104,349,1992,1,15,2016,6,30,349,1992,4,27,2016,6,30,'Slovenia',705,'SI','SVN
],[104,'Belgium',30611.6991041,'Brussels',4.333333,50.83333,105,211,1946,1,1,2016,6,30,211,1946,1,1,2016,6,30,'Belgium',56,'Be','BEL
],[105,'France',547871.255355,'Paris',2.333333,48.86666,106,220,1946,1,1,2016,6,30,220,1946,1,1,2016,6,30,'France',250,'Fr','FRA
],[106,'Liechtenstein',176.309935955,'Vaduz',9.516666,47.13334,107,223,1990,9,18,2016,6,30,223,1946,1,1,2016,6,30,'Liechtenstein',438,'LI','LIE
],[107,'Luxembourg',2577.8092797,'Luxembourg',6.13,49.61167,108,212,1946,1,1,2016,6,30,212,1946,1,1,2016,6,30,'Luxembourg',442,'Lu','LUX
],[108,'Netherlands',35485.5749945,'Amsterdam',4.916667,52.35,109,210,1946,1,1,2016,6,30,210,1946,1,1,2016,6,30,'Netherlands',528,'Nl','NLD
],[109,'Switzerland',41475.5712863,'Bern',7.466667,46.91667,110,225,1946,1,1,2016,6,30,225,1946,1,1,2016,6,30,'Switzerland',756,'ch','cHE
],[176,'Timor Leste',15138.8240257,'DilI',126,-8.558611,179,860,2002,9,27,2016,6,30,860,2002,5,20,2016,6,30,'Timor-Leste',626,'Tl','TLS
],[177,'Australia',7718924.6969,'canberra',149,-35.3,180,900,1946,1,1,2016,6,30,900,1946,1,1,2016,6,30,'Australia',36,'Au','AUS
],[110,'Belarus',207315.892132,'Minsk',27.56667,53.9,111,370,1991,12,26,2016,6,30,370,1991,8,25,2016,6,30,'Belarus',112,'By','BLR
],[111,'Estonia',45791.5800937,'Tallinn',24.72806,59.43389,112,366,1991,9,6,2016,6,30,366,1991,9,6,2016,6,30,'Estonia',233,'Ee','EST
],[112,'Finland',333891.198261,'HelsinkI',24.93417,60.17556,113,375,1946,1,1,2016,6,30,375,1946,1,1,2016,6,30,'Finland',246,'FI','FIN
],[178,'Nauru',27.3444606991,'Yaren',167,-0.543425,181,970,1999,9,14,2016,6,30,971,1968,12,31,2016,6,30,'Nauru',520,'Nr','NRU
],[179,'Papua New Guinea',465827.907822,'Port Moresby',147,-9.464723,182,910,1975,9,16,2016,6,30,910,1975,9,16,2016,6,30,'Papua New Guinea',598,'Pg','PNG
],[180,'Solomon Is.',27157.4690545,'Honiara',160,-9.433333,183,940,1978,7,7,2016,6,30,940,1978,7,7,2016,6,30,'Solomon Islands',90,'Sb','SLB
],[113,'Latvia',64469.1914752,'Riga',24.13333,56.96667,114,367,1991,9,6,2016,6,30,367,1991,9,6,2016,6,30,'Latvia',428,'Lv','LVA
],[181,'Tuvalu',29.2116662587,'FunafutI',179,-8.516666,184,947,2000,9,5,2016,6,30,973,1978,10,1,2016,6,30,'Tuvalu',798,'Tv','TUV
],[114,'Lithuania',64857.5568131,'Vilnius',25.31667,54.68333,115,368,1991,9,6,2016,6,30,368,1991,9,6,2016,6,30,'Lithuania',440,'Lt','LTU
],[115,'Moldova',33670.8700671,'chisinau',28.8575,47.00555,116,359,1991,12,26,2016,6,30,359,1991,8,27,2016,6,30,'Moldova',498,'Md','MDA
],[116,'Romania',237334.065117,'Bukarest',26.1,44.43333,117,360,1946,1,1,2016,6,30,360,1946,1,1,2016,6,30,'Romania',642,'Ro','ROU
],[117,'Ukraine',596958.708921,'Kiev',30.51667,50.43333,118,369,1991,12,26,2016,6,30,369,1991,12,1,2016,6,30,'Ukraine',804,'Ua','UKR
],[119,'Oman',310330.681626,'Masqat',58.59333,23.61333,120,698,1971,10,7,2016,6,30,698,1946,1,1,2016,6,30,'Oman',512,'Om','OMN
],[120,'Somalia',640483.58889,'Mogadishu',45.36666,2.066667,121,520,1960,7,1,2016,6,30,520,1960,7,1,2016,6,30,'Somalia',706,'So','SOM
],[121,'Sri Lanka',66468.3589613,'colombo',79.84778,6.931944,122,780,1948,2,4,2016,6,30,780,1948,2,4,2016,6,30,'Sri Lanka',144,'Lk','LKA
],[122,'Turkmenistan',472337.174738,'Ashgabat',58.38334,37.95,123,701,1991,12,26,2016,6,30,701,1991,10,27,2016,6,30,'Turkmenistan',795,'Tm','TKM
],[123,'Uzbekistan',446971.059951,'Tashkent',69.25,41.31667,124,704,1991,12,26,2016,6,30,704,1991,8,31,2016,6,30,'Uzbekistan',860,'Uz','UZB
],[190,'Zanzibar',2632.47294664,'Zanzibar City',39.18333,-6.166667,194,511,1963,12,19,1964,4,26,511,1963,12,19,1964,4,26,'',0,'',''
],[191,'Ethiopia',1256378.69628,'Addis Ababa',38.7,9.033333,195,530,1946,1,1,1993,5,22,530,1946,1,1,1993,5,23,'Ethiopia',230,'Et','ETH
],[192,'South Africa',2053300.74589,'Pretoria',28.22944,-25.7,196,560,1946,1,1,1990,2,28,560,1946,1,1,1990,3,20,'South Africa',710,'Za','ZAF
],[193,'Egypt',1002473.97844,'cairo',31.25,30.05,197,651,1961,9,1,1967,5,31,-1,-1,-1,-1,-1,-1,-1,'',0,'',''
],[124,'Yemen',456146.698895,'Sanaa',44.18333,15.28333,125,679,2000,6,12,2016,6,30,678,2000,6,12,2016,6,30,'Yemen',887,'Ye','YEM
],[126,'Azerbaijan',86045.8745881,'Baku',49.88222,40.39528,127,373,1991,12,26,2016,6,30,373,1991,12,21,2016,6,30,'Azerbaijan',31,'Az','AZE
],[127,'Bahrain',642.915228571,'Manama',50.58306,26.23611,128,692,1971,8,15,2016,6,30,692,1971,8,15,2016,6,30,'Bahrain',48,'Bh','BHR
],[132,'United Arab Emirates',70685.1741153,'Abu Dhabi',54.36666,24.46667,133,696,1971,12,2,2016,6,30,696,1971,12,2,2016,6,30,'United Arab Emirates',784,'Ae','ARE
],[133,'Afghanistan',643557.140913,'Kabul',69.18333,34.51667,134,700,1946,1,1,2016,6,30,700,1946,1,1,2016,6,30,'Afghanistan',4,'AF,'AFG
],[134,'Kyrgyzstan',199732.539747,'Bishkek',74.60028,42.87305,135,703,1991,12,26,2016,6,30,703,1991,8,31,2016,6,30,'Kyrgyzstan',417,'Kg','KGZ
],[139,'Bhutan',39991.3561241,'Thimphu',89.6,27.48333,140,760,1971,9,21,2016,6,30,760,1949,1,1,2016,6,30,'Bhutan',64,'Bt','BTN
],[135,'Nepal',147710.194614,'Kathmandu',85.31667,27.71667,136,790,1946,1,1,2016,6,30,790,1946,1,1,2016,6,30,'Nepal',524,'NP,'NPL
],[136,'Pakistan',879492.89693,'Islamabad',73.16666,33.7,137,770,1971,12,16,2016,6,30,770,1971,12,16,2016,6,30,'Pakistan',586,'Pk','PAK
],[137,'Tajikistan',142643.839584,'Dushanbe',68.77389,38.56,138,702,1991,12,26,2016,6,30,702,1991,9,9,2016,6,30,'Tajikistan',762,'TJ,'TJK
],[138,'Bangladesh',138505.328342,'Dhaka',90.40861,23.72305,139,771,1971,12,16,2016,6,30,771,1971,12,16,2016,6,30,'Bangladesh',50,'Bd','BGD
],[143,'North Korea',122351.726844,'Pyongyang',126,39.01944,144,731,1948,9,9,2016,6,30,731,1948,9,9,2016,6,30,"Korea', Democratic People's Republic of",408,'KP,'PRK
],[144,'Palau',382.603448129,'Koror',134,7.340556,145,986,1994,12,15,2016,6,30,986,1994,10,1,2016,6,30,'Palau',585,'Pw','PLW
],[145,'Philippines',294201.601138,'Manila',121,14.60417,146,840,1946,7,4,2016,6,30,840,1946,7,4,2016,6,30,'Philippines',608,'Ph','PHL
],[146,'South Korea',97429.7163986,'Seoul',127,37.56639,147,732,1949,6,29,2016,6,30,732,1948,8,15,2016,6,30,"Korea', Republic of",410,'Kr','KOR
],[147,'Cambodia',182847.295534,'Phnom Penh',105,11.55,148,811,1953,11,9,2016,6,30,811,1953,11,9,2016,6,30,'cambodia',116,'Kh','KHM
],[148,'Laos',231120.578521,'Vientiane',103,17.96667,149,812,1953,10,23,2016,6,30,812,1954,5,1,2016,6,30,'Lao Peoples Democratic Republic',418,'La','LAO
],[149,'Myanmar',670372.095395,'Yangon',96.16666,16.78333,150,775,1948,1,4,2016,6,30,775,1948,1,4,2016,6,30,'Myanmar',104,'Mm','MMR
],[150,'Thailand',515247.086971,'Bangkok',101,13.75,151,800,1946,1,1,2016,6,30,800,1946,1,1,2016,6,30,'Thailand',764,'Th','THA
],[151,'Vietnam',326086.275122,'HanoI',106,21.03333,152,816,1975,5,1,2016,6,30,816,1975,5,1,2016,6,30,'Viet Nam',704,'Vn','VNM
],[152,'Marshall Is.',34.9181835654,'Majuro',171,7.1,153,983,1991,9,17,2016,6,30,983,1986,10,21,2016,6,30,'Marshall Islands',584,'Mh','MHL
],[154,'Botswana',581127.216322,'Gaborone',25.91195,-24.6,155,571,1966,9,30,2016,6,30,571,1966,9,30,2016,6,30,'Botswana',72,'Bw','BWA
],[158,'Zimbabwe',391930.531069,'Harare',31.04472,-17.8,161,552,1965,11,11,2016,6,30,552,1965,11,11,2016,6,30,'Zimbabwe',716,'Zw','ZWE
],[159,'comoros',1726.73596857,'MoronI',43.24028,-11.7,162,581,1975,12,31,2016,6,30,581,1975,7,6,2016,6,30,'comoros',174,'Km','cOM
],[161,'MalawI',119198.864167,'Lilongwe',33.78333,-14,164,553,1964,7,6,2016,6,30,553,1964,7,6,2016,6,30,'MalawI',454,'Mw','MWI
],[162,'Mozambique',790679.445056,'Maputo',32.58917,-26,165,541,1975,6,25,2016,6,30,541,1975,6,25,2016,6,30,'Mozambique',508,'Mz','MOZ
],[163,'South Africa',1224585.7974,'Pretoria',28.22944,-25.7,166,560,1990,3,1,2016,6,30,560,1990,3,21,2016,6,30,'South Africa',710,'Za','ZAF
],[164,'Swaziland',17179.4175602,'Mbabane',31.13333,-26.3,167,572,1968,9,6,2016,6,30,572,1968,9,6,2016,6,30,'Swaziland',748,'Sz','SWZ
],[165,'Angola',1254965.95582,'Luanda',13.23444,-8.838333,168,540,1975,11,11,2016,6,30,540,1975,11,11,2016,6,30,'Angola',24,'Ao','AGO
],[166,'congo',346333.853886,'Brazzaville',15.28472,-4.259167,169,484,1960,8,15,2016,6,30,484,1960,8,15,2016,6,30,'congo',178,'cg','cOG
],[167,"CongoDRC",2342042.84456,'Kinshasa',15.315,-4.329722,170,490,1960,6,30,2016,6,30,490,1960,6,30,2016,6,30,"Congo', Democratic Republic of the",180,'cd','cOD
],[168,'Fiji',18172.8656291,'Suva',178,-18.1,171,950,1970,10,10,2016,6,30,950,1970,10,10,2016,6,30,'FijI',242,'FJ,'FJI
],[197,'Yemen Peoples Republic',289055.34435,'Aden',45.03667,12.77944,201,680,1967,11,30,1990,5,21,680,1967,11,30,1990,5,21,'Democratic Yemen',720,'Yd','YMD
],[169,'Gabon',262447.341739,'Libreville',9.45,0.3833333,172,481,1960,8,17,2016,6,30,481,1960,8,17,2016,6,30,'Gabon',266,'Ga','GAB
],[182,'Vanuatu',12335.1176704,'Port-Vila',168,-17.7,185,935,1981,9,15,2016,6,30,935,1980,6,30,2016,6,30,'Vanuatu',548,'Vu','VUT
],[183,'canada',9923995.40515,'Ottawa',-75.7,45.42083,186,20,1948,7,1,2016,6,30,20,1948,7,1,2016,6,30,'canada',124,'ca','cAN
],[184,'Germany',356448.186123,'Berlin',13.4,52.51667,187,255,1990,10,3,2016,6,30,260,1990,10,3,2016,6,30,'Germany',276,'De','DEU
],[185,'Germany Federal Republic',247366.38371,'Bonn',7.1,50.73333,188,260,1955,5,5,1990,10,2,260,1949,9,21,1990,10,2,"Germany', Federal Republic of",280,'De','DEU
],[186,'Germany Democratic Republic',109081.236269,'Berlin',13.4,52.51667,189,265,1954,3,25,1990,10,2,265,1949,10,5,1990,10,2,'German Democratic Republic',278,'Dd','DDR
],[198,'Taiwan',36184.0535634,'TaipeI',122,25.03917,202,713,1949,12,8,2016,6,30,713,1949,12,8,2016,6,30,"Taiwan', Province of China",158,'Tw','TWN
],[199,'china',9388505.69287,'Beijing',116,39.92889,203,710,1949,12,8,2016,6,30,710,1949,12,8,2016,6,30,'china',156,'cn','cHN
],[187,'czechoslovakia',127553.131842,'Prague',14.46667,50.08333,190,315,1946,1,1,1992,12,31,315,1946,1,1,1992,12,31,'czechoslovakia',200,'cs','cSK
],[188,'Yugoslavia',255286.979077,'Belgrade',20.46806,44.81861,191,345,1946,1,1,1992,1,14,345,1946,1,1,1991,11,19,"Yugoslavia', Socialist Federal Republic of",890,'Yu','YUG
],[189,'USSR',22008906.4725,'Moscow',37.61555,55.75222,192,365,1946,1,1,1991,9,5,365,1946,1,1,1991,8,24,'Union of Soviet Socialist Republics',810,'Su','SUN
],[200,'Pakistan',1017998.22527,'Islamabad',73.16666,33.7,204,770,1949,1,1,1971,12,15,770,1949,1,1,1971,12,15,'',0,'',''
],[194,'Syria',188436.319945,'Damascus',36.3,33.5,198,652,1961,9,29,1967,5,31,-1,-1,-1,-1,-1,-1,-1,'',0,'',''
],[195,'Egypt (United Arab Republic)',1190910.29838,'cairo',31.25,30.05,199,651,1958,2,1,1961,8,31,-1,-1,-1,-1,-1,-1,-1,'',0,'',''
],[196,'Yemen Arab Republic',137077.141408,'Sanaa',44.18333,15.28333,200,678,1946,1,1,1990,5,21,678,1946,1,1,1990,5,21,'Yemen',886,'Ye','YEM
],[212,'Nigeria',867986.036104,'Abuja',7.533333,9.083333,216,475,1960,10,1,1961,5,31,475,1960,10,1,1961,5,31,'',0,'',''
],[213,'Libya',1741136.62485,'TripolI',13.18,32.8925,217,620,1973,1,1,2016,6,30,620,1973,1,1,2016,6,30,'Libyan Arab Jamahiriya',434,'Ly','LBY
],[201,'Republic of Vietnam',170073.211147,'Saigon',107,10.75,205,817,1954,6,4,1975,4,30,817,1954,5,1,1975,4,30,'Republic of Viet Nam',714,'Vn','VNM
],[202,'Vietnam',155964.80192,'HanoI',106,21.03333,206,816,1954,7,21,1975,4,30,816,1954,5,1,1975,4,30,'Democratic Republic of Viet-Nam',704,'Vd','VDR
],[203,'Malaysia',331246.371802,'Kuala Lumpur',102,3.166667,207,820,1963,10,1,1965,7,31,820,1963,10,1,1965,7,31,'',0,'',''
],[218,'Israel',28159.60054ma32,'Jerusalem',35.22361,31.77917,222,666,1979,5,1,2016,6,30,666,1979,5,1,2016,6,30,'Israel',376,'Il','ISR
],[204,'Malaysia',132392.827233,'Kuala Lumpur',102,3.166667,208,820,1957,8,31,1963,9,30,820,1957,8,31,1963,9,30,'',0,'',''
],[205,'Malaysia',330691.511322,'Kuala Lumpur',102,3.166667,209,820,1965,8,1,2016,6,30,820,1965,8,1,2016,6,30,'Malaysia',458,'My','MYS
],[206,'Singapore',554.860480541,'Singapore City',104,1.293056,210,830,1965,8,9,2016,6,30,830,1965,8,9,2016,6,30,'Singapore',702,'Sg','SGP
],[207,'Indonesia',1890541.85149,'Jakarta',107,-6.174444,211,850,1963,5,1,1976,6,30,850,1963,5,1,1976,6,30,'Indonesia',360,'Id','IDN
],[208,'Indonesia',1905680.67552,'Jakarta',107,-6.174444,212,850,1976,7,1,2002,9,26,850,1976,7,1,2002,5,19,'Indonesia',360,'Id','IDN
],[209,'Indonesia',1890541.85149,'Jakarta',107,-6.174444,213,850,2002,9,27,2016,6,30,850,2002,5,20,2016,6,30,'Indonesia',360,'Id','IDN
],[210,'Mali Federation',1456367.07447,'Dakar',-17.5,14.70889,214,432,1960,6,20,1960,8,19,-1,-1,-1,-1,-1,-1,-1,'',0,'',''
],[221,'Egypt',941702.026033,'cairo',31.25,30.05,225,651,1967,6,1,1979,4,30,651,1967,6,1,1979,4,30,'Egypt',818,'Eg','EGY
],[222,'Egypt',1002473.97844,'cairo',31.25,30.05,226,651,1979,5,1,2016,6,30,651,1979,5,1,2016,6,30,'Egypt',818,'Eg','EGY
],[223,'India',3166803.18858,'New Delhi',77.2,28.6,227,750,1949,1,1,2016,6,30,750,1949,1,1,2016,6,30,'India',356,'In','IND
],[224,'Pakistan',936135.398004,'Islamabad',73.16666,33.7,228,770,1947,8,14,1948,12,31,770,1947,8,14,1948,12,31,'',0,'',''
],[211,'cameroon',425857.301369,'Yaounde',11.51667,3.866667,215,471,1960,1,1,1961,9,30,471,1960,1,1,1961,9,30,'',0,'',''
],[214,'chad',1161920.14763,'NDjamena',15.21667,11.71667,218,483,1973,1,1,2016,6,30,483,1973,1,1,2016,6,30,'chad',148,'Td','TCD
],[225,'India',3060680.75594,'New Delhi',77.2,28.6,229,750,1947,8,15,1948,12,31,750,1947,8,15,1948,12,31,'',0,'',''
],[215,'Morocco',404454.140456,'Rabat',-6.83,34.02,219,600,1958,4,1,1976,3,31,600,1958,4,1,1976,3,31,'Morocco',504,'Ma','MAR
],[216,'Morocco',576351.80116,'Rabat',-6.83,34.02,220,600,1976,4,1,1979,8,4,600,1976,4,1,1979,8,4,'Morocco',504,'Ma','MAR
],[217,'Mauritania',1142045.83291,'Nouakchott',-16,18.11944,221,435,1976,4,1,1979,8,4,435,1976,4,1,1979,8,4,'Mauritania',478,'Mr','MRT
],[219,'Israel',88932.2151228,'Jerusalem',35.22361,31.77917,223,666,1967,6,1,1979,4,30,666,1967,6,1,1979,4,30,'Israel',376,'Il','ISR
],[220,'Syria',187318.011221,'Damascus',36.3,33.5,224,652,1967,6,1,2016,6,30,652,1967,6,1,2016,6,30,'Syrian Arab Republic',760,'Sy','SYR
],[226,'Yugoslavia',178982.251352,'Belgrade',20.46806,44.81861,230,345,1992,1,15,1992,4,6,-1,-1,-1,-1,-1,-1,-1,"Yugoslavia', Socialist Federal Republic of",890,'Yu','YUG
],[227,'Serbia and Montenegro',101961.083314,'Belgrade',20.46806,44.81861,231,345,1993,4,8,2006,6,11,345,1992,4,27,2006,6,2,'Serbia and Montenegro',891,'cs','SCG
],[228,'USSR',21833788.1441,'Moscow',37.61555,55.75222,232,365,1991,9,6,1991,12,25,-1,-1,-1,-1,-1,-1,-1,'Union of Soviet Socialist Republics',810,'Su','SUN
],[229,'Yugoslavia',127444.276526,'Belgrade',20.46806,44.81861,233,345,1992,4,7,1993,4,7,-1,-1,-1,-1,-1,-1,-1,"Yugoslavia', Socialist Federal Republic of",890,'Yu','YUG
],[230,'Serbia',87939.1875043,'Belgrade',20.46806,44.81861,234,345,2006,6,12,2008,2,19,340,2006,6,3,2008,2,16,'Serbia',688,'Rs','SRB
],[231,'Montenegro',14021.8959544,'Podgorica',19.26361,42.44111,235,341,2006,6,12,2016,6,30,341,2006,6,3,2016,6,30,'Montenegro',499,'Me','MNE
],[232,'Kosovo',10737.5275034,'Pristina',21.16667,42.66667,236,347,2008,2,20,2016,6,30,347,2008,2,17,2016,6,30,'',0,'',''
],[235,'Uganda',243702.806282,'Kampala',32.56556,0.3155556,59,500,1962,10,9,2016,6,30,500,1962,10,9,2016,6,30,'Uganda',800,'Ug','UGA
],[236,'Tanzania',947555.660283,'Dar es Salaam',39.28333,-6.8,159,510,1964,4,1,1996,1,31,510,1964,4,1,1996,1,31,"Tanzania', United Republic of",834,'Tz','TZA
],[233,'Serbia',77201.6601505,'Belgrade',20.46806,44.81861,237,345,2008,2,20,2016,6,30,340,2008,2,17,2016,6,30,'Serbia',688,'Rs','SRB
],[234,'United States',9468306.22016,'Washington',-77,38.895,238,2,1946,1,1,2016,6,30,2,1946,1,1,2016,6,30,'United States',840,'Us','USA
],[254,''USSR'',19664153.4543,'Moscow',37.61555,55.75222,255,-1,-1,-1,-1,-1,-1,-1,365,1991,12,1,1991,12,15,'Union of Soviet Socialist Republics',810,'Su','SUN
],[255,''USSR'',16942946.3752,'Moscow',37.61555,55.75222,256,-1,-1,-1,-1,-1,-1,-1,365,1991,12,16,1991,12,20,'Union of Soviet Socialist Republics',810,'Su','SUN
],[237,'Kenya',585730.013409,'NairobI',36.81667,-1.283333,157,501,1963,12,12,2016,6,30,501,1963,12,12,2016,6,30,'Kenya',404,'Ke','KEN
],[238,'Tanzania',944923.187336,'Dar es Salaam',39.28333,-6.8,193,510,1961,12,9,1964,3,31,510,1961,12,9,1964,3,31,'',0,'',''
],[239,'Saudi Arabia',1963227.28406,'Riyadh',46.77278,24.64083,239,670,1946,1,1,2000,6,11,670,1946,1,1,2000,6,11,'Saudi Arabia',682,'Sa','SAU
],[240,'Yemen',426132.485757,'Sanaa',44.18333,15.28333,240,679,1990,5,22,2000,6,11,678,1990,5,22,2000,6,11,'Yemen',887,'Ye','YEM
],[241,'Brazil',8523619.57145,'Rio de Janeiro',-43.2,-22.9,241,140,1946,1,1,1960,4,20,140,1946,1,1,1960,4,20,'',0,'',''
],[242,'Tanzania',947555.660283,'Dodoma',35.7419,-6.173,242,510,1996,2,1,2016,6,30,510,1996,2,1,2016,6,30,"Tanzania', United Republic of",834,'Tz','TZA
],[243,'Morocco',674614.242721,'Rabat',-6.83,34.02,243,600,1979,8,5,2016,6,30,600,1979,8,5,2016,6,30,'Morocco',504,'Ma','MAR
],[244,'Mauritania',1043783.42747,'Nouakchott',-16,18.11944,244,435,1979,8,5,2016,6,30,435,1979,8,5,2016,6,30,'Mauritania',478,'Mr','MRT
],[245,'Sudan',1870245.08989,'Khartoum',32.53417,15.58806,245,625,2011,7,9,2016,6,30,625,2011,7,9,2016,6,30,'Sudan',729,'Sd','SDN
],[246,'South Sudan',631411.105296,'Juba',31.6,4.85,246,626,2011,7,9,2016,6,30,626,2011,7,9,2016,6,30,'South Sudan',728,'Ss','SSD
],[247,'Yugoslavia',229803.785868,'Belgrade',20.46806,44.81861,247,-1,-1,-1,-1,-1,-1,-1,345,1991,11,20,1992,4,26,"Yugoslavia', Socialist Federal Republic of",890,'Yu','YUG
],[248,'USSR',21801590.5804,'Moscow',37.61555,55.75222,249,-1,-1,-1,-1,-1,-1,-1,365,1991,8,25,1991,8,26,'Union of Soviet Socialist Republics',810,'Su','SUN
],[249,'USSR',21767919.7103,'Moscow',37.61555,55.75222,250,-1,-1,-1,-1,-1,-1,-1,365,1991,8,27,1991,8,30,'Union of Soviet Socialist Republics',810,'Su','SUN
],[250,'USSR',21121216.1106,'Moscow',37.61555,55.75222,251,-1,-1,-1,-1,-1,-1,-1,365,1991,8,31,1991,9,5,'Union of Soviet Socialist Republics',810,'Su','SUN
],[251,'USSR',20876093.1776,'Moscow',37.61555,55.75222,252,-1,-1,-1,-1,-1,-1,-1,365,1991,9,6,1991,9,8,'Union of Soviet Socialist Republics',810,'Su','SUN
],[252,'USSR',20733449.338,'Moscow',37.61555,55.75222,253,-1,-1,-1,-1,-1,-1,-1,365,1991,9,9,1991,10,26,'Union of Soviet Socialist Republics',810,'Su','SUN
],[253,'USSR',20261112.1633,'Moscow',37.61555,55.75222,254,-1,-1,-1,-1,-1,-1,-1,365,1991,10,27,1991,11,30,'Union of Soviet Socialist Republics',810,'Su','SUN]]


In [None]:
# Plot figures with predictions as they evolve over time
# New version

from matplotlib import cm

#ModelSelection = [1,3,5,9,11]

plt.rcParams["figure.figsize"] = (6, 6)
overleafpath = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/PredictionPlots/'
path = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/PredictionPlots/'

log_scale_value = np.array([np.log1p(0), np.log1p(1), np.log1p(3), np.log1p(10), np.log1p(30), np.log1p(100),np.log1p(300),np.log1p(1000),np.log1p(3000)])
log_scale_naming = ['0','1','3','10','30','100','300','1000','3000']
month_value = np.array([445,451,457,463,469,475,481,487,492])
month_name = ['Jan-17','Jul-17','Jan-18','Jul-18','Jan-19','Jul-19','Jan-20','Jul-20','Dec-20']
first_month = 444

CountryList = [
    ('Angola',165,500),
    ('Botswana',154,500),
    ('BurkinaFaso',47,2000),
    ('Burundi',155,2000),
#    ('Cameroon',211,2000),
    ('Chad',214,2000),
    ('Congo',166,2000),
    ('DR Congo',167,20000),
    ('Egypt',222,5000),
    ('Ethiopia',57,2000),
    ('Gabon',169,500),
    ('Iran',128,2000),
    ('Israel',218,2000),
    ('Jordan',62,2000),
    ('Kenya',237,2000),
    ('Lebanon',94,2000),
    ('Libya',213,5000),
    ('Madagascar',172,2000),
    ('Mali',50,20000),
    ('Mauritania',244,500),
    ('Morocco',243,500),
    ('Mozambique',162,2000),
    ('Namibia',170,500),
    ('Niger',78,2000),
    ('Nigeria',79,20000),
    ('Oman',119,2000),
    ('Rwanda',156,2000),
    ('Saudi Arabia',131,500),
    ('South Africa',163,2000),
    ('South Sudan',246,5000),
    ('Sudan',245,2000),
    ('Syria',220,50000),
    ('Tanzania',242,500),
    ('Uganda',235,2000),
    ('Yemen',124,20000),
    ('Zimbabwe',158,5000),
]
#t_range = range(0, 23)
t_range = [0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36]
#t_range = [0,3,6,9,12,15,18,21,24,27,30,33,36]
    
for model in EnsembleList:
    print(model['modelname'])
    # Calculate non-logged and cumulative series
    for cnt in CountryList:
#        print(cnt)
        sc_df = model['CountryList'][cnt[1]]['sc_df_smooth']
        months = sc_df.index.to_series()
        sc_df_exp = sc_df.copy()
        sc_df_cum = sc_df.copy()
        # Loop over all steps for each country
        for column in sc_df.columns:
            sc_df_exp[column]=np.rint(np.expm1(sc_df[column]))
            sc_df_cum[column]=sc_df_exp[column].cumsum(axis=0, skipna = True)
        sc_df_cumtemp = sc_df_cum.copy()
        # Set first value of an sc series to the cumulated count up to t-1 (cct1 below)
        cumdepvar = sc_df_cum['ln_ged_sb_dep']#.shift(1) # A  cumulative dependent variable series
        i = first_month + 1
        for column in sc_df.columns[2:]:
            cct1 = cumdepvar[i]
            sc_df_cum[column]=sc_df_cum[column]+cct1
            i = i + 1
            
        plt.clf()
#        print('Country',cnt[0])
        plt.bar(months, 'ln_ged_sb_dep', data=sc_df, color='.8')
        for m in t_range:
            series = 'sc_' + str(444+m)
            plt.plot(months, series, data=sc_df, c=cm.hot(np.abs((m/60)+.2)))
        plt.ylabel('Number of fatalities')
        plt.yticks(log_scale_value, log_scale_naming, rotation=30)
        plt.xticks(month_value, month_name, rotation=30)
        plt.grid(axis='y')
        plt.ylim([0,np.log1p(3000)])
        plt.title = cnt[0]
        filename = path + 'OverTime/' + model['modelname'] + '_' + cnt[0] + '.png'
        plt.savefig(filename, dpi=200)

        plt.clf()
        plt.title = cnt[0]
        plt.bar(months, 'ln_ged_sb_dep', data=sc_df_cum, color='.8')
        for m in t_range:
            series = 'sc_' + str(444+m)
            plt.plot(months, series, data=sc_df_cum, c=cm.hot(np.abs((m/60)+.2)))
        plt.ylabel('Cumulative (non-logged) fatalities')
        plt.xticks(month_value, month_name, rotation=30)
        plt.grid(axis='y')
        plt.ylim([0,cnt[2]])
    #    plt.show()
        filename = path + 'Cumulative/' + model['modelname'] + '_' + cnt[0] + '.png'
        plt.savefig(filename, dpi=200)



In [None]:
# Figures showing development over time for Mali and Burkina Faso

CountryList = [
    ('BurkinaFaso',47,2000),
    ('Mali',50,20000),
    ('Ethiopia',57,2000),
    ('Nigeria',79,20000),
]
t_range = [0] + steps

   
for model in EnsembleList[47:]:
    print(model['modelname'])
    # Calculate non-logged and cumulative series
    for cnt in CountryList:
#        print(cnt)
        sc_df = model['CountryList'][cnt[1]]['sc_df_smooth']
        months = sc_df.index.to_series()
        sc_df_exp = sc_df.copy()
        sc_df_cum = sc_df.copy()
        # Loop over all steps for each country
        for column in sc_df.columns:
            sc_df_exp[column]=np.rint(np.expm1(sc_df[column]))
            sc_df_cum[column]=sc_df_exp[column].cumsum(axis=0, skipna = True)
        sc_df_cumtemp = sc_df_cum.copy()
        # Set first value of an sc series to the cumulated count up to t-1 (cct1 below)
        cumdepvar = sc_df_cum['ln_ged_sb_dep']#.shift(1) # A  cumulative dependent variable series
        i = first_month + 1
        for column in sc_df.columns[2:]:
            cct1 = cumdepvar[i]
            sc_df_cum[column]=sc_df_cum[column]+cct1
            i = i + 1
            
        plt.clf()
#        print('Country',cnt[0])
        plt.ylabel('Number of fatalities')
        plt.yticks(log_scale_value, log_scale_naming, rotation=30)
        plt.xticks(month_value, month_name, rotation=30)
        plt.grid(axis='y')
        plt.ylim([0,np.log1p(3000)])
        plt.title = cnt[0]
        for m in t_range:
            
            plt.clf()
    #        print('Country',cnt[0])
            plt.ylabel('Number of fatalities')
            plt.yticks(log_scale_value, log_scale_naming, rotation=30)
            plt.xticks(month_value, month_name, rotation=30)
            plt.grid(axis='y')
            plt.ylim([0,np.log1p(3000)])
            sc_df['truncated_ged_sb'] = sc_df['ln_ged_sb_dep'][0:m]
            predseries = 'sc_' + str(444+m)
            plt.title = cnt[0]
            plt.bar(months, 'truncated_ged_sb', data=sc_df, color='.8')
            plt.plot(months, predseries, data=sc_df, c=cm.hot(np.abs((m/60)+.2)))

            filename = path + 'OverTime/Rolling/' + model['modelname'] + '_' + cnt[0] + '_' + str(444+m) + '.png'
            plt.savefig(filename, dpi=200)

In [None]:
cnt[0] + ':'

# Uncertainty of predictions

In [None]:
EnsembleList[-1]['test_df_calibrated'].describe()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
months = [487, 488,489,490,491,492]
plt.rcParams["figure.figsize"] = (6, 6)
overleafpath = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/PredictionPlots/'
path = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/PredictionPlots/'

df = EnsembleList[-1]['test_df_calibrated'].loc[months]
                      

bins = pd.IntervalIndex.from_tuples([(-1, 0.3), (0.3, 1.05), (1.05, 1.89), (1.89, 2.92), (2.92, 4.02), (4.02, 5.15), (5.15, 6.3), (6.3, 7.45), (7.45, 10)])
df['fatalitybins_1'] = pd.cut(df['step_pred_1'],bins)
df['fatalitybins_1'].describe()
i = 0
for value in [1,3,10,30,100,300,1000,3000]:
    print('Count:', value, 'logged:', np.log1p(value))

In [None]:
percentiles = (0.05, 0.10, 0.25,0.5,0.75, 0.9,0.95, 0.99)

df['ged_sb_dep'] = np.expm1(df['ln_ged_sb_dep'])
df['ged_sb_dep'].describe(percentiles = percentiles)
df['exp_pred_2'] = np.expm1(df['step_pred_2'])
df['exp_pred_13'] = np.expm1(df['step_pred_13'])
bins2 = pd.IntervalIndex.from_tuples([(0, 3), (3, 10), (10, 30), (30, 100), (100, 300), (300, 1000), (1000, 100000)])
bins2 = pd.IntervalIndex.from_tuples([(10, 30), (100, 300),  (1000, 100000)])
bins2 = pd.IntervalIndex.from_tuples([(3, 10), (30, 100),  (300, 1000)])


df['fatalitybins2_2'] = pd.cut(df['exp_pred_2'],bins2)
df['fatalitybins2_2'].describe()

df['fatalitybins2_13'] = pd.cut(df['exp_pred_13'],bins2)
df['fatalitybins2_13'].describe()


In [None]:
for bin in bins:
    print(bin)
    print( df['ln_ged_sb_dep'][df['fatalitybins_1']==bin].describe(percentiles = percentiles))

df.head()

In [None]:
sns.set_theme(style="ticks")
step=13
# Initialize the figure with a logarithmic x axis
f, ax = plt.subplots(figsize=(7, 6))
ax.set_xscale("log")

# Plot the orbital period with horizontal boxes
sns.boxplot(x="ged_sb_dep", y="fatalitybins2_13", data=df,
            whis=[5, 95], width=.9, palette="vlag")

# Add in points to show each observation
sns.stripplot(x="ged_sb_dep", y="fatalitybins2_13", data=df,
              size=4, color=".3", linewidth=0)

# Tweak the visual presentation
ax.xaxis.grid(True)
ax.set(ylabel="Predicted number of fatalities")
ax.set(xlabel="Observed number of fatalities")
sns.despine(trim=True, left=True)

overleafpath = '/Users/havardhegre/Dropbox (ViEWS)/Apps/Overleaf/ViEWS predicting fatalities/Figures/PredictionPlots/'
filename = overleafpath + 'PredictionUncertainty_cm_s' + str(step) + '.png'
plt.savefig(filename, dpi=300)
                      

In [None]:
# Horizontal box plots with observations
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_theme(style="ticks")

# Initialize the figure with a logarithmic x axis
f, ax = plt.subplots(figsize=(7, 6))
ax.set_xscale("log")

# Load the example planets dataset
planets = sns.load_dataset("planets")

# Plot the orbital period with horizontal boxes
sns.boxplot(x="distance", y="method", data=planets,
            whis=[0, 100], width=.6, palette="vlag")

# Add in points to show each observation
sns.stripplot(x="distance", y="method", data=planets,
              size=4, color=".3", linewidth=0)

# Tweak the visual presentation
ax.xaxis.grid(True)
ax.set(ylabel="")
sns.despine(trim=True, left=True)

In [None]:
# hist
import matplotlib.pyplot as plt
import numpy as np

#plt.style.use('_mpl-gallery')

# make data
np.random.seed(1)
x = 4 + np.random.normal(0, 1.5, 200)

# plot:
fig, ax = plt.subplots()

ax.hist(x, bins=8, linewidth=0.5, edgecolor="white")

ax.set(xlim=(0, 8), xticks=np.arange(1, 8),
       ylim=(0, 56), yticks=np.linspace(0, 56, 9))

plt.show()

In [None]:
# 2d hist

import matplotlib.pyplot as plt
import numpy as np

plt.style.use('_mpl-gallery-nogrid')

# make data: correlated + noise
np.random.seed(1)
x = np.random.randn(5000)
y = 1.2 * x + np.random.randn(5000) / 3

# plot:
fig, ax = plt.subplots()

ax.hist2d(x, y, bins=(np.arange(-3, 3, 0.1), np.arange(-3, 3, 0.1)))

ax.set(xlim=(-2, 2), ylim=(-3, 3))

plt.show()

In [None]:
# Box plot 

import matplotlib.pyplot as plt
import numpy as np

plt.style.use('_mpl-gallery')

# make data:
np.random.seed(10)
D = np.random.normal((3, 5, 4), (1.25, 1.00, 1.25), (100, 3))

# plot
fig, ax = plt.subplots()
VP = ax.boxplot(D, positions=[2, 4, 6], widths=1.5, patch_artist=True,
                showmeans=False, showfliers=False,
                medianprops={"color": "white", "linewidth": 0.5},
                boxprops={"facecolor": "C0", "edgecolor": "white",
                          "linewidth": 0.5},
                whiskerprops={"color": "C0", "linewidth": 1.5},
                capprops={"color": "C0", "linewidth": 1.5})

ax.set(xlim=(0, 8), xticks=np.arange(1, 8),
       ylim=(0, 8), yticks=np.arange(1, 8))

plt.show()

In [None]:
for m in t_range:
    series = 'sc_' + str(444+m)
    plt.plot(months, series, data=sc_df, c=cm.hot(np.abs((m/60)+.2)))
plt.ylabel('Number of fatalities')
plt.yticks(log_scale_value, log_scale_naming, rotation=30)
plt.xticks(month_value, month_name, rotation=30)
plt.grid(axis='y')
plt.ylim([0,np.log1p(3000)])
plt.title = cnt[0]
filename = path + 'OverTime/' + model['modelname'] + '_' + cnt[0] + '.png'
plt.savefig(filename, dpi=200)

# Mapping

In [None]:
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import contextily as ctx

from views_dataviz import color
from views_dataviz.map import utils
from views_dataviz.map.presets import ViewsMap

import sqlalchemy as sa
from ingester3.config import source_db_path
from ingester3.Country import Country
from ingester3.extensions import *
from ingester3.ViewsMonth import ViewsMonth

import geopandas as gpd
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import geopandas as gpd
import pandas as pd
import numpy as np

class Mapper2:
    """
    `Map` takes basic properties and allows the user to consecutively add
    layers to the Map object. This makes it possible to prepare mapping
    "presets" at any level of layeredness that can be built on further.
    
    Mapper2 allows for the customizable addition of scaling to the map. 
    -re-add the code for labels later when i can test it

    Attributes
    ----------
    width: Integer value for width in inches.
    height: Integer value for height in inches.
    bbox: List for the bbox per [xmin, xmax, ymin, ymax].
    frame_on: Bool for whether to draw a frame around the map.
    title: Optional default title at matplotlib's default size.
    figure: Optional tuple of (fig, size) to use if you want to plot into an
        already existing fig and ax, rather than making a new one.
    """

    def __init__(
        self,
        width,
        height,
        bbox=None,
        cmap=None,
        frame_on=True,
        title="",  # Default title without customization. (?)
        figure=None,
    ):
        self.width = width
        self.height = height
        self.bbox = bbox  # xmin, xmax, ymin, ymax
        self.cmap = cmap
        if figure is None:
            self.fig, self.ax = plt.subplots(figsize=(self.width, self.height))
        else:
            self.fig, self.ax = figure
        self.texts = []
        self.ax.set_title(title)

        if frame_on:  # Remove axis ticks only.
            self.ax.tick_params(
                top=False,
                bottom=False,
                left=False,
                right=False,
                labelleft=False,
                labelbottom=False,
            )
        else:
            self.ax.axis("off")

        if bbox is not None:
            self.ax.set_xlim((self.bbox[0], self.bbox[1]))
            self.ax.set_ylim((self.bbox[2], self.bbox[3]))

    def add_layer(self, gdf, map_scale=False, map_dictionary=False, cmap=None, inform_colorbar=False, **kwargs):
        """Add a geopandas plot to a new layer.

        Parameters
        ----------
        gdf: Geopandas GeoDataFrame to plot.
        cmap: Optional matplotlib colormap object or string reference
            (e.g. "viridis").
        inform_colorbar: Set or overwrite colorbar with the current layer.
            Not applicable when `color` is supplied in the kwargs.
        map_scale: set a manual scale for the map. If missing defaults to the Remco procedure. 
        map_dictionary: set manual labels for the map. If missing defaults to the default labels.
        **kwargs: Geopandas `.plot` keyword arguments.
        """
        if "color" in kwargs:
            colormap = None
        else:
            colormap = self.cmap if cmap is None else cmap
            if inform_colorbar and "column" in kwargs:
                if hasattr(self, "cax"):
                    self.cax.remove()
                if "vmin" not in kwargs:
                    self.vmin = gdf[kwargs["column"]].min()
                else:
                    self.vmin = kwargs["vmin"]
                if "vmax" not in kwargs:
                    self.vmax = gdf[kwargs["column"]].max()
                else:
                    self.vmax = kwargs["vmax"]
        
        try: Mapper2.add_colorbar(self, colormap, min(map_scale), max(map_scale))
        except: Mapper2.add_colorbar(self, colormap, self.vmin, self.vmax)
        
        try:
            self.ax = gdf.plot(ax=self.ax, cmap=colormap, vmin=min(map_scale), vmax=max(map_scale), **kwargs)
        except: 
            self.ax = gdf.plot(ax=self.ax, cmap=colormap, **kwargs)

                
        return self
    
    def add_colorbar(
        self,
        cmap,
        vmin,
        vmax,
        location="right",
        size="5%",
        pad=0.1,
        alpha=1,
        labelsize=16,
        tickparams=None,
    ):
        """Add custom colorbar to Map.

        Needed since GeoPandas legend and plot axes do not align, see:
        https://geopandas.readthedocs.io/en/latest/docs/user_guide/mapping.html

        Parameters
        ----------
        cmap: Matplotlib colormap object or string reference (e.g. "viridis").
        vmin: Minimum value of range colorbar.
        vmax: Maximum value of range colorbar.
        location: String for location of colorbar: "top", "bottom", "left"
            or "right".
        size: Size in either string percentage or number of pixels.
        pad: Float for padding between the plot's frame and colorbar.
        alpha: Float for alpha to apply to colorbar.
        labelsize: Integer value for the text size of the ticklabels.
        tickparams: Dictionary containing value-label pairs. For example:
            {0.05: "5%", 0.1: "10%"}
        """
        norm = plt.Normalize(vmin, vmax)
        if isinstance(cmap, str):
            cmap = plt.get_cmap(cmap)
        cmap = color.force_alpha_colormap(cmap=cmap, alpha=alpha)
        scalar_to_rgba = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        divider = make_axes_locatable(self.ax)
        self.cax = divider.append_axes(location, size, pad)
        self.cax.tick_params(labelsize=labelsize)
        tickvalues = (
            list(tickparams.keys()) if tickparams is not None else None
        )
        self.cbar = plt.colorbar(
            scalar_to_rgba, cax=self.cax, ticks=tickvalues
        )
        if tickparams is not None:
            self.cbar.set_ticklabels(list(tickparams.values()))
        return self
    
    def save(
        self, path, dpi=200, **kwargs
    ):  # Just some defaults to reduce work.
        """Save Map figure to file.
        Parameters
        ----------
        path: String path, e.g. "./example.png".
        dpi: Integer dots per inch. Increase for higher resolution figures.
        **kwargs: Matplotlib `savefig` keyword arguments.
        """
        self.fig.savefig(path, dpi=dpi, bbox_inches="tight", **kwargs)
        plt.close(self.fig)

In [None]:
def vid2date(i):
    year=str(ViewsMonth(i).year)
    month=str(ViewsMonth(i).month)
    return year+'/'+month

In [None]:
times = [445, 447, 450, 468]
allsteps = [1, 3, 6, 24]
titles = [vid2date(i) for i in times]
#note the zip function occured earlier
standard_scale = [np.log1p(0),np.log1p(10), np.log1p(50), np.log1p(100), np.log1p(1000), np.log1p(10000)]
standard_scale_labels = ['0','10', '50','100', '1000', '10000']

small_scale=[np.log1p(0),np.log1p(10), np.log1p(50), np.log1p(100), np.log1p(500)]
small_scale_labels = ['0','10', '50','100', '500']

In [None]:
# Prepare the gdf
engine = sa.create_engine(source_db_path)
gdf_base = gpd.GeoDataFrame.from_postgis(
    "SELECT id as country_id, in_africa, in_me, geom FROM prod.country", 
    engine, 
    geom_col='geom'
)
gdf = gdf_base.copy()

In [None]:
gdf.head()

In [None]:
# Test partition maps, predictions, rolling
times_steps = [1, 3]
lastmonthwithdata = 444
path = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/maps/cm_rolling/'

model = EnsembleList[-1]

gdf2 = gdf_base.copy()
df = model['test_df_calibrated'].copy()
df = df.join(gdf2.set_index("country_id"))
gdf = gpd.GeoDataFrame(df, geometry="geom")

for step in times_steps:
    for tshift in [0,3,6,9,12,15,18,21,24,27,30,33,36,39,42,45]:
        month = step + tshift + lastmonthwithdata
        modelname = model['modelname']
        m=Mapper2(
        width=10,
        height=10,
        frame_on=True,
        title='Model: '+ model['modelname'] + ', predictions as of ' + vid2date(lastmonthwithdata + tshift) + ', ' + str(step) + ' months ahead',
        bbox=[-18.5, 64.0, -35.5, 43.0], 
        ).add_layer(
        gdf=gdf.loc[month],
        map_scale=standard_scale,
        cmap="rainbow",
        edgecolor="black",
        linewidth=0.5,
        column=f"step_pred_{step}", 
        inform_colorbar=True
        )
        m.cbar.set_ticks(standard_scale)
        m.cbar.set_ticklabels(standard_scale_labels)

        m.save(f'{path}cm_{modelname}_standard_scale_s{step}_t{tshift}_m{month}.png')

In [None]:
# Test partition maps, actuals
lastmonthwithdata = 444
path = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/maps/cm_actuals/'

model = EnsembleList[-1]

gdf2 = gdf_base.copy()
df = model['test_df_calibrated'].copy()
df = df.join(gdf2.set_index("country_id"))
gdf = gpd.GeoDataFrame(df, geometry="geom")

for step in steps:
    month = step + lastmonthwithdata
    modelname = model['modelname']
    m=Mapper2(
    width=10,
    height=10,
    frame_on=True,
    title='Actually recorded fatalities, month ' + vid2date(month),
    bbox=[-18.5, 64.0, -35.5, 43.0], 
    ).add_layer(
    gdf=gdf.loc[month],
    map_scale=standard_scale,
    cmap="rainbow",
    edgecolor="black",
    linewidth=0.5,
    column=f"ln_ged_sb_dep", 
    inform_colorbar=True
    )
    m.cbar.set_ticks(standard_scale)
    m.cbar.set_ticklabels(standard_scale_labels)

    m.save(f'{path}cm_actuals_standard_scale_s{step}_m{month}.png')

In [None]:
# Future maps

In [None]:
# Future maps
times_steps = [1,3,6,12,24,36]
lastmonthwithdata = 492
path = '/Users/havardhegre/Dropbox (ViEWS)/ViEWS/Projects/PredictingFatalities/maps/FuturePredictions_cm/'

model = EnsembleList[-1] 

#gdf = gpd.GeoDataFrame.from_postgis(
#    "SELECT id as country_id, in_africa, in_me, geom FROM prod.country", 
#    engine, 
#    geom_col='geom'
#)
gdf2 = gdf_base.copy()
df = model['future_df_calibrated'].copy()
df = df.join(gdf2.set_index("country_id"))
gdf = gpd.GeoDataFrame(df, geometry="geom")

for step in times_steps:
    month = step + lastmonthwithdata
    modelname = model['modelname']
    m=Mapper2(
    width=10,
    height=10,
    frame_on=True,
    title='model '+ model['modelname']+ ', predictions, ' + str(step),
    bbox=[-18.5, 64.0, -35.5, 43.0], 
    ).add_layer(
    gdf=gdf.loc[month],
    map_scale=standard_scale,
    cmap="rainbow",
    edgecolor="black",
    linewidth=0.5,
    column=f"step_combined", 
    inform_colorbar=True
    )
    m.cbar.set_ticks(standard_scale)
    m.cbar.set_ticklabels(standard_scale_labels)
    
    m.save(f'{path}cm_{modelname}_standard_scale_{step}_{month}.png')

In [None]:
df.describe()

# Line plots, future predictions

In [None]:
df = EnsembleList[0]['future_df_calibrated'].copy()
df['exp_pred'] = np.expm1(df['step_combined'])
df['month'] = df.index.get_level_values('month_id')
months_to_plot = [*(range(505,540))]
month_labels = [vid2date(i) for i in months_to_plot]

CountryList = [
    ('Angola',165,500),
    ('Botswana',154,500),
    ('BurkinaFaso',47,2000),
    ('Burundi',155,2000),
#    ('Cameroon',211,2000),
    ('Chad',214,2000),
    ('Congo',166,2000),
    ('DR Congo',167,20000),
    ('Egypt',222,5000),
    ('Ethiopia',57,2000),
    ('Gabon',169,500),
    ('Iran',128,2000),
    ('Israel',218,2000),
    ('Jordan',62,2000),
    ('Kenya',237,2000),
    ('Lebanon',94,2000),
    ('Libya',213,5000),
    ('Mali',50,20000),
    ('Mauritania',244,500),
    ('Morocco',243,500),
    ('Mozambique',162,2000),
    ('Namibia',170,500),
    ('Niger',78,2000),
    ('Nigeria',79,20000),
    ('Oman',119,2000),
    ('Rwanda',156,2000),
    ('Saudi Arabia',131,500),
    ('South Africa',163,2000),
    ('South Sudan',246,5000),
    ('Sudan',245,2000),
    ('Syria',220,50000),
    ('Tanzania',242,500),
    ('Uganda',235,2000),
    ('Yemen',124,20000),
    ('Zimbabwe',158,5000),
]

CountryList = [
    ('BurkinaFaso',47,2000),
    ('Burundi',155,2000),
#    ('Cameroon',211,2000),
    ('Chad',214,2000),
    ('Congo',166,2000),
    ('DR Congo',167,20000),
    ('Egypt',222,5000),
    ('Ethiopia',57,2000),
    ('Libya',213,5000),
    ('Mali',50,20000),
    ('Mozambique',162,2000),
    ('Niger',78,2000),
    ('Nigeria',79,20000),
    ('Oman',119,2000),
    ('Rwanda',156,2000),
    ('South Sudan',246,5000),
    ('Sudan',245,2000),
    ('Syria',220,50000),
    ('Tanzania',242,500),
    ('Uganda',235,2000),
    ('Yemen',124,20000),
    ('Zimbabwe',158,5000),
]


import seaborn as sns
sns.set_theme(style="whitegrid")
for cnt in CountryList:
    print(cnt[0])
    df_cnt = df.xs(cnt[1], level='country_id')
    sns.lineplot(x="month", y="exp_pred",
                 data=df_cnt)
    
#    cntplot.set_xtickslabels(month_labels)
    plt.xlabel("Month")
    plt.ylabel("Predicted fatalities")

# Surrogate models

# Old stuff from here

In [None]:
# Cumulative fatalities


cnt[2].head(24)

In [None]:
plt.rcParams["figure.figsize"] = (10, 10)
from matplotlib import cm
t_range = range(0, 23)
t_range = [0,2,4,6,8,10,12,14,16,18,20,22,24]
width = .75 # width of a bar
fig, ax = plt.subplots()
Mozambique['lndepvar'].plot(kind= 'bar', width=width,color='grey')
Mozambique['sc_444'].plot(kind='line', color='blue',ms=100)
#df['data'].plot(kind='line', marker='*', color='black', ms=10)


#ax = plt.gca()
#plt.xlim([-width, len(Mozambique['sc_444'])-width])
#for m in t_range:
#    series = 'sc_' + str(444+m)
#    plt.plot('month_id', series, data=Mozambique, color=cm.hot(np.abs(m/48)))
ax.set_xticklabels(('Jan-18', '', '', 'Apr-18', '', '', 'Jul-18', '', '', 'Oct-18', '', '',
                   'Jan-19', '', '', 'Apr-19', '', '', 'Jul-19', '', '', 'Oct-19', '', '',
                   'Jan-20', '', '', 'Apr-20', '', '', 'Jul-20', '', '', 'Oct-20', '', ''))

plt.show()

In [None]:
# Inspect distribution of ensemble predictions
percentiles_to_inspect = [.75, .85, .9, .95, .99]
StepEnsembles[0]['ensembles_test'].describe(percentiles_to_inspect)

In [None]:
MSE_ensemble_all_df
#ensemble_models

In [None]:
overleafpath = '/Users/havardhegre/Dropbox (ViEWS)/Apps/Overleaf/ViEWS predicting fatalities/Tables/'
tables = [(MSE_ensemble_all_df,'MSE_ensemble'), (MSE_ensemble_zeros_df,'MSE_ensemble_zeros'), (MSE_ensemble_nonzeros_df,'MSE_ensemble_nonzeros')]
columns = ensemble_models
for table in tables:
    print('Table name',table[1])
    filename = overleafpath + table[1] + '.tex'
    print('File name: ', filename)
    with open(filename, "w") as f:
        f.write(table[0].to_latex(float_format="%.3f",index=False))

In [None]:
np.expm1(6.13)

In [None]:
# Preparing for mapping
engine = sa.create_engine(source_db_path)
gdf = gpd.GeoDataFrame.from_postgis(
    "SELECT id as country_id, in_africa, in_me, geom FROM prod.country", 
    engine, 
    geom_col='geom'
)
gdf = gdf.to_crs(4326)
data = StepEnsembles[0]['ensembles_test'].join(gdf.set_index("country_id"))
gdf = gpd.GeoDataFrame(data, geometry="geom")
gdf = gdf.loc[480]



In [None]:
# Mapping the outcome variable

#data = StepEnsembles[0]['ensembles_test'].loc[480]

ensemble_columns = ['lndepvar'] + ['unweighted_average','linear_regression','rf_regression']

for column in ensemble_columns:
    m = ViewsMap(
        width=10,
        label=f"Ensembles: {column}",
        title="2021-8",
        scale=None,
        bbox="africa_middle_east"
    ).add_layer(
        gdf,
        edgecolor="black",
        linewidth=0.5,
        column="lndepvar",
        inform_colorbar=True
    )

In [None]:
Plotting over time

## Inspect the predictions, calibration partition

In [None]:
to_inspect = ['ln_ged_sb_dep','step_pred_1','step_pred_2','step_pred_3','step_pred_4']
predictions_test[to_inspect].describe()

In [None]:
# Computing the predictions in non-logged form to explore calibration 
for var in ['step_pred_1','step_pred_2','step_pred_3','step_pred_4']:
    expvar = 'exp_' + var
    predictions_calib[expvar] = np.expm1(predictions_calib[var])
    predictions_test[expvar] = np.expm1(predictions_test[var])

predictions_calib['ged_sb'] = np.expm1(predictions_calib['ln_ged_sb_dep'])
predictions_test['ged_sb'] = np.expm1(predictions_test['ln_ged_sb_dep'])

to_inspect = ['ln_ged_sb_dep','ged_sb','exp_step_pred_1','exp_step_pred_2','exp_step_pred_3','exp_step_pred_4']
predictions_test[to_inspect].describe()

In [None]:
# 245: Sudan. 91: Egypt. 79: Nigeria. 50: Mali
colset = ['ged_sb','exp_step_pred_1','exp_step_pred_3','exp_step_pred_3','exp_step_pred_4']
countryset = [50, 79, 91, 245]
countryset = [50]
idx = pd.IndexSlice
predictions_calib.loc[idx[397:408, countryset], :][colset]

In [None]:
#hh_data.loc[idx[380:408, countryset], :]['ln_ged_sb']

percentiles_to_inspect = [.75, .85, .9, .95, .99]
predictions_calib['ln_ged_sb_dep'].describe(percentiles=percentiles_to_inspect)

In [None]:
# Calibration in terms of logged predictions/actuals:# Non-shifted
# "Old function"
print("In terms of logged predictions/actuals, non-shifted, all obs")
# Calibration partition:
predictions_calib['step_pred_1_cal_log_simple'] = mean_sd_calibrated_simple(
    y_true_calpart = predictions_calib['ln_ged_sb_dep'], 
    y_pred_calpart = predictions_calib['step_pred_1'], 
    y_pred_test = predictions_calib['step_pred_1'], 
    shift=False, 
    threshold = 0
)

# Test partition:
predictions_test['step_pred_1_cal_log_simple'] = mean_sd_calibrated_simple(
    y_true_calpart = predictions_calib['ln_ged_sb_dep'], 
    y_pred_calpart = predictions_calib['step_pred_1'], 
    y_pred_test = predictions_test['step_pred_1'],
    shift=False, 
    threshold = 0
)
# New function
# Non-shifted
print("In terms of logged predictions/actuals, non-shifted, all obs")
# Calibration partition:
predictions_calib['step_pred_1_cal_log_nonshifted'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ln_ged_sb_dep'], 
    y_pred_calpart = predictions_calib['step_pred_1'], 
    y_pred_test = predictions_calib['step_pred_1'], 
    shift=False, 
    threshold = 0
)

# Test partition:
predictions_test['step_pred_1_cal_log_nonshifted'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ln_ged_sb_dep'], 
    y_pred_calpart = predictions_calib['step_pred_1'], 
    y_pred_test = predictions_test['step_pred_1'],
    shift=False, 
    threshold = 0
)
    
# Shifted
print("In terms of logged predictions/actuals, shifted, all obs")
# Calibration partition:
predictions_calib['step_pred_1_cal_log_shifted'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ln_ged_sb_dep'], 
    y_pred_calpart = predictions_calib['step_pred_1'], 
    y_pred_test = predictions_calib['step_pred_1'], 
    shift=True, 
    threshold = 0
)

# Test partition:
predictions_test['step_pred_1_cal_log_shifted'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ln_ged_sb_dep'], 
    y_pred_calpart = predictions_calib['step_pred_1'], 
    y_pred_test = predictions_test['step_pred_1'],
    shift=True, 
    threshold = 0
)
        
# Calibration in terms of log, only "non-zero" predictions:
# Shifted
print("In terms of logged predictions/actuals, shifted, only non-zeros")
# Calibration partition:
predictions_calib['step_pred_1_cal_log_nonzero_shifted'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ln_ged_sb_dep'], 
    y_pred_calpart = predictions_calib['step_pred_1'], 
    y_pred_test = predictions_calib['step_pred_1'], 
    shift=True, 
    threshold = 1
)

# Test partition:
predictions_test['step_pred_1_cal_log_nonzero_shifted'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ln_ged_sb_dep'], 
    y_pred_calpart = predictions_calib['step_pred_1'], 
    y_pred_test = predictions_test['step_pred_1'],
    shift=True, 
    threshold = 1
)

# Non-shifted
print("In terms of logged predictions/actuals, non-shifted, only non-zeros")
# Calibration partition:
predictions_calib['step_pred_1_cal_log_nonzero_nonshifted'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ln_ged_sb_dep'], 
    y_pred_calpart = predictions_calib['step_pred_1'], 
    y_pred_test = predictions_calib['step_pred_1'], 
    shift=False, 
    threshold = 1
)

# Test partition:
predictions_test['step_pred_1_cal_log_nonzero_nonshifted'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ln_ged_sb_dep'], 
    y_pred_calpart = predictions_calib['step_pred_1'], 
    y_pred_test = predictions_test['step_pred_1'],
    shift=False, 
    threshold = 1
)

# Exponentiating predictions for presentation and for count-based calibration
for var in ['step_pred_1_cal_log_simple',
            'step_pred_1_cal_log_nonshifted',
            'step_pred_1_cal_log_shifted',
            'step_pred_1_cal_log_nonzero_shifted',
            'step_pred_1_cal_log_nonzero_nonshifted', 'step_pred_1']:
    expvar = 'exp_' + var
    predictions_test[expvar] = np.expm1(predictions_test[var])
    predictions_calib[expvar] = np.expm1(predictions_calib[var])

# Calibration in terms of counts:
print("In terms of non-logged predictions/actuals, non-shifted")
# Calibration partition:
predictions_calib['exp_step_pred_1_cal_count'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ged_sb'], 
    y_pred_calpart = predictions_calib['exp_step_pred_1'], 
    y_pred_test = predictions_calib['exp_step_pred_1'],
    shift=False, 
    threshold = 0
)
# Test partition:
predictions_test['exp_step_pred_1_cal_count'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ged_sb'], 
    y_pred_calpart = predictions_calib['exp_step_pred_1'], 
    y_pred_test = predictions_test['exp_step_pred_1'], 
    shift=False, 
    threshold = 0
)

print("In terms of non-logged predictions/actuals, non-shifted, non-zeros")
# Calibration partition:
predictions_calib['exp_step_pred_1_cal_count_nonzeros'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ged_sb'], 
    y_pred_calpart = predictions_calib['exp_step_pred_1'], 
    y_pred_test = predictions_calib['exp_step_pred_1'],
    shift=False, 
    threshold = 1
)
# Test partition:
predictions_test['exp_step_pred_1_cal_count_nonzeros'] = mean_sd_calibrated(
    y_true_calpart = predictions_calib['ged_sb'], 
    y_pred_calpart = predictions_calib['exp_step_pred_1'], 
    y_pred_test = predictions_test['exp_step_pred_1'], 
    shift=False, 
    threshold = 1
)

to_inspect = ['ln_ged_sb_dep','ged_sb','step_pred_1','exp_step_pred_1',
              'exp_step_pred_1_cal_log_simple',
              'exp_step_pred_1_cal_log_nonshifted', 'exp_step_pred_1_cal_log_shifted',
              'exp_step_pred_1_cal_log_nonzero_shifted','exp_step_pred_1_cal_log_nonzero_nonshifted',
             'exp_step_pred_1_cal_count','exp_step_pred_1_cal_count_nonzeros',
             'exp_step_pred_1_nln']
percentiles_to_inspect = [.75, .85, .9, .95, .99]

In [None]:
# Inspect calibration for the calibration partition:
predictions_calib[to_inspect].describe(percentiles=percentiles_to_inspect)

In [None]:
# Inspect calibration for the test partition:

predictions_test[to_inspect].describe(percentiles=percentiles_to_inspect)

In [None]:
from sklearn.utils import gen_even_slices

# From https://scikit-learn.org/stable/auto_examples/linear_model/plot_poisson_regression_non_normal_loss.html?highlight=calibration
# Simplified: removed weights
def _mean_frequency_by_risk_group(y_true, y_pred, n_bins=100):
    """Compare predictions and observations for bins ordered by y_pred.

    We order the samples by ``y_pred`` and split it in bins.
    In each bin the observed mean is compared with the predicted mean.

    Parameters
    ----------
    y_true: array-like of shape (n_samples,)
        Ground truth (correct) target values.
    y_pred: array-like of shape (n_samples,)
        Estimated target values.
    sample_weight : array-like of shape (n_samples,)
        Sample weights.
    n_bins: int
        Number of bins to use.

    Returns
    -------
    bin_centers: ndarray of shape (n_bins,)
        bin centers
    y_true_bin: ndarray of shape (n_bins,)
        average y_pred for each bin
    y_pred_bin: ndarray of shape (n_bins,)
        average y_pred for each bin
    """
    idx_sort = np.argsort(y_pred)
    bin_centers = np.arange(0, 1, 1 / n_bins) + 0.5 / n_bins
    y_pred_bin = np.zeros(n_bins)
    y_true_bin = np.zeros(n_bins)

    for n, sl in enumerate(gen_even_slices(len(y_true), n_bins)):
#        weights = sample_weight[idx_sort][sl]
        y_pred_bin[n] = np.average(y_pred[idx_sort][sl])
        y_true_bin[n] = np.average(y_true[idx_sort][sl])
    return bin_centers, y_true_bin, y_pred_bin

In [None]:
print(f"Actual number of fatalities: {predictions_test['ged_sb'].sum()}")
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(12, 8))
plt.subplots_adjust(wspace=0.3)

for axi, prediction_column in zip(ax.ravel(), ['exp_step_pred_1',
                                               'exp_step_pred_1_cal_log_nonshifted',
                                               'exp_step_pred_1_cal_log_shifted',
                                               'exp_step_pred_1_cal_log_nonzero_shifted',
                                               'exp_step_pred_1_cal_log_nonzero_nonshifted',
                                               'exp_step_pred_1_cal_count',
                                              'exp_step_pred_1_cal_count_nonzeros',
                                              'exp_step_pred_1_nln']):
    y_pred = predictions_test[prediction_column].values
    y_true = predictions_test['ged_sb'].values
    q, y_true_seg, y_pred_seg = _mean_frequency_by_risk_group(
        y_true, y_pred, n_bins=40
    )

    # Name of the model after the estimator used in the last step of the
    # pipeline.
    print(f"Predicted number of fatalities, {prediction_column}: {np.sum(y_pred):.1f}")

    axi.plot(q, y_pred_seg, marker="x", linestyle="--", label="predictions")
    axi.plot(q, y_true_seg, marker="o", linestyle="--", label="observations")
    axi.set_xlim(0, 1.0)
    axi.set_ylim(0, 2000)
    axi.set(
        title=prediction_column,
        xlabel="Fraction of samples sorted by y_pred",
        ylabel="Mean number of fatalities (y_pred)",
    )
    axi.legend()
plt.tight_layout()

# Evaluation

In [None]:
import matplotlib.pyplot as plt

# Fixing random state for reproducibility
np.random.seed(19680801)

x = predictions_test['step_pred_1']
y = predictions_test['ln_ged_sb_dep']
N = len(x)
colors = np.random.rand(N)
area = (30 * np.random.rand(N))**2  # 0 to 15 point radii
colors = x
area = 4

ax=0 # per-column
mse = ((x - y)**2).mean(axis=ax)
print(mse)

plt.scatter(x, y, s=area, c=colors, alpha=0.5)
plt.show()

In [None]:
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_poisson_deviance


def score_estimator(y_pred, df_test):
    """Score an estimator on the test set."""
#    y_pred = estimator.predict(df_test)

    print(
        "MSE: %.3f"
        % mean_squared_error(df_test, y_pred)
    )
    print(
        "MAE: %.3f"
        % mean_absolute_error(df_test, y_pred)
    )

    # Ignore non-positive predictions, as they are invalid for
    # the Poisson deviance.
    mask = y_pred > 0
    if (~mask).any():
        n_masked, n_samples = (~mask).sum(), mask.shape[0]
        print(
            "WARNING: Estimator yields invalid, non-positive predictions "
            f" for {n_masked} samples out of {n_samples}. These predictions "
            "are ignored when computing the Poisson deviance."
        )

    print(
        "mean Poisson deviance: %.3f"
        % mean_poisson_deviance(
            df_test[mask],
            y_pred[mask],
            sample_weight=df_test[mask],
        )
    )


print("Evaluation:")
score_estimator(predictions_test['step_pred_1'], predictions_test['ln_ged_sb_dep'])
score_estimator(predictions_test['exp_step_pred_1'], predictions_test['ged_sb'])

## Mapping

In [None]:
import geopandas as gpd
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from views_dataviz.map import mapper, utils
from views_dataviz import color

import sqlalchemy as sa
from ingester3.config import source_db_path

In [None]:
engine = sa.create_engine(source_db_path)
gdf = gpd.GeoDataFrame.from_postgis(
    "SELECT id, geom FROM prod.country", engine, geom_col='geom'
#    "SELECT id, geom FROM prod.country WHERE in_africa=1", engine, geom_col='geom'
)
gdf["preds"] = predictions.loc[397,:]
gdf["month_id"] = 397

In [None]:
np.exp(13)

In [None]:
predictions.loc[397,:]