In [None]:
import os
import csv

import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.neighbors import NearestNeighbors, LocalOutlierFactor
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

import keras
import tensorflow as tf
import random as python_random
from keras import layers
from keras.preprocessing import sequence
from keras.preprocessing.text import Tokenizer
from tensorflow.keras.callbacks import EarlyStopping
from keras import backend as K

import pandas.io.sql as sqlio
import psycopg2
import getpass

import matplotlib.pyplot as plt

## 1. Load data

In [None]:
# Define MIMIC data path
data_path = '../mimic_data/'

# Load training data
train_pos = pd.read_csv(data_path+'train_pos.txt', header=None)
train_neg = pd.read_csv(data_path+'train_neg.txt', header=None)

# Add target class label
train_pos['survival'] = [1 for i in range(train_pos.shape[0])]
train_neg['survival'] = [0 for i in range(train_neg.shape[0])]

In [None]:
train_pos.shape

In [None]:
# Concat into one data frame; and reorder it
train = pd.concat([train_pos, train_neg]).reset_index()
train_reordered = train.sample(frac=1, random_state=3)

X_train, y_train = train_reordered[0], train_reordered['survival']

In [None]:
X_train.head()

In [None]:
# Load validation data
validation_pos = pd.read_csv(data_path+'validation_pos.txt', header=None)
validation_neg = pd.read_csv(data_path+'validation_neg.txt', header=None)

# Add target class
validation_pos['survival'] = [1 for i in range(validation_pos.shape[0])]
validation_neg['survival'] = [0 for i in range(validation_neg.shape[0])]

In [None]:
validation = pd.concat([validation_pos, validation_neg]).reset_index()
validation_reordered = validation.sample(frac=1, random_state=3)

X_val, y_val = validation_reordered[0], validation_reordered['survival']

In [None]:
X_val.head()

### Data preprocessing

#### 1.1 Conver all the events into sequence (token) ids

In [None]:
# Set the vocab size and max sequence lenght
vocab_size = 1100 #(max vocab id=1024 in the training data)
max_seq_length = 74 #(the maximum sequence length in training/testing data)

In [None]:
# Use a text tokenizer to convert events
tokenizer = Tokenizer(num_words = vocab_size)
tokenizer.fit_on_texts(X_train)

In [None]:
X_train_sequences = tokenizer.texts_to_sequences(X_train)
X_val_sequences = tokenizer.texts_to_sequences(X_val)

In [None]:
# Before texts_to_sequences()
print(f'Before texts_to_sequences():\n {X_train.iloc[0]}\n')

# After texts_to_sequences()
print(f'After texts_to_sequences():\n {X_train_sequences[0]}')

#### 1.2 Padding converted sequences

In [None]:
# Pad X_train_sequences and X_val_sequences
X_train_padded = sequence.pad_sequences(X_train_sequences, maxlen=max_seq_length, padding='post')
X_val_padded = sequence.pad_sequences(X_val_sequences, maxlen=max_seq_length, padding='post')

In [None]:
X_train_padded.shape

In [None]:
X_val_padded.shape

In [None]:
X_val_padded[0]

## 2. Train the main LSTM model for survival prediction

In [None]:
# For plotting the accuracy/loss of keras models
def plot_graphs(history, string):
    plt.plot(history.history[string])
    plt.plot(history.history['val_'+string])
    plt.xlabel("Epochs")
    plt.ylabel(string)
    plt.legend([string, 'val_'+string])
    plt.show()

In [None]:
# Fix the random seeds to get consistent models
## ref: https://keras.io/getting_started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
seed_value = 3

os.environ['PYTHONHASHSEED']=str(seed_value)

# The below is necessary for starting Numpy generated random numbers in a well-defined initial state.
np.random.seed(seed_value)

# The below is necessary for starting core Python generated random numbers in a well-defined state.
python_random.seed(seed_value)

# The below set_seed() will make random number generation
tf.random.set_seed(seed_value)

# configure a new global `tensorflow` session
session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
tf.compat.v1.keras.backend.set_session(sess)

def reset_seeds(seed_value=3):
    os.environ['PYTHONHASHSEED']=str(seed_value)
    np.random.seed(seed_value) 
    python_random.seed(seed_value)
    tf.random.set_seed(seed_value)

reset_seeds() 

In [None]:
# Define the early stopping criteria
early_stopping = EarlyStopping(monitor='val_accuracy', patience=3)

In [None]:
# Define the model structure
# Input for variable-length sequences of integers
inputs = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 128-dimensional vector
x = layers.Embedding(vocab_size, 128)(inputs)

# Add 2 bidirectional LSTMs
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(x)
x = layers.Bidirectional(layers.LSTM(64))(x)

# Add a classifier
outputs = layers.Dense(1, activation="sigmoid")(x)
main_model = keras.Model(inputs, outputs)

main_model.summary()

In [None]:
main_model.compile("adam", "binary_crossentropy", metrics=["accuracy"])

reset_seeds()
model_history = main_model.fit(
    X_train_padded, 
    y_train, 
    epochs=30, 
    batch_size=64, 
    validation_data=(X_val_padded, y_val), 
    callbacks=[early_stopping])

In [None]:
# Visualize the training/validation accuracy and loss
plot_graphs(model_history, "accuracy")
plot_graphs(model_history, "loss")

In [None]:
# Get the predicted target class: if pred > 0.5, then y_pred = 1; else, y_pred = 0
y_pred = np.array([1 if pred > 0.5 else 0 for pred in main_model.predict(X_val_padded)])

In [None]:
# Calculate the validation accuracy
validation_acc = sum(y_pred == y_val)/len(y_val)
validation_acc

In [None]:
# Get the confusion matrix
confusion_matrix_df = pd.DataFrame(
        confusion_matrix(y_true=y_val, y_pred=y_pred, labels=[1, 0]),
        index=['True:pos', 'True:neg'], 
        columns=['Pred:pos', 'Pred:neg']
    )
confusion_matrix_df

In [None]:
# Counts of positive and negative predictions
pd.value_counts(y_pred)

## 3. Get the negative predictions from LSTM, for counterfactual explanations

In [None]:
# Get these instances of negative predictions
X_pred_negative = X_val_padded[y_pred == 0]

In [None]:
X_pred_negative.shape

#### Export as the desired input format of the DRG framework

In [None]:
# Convert negatively predicted instances back to medical event form
original_event_sequences = tokenizer.sequences_to_texts(X_pred_negative)

In [None]:
original_event_sequences[:5]

In [None]:
pd.DataFrame(original_event_sequences).to_csv(path_or_buf='../mimic_data/test_neg.txt', index=False, header=False, sep=' ', quoting = csv.QUOTE_NONE, escapechar = ' ')

Here, we need to use the inference script from the DRG framework (instructions in the README file) to modify those 110 negative predictions into positive instances. After that, we import the transformed results as below.

### 3.1 DeleteOnly model results

In [None]:
# Load the transformed data
results_path = '../pred_delete2/'
trans_results_delete = pd.read_csv(results_path+'preds', header=None)

In [None]:
X_test_sequences = tokenizer.texts_to_sequences(trans_results_delete[0])

X_test_padded = sequence.pad_sequences(X_test_sequences, maxlen=max_seq_length, padding='post')

### 3.2 DeleteAndRetrieve model results

In [None]:
# Load the transformed data
delete_generate_results_path = '../pred_delete_retrieve2/'
delete_generate_results = pd.read_csv(delete_generate_results_path+'preds', header=None)

In [None]:
X_test_sequences2 = tokenizer.texts_to_sequences(delete_generate_results[0])

X_test_padded2 = sequence.pad_sequences(X_test_sequences2, maxlen=max_seq_length, padding='post')

### 3.3 Use 1NN baseline method to modify the negatively predicted instances

In [None]:
# Fit an unsupervised 1NN with all the positive seuquences, using 'hamming' distance
nn_model = NearestNeighbors(1, metric='hamming')

target_label = 1 
X_target_label = X_train_padded[y_train == target_label]

nn_model.fit(X_target_label)

In [None]:
# Find the closest neighbor (positive sequence) with the minimum 'hamming' distance, take it as a counterfactual
closest = nn_model.kneighbors(X_pred_negative, return_distance=False)
trans_results_nn = X_target_label[closest[:, 0]]

trans_results_nn[0]

In [None]:
# Rename 'trans_results_nn' to 'X_test_padded3' for result comparison
X_test_padded3 = trans_results_nn

### 3.4 Convert transformed results to event sequence format

In [None]:
# Convert transformed sequences back to the form of original event sequences
trans_event_sequences1 = tokenizer.sequences_to_texts(X_test_padded)
trans_event_sequences2 = tokenizer.sequences_to_texts(X_test_padded2)
trans_event_sequences3 = tokenizer.sequences_to_texts(X_test_padded3)

## 4. Results comparison

### 4.1 Comparison between fraction of valid CFs (i.e. successfully generated counterfactuals)

In [None]:
# Get the total counts 
test_size = X_pred_negative.shape[0]

In [None]:
# Fraction of valid transformed sequences, for DeleteOnly
fraction_success = np.sum(main_model.predict(X_test_padded) > 0.5)/test_size
print(round(fraction_success, 4))

In [None]:
# For DeleteAndRetrieve
fraction_success2 = np.sum(main_model.predict(X_test_padded2) > 0.5)/test_size
print(round(fraction_success2, 4))

In [None]:
# For 1NN modification
fraction_success3 = np.sum(main_model.predict(X_test_padded3) > 0.5)/test_size
print(round(fraction_success3, 4))

### 4.2 Local outlier factor (LOF score)

In [None]:
# Fit the model for novelty detection (novelty=True), in order to get LOF score
clf = LocalOutlierFactor(n_neighbors=20, novelty=True, contamination=0.1)
clf.fit(X_train_padded)

In [None]:
# Get the LOF score for leave-out validation data
y_pred_val = clf.predict(X_val_padded)

n_error_val = y_pred_val[y_pred_val == -1].size

In [None]:
validation_size = X_val_padded.shape[0]
outlier_score_val = n_error_val/validation_size

outlier_score_val

In [None]:
# Get the LOF score for DeleteOnly results
y_pred_test = clf.predict(X_test_padded)
n_error_test = y_pred_test[y_pred_test == -1].size

outlier_score_test = n_error_test / test_size
print(round(outlier_score_test, 4))

In [None]:
# Get the outlier score for DeleteAndRetrieve results
y_pred_test2 = clf.predict(X_test_padded2)
n_error_test2 = y_pred_test2[y_pred_test2 == -1].size

outlier_score_test2 = n_error_test2 / test_size
print(round(outlier_score_test2, 4))

In [None]:
# Outlier score for 1NN baseline method
y_pred_test3 = clf.predict(X_test_padded3)
n_error_test3 = y_pred_test3[y_pred_test3 == -1].size

outlier_score_test3 = n_error_test3 / test_size
print(round(outlier_score_test3, 4))

### 4.3 BLEU-4 score (cumulative 4-gram BLEU score) 

In [None]:
# Define smoothing function
chencherry = SmoothingFunction()

In [None]:
# Define a function to get pairwise BLEU scores
def get_pairwise_bleu(original, transformed):
    # 'weights=[0.25, 0.25, 0.25, 0.25]' means that calculate 4-gram BLEU scores cumulatively
    results = [sentence_bleu(
        references=[pair[0].split()], 
        hypothesis=pair[1].split(), 
        weights=[0.25, 0.25, 0.25, 0.25], 
        smoothing_function=chencherry.method1) 
        for pair in zip(original, transformed)]
    
    return results

In [None]:
pairwise_bleu = get_pairwise_bleu(original_event_sequences, trans_event_sequences1)
avg_bleu = sum(pairwise_bleu)/test_size
print(round(avg_bleu, 4))

In [None]:
pairwise_bleu2 = get_pairwise_bleu(original_event_sequences, trans_event_sequences2)
avg_bleu2 = sum(pairwise_bleu2)/test_size
print(round(avg_bleu2, 4))

In [None]:
pairwise_bleu3 = get_pairwise_bleu(original_event_sequences, trans_event_sequences3)
avg_bleu3 = sum(pairwise_bleu3)/test_size
print(round(avg_bleu3, 4))

#### 4.3.1 Plot histograms of individual BLEU-4 scores

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16,4))     

plt.sca(ax[0])
plt.title('DeleteOnly, BLUE score')
plt.hist(pairwise_bleu, density=True, bins=30)

plt.sca(ax[1])
plt.title('DeleteAndRetrieve, BLUE score')
plt.hist(pairwise_bleu2, density=True, bins=30)

plt.sca(ax[2])
plt.title('1-NN, BLUE score')
plt.hist(pairwise_bleu3, density=True, bins=30)

plt.show()

### 4.4 Plot histograms of event count differences (modification counts)

In [None]:
# The counts of total, drug events and procedures, for the original sequences
original_counts = pd.DataFrame(columns=['total', 'drug', 'procedure'])

In [None]:
def get_counts_table(event_sequences):
    temp_list = list()
    for seq in event_sequences:
        splitted = seq.split()
        total = len(splitted)
        # MetaVision ITEMID values are all above 220000. Since this data only contains data from MetaVision, it only contains ITEMID above 220000
        drug = len([x for x in splitted if int(x)>=220000])
        procedure = total - drug

        temp_list.append({'total': total, 'drug': drug, 'procedure': procedure})
    
    return pd.DataFrame(temp_list)

In [None]:
df_original_counts = get_counts_table(original_event_sequences)

In [None]:
df_original_counts.head()

In [None]:
# Get count tables for all the tranformed results (generated counterfactuals)
trans_counts1 = get_counts_table(trans_event_sequences1)
trans_counts2 = get_counts_table(trans_event_sequences2)
trans_counts3 = get_counts_table(trans_event_sequences3)

In [None]:
# Substract the original counts to get event modifications for total, drug events and procedures
substracted1 = trans_counts1.subtract(df_original_counts)
substracted2 = trans_counts2.subtract(df_original_counts)
substracted3 = trans_counts3.subtract(df_original_counts)

In [None]:
# Plot 3x3 subplots
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(16,12))

plt.sca(ax[0,0])
plt.title('DeleteOnly, total difference')
plt.hist(substracted1['total'], density=True, bins=30)

plt.sca(ax[0,1])
plt.title('DeleteOnly, drug event difference')
plt.hist(substracted1['drug'], density=True, bins=30)

plt.sca(ax[0,2])
plt.title('DeleteOnly, procedure difference')
plt.hist(substracted1['procedure'], density=True, bins=12)

plt.sca(ax[1,0])
plt.title('DeleteAndRetrieve, total difference')
plt.hist(substracted2['total'], density=True, bins=30)

plt.sca(ax[1,1])
plt.title('DeleteAndRetrieve, drug event difference')
plt.hist(substracted2['drug'], density=True, bins=30)

plt.sca(ax[1,2])
plt.title('DeleteAndRetrieve, procedure difference')
plt.hist(substracted2['procedure'], density=True, bins=12)

plt.sca(ax[2,0])
plt.title('1-NN, total difference')
plt.hist(substracted3['total'], density=True, bins=30)

plt.sca(ax[2,1])
plt.title('1-NN, drug event difference')
plt.hist(substracted3['drug'], density=True, bins=30)

plt.sca(ax[2,2])
plt.title('1-NN, procedure difference')
plt.hist(substracted3['procedure'], density=True, bins=12)

plt.show()

### 4.5 Export example counterfactuals 

In [None]:
# Convert id to original event name
conn = psycopg2.connect(
    database="mimic", 
    user=$your_username$, 
    password=getpass.getpass("Enter postgres password"), 
    host="127.0.0.1", 
    port="5432",
    options=f'-c search_path=mimiciii')

In [None]:
# Get a mapping from itemid to name (drug events)
itemid_to_name = pd.read_sql(
    """
    SELECT itemid, abbreviation, label
    FROM d_items;
    """, conn)

itemid_to_name = itemid_to_name[itemid_to_name['itemid'] >= 220000]
itemid_to_name.head()

In [None]:
# Get another mapping from procedure itemid to name 
itemid_to_name2 = pd.read_sql(
    """
    SELECT icd9_code, short_title, long_title
    FROM d_icd_procedures;
    """, conn)

itemid_to_name2.head()

In [None]:
# Concate two itemid_to_name into one table
itemid_to_name2 = itemid_to_name2.rename(columns={'icd9_code': 'itemid', 'short_title': 'abbreviation', 'long_title': 'label'})

itemid_to_name_concat = pd.concat([itemid_to_name, itemid_to_name2])

In [None]:
# Convert data type to be consistent when filtering, e.g. 'itemid_to_name_concat['itemid'] == 9671'
itemid_to_name_concat['label'] = itemid_to_name_concat['label'].astype('str') 
itemid_to_name_concat['itemid'] = itemid_to_name_concat['itemid'].astype('int') 

In [None]:
# Define a method to convert event codes to original names
def code_to_name(event_sequence):
    code_sequence = [int(event) for event in event_sequence.split()]
    
    temp_list = list()
    for code in code_sequence:
        event_name = itemid_to_name_concat[itemid_to_name_concat['itemid'] == code]['label'].item()
        temp_list.append(event_name)
    
    return temp_list

In [None]:
# Get the sample for example counterfactuals
sample_id = 44

In [None]:
code_to_name(original_event_sequences[sample_id])

In [None]:
# original_event_sequences[sample_id]

In [None]:
code_to_name(trans_event_sequences1[sample_id])

In [None]:
# trans_event_sequences1[sample_id]

In [None]:
code_to_name(trans_event_sequences2[sample_id])

In [None]:
# trans_event_sequences2[sample_id]

In [None]:
code_to_name(trans_event_sequences3[sample_id])

In [None]:
# trans_event_sequences3[sample_id]