In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, KBinsDiscretizer
from sklearn.impute import SimpleImputer
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
import time
from sklearn.metrics import accuracy_score
from sympy import simplify_logic
from sklearn.model_selection import train_test_split

from lens.utils.base import validate_network, set_seed, tree_to_formula
from lens.utils.relu_nn import get_reduced_model, prune_features
from lens.utils.psi_nn import prune_equal_fanin
from lens import logic
import lens

set_seed(0)

#%%

data = pd.read_csv('data/mimic-ii/full_cohort_data.csv')
# data.drop('hgb_first')
fs = [
    'aline_flg',
    'gender_num',
    # 'hosp_exp_flg',
    # 'icu_exp_flg',
    # 'day_28_flg',
    # 'censor_flg',
    'sepsis_flg', 'chf_flg', 'afib_flg',
    'renal_flg', 'liver_flg', 'copd_flg', 'cad_flg', 'stroke_flg',
    'mal_flg', 'resp_flg',
]
features = fs
data1 = data[fs].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data1 = imp_mean.fit_transform(data1)

f2 = fs.copy()
f2.append('day_icu_intime')
f2.append('service_unit')
f2.append('day_28_flg')
f2.append('hospital_los_day')
f2.append('icu_exp_flg')
f2.append('hosp_exp_flg')
f2.append('censor_flg')
f2.append('mort_day_censored')
f2 = data.columns.difference(f2)
data2 = data[f2].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data2 = imp_mean.fit_transform(data2)
scaler = MinMaxScaler((0, 1))
data2 = scaler.fit_transform(data2)
features = features + list(f2)
est = KBinsDiscretizer(n_bins=3, encode='onehot-dense', strategy='uniform')
data2d = est.fit_transform(data2)
f2d = []
for feature in f2:
    #f2d.append(feature + '_VLOW')
    f2d.append(feature + '_LOW')
    f2d.append(feature + '_NORMAL')
    f2d.append(feature + '_HIGH')
    #f2d.append(feature + '_VHIGH')
features = fs + f2d

datax = np.hstack((data1, data2d))
datay = data['day_28_flg'].values

x = torch.FloatTensor(datax)
y = torch.LongTensor(datay)
print(x.shape)
print(y.shape)


torch.Size([1776, 90])
torch.Size([1776])


In [2]:
results_dir = 'results/mimic'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# concepts = [f'c{i:03}' for i in range(x.shape[1])]
concepts =  features
n_rep = 10
tot_epochs = 5001
prune_epochs = 2001
seed = 42

n_splits = 10
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)

In [3]:
def train_nn(x_train, y_train, need_pruning, seed, device, l1=0.001, lr=0.001, relu=False, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        torch.nn.Linear(x_train.size(1), 100),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(100, 10),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(10, 2),
        torch.nn.Softmax(dim=1),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_form = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train)
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += l1 * torch.norm(module.weight, 1)
                loss += l1 * torch.norm(module.bias, 1)
                break

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning and epoch % 1000 == 1:
            prune_features(model, n_classes=1, device=device)
            need_pruning = True
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = torch.argmax(y_pred, dim=1)
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [4]:
def train_psi_nn(x_train, y_train, need_pruning, seed, device, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device).to(torch.float)
    layers = [
        torch.nn.Linear(x_train.size(1), 10),
        torch.nn.Sigmoid(),
        torch.nn.Linear(10, 4),
        torch.nn.Sigmoid(),
        torch.nn.Linear(4, 1),
        torch.nn.Sigmoid(),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.00001 * torch.norm(module.weight, 1)

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            model = prune_equal_fanin(model, 2, validate=True, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [14]:
def c_to_y(method, need_pruning, l1=0.001, lr=0.001, relu=False, verbose=False):
    methods = []
    splits = []
    explanations = []
    explanations_inv = []
    model_accuracies = []
    explanation_accuracies = []
    explanation_accuracies_inv = []
    elapsed_times = []
    elapsed_times_inv = []
    explanation_fidelities = []
    explanation_complexities = []
    for split, (trainval_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
        print(f'Split [{split+1}/{n_splits}]')
        x_trainval, x_test = torch.FloatTensor(x[trainval_index]), torch.FloatTensor(x[test_index])
        y_trainval, y_test = torch.LongTensor(y[trainval_index]), torch.LongTensor(y[test_index])
        x_train, x_val, y_train, y_val = train_test_split(x_trainval, y_trainval, test_size=0.3, random_state=42)
    
        explanation, explanation_inv = '', ''
        explanation_accuracy, explanation_accuracy_inv = 0, 0
        
        if method == 'tree':
            classifier = DecisionTreeClassifier(random_state=seed)
            classifier.fit(x_trainval.cpu().detach().numpy(), y_trainval.cpu().detach().numpy())
            y_preds = classifier.predict(x_test.cpu().detach().numpy())
            model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds)

            target_class = 1
            start = time.time()
            explanation = tree_to_formula(classifier, concepts, target_class)
            elapsed_time = time.time() - start
            explanation_accuracy = model_accuracy
            explanation_fidelity = 1.
            explanation_complexity = lens.logic.complexity(explanation)

            target_class_inv = 0
            start = time.time()
            explanation_inv = tree_to_formula(classifier, concepts, target_class_inv)
            elapsed_time_inv = time.time() - start
            explanation_accuracy_inv = model_accuracy
        
        else:
            if method == 'psi':
                # positive class
                target_class = 1
                model = train_psi_nn(x_trainval, y_trainval.eq(target_class), need_pruning, seed, device, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy()>0.5
                model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds)
                
            else:
                model = train_nn(x_trainval, y_trainval, need_pruning, seed, device, l1, lr, relu, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy().argmax(axis=1)
                model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds)

            # positive class
            target_class = 1
            start = time.time()
            if method == 'psi':
                global_explanation = logic.generate_fol_explanations(model, device)[0]
            else:
                global_explanation, _, _ = logic.relu_nn.combine_local_explanations(model, 
                                                                                   x_val.to(device), 
                                                                                   y_val.to(device), 
                                                                                   topk_explanations=3,
                                                                                   target_class=target_class,
                                                                                   method=method, device=device)
            elapsed_time = time.time() - start
            
            if global_explanation:
#                 explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
                explanation = logic.base.replace_names(global_explanation, concepts)
                explanation_accuracy, y_formula = logic.base.test_explanation(global_explanation, 
                                                                              target_class, 
                                                                              x=x_test, y=y_test,
                                                                              metric=accuracy_score)
                explanation_fidelity = lens.logic.fidelity(y_formula, y_preds)
                explanation_complexity = lens.logic.complexity(global_explanation)

#             # negative class
#             target_class_inv = 0
#             if method == 'psi':
#                 model = train_psi_nn(x_trainval, y_trainval.eq(target_class_inv), need_pruning, seed, device, verbose)
            
#             start = time.time()
#             if method == 'psi':
#                 global_explanation_inv = logic.generate_fol_explanations(model, device)[0]
#             else:
#                 global_explanation_inv, _, _ = logic.relu_nn.combine_local_explanations(model, 
#                                                                                        x_val.to(device), 
#                                                                                        y_val.to(device), 
#                                                                                        topk_explanations=2,
#                                                                                        target_class=target_class_inv,
#                                                                                        method=method, device=device)
#             elapsed_time_inv = time.time() - start
#             if global_explanation_inv:
#                 explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, 
#                                                                           target_class_inv, x_test, y_test)
#                 explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
        
        if verbose:
            print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
            print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
            print(f'\t Elapsed time {elapsed_time}')
            print(f'\t Fidelity: "{explanation_fidelity:.4f}" - Complexity: "{explanation_complexity}"')
#             print(f'\t Class {target_class_inv} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
#             print(f'\t Elapsed time {elapsed_time_inv}')

        methods.append(method)
        splits.append(split)
        explanations.append(explanation)
#         explanations_inv.append(explanation_inv)
        model_accuracies.append(model_accuracy)
        explanation_accuracies.append(explanation_accuracy)
#         explanation_accuracies_inv.append(explanation_accuracy_inv)
        elapsed_times.append(elapsed_time)
#         elapsed_times_inv.append(elapsed_time_inv)
        explanation_fidelities.append(explanation_fidelity)
        explanation_complexities.append(explanation_complexity)
    
    explanation_consistency = lens.logic.formula_consistency(explanations)
    print(f'Consistency of explanations: {explanation_consistency:.4f}')
    
    results = pd.DataFrame({
        'method': methods,
        'split': splits,
        'explanation': explanations,
#         'explanation_inv': explanations_inv,
        'model_accuracy': model_accuracies,
        'explanation_accuracy': explanation_accuracies,
        'explanation_fidelity': explanation_fidelities,
        'explanation_complexity': explanation_complexities,
        'explanation_consistency': explanation_consistency,
#         'explanation_accuracy_inv': explanation_accuracies_inv,
        'elapsed_time': elapsed_times,
#         'elapsed_time_inv': elapsed_times_inv,
    })
    results.to_csv(os.path.join(results_dir, f'results_{method}.csv'))
    
    return results

# General pruning

In [6]:
method = 'pruning'
need_pruning = True
relu = False
results_pruning = c_to_y(method, need_pruning, relu=relu)
results_pruning

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]
Consistency of explanations: 0.3222


Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,pruning,0,stroke_flg,0.842697,0.842697,1.0,1,0.322222,0.043916
1,pruning,1,(stroke_flg & bun_first_LOW & ~sapsi_first_LOW...,0.837079,0.837079,0.988764,7,0.322222,0.077024
2,pruning,2,stroke_flg,0.853933,0.853933,1.0,1,0.322222,0.040891
3,pruning,3,stroke_flg & age_HIGH,0.865169,0.865169,1.0,2,0.322222,0.072806
4,pruning,4,(stroke_flg & age_HIGH & ~liver_flg) | (age_HI...,0.865169,0.865169,1.0,6,0.322222,0.09076
5,pruning,5,(stroke_flg & age_HIGH & bun_first_LOW) | (age...,0.859551,0.865169,0.994382,7,0.322222,0.053856
6,pruning,6,(stroke_flg & age_HIGH) | (age_HIGH & ~bun_fir...,0.847458,0.847458,1.0,4,0.322222,0.058986
7,pruning,7,(stroke_flg & age_HIGH) | (age_HIGH & ~bun_fir...,0.864407,0.864407,1.0,4,0.322222,0.049867
8,pruning,8,(stroke_flg & abg_count_LOW & bun_first_LOW & ...,0.830508,0.824859,0.99435,9,0.322222,0.115722
9,pruning,9,stroke_flg & ~sapsi_first_LOW,0.875706,0.875706,1.0,2,0.322222,0.051389


# ReLUs

In [7]:
method = 'weights'
need_pruning = False
relu = True
results_weights = c_to_y(method, need_pruning, 0.008, lr=0.0005, relu=relu, verbose=True)
results_weights

Split [1/10]
	 Epoch 0: train accuracy: 0.1589
	 Epoch 500: train accuracy: 0.8411
	 Epoch 1000: train accuracy: 0.8411
	 Epoch 1500: train accuracy: 0.8411
	 Epoch 2000: train accuracy: 0.8411
	 Epoch 2500: train accuracy: 0.8411
	 Epoch 3000: train accuracy: 0.8411
	 Epoch 3500: train accuracy: 0.8411
	 Epoch 4000: train accuracy: 0.8780
	 Epoch 4500: train accuracy: 0.8842
	 Epoch 5000: train accuracy: 0.8936
	 Model's accuracy: 0.8596
	 Class 1 - Global explanation: "(tco2_first_NORMAL & ~renal_flg) | (age_HIGH & ~stroke_flg & ~weight_first_NORMAL)" - Accuracy: 0.3483
	 Elapsed time 0.25098729133605957
	 Fidelity: "0.2753" - Complexity: "5"
Split [2/10]
	 Epoch 0: train accuracy: 0.1589
	 Epoch 500: train accuracy: 0.8411
	 Epoch 1000: train accuracy: 0.8411
	 Epoch 1500: train accuracy: 0.8411
	 Epoch 2000: train accuracy: 0.8411
	 Epoch 2500: train accuracy: 0.8411
	 Epoch 3000: train accuracy: 0.8411
	 Epoch 3500: train accuracy: 0.8411
	 Epoch 4000: train accuracy: 0.8411
	 Epo

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,weights,0,(tco2_first_NORMAL & ~renal_flg) | (age_HIGH &...,0.859551,0.348315,0.275281,5,0.194444,0.250987
1,weights,1,,0.837079,0.0,0.275281,5,0.194444,0.001992
2,weights,2,age_HIGH | (bun_first_LOW & ~mal_flg & ~age_LO...,0.859551,0.657303,0.674157,6,0.194444,0.279933
3,weights,3,,0.842697,0.0,0.674157,6,0.194444,0.001995
4,weights,4,,0.842697,0.0,0.674157,6,0.194444,0.001503
5,weights,5,~stroke_flg | ~sofa_first_LOW | (bun_first_LOW...,0.859551,0.168539,0.140449,5,0.194444,0.094747
6,weights,6,,0.841808,0.0,0.140449,5,0.194444,0.001995
7,weights,7,age_HIGH | (stroke_flg & hgb_first_HIGH & sofa...,0.881356,0.774011,0.779661,8,0.194444,0.742731
8,weights,8,(bun_first_LOW & sofa_first_NORMAL) | (hour_ic...,0.841808,0.463277,0.440678,6,0.194444,0.142297
9,weights,9,bun_first_LOW | ~stroke_flg,0.875706,0.158192,0.112994,2,0.194444,0.224867


# Psi network

In [12]:
method = 'psi'
need_pruning = True
relu = False
results_psi = c_to_y(method, need_pruning, relu, verbose=False)
results_psi

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]
Consistency of explanations: 0.3714


Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,psi,0,(stroke_flg | bun_first_NORMAL | sapsi_first_H...,0.837079,0.831461,0.848315,3,0.371429,0.030917
1,psi,1,(stroke_flg | abg_count_NORMAL),0.837079,0.814607,0.853933,2,0.371429,0.026901
2,psi,2,(stroke_flg | abg_count_HIGH | sapsi_first_HIGH),0.837079,0.842697,0.837079,3,0.371429,0.030916
3,psi,3,(stroke_flg | sapsi_first_HIGH),0.842697,0.842697,0.876404,2,0.371429,0.023905
4,psi,4,(stroke_flg | bun_first_NORMAL),0.842697,0.831461,0.842697,2,0.371429,0.025898
5,psi,5,(stroke_flg | bun_first_HIGH | chloride_first_...,0.842697,0.814607,0.825843,3,0.371429,0.031914
6,psi,6,(stroke_flg | bun_first_HIGH | chloride_first_...,0.836158,0.79096,0.819209,3,0.371429,0.029914
7,psi,7,(stroke_flg | bun_first_NORMAL | bun_first_HIGH),0.841808,0.847458,0.870056,3,0.371429,0.048869
8,psi,8,(stroke_flg | bun_first_NORMAL),0.841808,0.80791,0.819209,2,0.371429,0.027926
9,psi,9,(stroke_flg | chloride_first_HIGH | sapsi_firs...,0.841808,0.858757,0.870056,3,0.371429,0.04089


# Decision tree

In [15]:
method = 'tree'
need_pruning = False
relu = False
results_tree = c_to_y(method, need_pruning, relu)
results_tree

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]
Consistency of explanations: 0.8150


Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,tree,0,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & resp_...,0.803371,0.803371,1.0,1363,0.815033,0.002984
1,tree,1,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.803371,0.803371,1.0,1424,0.815033,0.002992
2,tree,2,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & age_N...,0.808989,0.808989,1.0,1345,0.815033,0.002992
3,tree,3,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & age_N...,0.780899,0.780899,1.0,1391,0.815033,0.002992
4,tree,4,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.786517,0.786517,1.0,1387,0.815033,0.003986
5,tree,5,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,0.848315,0.848315,1.0,1319,0.815033,0.003986
6,tree,6,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,0.762712,0.762712,1.0,1271,0.815033,0.002988
7,tree,7,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,0.813559,0.813559,1.0,1321,0.815033,0.004001
8,tree,8,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,0.774011,0.774011,1.0,1353,0.815033,0.003003
9,tree,9,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.774011,0.774011,1.0,1252,0.815033,0.004003


# Summary

In [17]:
cols = ['model_accuracy', 'explanation_accuracy', 'explanation_fidelity', 
        'explanation_complexity', 'elapsed_time', 'explanation_consistency']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

# pruning
df_mean = results_pruning[cols].mean()
df_sem = results_pruning[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_pruning = pd.concat([df_mean, df_sem])
summary_pruning.name = 'pruning'

# # lime
# df_mean = results_lime[cols].mean()
# df_sem = results_lime[cols].sem()
# df_mean.columns = mean_cols
# df_sem.columns = sem_cols
# summary_lime = pd.concat([df_mean, df_sem])
# summary_lime.name = 'lime'

# weights
df_mean = results_weights[cols].mean()
df_sem = results_weights[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_weights = pd.concat([df_mean, df_sem])
summary_weights.name = 'weights'

# psi
df_mean = results_psi[cols].mean()
df_sem = results_psi[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_psi = pd.concat([df_mean, df_sem])
summary_psi.name = 'psi'

# tree
df_mean = results_tree[cols].mean()
df_sem = results_tree[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_tree = pd.concat([df_mean, df_sem])
summary_tree.name = 'tree'

summary = pd.concat([summary_pruning, 
#                      summary_lime, 
                     summary_weights, 
                     summary_psi, 
                     summary_tree], axis=1).T
summary.columns = mean_cols + sem_cols
summary

Unnamed: 0,model_accuracy_mean,explanation_accuracy_mean,explanation_fidelity_mean,explanation_complexity_mean,elapsed_time_mean,explanation_consistency_mean,model_accuracy_sem,explanation_accuracy_sem,explanation_fidelity_sem,explanation_complexity_sem,elapsed_time_sem,explanation_consistency_sem
pruning,0.854167,0.854164,0.99775,4.3,0.065522,0.322222,0.004556,0.005011,0.001243,0.895048,0.007483,1.850372e-17
weights,0.85418,0.256964,0.418727,5.4,0.174305,0.194444,0.004888,0.091907,0.082731,0.476095,0.072187,0.0
psi,0.840091,0.828261,0.84628,2.6,0.031805,0.371429,0.000894,0.006595,0.006734,0.163299,0.002393,0.0
tree,0.795575,0.795575,1.0,1342.6,0.003393,0.815033,0.007968,0.007968,0.0,16.94055,0.000164,3.700743e-17


In [18]:
summary.to_csv(os.path.join(results_dir, 'summary.csv'))