In [6]:
import pandas as pd
import xgboost as xgb
import random
import numpy as np
import os, sys
import matplotlib.pyplot as plt
import pickle as pkl

from typing import List
from sklearn.impute import SimpleImputer
from sklearn import preprocessing, model_selection
from sklearn import metrics

module_path = os.path.abspath(os.path.join('../CATENets/'))
if module_path not in sys.path:
    sys.path.append(module_path)


import catenets.models.torch.pseudo_outcome_nets as pseudo_outcome_nets

In [7]:

def normalize_data(x_train):
    
    x_normalized_train = (x_train - np.min(x_train, axis=0)) / (np.max(x_train, axis=0) - np.min(x_train, axis=0))

    return x_normalized_train


def subgroup_ate(
    method: str,
    features: List[int],
    y_true_train: np.ndarray,
    y_true_test: np.ndarray,
    estimated_ate_test: np.ndarray,
    iss_test: np.ndarray
) -> None:
    
    xgb_model = xgb.XGBClassifier(  
        max_depth=3,
        reg_lambda=2,
        # min_split_loss=2
    )

    xgb_model.fit(x_train[:, features], y_true_train)

    y_pred = xgb_model.predict(x_test[:, features])
    y_pred_train = xgb_model.predict(x_train[:, features])
    
    ate = np.sum(estimated_ate_test[y_pred == 1])/len(estimated_ate_test)
    auroc = metrics.roc_auc_score(y_true_test, y_pred)
        
    print("===================")
    print("%s - auroc %s"%(method, auroc))
    print("%s - ATE %s"%(method, ate))

def feature_idx(
    method: str,
    cohort: str,
    learner: str
)-> List[int]:
    
    if method == "shap":
        file_path = f"../results/{cohort}/naive_shap_top_5_features_{learner}.csv"
    elif method == "ig":
        file_path = f"../results/{cohort}/integrated_gradients_top_5_features_{learner}.csv"
    elif method == "shap - 0 ":
        file_path = f"../results/{cohort}/shapley_value_sampling_top_5_features_{learner}.csv"
        
    df = pd.read_csv(file_path,keep_default_na=False)
    
    df_sorted = df.sort_values(
        by='count (%)', 
        ascending=False
    )
    print(df_sorted["feature"].head(5).tolist())
    
    indices  = [ x.columns.get_loc(col) for col in df_sorted["feature"].head(5) ]
    
    for i in indices:
        if i > treatment_index:
            i -= 1
    return indices

In [8]:
fluid_cohort = pd.read_pickle("../data/low_bp_survival.pkl")


all_year = pd.read_csv("../data/all_year.csv", index_col=0)

fluid_cohort = pd.merge(fluid_cohort,all_year[['registryid','iss']],on='registryid', how='left')
fluid_cohort["iss"] = pd.to_numeric(fluid_cohort["iss"], errors='coerce')

#
fluid_cohort = fluid_cohort[fluid_cohort.columns.drop(list(fluid_cohort.filter(regex='proc')))]
fluid_cohort = fluid_cohort[fluid_cohort.columns.drop(list(fluid_cohort.filter(regex='ethnicity')))]
fluid_cohort = fluid_cohort[fluid_cohort.columns.drop(list(fluid_cohort.filter(regex='residencestate')))]
fluid_cohort = fluid_cohort[fluid_cohort.columns.drop(list(fluid_cohort.filter(regex='toxicologyresults')))]


x = fluid_cohort.loc[:, ~fluid_cohort.columns.isin(["registryid",
                                                            "COV",
                                                            "TT",
                                                            "scenegcsmotor",
                                                            "scenegcseye",
                                                            "scenegcsverbal",
                                                            "edgcsmotor",
                                                            "edgcseye",
                                                            "edgcsverbal",
                                                            "outcome",
                                                            "sex_F",
                                                            "traumatype_P",
                                                            "traumatype_other"
                                                            ])]

n, feature_size = x.shape
names = x.drop(["treated"], axis=1).columns
treatment_index = x.columns.get_loc("treated")
iss_index = x.columns.get_loc("iss")

var_index = [i for i in range(feature_size) if i != treatment_index and i != iss_index]

x_norm = normalize_data(x)

## impute missing value

imp = SimpleImputer(missing_values=np.nan, strategy='mean')
imp.fit(x_norm)
x_train_scaled = imp.transform(x_norm)

x_train, x_test, y_train, y_test = model_selection.train_test_split(
                                             x_train_scaled,  
                                             fluid_cohort["outcome"], 
                                             test_size=0.2, 
                                             random_state=42,
                                             stratify=fluid_cohort["treated"]
                                    )

x_train, x_val, y_train, y_val = model_selection.train_test_split(
                                             x_train,  
                                             y_train, 
                                             test_size=0.2, 
                                             random_state=42,
                                             stratify=x_train[:, treatment_index]
                                    )

w_train = x_train[:, treatment_index]
w_val = x_val[:, treatment_index]
w_test =  x_test[:, treatment_index]


iss_train = x_train[:, iss_index]
iss_test =  x_test[:, iss_index]

x_train = x_train[:,var_index]
x_val = x_val[:, var_index]
x_test = x_test[:, var_index]

  all_year = pd.read_csv("../data/all_year.csv", index_col=0)


In [9]:
results_train = pkl.load(open("../results/massive_trans/train_xlearner.pkl", "rb"))
results_test = pkl.load(open("../results/massive_trans/test_xlearner.pkl", "rb"))

print(np.mean(results_train), np.std(results_train)/np.sqrt(results_train.shape[1]))
print(np.mean(results_test), np.std(results_test)/np.sqrt(results_test.shape[1]))

estimated_ate_train = np.mean(results_train, axis=0)
estimated_ate_test = np.mean(results_test, axis=0)
threshold = np.mean(estimated_ate_train)

y_true_train = (estimated_ate_train > threshold)
y_true_test = (estimated_ate_test > threshold)

scenefirstbloodpressure = x.columns.get_loc("scenefirstbloodpressure")
lac  = x.columns.get_loc("LAC")
inr  = x.columns.get_loc("INR")
hgb  = x.columns.get_loc("HGB")

explainers = {
    
    "shap": feature_idx("shap","massive_trans", "xlearner"),
    "shap - 0 ": feature_idx("shap - 0 ","massive_trans" ,"xlearner"), #[temp, ph, bd, hgb, pulse ]
    "ig": feature_idx("ig","massive_trans" ,"xlearner"), #[ph, na, temp, gender, fio2 ],
    
    "clinical": [lac, inr, hgb,scenefirstbloodpressure ],
    "full features": [ i for i in range(x_train.shape[1])],
    "random features": np.random.randint(x_train.shape[1], size=(5)),
}


print("mean ISS: ", np.mean(iss_test)*74+1)
print("original", np.sum(estimated_ate_test[w_test==1])/n)
print("original - iss", np.mean(iss_test[w_test==1])*74+1, np.mean(iss_test[w_test==0])*74+1)
print("===================================")

for explainer, features in explainers.items():
    subgroup_ate(
        explainer,
        features,
        y_true_train,
        y_true_test,
        estimated_ate_test,
        iss_test
    )

-0.04359583797298134 0.01224235412480223
-0.044979371980935406 0.02237256661326439
['traumatype_B', 'edgcs', 'causecode_FALL', 'sex_M', 'causecode_CUT']
['temps2', 'HCT', 'NA', 'HGB', 'edfirstpulse']
['temps2', 'HCT', 'traumatype_B', 'PH', 'NA']
mean ISS:  29.1565598580524
original -0.007728684288254787
original - iss 32.176198801198794 25.356564783531073
shap - auroc 0.794240317775571
shap - ATE 0.039242999966659434
shap - 0  - auroc 0.5737835153922541
shap - 0  - ATE -0.0031156017095202676
ig - auroc 0.6503475670307846
ig - ATE 0.009749850991010594
clinical - auroc 0.5156901688182721
clinical - ATE -0.024012609621491376
full features - auroc 0.8109235352532274
full features - ATE 0.040060226552290336
random features - auroc 0.5940913604766633
random features - ATE 0.00457393794751434


In [10]:
results_train = pkl.load(open("../results/massive_trans/train_xlearner.pkl", "rb"))
results_test = pkl.load(open("../results/massive_trans/test_xlearner.pkl", "rb"))

print(np.mean(results_train), np.std(results_train)/np.sqrt(results_train.shape[1]))
print(np.mean(results_test), np.std(results_test)/np.sqrt(results_test.shape[1]))

estimated_ate_train = np.mean(results_train, axis=0)
estimated_ate_test = np.mean(results_test, axis=0)
threshold = np.mean(estimated_ate_train)

y_true_train = (estimated_ate_train > threshold)
y_true_test = (estimated_ate_test > threshold)

scenefirstbloodpressure = x.columns.get_loc("scenefirstbloodpressure")
lac  = x.columns.get_loc("LAC")
inr  = x.columns.get_loc("INR")
hgb  = x.columns.get_loc("HGB")

explainers = {
    
    "shap": feature_idx("shap","massive_trans", "ensemble"),
    "shap - 0 ": feature_idx("shap - 0 ","massive_trans" ,"ensemble"), #[temp, ph, bd, hgb, pulse ]
    "ig": feature_idx("ig","massive_trans" ,"ensemble"), #[ph, na, temp, gender, fio2 ],
    
    "clinical": [lac, inr, hgb,scenefirstbloodpressure ],
    "full features": [ i for i in range(x_train.shape[1])],
    "random features": np.random.randint(x_train.shape[1], size=(5)),
}


print("mean ISS: ", np.mean(iss_test)*74+1)
print("original", np.sum(estimated_ate_test[w_test==1])/n)
print("original - iss", np.mean(iss_test[w_test==1])*74+1, np.mean(iss_test[w_test==0])*74+1)
print("===================================")

for explainer, features in explainers.items():
    subgroup_ate(
        explainer,
        features,
        y_true_train,
        y_true_test,
        estimated_ate_test,
        iss_test
    )

-0.04359583797298134 0.01224235412480223
-0.044979371980935406 0.02237256661326439
['causecode_PEDESTRIAN', 'sex_M', 'edgcs', 'scenegcs', 'traumatype_B']
['edgcs', 'scenegcs', 'traumatype_B', 'scenefirstpulse', 'HGB']
['FIO2', 'HGB', 'sex_M', 'temps2', 'traumatype_B']
mean ISS:  29.1565598580524
original -0.007728684288254787
original - iss 32.176198801198794 25.356564783531073
shap - auroc 0.7263654419066534
shap - ATE 0.03105440185826958
shap - 0  - auroc 0.7379841112214497
shap - 0  - ATE 0.029720891329612746
ig - auroc 0.6919066534260179
ig - ATE 0.02325804521797679
clinical - auroc 0.5156901688182721
clinical - ATE -0.024012609621491376
full features - auroc 0.8109235352532274
full features - ATE 0.040060226552290336
random features - auroc 0.5857497517378352
random features - ATE 0.0009342930210170483
