In [1]:
import pandas as pd
import numpy as np

import shap
import os
import sys
import collections
import torch

from scipy import stats
from shapreg import shapley, games, removal, shapley_sampling
from sklearn.impute import SimpleImputer
from sklearn import preprocessing, model_selection

from captum.attr import (
    DeepLift,
    FeatureAblation,
    FeaturePermutation,
    IntegratedGradients,
    KernelShap,
    Lime,
    ShapleyValueSampling,
    GradientShap,
)


module_path = os.path.abspath(os.path.join('CATENets/'))
if module_path not in sys.path:
    sys.path.append(module_path)

import catenets.models.torch.pseudo_outcome_nets as pseudo_outcome_nets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def plot_feature_values(feature_values):

    
    ind = np.argpartition(np.abs(feature_values).mean(0).round(2), -15)[-15:]
    
    feature_names = [
        a + ": " + str(b) for a,b in zip(names[ind], np.abs(feature_values[:, ind]).mean(0).round(2))
    ]

    shap.summary_plot(
        feature_values[:, ind],
        X_test[:, ind], 
        feature_names=feature_names,
        title = "IG"
     )
    
def plot_feature_values_ind(feature_values, indices):
    
    selected_sample = feature_values[indices]
    filtered_test = X_test[indices]
    
    ind = np.argpartition(np.abs(selected_sample).mean(0).round(2), -15)[-15:]
    
    feature_names = [
        a + ": " + str(b) for a,b in zip(names[ind], np.abs(selected_sample[:, ind]).mean(0).round(2))
    ]

    shap.summary_plot(
        selected_sample[:, ind],
        filtered_test[:, ind], 
        feature_names=feature_names,
        title = "IG"
     )

def normalize_data(X_train):
    
    X_normalized_train = (X_train - np.min(X_train, axis=0)) / (np.max(X_train, axis=0) - np.min(X_train, axis=0))

    return X_normalized_train

In [47]:
ist3["glucose"].max()

20.0

In [74]:
ist3 = pd.read_sas("data/datashare_aug2015.sas7bdat")

continuous_vars = [
                    "gender",
                    "age",
                    "weight",
                    "glucose",
                    "gcs_eye_rand",
                    "gcs_motor_rand",
                    "gcs_verbal_rand",
                    # "gcs_score_rand",   
                     "nihss" ,
                     "sbprand",
                     "dbprand",
                  ]

cate_variables = [
                     # "livealone_rand",
                     # "indepinadl_rand",
                     "infarct",
                     "antiplat_rand",
                     # "atrialfib_rand",
                    #  "liftarms_rand",
                    # "ablewalk_rand",
                    # "weakface_rand",
                    # "weakarm_rand",
                    # "weakleg_rand",
                    # "dysphasia_rand",
                    # "hemianopia_rand",
                    # "visuospat_rand",
                    # "brainstemsigns_rand",
                    # "otherdeficit_rand",
                    "stroketype"
                 ]

outcomes = ["dead7","dead6mo"]
treatment = ["itt_treat"]

In [75]:
x = ist3[continuous_vars + cate_variables + treatment]

x = pd.get_dummies(x, columns=cate_variables)

n, feature_size = x.shape


names = x.drop(["itt_treat"], axis=1).columns
treatment_index = x.columns.get_loc("itt_treat")
var_index = [i for i in range(feature_size) if i != treatment_index]

x_norm = normalize_data(x)

## impute missing value

imp = SimpleImputer(missing_values=np.nan, strategy='mean')
imp.fit(x_norm)
x_train_scaled = imp.transform(x_norm)

X_train, X_test, y_train, y_test = model_selection.train_test_split(
                                             x_train_scaled,  
                                             ist3["dead6mo"], 
                                             test_size=0.2, 
                                             random_state=10,
                                    )


w_train = X_train[:, treatment_index] == 0
w_test =  X_test[:, treatment_index] == 0

X_train = X_train[:,var_index]
X_test = X_test[:, var_index]

y_train = y_train ==0
y_test = y_test ==0

In [76]:
model = pseudo_outcome_nets.XLearner(  
                                        X_train.shape[1],
                                        binary_y=(len(np.unique(y_train)) == 2),
                                        n_layers_out=2,
                                        n_units_out=100,
                                        batch_size=128,
                                        n_iter=1000,
                                        nonlin="relu",
                                        device="cuda:1",
                                        )

model.fit(X_train, y_train, w_train)

[po_estimator_0_impute_pos] Epoch: 0, current validation loss: 0.777880847454071, train_loss: 0.793130099773407
[po_estimator_0_impute_pos] Epoch: 50, current validation loss: 0.5362330079078674, train_loss: 0.5426479578018188
[po_estimator_0_impute_pos] Epoch: 100, current validation loss: 0.4684380888938904, train_loss: 0.4624243676662445
[po_estimator_0_impute_pos] Epoch: 150, current validation loss: 0.44742637872695923, train_loss: 0.41952669620513916
[po_estimator_0_impute_pos] Epoch: 200, current validation loss: 0.4395703673362732, train_loss: 0.38949960470199585
[po_estimator_0_impute_pos] Epoch: 250, current validation loss: 0.4383486211299896, train_loss: 0.3938311040401459
[po_estimator_0_impute_pos] Epoch: 300, current validation loss: 0.4376206398010254, train_loss: 0.3900189995765686
[po_estimator_1_impute_pos] Epoch: 0, current validation loss: 0.7410100698471069, train_loss: 0.7296158671379089
[po_estimator_1_impute_pos] Epoch: 50, current validation loss: 0.5362434387

XLearner(
  (_te_estimator): BasicNet(
    (model): Sequential(
      (0): Linear(in_features=20, out_features=100, bias=True)
      (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=100, out_features=100, bias=True)
      (4): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Linear(in_features=100, out_features=1, bias=True)
    )
  )
  (_po_estimator): BasicNet(
    (model): Sequential(
      (0): Linear(in_features=20, out_features=100, bias=True)
      (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=100, out_features=100, bias=True)
      (4): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Linear(in_features=100, out_features=1, bias=True)
      (7): Sigmoid()
    )
  )
  (_propensity_estimato

In [None]:
learner_explanations = {}

learner_explanations["shapley_sampling"] = np.zeros((X_test.shape))
marginal_extension = removal.MarginalExtension(X_test, model)

for test_ind in range(len(X_test)):
    instance = X_test[test_ind]
    game = games.PredictionGame(marginal_extension, instance)
    explanation = shapley_sampling.ShapleySampling(game, thresh=0.01, batch_size=128)
    learner_explanations["shapley_sampling"][test_ind] = explanation.values.reshape(-1, X_test.shape[1])

plot_feature_values(learner_explanations["shapley_sampling"])

100%|█████████████████████████████████████████████| 1/1 [00:18<00:00, 18.41s/it]
100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.75s/it]
100%|█████████████████████████████████████████████| 1/1 [00:08<00:00,  8.19s/it]
100%|█████████████████████████████████████████████| 1/1 [00:05<00:00,  5.80s/it]
100%|█████████████████████████████████████████████| 1/1 [00:07<00:00,  7.51s/it]
100%|█████████████████████████████████████████████| 1/1 [00:10<00:00, 10.60s/it]
100%|█████████████████████████████████████████████| 1/1 [00:06<00:00,  6.15s/it]
100%|█████████████████████████████████████████████| 1/1 [00:07<00:00,  7.86s/it]
100%|█████████████████████████████████████████████| 1/1 [00:11<00:00, 11.28s/it]
100%|█████████████████████████████████████████████| 1/1 [00:13<00:00, 13.00s/it]
100%|█████████████████████████████████████████████| 1/1 [00:08<00:00,  8.22s/it]
100%|█████████████████████████████████████████████| 1/1 [00:05<00:00,  5.14s/it]
100%|███████████████████████

In [None]:
# Shapley value sampling
shapley_value_sampling_model = ShapleyValueSampling(model)

learner_explanations["shapley_sampling_0"] = shapley_value_sampling_model.attribute(
                                                 torch.from_numpy(X_test).to("cuda:1").requires_grad_(),
                                                n_samples=500,
                                                perturbations_per_eval=10,
                                            ).detach().cpu().numpy()

plot_feature_values(learner_explanations["shapley_sampling_0"])

In [None]:
# Shapley value sampling
ig = IntegratedGradients(model)

learner_explanations["ig"] = ig.attribute(
                                    torch.from_numpy(X_test).to("cuda:1").requires_grad_(),
                                    n_steps=500,
                            ).detach().cpu().numpy()

plot_feature_values(learner_explanations["ig"])