## Explain XGB on TBI signal data

Explain XGBoost model trained on raw data

In [None]:
from tbi_downstream_prediction import *
PATH = "/homes/gws/hughchen/phase/tbi_subset/"
DPATH = PATH+"tbi/processed_data/hypoxemia/"
RESULTPATH = PATH+"results/"

# Set important variables
label_type = "desat_bool92_5_nodesat"; eta = 0.02
hosp_data = "tbi"; data_type = "raw[top11]"
mod_type = "xgb_{}_eta{}".format(label_type,eta)

# Set up result directory
RESDIR = '{}results/{}/'.format(PATH, mod_type)
if not os.path.exists(RESDIR): os.makedirs(RESDIR)

# Load tbi data
y_tbi = np.load(DPATH+"tbiy.npy",mmap_mode="r")
X_tbi = np.load(DPATH+"X_tbi_imp_standard.npy",mmap_mode="r")

feats = ["ECGRATE", "ETCO2", "ETSEV", "ETSEVO", "FIO2", "NIBPD", "NIBPM", 
         "NIBPS","PEAK", "PEEP", "PIP", "RESPRATE", "SAO2", "TEMP1", "TV"]
weird_feats = ["ETSEV", "PIP", "PEEP", "TV"]
feat_inds = np.array([feats.index(feat) for feat in feats if feat not in weird_feats])    
X_tbi2 = X_tbi[:,feat_inds,:]
(X_test, y_test, X_valid, y_valid, X_train, y_train) = split_data(DPATH,X_tbi2,y_tbi)
feat_lst2 = [feat for feat in feats if feat not in weird_feats]

# Set parameters to train model
param = {'max_depth':6, 'eta':eta, 'subsample':0.5, 'gamma':1.0, 
         'min_child_weight':10, 'base_score':y_train.mean(), 
         'objective':'binary:logistic', 'eval_metric':["logloss"]}

# Generate attributions
import shap
save_path = RESDIR+"hosp{}_data/{}/".format(hosp_data,data_type)
if DEBUG: print("[DEBUG] Loading model from {}".format(save_path))
bst = xgb.Booster()
bst.load_model(save_path+'mod_eta{}.model'.format(param['eta']))

np.random.seed(10)
background_inds = np.random.choice(np.arange(0,X_train.shape[0]),1000)
X_train_background = X_train[background_inds]
ind_explainer = shap.TreeExplainer(bst,data=X_train_background,feature_dependence="independent")
X_train_ind_explanations = ind_explainer.shap_values(X_train)
np.save(save_path+"X_train_explanations_ind",X_train_ind_explanations)

#### Plot the attributions for each time point

In [None]:
save_path = RESDIR+"hosp{}_data/{}/".format(hosp_data,data_type)
import itertools
feat_lst2_expanded = [[f+"_"+str(i) for i in range(0,60)] for f in feat_lst2]
feat_lst2_expanded = list(itertools.chain.from_iterable(feat_lst2_expanded))
shap.summary_plot(X_train_ind_explanations,features=X_train,feature_names=feat_lst2_expanded,max_display=10)

#### Plot the aggregated attributions

In [None]:
X_train_explanations_agg = np.transpose(np.vstack([X_train_ind_explanations[:,i*60:(i+1)*60].sum(1) for i in range(len(feat_lst2))]))
X_train_agg = np.transpose(np.vstack([X_train[:,i*60:(i+1)*60].sum(1) for i in range(len(feat_lst2))]))
shap.summary_plot(X_train_explanations_agg,features=X_train_agg,feature_names=feat_lst2)

### Plot the time series signal and attribution for a single sample

In [None]:
import os
import numpy as np
DPATH = "/homes/gws/hughchen/phase/tbi_subset/tbi/processed_data/hypoxemia/"

feat_lst = ["ECGRATE", "ETCO2", "ETSEV", "ETSEVO", "FIO2", "NIBPD", "NIBPM", 
            "NIBPS","PEAK", "PEEP", "PIP", "RESPRATE", "SAO2", "TEMP1", "TV"]

# The data used for training is normalized.  Here we are getting the unnormalized
# but imputed data for the sake of plotting.
X_tbi = np.load(DPATH+"tbiX.npy",mmap_mode="r")

# Impute missing data
def impute_data(dataX):
    imp_val = np.median(dataX[dataX != 0])
    if np.isnan(imp_val): imp_val = 0
    dataX[dataX == 0] = imp_val
    return(dataX)

X_tbi_imp = []
for feat in feat_lst:
    fpath = "/homes/gws/hughchen/RNN/LSTM_Feature/code/"
    fpath += "min5_data/{}minimum5/hospital_0/raw/".format(feat)
    fname = [f for f in os.listdir(fpath) if feat+".npy" in f and "X_train_val" in f][0]
    X_trval = np.load(fpath+fname)
    X_tbi_curr = np.copy(X_tbi[:,feat_lst.index(feat),:])
    if feat == "TV": X_tbi_curr = X_tbi_curr/1000.0

    X_trval_imp = impute_data(X_trval)
    X_tbi_curr[np.isnan(X_tbi_curr)] = 0
    X_tbi_curr_imp = impute_data(X_tbi_curr)
    X_tbi_imp.append(X_tbi_curr_imp)

X_tbi_imp = np.stack(X_tbi_imp,axis=1)

(_, _, _, _, X_train_raw_imp, _) = split_data(DPATH,X_tbi_imp,y_tbi,flatten=False)

# FIO2_59_ind = ((60*(feat_lst2.index("FIO2")+1))-1)
# FIO2_59_exp = X_train_ind_explanations[:,FIO2_59_ind]
# best_samp_ind = np.where(FIO2_59_exp==FIO2_59_exp.max())[0][0]
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (5,2)

samp_ind = np.random.choice(X_train.shape[0])
# samp_ind = 189432
samp_ind = 14898

signame = "TEMP1"
ind = feat_lst2.index(signame)
# plt.plot(X_train_raw_imp[feats.index(signame)][samp_ind])
plt.plot(X_train_raw_imp[samp_ind,:,feats.index(signame)])
# plt.plot(X_train[samp_ind,(60*ind):(60*(ind+1))])
plt.ylabel(signame)
plt.savefig("fig/tempvalu{}.pdf".format(samp_ind))
plt.show()
plt.plot(X_train_ind_explanations[samp_ind,(60*ind):(60*(ind+1))])
plt.savefig("fig/tempattr{}.pdf".format(samp_ind))
plt.show()

signame = "SAO2"
ind = feat_lst2.index(signame)
plt.plot(X_train_raw_imp[samp_ind,:,feats.index(signame)])
plt.ylabel(signame)
plt.savefig("fig/sao2valu{}.pdf".format(samp_ind))
plt.show()
plt.plot(X_train_ind_explanations[samp_ind,(60*ind):(60*(ind+1))])
plt.savefig("fig/sao2attr{}.pdf".format(samp_ind))
plt.show()

signame = "FIO2"
ind = feat_lst2.index(signame)
# plt.plot(X_train_raw_imp[feats.index(signame)][samp_ind])
plt.plot(X_train_raw_imp[samp_ind,:,feats.index(signame)])
# plt.plot(X_train[samp_ind,(60*ind):(60*(ind+1))])
plt.ylabel(signame)
plt.savefig("fig/fio2valu{}.pdf".format(samp_ind))
plt.show()
plt.plot(X_train_ind_explanations[samp_ind,(60*ind):(60*(ind+1))])
plt.savefig("fig/fio2attr{}.pdf".format(samp_ind))
plt.show()