In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, KBinsDiscretizer
from deep_logic.models.brl import XBRLClassifier
from sklearn.impute import SimpleImputer
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import TensorDataset, DataLoader
import time
from sklearn.metrics import accuracy_score
from sympy import simplify_logic
from sklearn.model_selection import train_test_split

from deep_logic.utils.base import validate_network, set_seed, tree_to_formula
from deep_logic.utils.relu_nn import get_reduced_model, prune_features
from deep_logic.utils.psi_nn import prune_equal_fanin
from deep_logic import logic
import deep_logic as dl

set_seed(0)

#%%

data = pd.read_csv('data/mimic-ii/full_cohort_data.csv')
# data.drop('hgb_first')
fs = [
    'aline_flg',
    'gender_num',
    # 'hosp_exp_flg',
    # 'icu_exp_flg',
    # 'day_28_flg',
    # 'censor_flg',
    'sepsis_flg', 'chf_flg', 'afib_flg',
    'renal_flg', 'liver_flg', 'copd_flg', 'cad_flg', 'stroke_flg',
    'mal_flg', 'resp_flg',
]
features = fs
data1 = data[fs].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data1 = imp_mean.fit_transform(data1)

f2 = fs.copy()
f2.append('day_icu_intime')
f2.append('service_unit')
f2.append('day_28_flg')
f2.append('hospital_los_day')
f2.append('icu_exp_flg')
f2.append('hosp_exp_flg')
f2.append('censor_flg')
f2.append('mort_day_censored')
f2 = data.columns.difference(f2)
data2 = data[f2].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data2 = imp_mean.fit_transform(data2)
scaler = MinMaxScaler((0, 1))
data2 = scaler.fit_transform(data2)
features = features + list(f2)
est = KBinsDiscretizer(n_bins=3, encode='onehot-dense', strategy='uniform')
data2d = est.fit_transform(data2)
f2d = []
for feature in f2:
    #f2d.append(feature + '_VLOW')
    f2d.append(feature + '_LOW')
    f2d.append(feature + '_NORMAL')
    f2d.append(feature + '_HIGH')
    #f2d.append(feature + '_VHIGH')
features = fs + f2d

datax = np.hstack((data1, data2d))
datay = data['day_28_flg'].values

x = torch.FloatTensor(datax)
y = torch.LongTensor(datay)
print(x.shape)
print(y.shape)


torch.Size([1776, 90])
torch.Size([1776])


In [2]:
results_dir = 'results/mimic_explain_classifier'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# concepts = [f'c{i:03}' for i in range(x.shape[1])]
concepts =  features
n_rep = 10
tot_epochs = 5001
prune_epochs = 2001
seed = 42

n_splits = 10
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)

In [3]:
def train_nn(x_train, y_train, need_pruning, seed, device, l1=0.001, lr=0.001, relu=False, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        torch.nn.Linear(x_train.size(1), 100),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(100, 10),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(10, 2),
        torch.nn.Softmax(dim=1),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_form = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train)
        # Compute Loss
        loss = loss_form(y_pred, y_train)
        
        if need_pruning or relu:
            for module in model.children():
                if isinstance(module, torch.nn.Linear):
                    loss += l1 * torch.norm(module.weight, 1)
                    loss += l1 * torch.norm(module.bias, 1)
                    break

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning and not relu and epoch % 1000 == 1:
            prune_features(model, n_classes=1, device=device)
            need_pruning = True
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = torch.argmax(y_pred, dim=1)
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [34]:
def train_psi_nn(x_train, y_train, need_pruning, seed, device, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device).to(torch.float)
    layers = [
        torch.nn.Linear(x_train.size(1), 5),
        torch.nn.Sigmoid(),
        torch.nn.Linear(5, 5),
        torch.nn.Sigmoid(),
        torch.nn.Linear(5, 1),
        torch.nn.Sigmoid(),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.001 * torch.norm(module.weight, 1)

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            model = prune_equal_fanin(model, 2, validate=True, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [40]:
def c_to_y(method, need_pruning, l1=0.0001, lr=0.0001, relu=False, verbose=False):
    methods = []
    splits = []
    explanations = []
    explanations_inv = []
    model_accuracies = []
    explanation_accuracies = []
    explanation_accuracies_inv = []
    elapsed_times = []
    elapsed_times_inv = []
    explanation_fidelities = []
    explanation_complexities = []
    for split, (trainval_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
        print(f'Split [{split+1}/{n_splits}]')
        x_trainval, x_test = torch.FloatTensor(x[trainval_index]), torch.FloatTensor(x[test_index])
        y_trainval, y_test = torch.LongTensor(y[trainval_index]), torch.LongTensor(y[test_index])
        x_train, x_val, y_train, y_val = train_test_split(x_trainval, y_trainval, test_size=0.3, random_state=42)
        
        model = train_nn(x_trainval, y_trainval, False, seed, device, l1=0, lr=0.001, relu=False, verbose=False)
        y_preds_trainval = model(x_trainval.to(device)).cpu().argmax(dim=1)
        y_preds_train = model(x_train.to(device)).cpu().argmax(dim=1)
        y_preds_val = model(x_val.to(device)).cpu().argmax(dim=1)
        y_preds_test = model(x_test.to(device)).cpu().argmax(dim=1)
    
        explanation, explanation_inv = '', ''
        explanation_accuracy, explanation_accuracy_inv = 0, 0
        
        if method == 'tree':
            classifier = DecisionTreeClassifier(random_state=seed)
            classifier.fit(x_trainval.cpu().detach().numpy(), y_preds_trainval.cpu().detach().numpy())
            y_preds = classifier.predict(x_test.cpu().detach().numpy())
            model_accuracy = accuracy_score(y_preds, y_test)

            target_class = 1
            start = time.time()
            explanation = tree_to_formula(classifier, concepts, target_class)
            elapsed_time = time.time() - start
            explanation_accuracy = model_accuracy
            explanation_fidelity = accuracy_score(y_preds_test.cpu().detach().numpy(), y_preds)
            explanation_complexity = dl.logic.complexity(explanation)
        
        elif method == 'brl':
            y_preds_trainval2 = torch.zeros((y_preds_trainval.shape[0], 2))
            print(y_preds_trainval2.shape)
            y_preds_trainval2[:, 0] = 1-y_preds_trainval
            y_preds_trainval2[:, 1] = y_preds_trainval
            train_dataset = TensorDataset(x_trainval, y_preds_trainval2)
            y_preds_val2 = torch.zeros((y_preds_val.shape[0], 2))
            print(y_preds_val2.shape)
            y_preds_val2[:, 0] = 1-y_preds_val
            y_preds_val2[:, 1] = y_preds_val
            val_dataset = TensorDataset(x_val, y_preds_val2)
            test_dataset = TensorDataset(x_test, y_preds_test)
            dummy_concepts = [f'feature{i:010}' for i in range(len(concepts))]
            model = XBRLClassifier(name=os.path.join(results_dir, f'{method}_{split}'), 
                                   n_classes=len(y_train.unique()), 
                                   n_features=x_train.shape[1], feature_names=dummy_concepts, 
                                   class_names=['~day_28_flg', 'day_28_flg'], discretize=True)
            results = model.fit(train_dataset, metric=accuracy_score, save=True)
            model_accuracy = model.evaluate(test_dataset)
            target_class = 1
            global_explanation, elapsed_time = model.get_global_explanation(target_class, concepts=dummy_concepts, return_time=True)
            explanation_fidelity = 0
            explanation_complexity = 0
            if global_explanation:
                explanation = logic.base.replace_names(global_explanation, concepts)
                explanation_accuracy, y_formula = logic.base.test_explanation(global_explanation, 
                                                                              target_class, 
                                                                              x=x_test, y=y_test,
                                                                              metric=accuracy_score)
                explanation_fidelity = dl.logic.fidelity(y_formula, y_preds_test)
                explanation_complexity = dl.logic.complexity(global_explanation)
                
        else:
            if method == 'psi':
                # positive class
                target_class = 1
                model = train_psi_nn(x_trainval, y_preds_trainval.eq(target_class), need_pruning, seed, device, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy()>0.5
                model_accuracy = accuracy_score(y_preds_test.cpu().detach().numpy(), y_preds)
                
            else:
                model = train_nn(x_trainval, y_preds_trainval, need_pruning, seed, device, l1, lr, relu, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy().argmax(axis=1)
                model_accuracy = accuracy_score(y_preds_test.cpu().detach().numpy(), y_preds)

            # positive class
            target_class = 1
            start = time.time()
            if method == 'psi':
                global_explanation = logic.generate_fol_explanations(model, device)[0]
                print(global_explanation)
            else:
                global_explanation, _, _ = logic.relu_nn.combine_local_explanations(model, 
                                                                                   x_val.to(device), 
                                                                                   y_preds_val.to(device), 
                                                                                   topk_explanations=2,
                                                                                   target_class=target_class,
                                                                                   method=method, device=device)
            elapsed_time = time.time() - start
            
            explanation_fidelity = 0
            explanation_complexity = 0
            if global_explanation:
                explanation = logic.base.replace_names(global_explanation, concepts)
                explanation_accuracy, y_formula = logic.base.test_explanation(global_explanation, 
                                                                              target_class, 
                                                                              x=x_test, y=y_test,
                                                                              metric=accuracy_score)
                explanation_fidelity = dl.logic.fidelity(y_formula, y_preds_test)
                explanation_complexity = dl.logic.complexity(global_explanation)

        if verbose:
            print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
            print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
            print(f'\t Elapsed time {elapsed_time}')
            print(f'\t Fidelity: "{explanation_fidelity:.4f}" - Complexity: "{explanation_complexity}"')

        methods.append(method)
        splits.append(split)
        explanations.append(explanation)
        model_accuracies.append(model_accuracy)
        explanation_accuracies.append(explanation_accuracy)
        elapsed_times.append(elapsed_time)
        explanation_fidelities.append(explanation_fidelity)
        explanation_complexities.append(explanation_complexity)
    
    explanation_consistency = dl.logic.formula_consistency(explanations)
    print(f'Consistency of explanations: {explanation_consistency:.4f}')
    
    results = pd.DataFrame({
        'method': methods,
        'split': splits,
        'explanation': explanations,
        'model_accuracy': model_accuracies,
        'explanation_accuracy': explanation_accuracies,
        'explanation_fidelity': explanation_fidelities,
        'explanation_complexity': explanation_complexities,
        'explanation_consistency': explanation_consistency,
        'elapsed_time': elapsed_times,
    })
    results.to_csv(os.path.join(results_dir, f'results_{method}.csv'))
    
    return results

# General pruning

In [6]:
method = 'pruning'
need_pruning = True
relu = False
results_pruning = c_to_y(method, need_pruning, relu=relu)
results_pruning

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]
Consistency of explanations: 0.3167


Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,pruning,0,stroke_flg,0.91573,0.842697,0.91573,1,0.316667,0.062833
1,pruning,1,afib_flg & stroke_flg,0.88764,0.848315,0.88764,2,0.316667,0.023936
2,pruning,2,stroke_flg & age_HIGH & ~sapsi_first_LOW,0.88764,0.876404,0.88764,3,0.316667,0.053856
3,pruning,3,,0.898876,0.0,0.0,0,0.316667,0.003024
4,pruning,4,stroke_flg,0.865169,0.848315,0.865169,1,0.316667,0.05286
5,pruning,5,stroke_flg,0.882022,0.831461,0.882022,1,0.316667,0.030914
6,pruning,6,stroke_flg & ~sapsi_first_LOW,0.898305,0.830508,0.898305,2,0.316667,0.044881
7,pruning,7,stroke_flg & age_HIGH & ~sapsi_first_LOW & ~we...,0.943503,0.875706,0.943503,4,0.316667,0.035904
8,pruning,8,stroke_flg & ~weight_first_NORMAL,0.903955,0.841808,0.903955,2,0.316667,0.050864
9,pruning,9,stroke_flg & age_HIGH,0.920904,0.870056,0.920904,2,0.316667,0.034907


# ReLUs

In [28]:
method = 'weights'
need_pruning = False
relu = True
results_weights = c_to_y(method='weights', need_pruning=False, l1=0.00001, lr=0.0005, relu=True, verbose=True)
results_weights

Split [1/10]


KeyboardInterrupt: 

# Psi network

In [35]:
method = 'psi'
need_pruning = True
relu = False
results_psi = c_to_y(method, need_pruning, l1=0.0001, lr=0.0001, relu=False, verbose=True)
results_psi

Split [1/10]
	 Epoch 0: train accuracy: 0.1176
	 Epoch 500: train accuracy: 0.9380
	 Epoch 1000: train accuracy: 0.9406
	 Epoch 1500: train accuracy: 0.9380
	 Epoch 2000: train accuracy: 0.9355
	 Epoch 2500: train accuracy: 0.8824
	 Epoch 3000: train accuracy: 0.8824
	 Epoch 3500: train accuracy: 0.8824
	 Epoch 4000: train accuracy: 0.8824
	 Epoch 4500: train accuracy: 0.8824
	 Epoch 5000: train accuracy: 0.8824
(feature0000000009)
	 Model's accuracy: 0.9101
	 Class 1 - Global explanation: "(stroke_flg)" - Accuracy: 0.8427
	 Elapsed time 0.00501704216003418
	 Fidelity: "0.9157" - Complexity: "1"
Split [2/10]
	 Epoch 0: train accuracy: 0.1214
	 Epoch 500: train accuracy: 0.9337
	 Epoch 1000: train accuracy: 0.9305
	 Epoch 1500: train accuracy: 0.9387
	 Epoch 2000: train accuracy: 0.9368
	 Epoch 2500: train accuracy: 0.8867
	 Epoch 3000: train accuracy: 0.8861
	 Epoch 3500: train accuracy: 0.8861
	 Epoch 4000: train accuracy: 0.8861
	 Epoch 4500: train accuracy: 0.8861
	 Epoch 5000: trai

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,psi,0,(stroke_flg),0.910112,0.842697,0.91573,1,0.325,0.005017
1,psi,1,(stroke_flg),0.853933,0.825843,0.853933,1,0.325,0.004986
2,psi,2,(~sapsi_first_LOW),0.898876,0.488764,0.477528,1,0.325,0.00402
3,psi,3,(stroke_flg & ~sapsi_first_LOW),0.91573,0.859551,0.91573,2,0.325,0.004988
4,psi,4,(stroke_flg | abg_count_HIGH),0.865169,0.848315,0.865169,2,0.325,0.008008
5,psi,5,(stroke_flg),0.870787,0.831461,0.882022,1,0.325,0.003989
6,psi,6,(~sapsi_first_LOW),0.830508,0.463277,0.485876,1,0.325,0.004986
7,psi,7,(stroke_flg),0.920904,0.858757,0.915254,1,0.325,0.003989
8,psi,8,(stroke_flg & ~weight_first_NORMAL),0.903955,0.841808,0.903955,2,0.325,0.004987
9,psi,9,(stroke_flg),0.870056,0.881356,0.909605,1,0.325,0.003991


# Decision tree

In [41]:
method = 'tree'
need_pruning = False
relu = False
results_tree = c_to_y(method, need_pruning, relu, verbose=True)
results_tree

Split [1/10]
	 Model's accuracy: 0.8483
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & bun_first_NORMAL <= 0.50 & hour_icu_intime_NORMAL > 0.50 & chloride_first_NORMAL <= 0.50 & hr_1st_LOW > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & bun_first_NORMAL <= 0.50 & hour_icu_intime_NORMAL > 0.50 & chloride_first_NORMAL > 0.50 & sofa_first_HIGH <= 0.50 & mal_flg <= 0.50 & afib_flg > 0.50 & aline_flg > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & bun_first_NORMAL <= 0.50 & hour_icu_intime_NORMAL > 0.50 & chloride_first_NORMAL > 0.50 & sofa_first_HIGH <= 0.50 & mal_flg > 0.50 & hr_1st_NORMAL <= 0.50 & gender_num <= 0.50 & aline_flg <= 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.

	 Model's accuracy: 0.8258
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & afib_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & cad_flg <= 0.50 & sofa_first_HIGH <= 0.50 & mal_flg > 0.50 & hour_icu_intime_NORMAL > 0.50 & hr_1st_NORMAL <= 0.50 & gender_num <= 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & cad_flg <= 0.50 & sofa_first_HIGH > 0.50 & hour_icu_intime_NORMAL > 0.50 & po2_first_LOW <= 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & cad_flg > 0.50 & tco2_first_NORMAL <= 0.50 & day_icu_intime_num_LOW > 0.50 & aline_flg > 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_LOW > 0.50 & map_1st_NORMAL <= 0.5

	 Model's accuracy: 0.8371
	 Class 1 - Global explanation: "(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & hour_icu_intime_NORMAL > 0.50 & sofa_first_HIGH <= 0.50 & mal_flg <= 0.50 & afib_flg > 0.50 & po2_first_NORMAL > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & hour_icu_intime_NORMAL > 0.50 & sofa_first_HIGH <= 0.50 & mal_flg > 0.50 & hr_1st_NORMAL <= 0.50 & gender_num <= 0.50 & tco2_first_NORMAL <= 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & hour_icu_intime_NORMAL > 0.50 & sofa_first_HIGH <= 0.50 & mal_flg > 0.50 & hr_1st_NORMAL <= 0.50 & gender_num <= 0.50 & tco2_first_NORMAL > 0.50 & hgb_first_NORMAL > 0.50 & sofa_first_LOW <= 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & hour_icu_intime_NORMAL > 0.50 & sofa_first_HIGH > 0.50 & service_num_LOW > 0.50) | (age_HIG

	 Model's accuracy: 0.8708
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg <= 0.50 & sodium_first_LOW <= 0.50 & hr_1st_HIGH > 0.50 & iv_day_1_NORMAL <= 0.50 & mal_flg > 0.50 & tco2_first_LOW > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg <= 0.50 & sodium_first_LOW <= 0.50 & hr_1st_HIGH > 0.50 & iv_day_1_NORMAL > 0.50 & potassium_first_NORMAL > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg <= 0.50 & sodium_first_LOW > 0.50 & mal_flg > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg > 0.50 & bun_first_NORMAL <= 0.50 & day_icu_intime_num_LOW > 0.50 & chloride_first_LOW > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg > 0.50 & bun_first_NORMAL > 0.50 & day_icu_intime_num_NORMAL > 0.50 & renal_flg <= 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH > 0.50 & hour

	 Model's accuracy: 0.8427
	 Class 1 - Global explanation: "(age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & stroke_flg <= 0.50 & abg_count_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & afib_flg <= 0.50 & hr_1st_HIGH <= 0.50 & mal_flg > 0.50 & hour_icu_intime_NORMAL > 0.50 & hr_1st_LOW > 0.50 & aline_flg <= 0.50) | (age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & stroke_flg <= 0.50 & abg_count_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & afib_flg <= 0.50 & hr_1st_HIGH > 0.50 & map_1st_LOW <= 0.50 & mal_flg > 0.50 & tco2_first_LOW > 0.50) | (age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & stroke_flg <= 0.50 & abg_count_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & afib_flg <= 0.50 & hr_1st_HIGH > 0.50 & map_1st_LOW > 0.50 & chloride_first_LOW <= 0.50 & day_icu_intime_num_HIGH > 0.50 & sofa_first_HIGH > 0.50) | (age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & stroke_flg <= 0.50 & abg_count_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & afib_flg <= 0.50 & hr_1st_HIGH > 0.50 & map_1st_LOW > 0.50 & chl

	 Model's accuracy: 0.8146
	 Class 1 - Global explanation: "(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & mal_flg <= 0.50 & iv_day_1_NORMAL <= 0.50 & sofa_first_HIGH <= 0.50 & afib_flg > 0.50 & po2_first_NORMAL > 0.50 & day_icu_intime_num_NORMAL > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & mal_flg <= 0.50 & iv_day_1_NORMAL <= 0.50 & sofa_first_HIGH > 0.50 & hr_1st_HIGH > 0.50 & service_num_LOW > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & mal_flg <= 0.50 & iv_day_1_NORMAL > 0.50 & hr_1st_LOW <= 0.50 & map_1st_NORMAL <= 0.50 & chloride_first_NORMAL <= 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & mal_flg <= 0.50 & iv_day_1_NORMAL > 0.50 & hr_1st_LOW <= 0.50 & map_1st_NORMAL <= 0.50 & chloride_first_NORMAL > 0.50 & age_LOW <= 0.50 & chf_flg <= 0.50) | (ag

	 Model's accuracy: 0.8136
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & platelet_first_HIGH <= 0.50 & afib_flg <= 0.50 & cad_flg <= 0.50 & sofa_first_HIGH <= 0.50 & mal_flg > 0.50 & hr_1st_NORMAL <= 0.50 & resp_flg > 0.50 & gender_num <= 0.50 & sofa_first_LOW <= 0.50 & hour_icu_intime_NORMAL > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & platelet_first_HIGH <= 0.50 & afib_flg <= 0.50 & cad_flg <= 0.50 & sofa_first_HIGH <= 0.50 & mal_flg > 0.50 & hr_1st_NORMAL <= 0.50 & resp_flg > 0.50 & gender_num <= 0.50 & sofa_first_LOW > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & platelet_first_HIGH <= 0.50 & afib_flg <= 0.50 & cad_flg <= 0.50 & sofa_first_HIGH > 0.50 & day_icu_intime_num_HIGH > 0.50 & po2_first_NORMAL > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & 

	 Model's accuracy: 0.8192
	 Class 1 - Global explanation: "(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & bun_first_NORMAL <= 0.50 & sofa_first_HIGH <= 0.50 & afib_flg <= 0.50 & mal_flg > 0.50 & hour_icu_intime_NORMAL > 0.50 & day_icu_intime_num_HIGH > 0.50 & sofa_first_NORMAL > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & bun_first_NORMAL <= 0.50 & sofa_first_HIGH <= 0.50 & afib_flg > 0.50 & po2_first_NORMAL > 0.50 & tco2_first_NORMAL <= 0.50 & day_icu_intime_num_LOW <= 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & bun_first_NORMAL <= 0.50 & sofa_first_HIGH > 0.50 & resp_flg > 0.50 & hour_icu_intime_NORMAL > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_LO

	 Model's accuracy: 0.8136
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & platelet_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & resp_flg > 0.50 & sofa_first_HIGH <= 0.50 & mal_flg <= 0.50 & afib_flg > 0.50 & hour_icu_intime_NORMAL > 0.50 & po2_first_NORMAL > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & platelet_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & resp_flg > 0.50 & sofa_first_HIGH <= 0.50 & mal_flg > 0.50 & hour_icu_intime_HIGH <= 0.50 & hr_1st_NORMAL <= 0.50 & sofa_first_NORMAL <= 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & platelet_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & resp_flg > 0.50 & sofa_first_HIGH <= 0.50 & mal_flg > 0.50 & hour_icu_intime_HIGH <= 0.50 & hr_1st_NORMAL <= 0.50 & sofa_first_NORMAL > 0.50 & day_icu_intime_num_HIGH > 0.50 & hr_1st_HIGH <= 0.5

	 Model's accuracy: 0.8475
	 Class 1 - Global explanation: "(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & afib_flg <= 0.50 & liver_flg <= 0.50 & iv_day_1_NORMAL <= 0.50 & bun_first_NORMAL <= 0.50 & mal_flg > 0.50 & hour_icu_intime_NORMAL > 0.50 & hr_1st_LOW > 0.50 & aline_flg <= 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & afib_flg <= 0.50 & liver_flg <= 0.50 & iv_day_1_NORMAL <= 0.50 & bun_first_NORMAL > 0.50 & mal_flg > 0.50 & sapsi_first_LOW <= 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & afib_flg <= 0.50 & liver_flg <= 0.50 & iv_day_1_NORMAL > 0.50 & hr_1st_LOW <= 0.50 & map_1st_NORMAL <= 0.50 & aline_flg <= 0.50 & hgb_first_NORMAL > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & chloride_first_LOW <= 0.50 & afib_flg <= 0.50 & liver_flg <= 0.50 & iv_day_1_NORMAL > 0.50 & hr_1st_LOW > 0

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,tree,0,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.848315,0.848315,0.921348,691,0.745255,0.001995
1,tree,1,(stroke_flg <= 0.50 & afib_flg <= 0.50 & age_H...,0.825843,0.825843,0.921348,726,0.745255,0.000995
2,tree,2,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.837079,0.837079,0.893258,723,0.745255,0.001996
3,tree,3,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.870787,0.870787,0.949438,703,0.745255,0.001992
4,tree,4,(age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 &...,0.842697,0.842697,0.926966,653,0.745255,0.000998
5,tree,5,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.814607,0.814607,0.932584,652,0.745255,0.000997
6,tree,6,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.813559,0.813559,0.847458,725,0.745255,0.001995
7,tree,7,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.819209,0.819209,0.920904,730,0.745255,0.000998
8,tree,8,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.813559,0.813559,0.887006,654,0.745255,0.001995
9,tree,9,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.847458,0.847458,0.909605,758,0.745255,0.000997


# BRL

In [37]:
method = 'brl'
need_pruning = False
relu = False
results_brl = c_to_y(method, need_pruning, relu, verbose=True)
results_brl

Split [1/10]
torch.Size([1598, 2])
torch.Size([480, 2])
Labels (1598, 2)
Discretized features
Completed model 1/2!
Completed model 2/2!
Train_acc: 0.9, Val_acc: 0.0
	 Model's accuracy: 91.5730
	 Class 1 - Global explanation: "stroke_flg & age_HIGH & ~stroke_flg & ~abg_count_NORMAL & ~chloride_first_HIGH & ~pco2_first_HIGH & ~sapsi_first_LOW" - Accuracy: 0.8371
	 Elapsed time 0.006980419158935547
	 Fidelity: "0.9101" - Complexity: "7"
Split [2/10]
torch.Size([1598, 2])
torch.Size([480, 2])
Labels (1598, 2)
Discretized features
Completed model 1/2!
Completed model 2/2!
Train_acc: 0.9, Val_acc: 0.0
	 Model's accuracy: 89.3258
	 Class 1 - Global explanation: "stroke_flg & weight_first_LOW & ~renal_flg & ~creatinine_first_HIGH & ~sapsi_first_LOW" - Accuracy: 0.8596
	 Elapsed time 0.003988742828369141
	 Fidelity: "0.8876" - Complexity: "5"
Split [3/10]
torch.Size([1598, 2])
torch.Size([480, 2])
Labels (1598, 2)
Discretized features
Completed model 1/2!
Completed model 2/2!
Train_acc: 0.9, Va

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,brl,0,stroke_flg & age_HIGH & ~stroke_flg & ~abg_cou...,91.573034,0.837079,0.910112,7,0.23125,0.00698
1,brl,1,stroke_flg & weight_first_LOW & ~renal_flg & ~...,89.325843,0.859551,0.88764,5,0.23125,0.003989
2,brl,2,(stroke_flg & ~creatinine_first_NORMAL & ~saps...,89.325843,0.876404,0.898876,14,0.23125,0.012965
3,brl,3,stroke_flg & chloride_first_NORMAL & ~age_NORM...,91.573034,0.859551,0.91573,4,0.23125,0.002992
4,brl,4,(age_HIGH & ~bun_first_LOW) | (stroke_flg & ag...,93.258427,0.853933,0.938202,27,0.23125,0.015911
5,brl,5,stroke_flg & age_HIGH & bun_first_LOW & weight...,93.820225,0.876404,0.938202,5,0.23125,0.00211
6,brl,6,(stroke_flg & weight_first_LOW & ~afib_flg & ~...,91.525424,0.836158,0.881356,17,0.23125,0.011966
7,brl,7,(stroke_flg & hr_1st_NORMAL & ~sapsi_first_LOW...,93.785311,0.870056,0.926554,13,0.23125,0.009004
8,brl,8,(afib_flg & resp_flg & ~bmi_HIGH & ~pco2_first...,93.220339,0.819209,0.892655,18,0.23125,0.012903
9,brl,9,stroke_flg & hgb_first_NORMAL & hr_1st_NORMAL ...,91.525424,0.847458,0.864407,6,0.23125,0.002986


# Summary

In [42]:
cols = ['model_accuracy', 'explanation_accuracy', 'explanation_fidelity', 
        'explanation_complexity', 'elapsed_time', 'explanation_consistency']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

# pruning
df_mean = results_pruning[cols].mean()
df_sem = results_pruning[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_pruning = pd.concat([df_mean, df_sem])
summary_pruning.name = 'pruning'

# weights
df_mean = results_weights[cols].mean()
df_sem = results_weights[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_weights = pd.concat([df_mean, df_sem])
summary_weights.name = 'weights'

# psi
df_mean = results_psi[cols].mean()
df_sem = results_psi[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_psi = pd.concat([df_mean, df_sem])
summary_psi.name = 'psi'

# tree
df_mean = results_tree[cols].mean()
df_sem = results_tree[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_tree = pd.concat([df_mean, df_sem])
summary_tree.name = 'tree'

# BRL
df_mean = results_brl[cols].mean()
df_sem = results_brl[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_brl = pd.concat([df_mean, df_sem])
summary_brl.name = 'BRL'

summary = pd.concat([summary_pruning,
                     summary_weights, 
                     summary_psi, 
                     summary_tree,
                     summary_brl], axis=1).T
summary.columns = mean_cols + sem_cols
summary

Unnamed: 0,model_accuracy_mean,explanation_accuracy_mean,explanation_fidelity_mean,explanation_complexity_mean,elapsed_time_mean,explanation_consistency_mean,model_accuracy_sem,explanation_accuracy_sem,explanation_fidelity_sem,explanation_complexity_sem,elapsed_time_sem,explanation_consistency_sem
pruning,0.900375,0.766527,0.810487,1.8,0.039398,0.316667,0.007022,0.085338,0.090327,0.359011,0.005545,1.850372e-17
weights,0.95665,0.727442,0.74768,8.2,3.96321,0.325,0.003843,0.063313,0.079528,1.554921,2.977609,1.850372e-17
psi,0.884003,0.774183,0.81248,1.3,0.004896,0.325,0.009521,0.049974,0.055566,0.152753,0.00038,0.0
tree,0.833311,0.833311,0.910992,701.5,0.001496,0.745255,0.006071,0.006071,0.009093,11.92127,0.000166,0.0
BRL,91.89329,0.85358,0.905374,11.6,0.008181,0.23125,0.522567,0.005944,0.007823,2.386071,0.001598,9.251859000000001e-18


In [43]:
summary.to_csv(os.path.join(results_dir, 'summary.csv'))