In [65]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, KBinsDiscretizer
from sklearn.impute import SimpleImputer
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
import time
from sklearn.metrics import accuracy_score
from sympy import simplify_logic

from deep_logic.utils.base import validate_network, set_seed, tree_to_formula
from deep_logic.utils.relunn import get_reduced_model, prune_features
from deep_logic.utils.sigmoidnn import prune_equal_fanin
from deep_logic import logic

set_seed(0)

#%%

data = pd.read_csv('data/mimic-ii/full_cohort_data.csv')
# data.drop('hgb_first')
fs = [
    'aline_flg',
    'gender_num',
    # 'hosp_exp_flg',
    # 'icu_exp_flg',
    # 'day_28_flg',
    # 'censor_flg',
    'sepsis_flg', 'chf_flg', 'afib_flg',
    'renal_flg', 'liver_flg', 'copd_flg', 'cad_flg', 'stroke_flg',
    'mal_flg', 'resp_flg',
]
features = fs
data1 = data[fs].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data1 = imp_mean.fit_transform(data1)

f2 = fs.copy()
f2.append('day_icu_intime')
f2.append('service_unit')
f2.append('day_28_flg')
f2.append('hospital_los_day')
f2.append('icu_exp_flg')
f2.append('hosp_exp_flg')
f2.append('censor_flg')
f2.append('mort_day_censored')
f2 = data.columns.difference(f2)
data2 = data[f2].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data2 = imp_mean.fit_transform(data2)
scaler = MinMaxScaler((0, 1))
data2 = scaler.fit_transform(data2)
features = features + list(f2)
est = KBinsDiscretizer(n_bins=3, encode='onehot-dense', strategy='uniform')
data2d = est.fit_transform(data2)
f2d = []
for feature in f2:
    #f2d.append(feature + '_VLOW')
    f2d.append(feature + '_LOW')
    f2d.append(feature + '_NORMAL')
    f2d.append(feature + '_HIGH')
    #f2d.append(feature + '_VHIGH')
features = fs + f2d

datax = np.hstack((data1, data2d))
datay = data['day_28_flg'].values

x = torch.FloatTensor(datax)
y = torch.LongTensor(datay)
print(x.shape)
print(y.shape)


torch.Size([1776, 90])
torch.Size([1776])


In [66]:
results_dir = 'results/mimic'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# concepts = [f'c{i:03}' for i in range(x.shape[1])]
concepts =  features
n_rep = 10
tot_epochs = 5001
prune_epochs = 2001

n_splits = 10
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)

# General pruning

In [67]:
def train_nn(x_train, y_train, need_pruning, seed, device, relu=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        torch.nn.Linear(x_train.size(1), 100),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(100, 50),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(50, 10),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(10, 2),
        torch.nn.Softmax(dim=1),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    loss_form = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train)
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.001 * torch.norm(module.weight, 1)
                loss += 0.001 * torch.norm(module.bias, 1)
                break

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            prune_features(model, n_classes=1, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0:
            y_pred_d = torch.argmax(y_pred, dim=1)
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [79]:
need_pruning = True
method = 'pruning'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []
for split, (train_index, test_index) in enumerate(skf.split(x, y)):
    print(f'Split [{split+1}/{n_splits}]')
    x_train, x_test = torch.FloatTensor(x[train_index]), torch.FloatTensor(x[test_index])
    y_train, y_test = torch.LongTensor(y[train_index]), torch.LongTensor(y[test_index])
    
#     if split not in [5]: continue
    
    model = train_nn(x_train, y_train, need_pruning, split, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    # positive class
    target_class = 1
    start = time.time()
    global_explanation, _, counter = logic.relunn.combine_local_explanations(model, 
                                                                       x_train.to(device), y_train.to(device), 
                                                                       target_class=target_class,
                                                                       topk_explanations=3,
                                                                       method=method, device=device)
    elapsed_time = time.time() - start
    if global_explanation:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    start = time.time()
    global_explanation_inv, _, counter_inv = logic.relunn.combine_local_explanations(model, 
                                                                           x_train.to(device), y_train.to(device), 
                                                                           target_class=target_class,
                                                                           topk_explanations=3,
                                                                           method=method, device=device)
    elapsed_time_inv = time.time() - start
    if global_explanation_inv:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(split)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Split [1/10]
	 Epoch 0: train accuracy: 0.7910
	 Epoch 500: train accuracy: 0.8411
	 Epoch 1000: train accuracy: 0.9168
	 Epoch 1500: train accuracy: 0.9193
	 Epoch 2000: train accuracy: 0.9155
	 Epoch 2500: train accuracy: 0.8748
	 Epoch 3000: train accuracy: 0.8748
	 Epoch 3500: train accuracy: 0.8748
	 Epoch 4000: train accuracy: 0.8748
	 Epoch 4500: train accuracy: 0.8748
	 Epoch 5000: train accuracy: 0.8748
	 Model's accuracy: 0.8596
	 Class 1 - Global explanation: "(afib_flg & age_HIGH & ~bun_first_LOW) | (afib_flg & stroke_flg & ~sapsi_first_LOW & ~weight_first_NORMAL) | (stroke_flg & age_HIGH & ~sapsi_first_LOW & ~weight_first_NORMAL)" - Accuracy: 0.8596
	 Elapsed time 0.5704751014709473
	 Class 0 - Global explanation: "(bun_first_LOW & sapsi_first_LOW) | (bun_first_LOW & weight_first_NORMAL) | (bun_first_LOW & ~stroke_flg)" - Accuracy: 0.8427
	 Elapsed time 2.38761830329895
Split [2/10]
	 Epoch 0: train accuracy: 0.1589
	 Epoch 500: train accuracy: 0.8411
	 Epoch 1000: train a

In [69]:
results_pruning = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_pruning.to_csv(os.path.join(results_dir, 'results_pruning.csv'))
results_pruning

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,pruning,0,stroke_flg & age_HIGH & ~sapsi_first_LOW & ~we...,bun_first_LOW & ~stroke_flg,0.859551,0.865169,0.831461,0.552446,2.299253
1,pruning,1,stroke_flg & age_HIGH,bun_first_LOW & ~age_HIGH,0.870787,0.876404,0.764045,0.468843,1.945834
2,pruning,2,stroke_flg & age_HIGH,~age_HIGH,0.870787,0.865169,0.735955,0.482188,2.105507
3,pruning,3,stroke_flg & ~sapsi_first_LOW & ~weight_first_...,~stroke_flg,0.831461,0.831461,0.808989,0.53266,2.089467
4,pruning,4,stroke_flg & ~bun_first_LOW,bun_first_LOW,0.842697,0.842697,0.808989,0.538386,1.759399
5,pruning,5,stroke_flg & age_HIGH,~age_HIGH,0.870787,0.859551,0.792135,0.471477,2.032206
6,pruning,6,stroke_flg & age_HIGH & sofa_first_LOW & ~rena...,sapsi_first_LOW & sofa_first_LOW,0.864407,0.870056,0.39548,1.665268,5.964563
7,pruning,7,stroke_flg & age_HIGH & bun_first_LOW & ~sapsi...,bun_first_LOW & sapsi_first_LOW,0.892655,0.875706,0.497175,0.775244,2.862986
8,pruning,8,stroke_flg & age_HIGH,bun_first_LOW & ~age_HIGH,0.836158,0.841808,0.740113,0.515775,2.563376
9,pruning,9,stroke_flg & age_HIGH,~age_HIGH,0.858757,0.870056,0.813559,0.42438,1.795328


# ReLUs

In [70]:
need_pruning = False
method = 'weights'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []

for split, (train_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
    print(f'Split [{split+1}/{n_splits}]')
    x_train, x_test = torch.FloatTensor(x[train_index]), torch.FloatTensor(x[test_index])
    y_train, y_test = torch.LongTensor(y[train_index]), torch.LongTensor(y[test_index])
    
    model = train_nn(x_train, y_train, need_pruning, split, device, relu=True)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    # positive class
    target_class = 1
    start = time.time()
    global_explanation, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                       x_train.to(device), y_train.to(device),
                                                                       topk_explanations=1, 
                                                                       target_class=target_class,
                                                                       method=method, device=device)
    elapsed_time = time.time() - start
    if global_explanation:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    start = time.time()
    global_explanation_inv, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                           x_train.to(device), y_train.to(device), 
                                                                           topk_explanations=1, 
                                                                           target_class=target_class,
                                                                           method=method, device=device)
    elapsed_time_inv = time.time() - start
    if global_explanation_inv:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(split)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Split [1/10]
	 Epoch 0: train accuracy: 0.7985
	 Epoch 500: train accuracy: 0.8411
	 Epoch 1000: train accuracy: 0.9318
	 Epoch 1500: train accuracy: 0.9368
	 Epoch 2000: train accuracy: 0.9406
	 Epoch 2500: train accuracy: 0.9380
	 Epoch 3000: train accuracy: 0.8411
	 Epoch 3500: train accuracy: 0.9099
	 Epoch 4000: train accuracy: 0.9249
	 Epoch 4500: train accuracy: 0.9249
	 Epoch 5000: train accuracy: 0.9262
	 Model's accuracy: 0.8315
	 Class 1 - Global explanation: "bun_first_LOW & ~renal_flg & ~stroke_flg" - Accuracy: 0.1629
	 Elapsed time 1.2202937602996826
	 Class 0 - Global explanation: "bun_first_LOW & ~renal_flg & ~stroke_flg" - Accuracy: 0.8371
	 Elapsed time 3.742262363433838
Split [2/10]
	 Epoch 0: train accuracy: 0.1589
	 Epoch 500: train accuracy: 0.8411
	 Epoch 1000: train accuracy: 0.9368
	 Epoch 1500: train accuracy: 0.9399
	 Epoch 2000: train accuracy: 0.8411
	 Epoch 2500: train accuracy: 0.8598
	 Epoch 3000: train accuracy: 0.8817
	 Epoch 3500: train accuracy: 0.92

In [71]:
results_weights = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_weights.to_csv(os.path.join(results_dir, 'results_weights.csv'))
results_weights

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,weights,0,bun_first_LOW & ~renal_flg & ~stroke_flg,bun_first_LOW & ~renal_flg & ~stroke_flg,0.831461,0.162921,0.837079,1.220294,3.742262
1,weights,1,bun_first_LOW & ~renal_flg & ~stroke_flg,bun_first_LOW,0.837079,0.162921,0.820225,0.71949,3.16263
2,weights,2,~stroke_flg & ~abg_count_NORMAL & ~hgb_first_L...,sapsi_first_LOW,0.882022,0.134831,0.488764,0.984264,4.657933
3,weights,3,bun_first_LOW & ~mal_flg & ~age_LOW & ~weight_...,bun_first_LOW & ~stroke_flg & ~hr_1st_HIGH,0.859551,0.567416,0.758427,1.043212,4.226703
4,weights,4,bun_first_LOW & ~stroke_flg & ~sapsi_first_LOW,bun_first_LOW & ~stroke_flg & ~sapsi_first_LOW,0.870787,0.455056,0.544944,1.417243,3.647221
5,weights,5,bun_first_LOW & ~stroke_flg,bun_first_LOW & ~stroke_flg,0.837079,0.146067,0.853933,0.977387,3.655289
6,weights,6,bun_first_LOW & ~stroke_flg,bun_first_LOW,0.841808,0.146067,0.779661,0.746969,3.202442
7,weights,7,bun_first_LOW & ~stroke_flg,tco2_first_NORMAL,0.841808,0.146067,0.683616,0.798894,3.619297
8,weights,8,bun_first_LOW & ~stroke_flg,sofa_first_LOW,0.841808,0.146067,0.531073,0.785931,3.9165
9,weights,9,bun_first_LOW & ~age_LOW & ~sapsi_first_LOW & ...,bun_first_LOW & sapsi_first_LOW,0.79661,0.649718,0.429379,1.065148,4.268564


# Psi network

In [72]:
def train_psi_nn(x_train, y_train, need_pruning, seed, device):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device).to(torch.float)
    layers = [
        torch.nn.Linear(x_train.size(1), 10),
        torch.nn.Sigmoid(),
        torch.nn.Linear(10, 4),
        torch.nn.Sigmoid(),
        torch.nn.Linear(4, 1),
        torch.nn.Sigmoid(),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.00001 * torch.norm(module.weight, 1)

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            model = prune_equal_fanin(model, 2, validate=True, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [73]:
need_pruning = True
method = 'psi'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []

for split, (train_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
    print(f'Split [{split+1}/{n_splits}]')
    x_train, x_test = torch.FloatTensor(x[train_index]), torch.FloatTensor(x[test_index])
    y_train, y_test = torch.LongTensor(y[train_index]), torch.LongTensor(y[test_index])
    
    # positive class
    target_class = 1
    model = train_psi_nn(x_train, y_train, need_pruning, split, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds > 0.5)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    start = time.time()
    global_explanation = logic.generate_fol_explanations(model, device)[0]
    elapsed_time = time.time() - start
    if global_explanation not in ['(False)', '(True)', '']:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
        print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
        print(f'\t Elapsed time {elapsed_time}')
    else:
        explanation_accuracy = 0
        explanation = global_explanation
        
    # negative class
    target_class = 0
    model = train_psi_nn(x_train, y_train.eq(target_class), need_pruning, split, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.eq(target_class).cpu().detach().numpy(), y_preds > 0.5)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    start = time.time()
    global_explanation_inv = logic.generate_fol_explanations(model, device)[0]
    elapsed_time_inv = time.time() - start
    if global_explanation_inv not in ['(False)', '(True)', '']:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, 
                                                                  target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
        print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
        print(f'\t Elapsed time {elapsed_time_inv}')
    else:
        explanation_accuracy_inv = 0
        explanation_inv = global_explanation_inv
    
    methods.append(method)
    splits.append(split)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Split [1/10]
	 Epoch 0: train accuracy: 0.1589
	 Epoch 500: train accuracy: 0.8411
	 Epoch 1000: train accuracy: 0.8411
	 Epoch 1500: train accuracy: 0.8411
	 Epoch 2000: train accuracy: 0.8723
	 Epoch 2500: train accuracy: 0.8411
	 Epoch 3000: train accuracy: 0.8411
	 Epoch 3500: train accuracy: 0.8411
	 Epoch 4000: train accuracy: 0.8411
	 Epoch 4500: train accuracy: 0.8411
	 Epoch 5000: train accuracy: 0.8404
	 Model's accuracy: 0.8371
	 Class 1 - Global explanation: "(copd_flg | stroke_flg)" - Accuracy: 0.7584
	 Elapsed time 0.0359339714050293
	 Epoch 0: train accuracy: 0.8411
	 Epoch 500: train accuracy: 0.8411
	 Epoch 1000: train accuracy: 0.8411
	 Epoch 1500: train accuracy: 0.8773
	 Epoch 2000: train accuracy: 0.9099
	 Epoch 2500: train accuracy: 0.8411
	 Epoch 3000: train accuracy: 0.8411
	 Epoch 3500: train accuracy: 0.8411
	 Epoch 4000: train accuracy: 0.8411
	 Epoch 4500: train accuracy: 0.8411
	 Epoch 5000: train accuracy: 0.8411
	 Model's accuracy: 0.8371
	 Class 0 - Glob

	 Epoch 4500: train accuracy: 0.8405
	 Epoch 5000: train accuracy: 0.8405
	 Model's accuracy: 0.8418
	 Class 1 - Global explanation: "(stroke_flg | sapsi_first_HIGH)" - Accuracy: 0.8531
	 Elapsed time 0.03194618225097656
	 Epoch 0: train accuracy: 0.8405
	 Epoch 500: train accuracy: 0.8405
	 Epoch 1000: train accuracy: 0.8405
	 Epoch 1500: train accuracy: 0.8856
	 Epoch 2000: train accuracy: 0.9250
	 Epoch 2500: train accuracy: 0.8405
	 Epoch 3000: train accuracy: 0.8405
	 Epoch 3500: train accuracy: 0.8405
	 Epoch 4000: train accuracy: 0.8405
	 Epoch 4500: train accuracy: 0.8405
	 Epoch 5000: train accuracy: 0.8405
	 Model's accuracy: 0.8418
Split [9/10]
	 Epoch 0: train accuracy: 0.8405
	 Epoch 500: train accuracy: 0.8405
	 Epoch 1000: train accuracy: 0.8405
	 Epoch 1500: train accuracy: 0.8768
	 Epoch 2000: train accuracy: 0.9131
	 Epoch 2500: train accuracy: 0.8405
	 Epoch 3000: train accuracy: 0.8405
	 Epoch 3500: train accuracy: 0.8405
	 Epoch 4000: train accuracy: 0.8405
	 Epoch

In [74]:
results_psi = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_psi.to_csv(os.path.join(results_dir, 'results_psi.csv'))
results_psi

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,psi,0,(copd_flg | stroke_flg),(~copd_flg & ~sapsi_first_HIGH),0.837079,0.758427,0.775281,0.035934,0.044879
1,psi,1,(stroke_flg),(~abg_count_HIGH & ~bun_first_NORMAL & (~strok...,0.831461,0.825843,0.820225,0.01898,0.045846
2,psi,2,(stroke_flg & ~sapsi_first_LOW),(True),0.837079,0.882022,0.0,0.020974,0.019979
3,psi,3,(stroke_flg | sapsi_first_HIGH),(~stroke_flg & (~copd_flg | ~bun_first_NORMAL)...,0.837079,0.797753,0.842697,0.025961,0.031879
4,psi,4,(stroke_flg | sofa_first_HIGH),(~stroke_flg & ~bun_first_NORMAL),0.870787,0.865169,0.842697,0.041893,0.025967
5,psi,5,(stroke_flg | sapsi_first_HIGH | (renal_flg & ...,(~copd_flg & ~stroke_flg & ~bun_first_HIGH & ~...,0.842697,0.859551,0.808989,0.037923,0.035934
6,psi,6,(stroke_flg | sapsi_first_HIGH),(~sapsi_first_HIGH & (~copd_flg | ~bun_first_N...,0.836158,0.841808,0.824859,0.023937,0.029918
7,psi,7,(stroke_flg | sapsi_first_HIGH),(True),0.841808,0.853107,0.0,0.031946,0.037929
8,psi,8,(False),(~stroke_flg & (~copd_flg | ~sapsi_first_HIGH)),0.830508,0.0,0.841808,0.027956,0.043883
9,psi,9,(copd_flg | stroke_flg | sapsi_first_HIGH),(map_1st_HIGH | (~stroke_flg & ~bun_first_NORM...,0.847458,0.79661,0.80791,0.018952,0.051887


# Decision tree

In [75]:
need_pruning = False
method = 'decision_tree'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []

for split, (train_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
    print(f'Split [{split+1}/{n_splits}]')
    x_train, x_test = x[train_index], x[test_index]
    y_train, y_test = y[train_index], y[test_index]
    
    classifier = DecisionTreeClassifier(random_state=split)
    classifier.fit(x_train.cpu().detach().numpy(), y_train.cpu().detach().numpy())
    y_preds = classifier.predict(x_test.cpu().detach().numpy())
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    target_class = 1
    start = time.time()
    explanation = tree_to_formula(classifier, concepts, target_class)
    elapsed_time = time.time() - start
    print(f'\t Class {target_class} - Global explanation: {explanation}')
    print(f'\t Elapsed time {elapsed_time}')
    
    target_class = 0
    start = time.time()
    explanation_inv = tree_to_formula(classifier, concepts, target_class)
    elapsed_time_inv = time.time() - start
    print(f'\t Class {target_class} - Global explanation: {explanation_inv}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(split)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(model_accuracy)
    explanation_accuracies_inv.append(model_accuracy)
    elapsed_times.append(0)
    elapsed_times_inv.append(0)

Split [1/10]
	 Model's accuracy: 0.8034
	 Class 1 - Global explanation: (age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_flg <= 0.50 & chf_flg <= 0.50 & wbc_first_NORMAL <= 0.50 & liver_flg <= 0.50 & sapsi_first_NORMAL <= 0.50 & copd_flg > 0.50 & service_num_LOW > 0.50 & potassium_first_NORMAL > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_flg <= 0.50 & chf_flg <= 0.50 & wbc_first_NORMAL <= 0.50 & liver_flg <= 0.50 & sapsi_first_NORMAL > 0.50 & po2_first_NORMAL <= 0.50 & icu_los_day_NORMAL <= 0.50 & hr_1st_HIGH <= 0.50 & potassium_first_NORMAL > 0.50 & aline_flg <= 0.50 & sofa_first_LOW <= 0.50 & service_num_HIGH > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_flg <= 0.50 & chf_flg <= 0.50 & wbc_first_NORMAL <= 0.50 & liver_flg <= 0.50 & sapsi_first_NORMAL > 0.50 & po2_first_NORMAL <= 0.50 & icu_los_day_NORMAL <= 0.50 & hr_1st_HIGH > 0.50 & mal_flg > 0.50 & service_num_HIGH > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_flg <= 0.50 & chf_flg <= 0.50 & wbc_first_N

	 Class 1 - Global explanation: (age_HIGH <= 0.50 & stroke_flg <= 0.50 & age_LOW <= 0.50 & resp_flg <= 0.50 & chf_flg <= 0.50 & wbc_first_NORMAL <= 0.50 & hr_1st_HIGH <= 0.50 & liver_flg <= 0.50 & sapsi_first_NORMAL > 0.50 & po2_first_NORMAL <= 0.50 & hour_icu_intime_LOW > 0.50 & potassium_first_LOW <= 0.50 & day_icu_intime_num_HIGH > 0.50 & tco2_first_LOW <= 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & age_LOW <= 0.50 & resp_flg <= 0.50 & chf_flg <= 0.50 & wbc_first_NORMAL <= 0.50 & hr_1st_HIGH <= 0.50 & liver_flg <= 0.50 & sapsi_first_NORMAL > 0.50 & po2_first_NORMAL <= 0.50 & hour_icu_intime_LOW > 0.50 & potassium_first_LOW > 0.50 & day_icu_intime_num_NORMAL > 0.50 & service_num_LOW <= 0.50 & gender_num > 0.50 & sofa_first_LOW <= 0.50 & hgb_first_NORMAL > 0.50 & tco2_first_LOW <= 0.50 & iv_day_1_LOW > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & age_LOW <= 0.50 & resp_flg <= 0.50 & chf_flg <= 0.50 & wbc_first_NORMAL <= 0.50 & hr_1st_HIGH <= 0.50 & liver_flg <= 0.50 & sapsi_f

	 Class 1 - Global explanation: (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & age_LOW <= 0.50 & resp_flg <= 0.50 & chf_flg <= 0.50 & wbc_first_NORMAL <= 0.50 & hgb_first_LOW <= 0.50 & sofa_first_NORMAL <= 0.50 & copd_flg <= 0.50 & day_icu_intime_num_LOW > 0.50 & liver_flg <= 0.50 & hour_icu_intime_NORMAL > 0.50 & hr_1st_LOW > 0.50 & aline_flg > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & age_LOW <= 0.50 & resp_flg <= 0.50 & chf_flg <= 0.50 & wbc_first_NORMAL <= 0.50 & hgb_first_LOW <= 0.50 & sofa_first_NORMAL <= 0.50 & copd_flg <= 0.50 & day_icu_intime_num_LOW > 0.50 & liver_flg > 0.50 & service_num_HIGH <= 0.50 & sofa_first_LOW > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & age_LOW <= 0.50 & resp_flg <= 0.50 & chf_flg <= 0.50 & wbc_first_NORMAL <= 0.50 & hgb_first_LOW <= 0.50 & sofa_first_NORMAL <= 0.50 & copd_flg > 0.50 & day_icu_intime_num_NORMAL > 0.50 & hgb_first_NORMAL <= 0.50 & potassium_first_N

In [76]:
results_tree = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_tree.to_csv(os.path.join(results_dir, 'results_tree.csv'))
results_tree

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,decision_tree,0,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,0.803371,0.803371,0.803371,0,0
1,decision_tree,1,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.803371,0.803371,0.803371,0,0
2,decision_tree,2,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,0.859551,0.859551,0.859551,0,0
3,decision_tree,3,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.797753,0.797753,0.797753,0,0
4,decision_tree,4,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & age_L...,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & age_L...,0.808989,0.808989,0.808989,0,0
5,decision_tree,5,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,0.814607,0.814607,0.814607,0,0
6,decision_tree,6,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.745763,0.745763,0.745763,0,0
7,decision_tree,7,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.836158,0.836158,0.836158,0,0
8,decision_tree,8,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,0.779661,0.779661,0.779661,0,0
9,decision_tree,9,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.779661,0.779661,0.779661,0,0


# Summary

In [77]:
cols = ['model_accuracy', 'explanation_accuracy', 'explanation_accuracy_inv', 'elapsed_time', 'elapsed_time_inv']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

# pruning
df_mean = results_pruning[cols].mean()
df_sem = results_pruning[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_pruning = pd.concat([df_mean, df_sem])
summary_pruning.name = 'pruning'

# # lime
# df_mean = results_lime[cols].mean()
# df_sem = results_lime[cols].sem()
# df_mean.columns = mean_cols
# df_sem.columns = sem_cols
# summary_lime = pd.concat([df_mean, df_sem])
# summary_lime.name = 'lime'

# weights
df_mean = results_weights[cols].mean()
df_sem = results_weights[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_weights = pd.concat([df_mean, df_sem])
summary_weights.name = 'weights'

# psi
df_mean = results_psi[cols].mean()
df_sem = results_psi[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_psi = pd.concat([df_mean, df_sem])
summary_psi.name = 'psi'

# tree
df_mean = results_tree[cols].mean()
df_sem = results_tree[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_tree = pd.concat([df_mean, df_sem])
summary_tree.name = 'tree'

summary = pd.concat([summary_pruning, 
#                      summary_lime, 
                     summary_weights, 
                     summary_psi, 
                     summary_tree], axis=1).T
summary.columns = mean_cols + sem_cols
summary

Unnamed: 0,model_accuracy_mean,explanation_accuracy_mean,explanation_accuracy_inv_mean,elapsed_time_mean,elapsed_time_inv_mean,model_accuracy_sem,explanation_accuracy_sem,explanation_accuracy_inv_sem,elapsed_time_sem,elapsed_time_inv_sem
pruning,0.859804,0.859808,0.71879,0.642667,2.541792,0.005886,0.004962,0.047113,0.117541,0.395207
weights,0.844001,0.271713,0.67271,0.975883,3.809884,0.007392,0.064074,0.050548,0.070919,0.148842
psi,0.841211,0.748029,0.656446,0.028446,0.03681,0.003656,0.083964,0.1096,0.002585,0.00316
tree,0.802888,0.802888,0.802888,0.0,0.0,0.009923,0.009923,0.009923,0.0,0.0


In [78]:
summary.to_csv(os.path.join(results_dir, 'summary.csv'))