# Test of counterfactuals for thrombolysis classification

## Code setup

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
import pickle
from xgboost import XGBClassifier
import xgboost

import dice_ml
from dice_ml.utils import helpers # helper functions

from dataclasses import dataclass

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


In [2]:
import dice_ml.explainer_interfaces.dice_xgboost

Set up paths and file names:

In [3]:
@dataclass(frozen=True)
class Paths:
    '''Singleton object for storing paths to data and database.'''

    data_read_path: str = 'data'
    # output_folder = 'output'
    model_folder = 'data'
    patient_data_path: str = '~/stroke-modelling/stroke-utilities/stroke_utilities/data/'
    data_test_filename: str = 'cohort_10000_test.csv'
    data_train_filename: str = 'cohort_10000_train.csv'

paths = Paths()

## Load model

Existing classification model for predictions

In [4]:
model = XGBClassifier({'nthread': 4})  # init model
filename = f'{paths.model_folder}/model_resave.json'
model.load_model(filename)

Pass `objective` as keyword args.


And for use with counterfactuals:

In [5]:
m = dice_ml.Model(
    model=model,
    backend={
        "model": "xgboost_model.XGBoostModel",
        "explainer": dice_ml.explainer_interfaces.dice_xgboost.DiceXGBoost
    },
    func="ohe-min-max",
)

{'model': 'xgboost_model.XGBoostModel', 'explainer': <class 'dice_ml.explainer_interfaces.dice_xgboost.DiceXGBoost'>} backend not in supported backends sklearn,TF1,TF2,PYT


## Load data

Look at real cleaned data to get the column names and ranges or sets of available values:

Convert bool columns containing missing values to float to prevent errors later.

In [6]:
filename = paths.patient_data_path + paths.data_test_filename
test = pd.read_csv(filename)
test['data_source'] = 'test'

filename = paths.patient_data_path + paths.data_train_filename
train = pd.read_csv(filename)
train['data_source'] = 'train'

data = pd.concat((train, test), axis='rows', ignore_index=True)

data = data.drop('year', axis=1)
data = data.drop('data_source', axis=1)

In [7]:
data.columns

Index(['stroke_team_id', 'stroke_severity', 'prior_disability', 'age',
       'infarction', 'onset_to_arrival_time', 'precise_onset_known',
       'onset_during_sleep', 'arrival_to_scan_time', 'afib_anticoagulant',
       'thrombolysis'],
      dtype='object')

One-hot encode the stroke teams:

In [8]:
def one_hot_encode_column(X, col, prefix='team'):
    """
    Create a copy of X data with one column made one-hot-encoded.
    """
    # Keep copy of original, with 'Stroke team' not one-hot encoded
    # X_combined = X.copy(deep=True)
    
    # One-hot encode 'Stroke team'
    X_hosp = pd.get_dummies(X[col], prefix=prefix)
    X = pd.concat([X, X_hosp], axis=1)
    X.drop(col, axis=1, inplace=True)
    return X


data = one_hot_encode_column(data, 'stroke_team_id', prefix='team')

In [9]:
cols_teams = [c for c in data.columns if c.startswith('team')]
data[cols_teams] = data[cols_teams].astype(int)

In [10]:
data.columns[-5:]

Index(['team_115', 'team_116', 'team_117', 'team_118', 'team_119'], dtype='object')

In [11]:
feature_cols = [c for c in data.columns if c != 'thrombolysis']
cols_categorical = ['infarction', 'prior_disability', 'stroke_severity', 'precise_onset_known', 'onset_during_sleep', 'afib_anticoagulant']
cols_categorical += [c for c in feature_cols if c.startswith('team')]
cols_continuous = [c for c in feature_cols if c not in cols_categorical]

cols_continuous

['age', 'onset_to_arrival_time', 'arrival_to_scan_time']

In [12]:
features_dict = {}

for c in feature_cols:
    if c in cols_categorical:
        if any(data[c].isna()):
            # vals = [pd.NA] + 
            # Flag problem but don't use NaN as DiCE can't handle it:
            print(f'Missing data in {c} cannot be accessed')
            vals = sorted(list(set(data.loc[data[c].notna(), c])))
        else:
            vals = sorted(list(set(data[c])))
    else:
        vals = [data[c].min(), data[c].max()]
    features_dict[c] = vals

Missing data in afib_anticoagulant cannot be accessed


In [13]:
features_dict

{'stroke_severity': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42],
 'prior_disability': [0, 1, 2, 3, 4, 5],
 'age': [37.5, 92.5],
 'infarction': [0.0, 1.0],
 'onset_to_arrival_time': [1.0, 239.0],
 'precise_onset_known': [0, 1],
 'onset_during_sleep': [0, 1],
 'arrival_to_scan_time': [1.0, 232.0],
 'afib_anticoagulant': [0.0, 1.0],
 'team_1': [0, 1],
 'team_2': [0, 1],
 'team_3': [0, 1],
 'team_4': [0, 1],
 'team_5': [0, 1],
 'team_6': [0, 1],
 'team_7': [0, 1],
 'team_8': [0, 1],
 'team_9': [0, 1],
 'team_10': [0, 1],
 'team_11': [0, 1],
 'team_12': [0, 1],
 'team_13': [0, 1],
 'team_14': [0, 1],
 'team_15': [0, 1],
 'team_16': [0, 1],
 'team_17': [0, 1],
 'team_18': [0, 1],
 'team_19': [0, 1],
 'team_20': [0, 1],
 'team_21': [0, 1],
 'team_22': [0, 1],
 'team_23': [0, 1],
 'team_24': [0

In [14]:
d = dice_ml.Data(features=features_dict, outcome_name='thrombolysis', continuous_features=cols_continuous)

## Generate counterfactuals

In [15]:
df_proto = pd.read_csv(f'{paths.data_read_path}/prototype_patients.csv', index_col=0)

In [16]:
df_proto.index

Index(['Ideal', 'Late arrival', 'Mild', 'Prior disability', 'Imprecise', 'Age',
       'Mild + Prior disability', 'Mild + Imprecise', 'Mild + Age',
       'Mild + Late', 'Imprecise + Prior disability', 'Imprecise + Age',
       'Imprecise + Late', 'Prior disability + Age', 'Prior disability + Late',
       'Mild + Prior disability + Imprecise',
       'Prior disability + Imprecise + Late'],
      dtype='object', name='Patient prototype')

Order columns to match model:

In [17]:
df_proto = df_proto[['stroke_severity', 'prior_disability', 'age', 'infarction', 'onset_to_arrival_time', 'precise_onset_known', 'onset_during_sleep', 'arrival_to_scan_time', 'afib_anticoagulant', 'stroke_team']]

In [18]:
df_teams_empty = pd.DataFrame(np.full((len(df_proto), len(cols_teams)), 0),
                              columns=cols_teams, index=df_proto.index)
df_proto = pd.concat((df_proto, df_teams_empty), axis='columns')
df_proto = df_proto.drop('stroke_team', axis='columns')

In [19]:
df_proto.head(3)

Unnamed: 0_level_0,stroke_severity,prior_disability,age,infarction,onset_to_arrival_time,precise_onset_known,onset_during_sleep,arrival_to_scan_time,afib_anticoagulant,team_1,...,team_110,team_111,team_112,team_113,team_114,team_115,team_116,team_117,team_118,team_119
Patient prototype,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Ideal,15,0,72.5,1,90,1,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0
Late arrival,15,0,72.5,1,225,1,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0
Mild,3,0,72.5,1,90,1,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0


In [20]:
df_proto.tail(3)

Unnamed: 0_level_0,stroke_severity,prior_disability,age,infarction,onset_to_arrival_time,precise_onset_known,onset_during_sleep,arrival_to_scan_time,afib_anticoagulant,team_1,...,team_110,team_111,team_112,team_113,team_114,team_115,team_116,team_117,team_118,team_119
Patient prototype,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Prior disability + Late,15,3,72.5,1,225,1,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0
Mild + Prior disability + Imprecise,3,3,72.5,1,90,0,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0
Prior disability + Imprecise + Late,15,3,72.5,1,225,0,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0


In [21]:
col_proto_source = 'Prior disability + Imprecise + Late'
col_this_patient = 'this_patient'

df_this_patient = df_proto.loc[[col_proto_source]].copy()
# Rename:
df_this_patient = df_this_patient.rename(index={col_proto_source: col_this_patient})
# Update values using loc so that column dtype doesn't change.
# Assign a team:
df_this_patient.loc[col_this_patient, 'team_42'] = 1
# Update other columns to reduce propensity for thrombolysis:
df_this_patient.loc[col_this_patient, 'stroke_severity'] = 3
# df_this_patient.loc[col_this_patient, 'afib_anticoagulant'] = pd.NA  # DiCE can't handle NaN
df_this_patient

Unnamed: 0_level_0,stroke_severity,prior_disability,age,infarction,onset_to_arrival_time,precise_onset_known,onset_during_sleep,arrival_to_scan_time,afib_anticoagulant,team_1,...,team_110,team_111,team_112,team_113,team_114,team_115,team_116,team_117,team_118,team_119
Patient prototype,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
this_patient,3,3,72.5,1,225,0,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0


In [22]:
df_this_patient.dtypes[:10]

stroke_severity            int64
prior_disability           int64
age                      float64
infarction                 int64
onset_to_arrival_time      int64
precise_onset_known        int64
onset_during_sleep         int64
arrival_to_scan_time       int64
afib_anticoagulant         int64
team_1                     int64
dtype: object

Would this patient receive thrombolysis?

In [23]:
predicted_proba = model.predict_proba(df_this_patient)[:,1]

In [24]:
predicted_proba

array([0.16237952], dtype=float32)

In [25]:
predict = model.predict(df_this_patient)

In [26]:
predict

array([0])

In [27]:
# Using method=random for generating CFs
exp = dice_ml.Dice(d, m, method="random")

In [28]:
e1 = exp.generate_counterfactuals(df_this_patient, total_CFs=2, desired_class="opposite")

100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.19it/s]


In [29]:
df_e1 = e1.visualize_as_dataframe(show_only_changes=True)

Query instance (original outcome : 0)


Unnamed: 0,stroke_severity,prior_disability,age,infarction,onset_to_arrival_time,precise_onset_known,onset_during_sleep,arrival_to_scan_time,afib_anticoagulant,team_1,...,team_111,team_112,team_113,team_114,team_115,team_116,team_117,team_118,team_119,thrombolysis
0,3,3,72.5,1,225,0,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0



Diverse Counterfactual set without sparsity correction since only metadata about each  feature is available (new outcome: 1


Unnamed: 0,stroke_severity,prior_disability,age,infarction,onset_to_arrival_time,precise_onset_known,onset_during_sleep,arrival_to_scan_time,afib_anticoagulant,team_1,...,team_111,team_112,team_113,team_114,team_115,team_116,team_117,team_118,team_119,thrombolysis
0,-,1.0,-,-,-,-,-,-,-,-,...,-,-,-,-,-,-,-,-,-,1.0
1,-,-,-,-,-,-,-,-,-,-,...,-,-,-,-,-,-,-,-,-,1.0


In [30]:
e1.cf_examples_list[0].final_cfs_df

Unnamed: 0,stroke_severity,prior_disability,age,infarction,onset_to_arrival_time,precise_onset_known,onset_during_sleep,arrival_to_scan_time,afib_anticoagulant,team_1,...,team_111,team_112,team_113,team_114,team_115,team_116,team_117,team_118,team_119,thrombolysis
0,3,1,72.5,1,225,0,0,15,0,0,...,0,0,0,0,0,0,0,0,0,1
1,3,3,72.5,1,225,0,0,15,0,0,...,0,0,0,0,0,0,0,0,0,1


In [31]:
with pd.option_context('display.max_rows', 500):
    display(e1.cf_examples_list[0].final_cfs_df.T)

Unnamed: 0,0,1
stroke_severity,3.0,3.0
prior_disability,1.0,3.0
age,72.5,72.5
infarction,1.0,1.0
onset_to_arrival_time,225.0,225.0
precise_onset_known,0.0,0.0
onset_during_sleep,0.0,0.0
arrival_to_scan_time,15.0,15.0
afib_anticoagulant,0.0,0.0
team_1,0.0,0.0


Need to tell the model that it can't send patients to multiple stroke teams.