In [1]:
import pandas as pd
import numpy as np

import shap
import os
import sys
import collections
import torch
import xgboost as xgb

from scipy import stats
from shapreg import shapley, games, removal, shapley_sampling
from sklearn.impute import SimpleImputer
from sklearn import preprocessing, model_selection
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier


from captum.attr import (
    DeepLift,
    FeatureAblation,
    FeaturePermutation,
    IntegratedGradients,
    KernelShap,
    Lime,
    ShapleyValueSampling,
    GradientShap,
)


module_path = os.path.abspath(os.path.join('../CATENets/'))
if module_path not in sys.path:
    sys.path.append(module_path)

import catenets.models.torch.pseudo_outcome_nets as pseudo_outcome_nets
import catenets.models as cate_models
import catenets.models.torch.tlearner as tlearner

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def plot_feature_values(feature_values):

    
    ind = np.argpartition(np.abs(feature_values).mean(0).round(2), -15)[-15:]
    
    feature_names = [
        a + ": " + str(b) for a,b in zip(names[ind], np.abs(feature_values[:, ind]).mean(0).round(2))
    ]

    shap.summary_plot(
        feature_values[:, ind],
        X_test[:, ind], 
        feature_names=feature_names,
        title = "IG"
     )
    
def plot_feature_values_ind(feature_values, indices):
    
    selected_sample = feature_values[indices]
    filtered_test = X_test[indices]
    
    ind = np.argpartition(np.abs(selected_sample).mean(0).round(2), -15)[-15:]
    
    feature_names = [
        a + ": " + str(b) for a,b in zip(names[ind], np.abs(selected_sample[:, ind]).mean(0).round(2))
    ]

    shap.summary_plot(
        selected_sample[:, ind],
        filtered_test[:, ind], 
        feature_names=feature_names,
        title = "IG"
     )

def normalize_data(X_train):
    
    X_normalized_train = (X_train - np.min(X_train, axis=0)) / (np.max(X_train, axis=0) - np.min(X_train, axis=0))

    return X_normalized_train


In [3]:
ist3 = pd.read_sas("../data/datashare_aug2015.sas7bdat")

continuous_vars = [
                    "gender",
                    "age",
                    "weight",
                    "glucose",
                    "gcs_eye_rand",
                    "gcs_motor_rand",
                    "gcs_verbal_rand",
                    # "gcs_score_rand",   
                     "nihss" ,
                     "sbprand",
                     "dbprand",
                  ]

cate_variables = [
                     # "livealone_rand",
                     # "indepinadl_rand",
                     "infarct",
                     "antiplat_rand",
                     # "atrialfib_rand",
                    #  "liftarms_rand",
                    # "ablewalk_rand",
                    # "weakface_rand",
                    # "weakarm_rand",
                    # "weakleg_rand",
                    # "dysphasia_rand",
                    # "hemianopia_rand",
                    # "visuospat_rand",
                    # "brainstemsigns_rand",
                    # "otherdeficit_rand",
                    "stroketype"
                 ]

outcomes = ["dead7","dead6mo","aliveind6"]
treatment = ["itt_treat"]

In [4]:
x = ist3[continuous_vars + cate_variables + treatment]

x = pd.get_dummies(x, columns=cate_variables)

n, feature_size = x.shape


names = x.drop(["itt_treat"], axis=1).columns
treatment_index = x.columns.get_loc("itt_treat")
var_index = [i for i in range(feature_size) if i != treatment_index]

x_norm = normalize_data(x)

## impute missing value

imp = SimpleImputer(missing_values=np.nan, strategy='mean')
imp.fit(x_norm)
x_train_scaled = imp.transform(x_norm)

X_train, X_test, y_train, y_test = model_selection.train_test_split(
                                             x_train_scaled,  
                                             ist3["aliveind6"], 
                                             test_size=0.2, 
                                             random_state=10,
                                    )


w_train = X_train[:, treatment_index] == 0
w_test =  X_test[:, treatment_index] == 0

X_train = X_train[:,var_index]
X_test = X_test[:, var_index]

y_train = y_train == 0
y_test = y_test == 0

In [5]:
learners = {
    "XLearner":pseudo_outcome_nets.XLearner(  
                X_train.shape[1],
                binary_y=(len(np.unique(y_train)) == 2),
                n_layers_out=2,
                n_units_out=100,
                batch_size=128,
                n_iter=1000,
                nonlin="relu",
                device="cuda:1"
            ),
    "DRLearner": pseudo_outcome_nets.DRLearner(
               X_train.shape[1],
               binary_y=(len(np.unique(y_train)) == 2),
                n_layers_out=2,
                n_units_out=100,
                batch_size=128,
                n_iter=1000,
                nonlin="relu",
                device="cuda:1"
           ),
    "SLearner": cate_models.torch.SLearner(
               X_train.shape[1],
               binary_y=(len(np.unique(y_train)) == 2),
                n_layers_out=2,
                n_units_out=100,
                batch_size=128,
                n_iter=1000,
                nonlin="relu",
                device="cuda:1"
           ),
    "TLearner": cate_models.torch.TLearner(
               X_train.shape[1],
               binary_y=(len(np.unique(y_train)) == 2),
                n_layers_out=2,
                n_units_out=100,
                batch_size=128,
                n_iter=1000,
                nonlin="relu",
                device="cuda:1"
           ),
    
}

trials = 5

if_results = np.zeros(( trials, len(learners)))
seeds = np.arange(0, 5, 1, dtype=int)


for i, seed in enumerate(seeds):
    np.random.seed(seed)
    
    # train plugin-estimator
    
    xgb_plugin1 = xgb.XGBClassifier(max_depth=6, random_state=seed, n_estimators=100)
    xgb_plugin0 = xgb.XGBClassifier(max_depth=6, random_state=seed, n_estimators=100)   

    rf = RandomForestClassifier(max_depth=6, random_state=seed)

    x0 = X_train[w_train==0]
    x1 = X_train[w_train==1]

    y0 = y_train[w_train==0]
    y1 = y_train[w_train==1]


    xgb_plugin0.fit(x0, y0)
    xgb_plugin1.fit(x1, y1)

    rf.fit(X_train, w_train)

    y_pred0 = xgb_plugin0.predict(X_test)
    y_pred1 = xgb_plugin1.predict(X_test)

    t_plugin = y_pred1 - y_pred0

    ps = rf.predict_proba(X_test)[:, 1]
    a = (w_test - ps)

    ident = np.array([1]*len(ps))
    c = (ps*(ident-ps))

    b = np.array([2]*len(w_test))*w_test*(w_test-ps) / c

    for learner_idx, (learner_name,_ ) in enumerate(learners.items()):
        
        model = learners[learner_name]
        model.fit(X_train, y_train, w_train)

        cate = model.predict(X_test).detach().cpu().numpy()
        cate = cate.flatten()

        plug_in = (t_plugin - cate)**2
        l_de = (ident - b) * t_plugin**2 + b*y_test*(t_plugin - cate) + (- a*(t_plugin - cate)**2 + cate**2)
        
        if_results[i, learner_idx] = np.sum(l_de) + np.sum(plug_in)

[po_estimator_0_impute_pos] Epoch: 0, current validation loss: 0.2958980202674866, train_loss: 0.30754077434539795
[po_estimator_0_impute_pos] Epoch: 50, current validation loss: 0.016850119456648827, train_loss: 0.008345781825482845
[po_estimator_0_impute_pos] Epoch: 100, current validation loss: 0.011545160785317421, train_loss: 0.0031158681958913803
[po_estimator_0_impute_pos] Epoch: 150, current validation loss: 0.009853529743850231, train_loss: 0.002956594806164503
[po_estimator_0_impute_pos] Epoch: 200, current validation loss: 0.009441617876291275, train_loss: 0.0029753106646239758
[po_estimator_1_impute_pos] Epoch: 0, current validation loss: 0.2110263556241989, train_loss: 0.22708368301391602
[po_estimator_1_impute_pos] Epoch: 50, current validation loss: 0.04732479155063629, train_loss: 0.02508283033967018
[po_estimator_1_impute_pos] Epoch: 100, current validation loss: 0.03256930038332939, train_loss: 0.010434633120894432
[po_estimator_1_impute_pos] Epoch: 150, current valid

In [8]:
np.mean(if_results, axis=0)

array([ 8.41133901,  8.43228651,  2.38054062, 11.47251618])