# Figure 2 - Sepsis manuscript

This is a Jupyter notebook to generate Figure 2 in the sepsis manuscript

First, we will setup the notebook requirements.

In [None]:
%cd ../src
%pwd
%matplotlib inline
%load_ext autoreload
%autoreload 2

Next, we will import all important libaries.

In [None]:
# general imports
import os
import pandas as pd
import matplotlib.pyplot as plt
import string
import datetime as dt

# utils import
from utils.files import get_latest_version
from utils.cross_validation import cross_validate_ROC, cross_validate_risk_ROC, compare_models, cross_validate_calibration
from utils.risk_scores import read_data_risk_score
from xgboost2.cross_validate import read_data

# xgboost
from xgboost import XGBClassifier

Then, we define the model version to use. 
This is version 20 for shock, and 22 for mortality.

In [None]:
model_version_shock = 20
model_version_mort = 22

Now, lets start.

In [None]:
modelslist = []
scoreslist = []

i = 0
fig, ax = plt.subplots(ncols=3, figsize=(18, 6))
    
for model in [1,3]: #datasets 1 and 3
    data_dict = read_data(model = model, 
                       out = 1, # shock
                       version = model_version_shock)
        
    modelslist.append(['XGBoost', model, 1, # shock
                        cross_validate_ROC(('XGBoost', XGBClassifier(), data_dict['features']), data_dict,
                        kfolds=5, model_nr = model, output_nr = 1, ax=ax[i])])
        
    i += 1
    
compare_models(modelslist,
               scoreslist,
               risk_score=False, ax=ax[i])

last_version = get_latest_version(os.path.join(os.getcwd(), '..',
                                                  'figures', 'crossvalidation'))
current_date = dt.datetime.now().strftime('%d %B %y')
file_name = 'v{0:03d} - {1}.svg'.format(last_version + 1, current_date)
plt.savefig(os.path.join(os.getcwd(), '..',
                        'figures',
                        'crossvalidation',
                        file_name), format='svg', dpi=1200)
plt.show()

In [None]:
modelslist = []
scoreslist = []

i = 0
fig, ax = plt.subplots(ncols=3, figsize=(18, 6))
    
for model in [1,3]: #datasets 1 and 3
    data_dict = read_data(model = model, 
                       out = 3, # mortality
                       version = model_version_mort)
        
    modelslist.append(['XGBoost', model, 3, # mortality
                        cross_validate_ROC(('XGBoost', XGBClassifier(), data_dict['features']), data_dict,
                        kfolds=5, model_nr = model, output_nr = 3, save_fig=True, ax=ax[i])])
        
    i += 1
    
compare_models(modelslist,
               scoreslist,
               risk_score=True, ax=ax[i])

last_version = get_latest_version(os.path.join(os.getcwd(), '..',
                                                  'figures', 'crossvalidation'))
current_date = dt.datetime.now().strftime('%d %B %y')
file_name = 'v{0:03d} - {1}.svg'.format(last_version + 1, current_date)
plt.savefig(os.path.join(os.getcwd(), '..',
                        'figures',
                        'crossvalidation',
                        file_name), format='svg', dpi=1200)

In [None]:
modelslist = []
scoreslist = []

i = 0
fig, ax = plt.subplots(ncols=3, figsize=(18, 6))
    
for model in [1,3]: #datasets 1 and 3
    data_dict = read_data(model = model, 
                       out = 3, # mortality
                       version = model_version_mort)
        
    modelslist.append(['XGBoost', model, 3, # mortality
                        cross_validate_calibration(('XGBoost', XGBClassifier(), data_dict['features']), data_dict,
                        kfolds=5, model_nr = model, output_nr = 3, save_fig=True, ax=ax[i])])
        
    i += 1