# Test of counterfactuals for thrombolysis classification

## Code setup

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
import pickle
from xgboost import XGBClassifier
import xgboost

import dice_ml
from dice_ml.utils import helpers # helper functions

from dataclasses import dataclass

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


In [2]:
import dice_ml.explainer_interfaces.dice_xgboost

Set up paths and file names:

In [3]:
@dataclass(frozen=True)
class Paths:
    '''Singleton object for storing paths to data and database.'''

    data_read_path: str = 'data'
    # output_folder = 'output'
    model_folder = 'data'
    patient_data_path: str = '~/stroke-modelling/stroke-utilities/stroke_utilities/data/'
    data_test_filename: str = 'cohort_10000_test.csv'
    data_train_filename: str = 'cohort_10000_train.csv'

paths = Paths()

## Load model

Existing classification model for predictions

In [4]:
model = XGBClassifier({'nthread': 4})  # init model
filename = f'{paths.model_folder}/model_resave.json'
model.load_model(filename)

Pass `objective` as keyword args.


And for use with counterfactuals:

In [46]:
m = dice_ml.Model(
    model=model,
    backend={
        "model": "xgboost_model.XGBoostModel",
        "explainer": dice_ml.explainer_interfaces.dice_xgboost.DiceXGBoost
    },
    # func="ohe-min-max",
)

{'model': 'xgboost_model.XGBoostModel', 'explainer': <class 'dice_ml.explainer_interfaces.dice_xgboost.DiceXGBoost'>} backend not in supported backends sklearn,TF1,TF2,PYT


## Load data

Look at real cleaned data to get the column names and ranges or sets of available values:

Convert bool columns containing missing values to float to prevent errors later.

In [6]:
filename = paths.patient_data_path + paths.data_test_filename
test = pd.read_csv(filename)
test['data_source'] = 'test'

filename = paths.patient_data_path + paths.data_train_filename
train = pd.read_csv(filename)
train['data_source'] = 'train'

data = pd.concat((train, test), axis='rows', ignore_index=True)

data = data.drop('year', axis=1)
data = data.drop('data_source', axis=1)

In [7]:
data.columns

Index(['stroke_team_id', 'stroke_severity', 'prior_disability', 'age',
       'infarction', 'onset_to_arrival_time', 'precise_onset_known',
       'onset_during_sleep', 'arrival_to_scan_time', 'afib_anticoagulant',
       'thrombolysis'],
      dtype='object')

In [8]:
stroke_teams_list = sorted(data['stroke_team_id'].unique())

One-hot encode the stroke teams:

In [9]:
def one_hot_encode_column(X, col, prefix='team'):
    """
    Create a copy of X data with one column made one-hot-encoded.
    """
    # Keep copy of original, with 'Stroke team' not one-hot encoded
    # X_combined = X.copy(deep=True)
    
    # One-hot encode 'Stroke team'
    X_hosp = pd.get_dummies(X[col], prefix=prefix)
    X = pd.concat([X, X_hosp], axis=1)
    X.drop(col, axis=1, inplace=True)
    return X


data = one_hot_encode_column(data, 'stroke_team_id', prefix='team')
cols_teams = [c for c in data.columns if c.startswith('team')]
data[cols_teams] = data[cols_teams].astype(int)

In [10]:
data.columns[-5:]

Index(['team_115', 'team_116', 'team_117', 'team_118', 'team_119'], dtype='object')

## Define patients for generating counterfactuals

In [11]:
df_proto = pd.read_csv(f'{paths.data_read_path}/prototype_patients.csv', index_col=0)

In [12]:
df_proto.index

Index(['Ideal', 'Late arrival', 'Mild', 'Prior disability', 'Imprecise', 'Age',
       'Mild + Prior disability', 'Mild + Imprecise', 'Mild + Age',
       'Mild + Late', 'Imprecise + Prior disability', 'Imprecise + Age',
       'Imprecise + Late', 'Prior disability + Age', 'Prior disability + Late',
       'Mild + Prior disability + Imprecise',
       'Prior disability + Imprecise + Late'],
      dtype='object', name='Patient prototype')

Order columns to match model:

In [13]:
df_proto = df_proto[[
    'stroke_severity', 'prior_disability', 'age', 'infarction',
    'onset_to_arrival_time', 'precise_onset_known', 'onset_during_sleep',
    'arrival_to_scan_time', 'afib_anticoagulant', 'stroke_team'
]]

In [14]:
df_teams_empty = pd.DataFrame(np.full((len(df_proto), len(cols_teams)), 0),
                              columns=cols_teams, index=df_proto.index).astype(int)
df_proto = pd.concat((df_proto, df_teams_empty), axis='columns')
df_proto = df_proto.drop('stroke_team', axis='columns')

In [15]:
df_proto.head(3)

Unnamed: 0_level_0,stroke_severity,prior_disability,age,infarction,onset_to_arrival_time,precise_onset_known,onset_during_sleep,arrival_to_scan_time,afib_anticoagulant,team_1,...,team_110,team_111,team_112,team_113,team_114,team_115,team_116,team_117,team_118,team_119
Patient prototype,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Ideal,15,0,72.5,1,90,1,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0
Late arrival,15,0,72.5,1,225,1,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0
Mild,3,0,72.5,1,90,1,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0


In [16]:
df_proto.tail(3)

Unnamed: 0_level_0,stroke_severity,prior_disability,age,infarction,onset_to_arrival_time,precise_onset_known,onset_during_sleep,arrival_to_scan_time,afib_anticoagulant,team_1,...,team_110,team_111,team_112,team_113,team_114,team_115,team_116,team_117,team_118,team_119
Patient prototype,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Prior disability + Late,15,3,72.5,1,225,1,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0
Mild + Prior disability + Imprecise,3,3,72.5,1,90,0,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0
Prior disability + Imprecise + Late,15,3,72.5,1,225,0,0,15,0,0,...,0,0,0,0,0,0,0,0,0,0


In [17]:
col_proto_source = 'Prior disability + Imprecise + Late'
col_this_patient = 'this_patient'

df_this_patient = df_proto.loc[[col_proto_source]].copy()
# Rename:
df_this_patient = df_this_patient.rename(index={col_proto_source: col_this_patient})
# Update values using loc so that column dtype doesn't change.
# Assign a team:
df_this_patient.loc[col_this_patient, 'team_42'] = 1
# Update other columns to reduce propensity for thrombolysis:
df_this_patient.loc[col_this_patient, 'stroke_severity'] = 3
# df_this_patient.loc[col_this_patient, 'afib_anticoagulant'] = pd.NA  # DiCE can't handle NaN



# Would this patient receive thrombolysis?
predicted_proba = model.predict_proba(df_this_patient)[:,1]
predict = model.predict(df_this_patient)

df_this_patient.loc[col_this_patient, 'thrombolysis_prob'] = predicted_proba
df_this_patient.loc[col_this_patient, 'thrombolysis'] = predict

In [18]:
df_this_patient.T

Patient prototype,this_patient
stroke_severity,3.00000
prior_disability,3.00000
age,72.50000
infarction,1.00000
onset_to_arrival_time,225.00000
...,...
team_117,0.00000
team_118,0.00000
team_119,0.00000
thrombolysis_prob,0.16238


In [19]:
df_this_patient_not_ohe = df_this_patient.copy()

team_here = [c for c in cols_teams if (df_this_patient_not_ohe[c].sum() > 0)]
df_this_patient_not_ohe['team'] = team_here[0].split('team_')[-1]
df_this_patient_not_ohe = df_this_patient_not_ohe.drop(cols_teams, axis='columns')

In [20]:
df_this_patient_not_ohe

Unnamed: 0_level_0,stroke_severity,prior_disability,age,infarction,onset_to_arrival_time,precise_onset_known,onset_during_sleep,arrival_to_scan_time,afib_anticoagulant,thrombolysis_prob,thrombolysis,team
Patient prototype,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
this_patient,3,3,72.5,1,225,0,0,15,0,0.16238,0.0,42


## Data setup for dice model

In [21]:
feature_cols = [c for c in data.columns if c != 'thrombolysis']
cols_discrete = ['infarction', 'prior_disability', 'stroke_severity', 'precise_onset_known', 'onset_during_sleep', 'afib_anticoagulant', 'age']
# cols_categorical += [c for c in feature_cols if c.startswith('team')]
cols_continuous = [c for c in feature_cols if ((c not in cols_discrete) & ('team' not in c))]

cols_continuous

['onset_to_arrival_time', 'arrival_to_scan_time']

In [22]:
features_dict = {}

for c in cols_continuous + cols_discrete + cols_teams:
    if c in cols_discrete:
        if any(data[c].isna()):
            # vals = [pd.NA] + 
            # Flag problem but don't use NaN as DiCE can't handle it:
            print(f'Missing data in {c} cannot be accessed')
            vals = sorted(list(set(data.loc[data[c].notna(), c])))
        else:
            vals = sorted(list(set(data[c])))
    else:
        vals = [data[c].min(), data[c].max()]
    features_dict[c] = vals

# features_dict['team'] = [str(s) for s in stroke_teams_list]


Missing data in afib_anticoagulant cannot be accessed


In [23]:
features_dict

{'onset_to_arrival_time': [1.0, 239.0],
 'arrival_to_scan_time': [1.0, 232.0],
 'infarction': [0.0, 1.0],
 'prior_disability': [0, 1, 2, 3, 4, 5],
 'stroke_severity': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42],
 'precise_onset_known': [0, 1],
 'onset_during_sleep': [0, 1],
 'afib_anticoagulant': [0.0, 1.0],
 'age': [37.5,
  42.5,
  47.5,
  52.5,
  57.5,
  62.5,
  67.5,
  72.5,
  77.5,
  82.5,
  87.5,
  92.5],
 'team_1': [0, 1],
 'team_2': [0, 1],
 'team_3': [0, 1],
 'team_4': [0, 1],
 'team_5': [0, 1],
 'team_6': [0, 1],
 'team_7': [0, 1],
 'team_8': [0, 1],
 'team_9': [0, 1],
 'team_10': [0, 1],
 'team_11': [0, 1],
 'team_12': [0, 1],
 'team_13': [0, 1],
 'team_14': [0, 1],
 'team_15': [0, 1],
 'team_16': [0, 1],
 'team_17': [0, 1],
 'team_18': [0, 1],
 'team_19': [0, 1],
 'team_20': 

Define custom weights:

In [24]:
cols_patient = ['stroke_severity', 'prior_disability', 'age', 'infarction', 'afib_anticoagulant']
cols_pathway = ['onset_to_arrival_time', 'precise_onset_known', 'onset_during_sleep', 'arrival_to_scan_time']
# cols_teams = [c for c in feature_cols if c.startswith('team')]

In [25]:
feature_weights = {}

for col in feature_cols:
    feature_weights[col] = 1.0 if col in cols_pathway else 1000.0

In [26]:
features_to_vary = cols_pathway

## Make data object

In [28]:
d = dice_ml.Data(
    features=features_dict,
    outcome_name='thrombolysis',
)

In [49]:
d.continuous_feature_names[:12]

['onset_to_arrival_time',
 'arrival_to_scan_time',
 'infarction',
 'prior_disability',
 'stroke_severity',
 'precise_onset_known',
 'onset_during_sleep',
 'afib_anticoagulant',
 'age',
 'team_1',
 'team_2',
 'team_3']

In [30]:
d.categorical_feature_names

['team']

Check default feature weights - requires access to full data, not just value ranges / sets.

Higher weights mean the feature is harder to change.

In [39]:
# get MAD
mads = d.get_mads(normalized=True)

# create feature weights
feature_weights = {}
for feature in mads:
    feature_weights[feature] = round(1/mads[feature], 2)

In [40]:
feature_weights

{}

Set teams to be one-hot-encoded.

## Make counterfactual model

In [47]:
# Using method=random for generating CFs
exp = dice_ml.Dice(d, m, method="random")

## Make counterfactual for this patient

In [42]:
df_this_patient.columns[:20]

Index(['stroke_severity', 'prior_disability', 'age', 'infarction',
       'onset_to_arrival_time', 'precise_onset_known', 'onset_during_sleep',
       'arrival_to_scan_time', 'afib_anticoagulant', 'team_1', 'team_2',
       'team_3', 'team_4', 'team_5', 'team_6', 'team_7', 'team_8', 'team_9',
       'team_10', 'team_11'],
      dtype='object')

In [43]:
data.columns[:20]

Index(['stroke_severity', 'prior_disability', 'age', 'infarction',
       'onset_to_arrival_time', 'precise_onset_known', 'onset_during_sleep',
       'arrival_to_scan_time', 'afib_anticoagulant', 'thrombolysis', 'team_1',
       'team_2', 'team_3', 'team_4', 'team_5', 'team_6', 'team_7', 'team_8',
       'team_9', 'team_10'],
      dtype='object')

In [None]:
e1 = exp.generate_counterfactuals(
    df_this_patient.drop(['thrombolysis', 'thrombolysis_prob'], axis='columns'),  # df_test,
    total_CFs=2, desired_class="opposite",
    # features_to_vary=features_to_vary,
    # proximity_weight=1.5, diversity_weight=1.0,
    verbose=True,
)

## Results

In [None]:
# Pick out results:
df_cf = e1.cf_examples_list[0].final_cfs_df
# Place in same dataframe as starting values:
df_results = pd.concat((df_this_patient, df_cf), axis='rows')

View columns that have changed values:

In [None]:
a = df_results.to_numpy()
cols_changed = df_results.columns[~(a[0] == a).all(axis=0)]

df_results[cols_changed]

In [None]:
df_this_patient[[c for c in df_this_patient.columns if (not c.startswith('team')) | (c.startswith('team') & (df_this_patient[c].sum() > 0))]]

## TO DO

Need to tell the model that it can't send patients to multiple stroke teams.