In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, KBinsDiscretizer
from lens.models.brl import XBRLClassifier
from sklearn.impute import SimpleImputer
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import TensorDataset, DataLoader
import time
from sklearn.metrics import accuracy_score
from sympy import simplify_logic
from sklearn.model_selection import train_test_split

from lens.utils.base import validate_network, set_seed, tree_to_formula
from lens.utils.relu_nn import get_reduced_model, prune_features
from lens.utils.psi_nn import prune_equal_fanin
from lens import logic
import lens

set_seed(0)

#%%

data = pd.read_csv('data/mimic-ii/full_cohort_data.csv')
# data.drop('hgb_first')
fs = [
    'aline_flg',
    'gender_num',
    # 'hosp_exp_flg',
    # 'icu_exp_flg',
    # 'day_28_flg',
    # 'censor_flg',
    'sepsis_flg', 'chf_flg', 'afib_flg',
    'renal_flg', 'liver_flg', 'copd_flg', 'cad_flg', 'stroke_flg',
    'mal_flg', 'resp_flg',
]
features = fs
data1 = data[fs].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data1 = imp_mean.fit_transform(data1)

f2 = fs.copy()
f2.append('day_icu_intime')
f2.append('service_unit')
f2.append('day_28_flg')
f2.append('hospital_los_day')
f2.append('icu_exp_flg')
f2.append('hosp_exp_flg')
f2.append('censor_flg')
f2.append('mort_day_censored')
f2 = data.columns.difference(f2)
data2 = data[f2].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data2 = imp_mean.fit_transform(data2)
scaler = MinMaxScaler((0, 1))
data2 = scaler.fit_transform(data2)
features = features + list(f2)
est = KBinsDiscretizer(n_bins=3, encode='onehot-dense', strategy='uniform')
data2d = est.fit_transform(data2)
f2d = []
for feature in f2:
    #f2d.append(feature + '_VLOW')
    f2d.append(feature + '_LOW')
    f2d.append(feature + '_NORMAL')
    f2d.append(feature + '_HIGH')
    #f2d.append(feature + '_VHIGH')
features = fs + f2d

datax = np.hstack((data1, data2d))
datay = data['day_28_flg'].values

x = torch.FloatTensor(datax)
y = torch.LongTensor(datay)
print(x.shape)
print(y.shape)


torch.Size([1776, 90])
torch.Size([1776])


In [2]:
results_dir = 'results/mimic_explain_classifier'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# concepts = [f'c{i:03}' for i in range(x.shape[1])]
concepts =  features
n_rep = 10
tot_epochs = 5001
prune_epochs = 2001
seed = 42

n_splits = 10
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)

In [3]:
def train_nn(x_train, y_train, need_pruning, seed, device, l1=0.001, lr=0.001, relu=False, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        torch.nn.Linear(x_train.size(1), 100),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(100, 50),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(50, 30),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(30, 10),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(10, 2),
        torch.nn.Softmax(dim=1),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_form = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train)
        # Compute Loss
        loss = loss_form(y_pred, y_train)
        
        if need_pruning or relu:
            for module in model.children():
                if isinstance(module, torch.nn.Linear):
                    loss += l1 * torch.norm(module.weight, 1)
                    loss += l1 * torch.norm(module.bias, 1)
                    break

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch == prune_epochs and need_pruning and not relu and epoch % 1000 == 1:
            prune_features(model, n_classes=1, device=device)
            need_pruning = True
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = torch.argmax(y_pred, dim=1)
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [4]:
def train_psi_nn(x_train, y_train, need_pruning, seed, device, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device).to(torch.float)
    layers = [
        torch.nn.Linear(x_train.size(1), 5),
        torch.nn.Sigmoid(),
        torch.nn.Linear(5, 5),
        torch.nn.Sigmoid(),
        torch.nn.Linear(5, 1),
        torch.nn.Sigmoid(),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.001 * torch.norm(module.weight, 1)

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            model = prune_equal_fanin(model, 2, validate=True, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [6]:
def c_to_y(method, need_pruning, l1=0.0001, lr=0.0001, relu=False, verbose=False):
    methods = []
    splits = []
    explanations = []
    explanations_inv = []
    model_accuracies = []
    explanation_accuracies = []
    explanation_accuracies_inv = []
    elapsed_times = []
    elapsed_times_inv = []
    explanation_fidelities = []
    explanation_complexities = []
    for split, (trainval_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
        print(f'Split [{split+1}/{n_splits}]')
        x_trainval, x_test = torch.FloatTensor(x[trainval_index]), torch.FloatTensor(x[test_index])
        y_trainval, y_test = torch.LongTensor(y[trainval_index]), torch.LongTensor(y[test_index])
        x_train, x_val, y_train, y_val = train_test_split(x_trainval, y_trainval, test_size=0.3, random_state=42)
        
        model = train_nn(x_trainval, y_trainval, False, seed, device, l1=0, lr=0.0005, relu=False, verbose=False)
        y_preds_trainval = model(x_trainval.to(device)).cpu().argmax(dim=1)
        y_preds_train = model(x_train.to(device)).cpu().argmax(dim=1)
        y_preds_val = model(x_val.to(device)).cpu().argmax(dim=1)
        y_preds_test = model(x_test.to(device)).cpu().argmax(dim=1)
        
        black_box_accuracy = accuracy_score(y_test, y_preds_test)
        print(f'\t Black box accuracy: {black_box_accuracy:.4f}')
            
        explanation, explanation_inv = '', ''
        explanation_accuracy, explanation_accuracy_inv = 0, 0

        target_class = 1
        start = time.time()
        
        if 'tree' in method:
            if method == 'tree-3':
                classifier = DecisionTreeClassifier(random_state=seed, max_depth=3)
            else:
                classifier = DecisionTreeClassifier(random_state=seed)
            classifier.fit(x_trainval.cpu().detach().numpy(), y_preds_trainval.cpu().detach().numpy())
            y_preds = classifier.predict(x_test.cpu().detach().numpy())

            explanation = tree_to_formula(classifier, concepts, target_class)
            
            # compute metrics
            model_accuracy = accuracy_score(y_test, y_preds)
            explanation_accuracy = model_accuracy
            explanation_fidelity = accuracy_score(y_preds_test.cpu().detach().numpy(), y_preds)
            explanation_complexity = lens.logic.complexity(explanation)
        
        elif method == 'brl':
            train_dataset = TensorDataset(x_trainval, y_preds_trainval)
            val_dataset = TensorDataset(x_val, y_preds_val)
            test_dataset = TensorDataset(x_test, y_preds_test)
            dummy_concepts = [f'feature{i:010}' for i in range(len(concepts))]
            model = XBRLClassifier(name=os.path.join(results_dir, f'{method}_{split}'), 
                                   n_classes=1, 
                                   n_features=x_train.shape[1], feature_names=dummy_concepts, 
                                   class_names=['~day_28_flg', 'day_28_flg'], discretize=True)
            results = model.fit(train_dataset, metric=accuracy_score, save=True)
            global_explanation, elapsed_time = model.get_global_explanation(0, concepts=dummy_concepts, return_time=True)
            explanation_fidelity = 0
            explanation_complexity = 0
            
            # compute metrics
            y_preds = model.forward(x_test)
            model_accuracy = accuracy_score(y_test, y_preds)
            explanation = logic.base.replace_names(global_explanation, concepts)
            target_class = 1
            explanation_accuracy, y_formula = logic.base.test_explanation(global_explanation, 
                                                                          target_class, 
                                                                          x=x_test, y=y_test,
                                                                          metric=accuracy_score)
            explanation_fidelity = accuracy_score(y_preds_test.cpu().detach().numpy(), y_formula)
            explanation_complexity = lens.logic.complexity(global_explanation)
                
        else:
            if method == 'psi':
                # positive class
                target_class = 1
                model = train_psi_nn(x_trainval, y_preds_trainval.eq(target_class), need_pruning, seed, device, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy()>0.5
                model_accuracy = accuracy_score(y_preds_test.cpu().detach().numpy(), y_preds)
                
            else:
                model = train_nn(x_trainval, y_preds_trainval, need_pruning, seed, device, l1, lr, relu, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy().argmax(axis=1)
                model_accuracy = accuracy_score(y_preds_test.cpu().detach().numpy(), y_preds)

            print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
            
            # positive class
            target_class = 1
            if method == 'psi':
                global_explanation = logic.generate_fol_explanations(model, device)[0]
                print(global_explanation)
            else:
                global_explanation, _, _ = logic.relu_nn.combine_local_explanations(model, 
                                                                                   x_val.to(device), 
                                                                                   y_preds_val.to(device), 
                                                                                   topk_explanations=1,
                                                                                   target_class=target_class,
                                                                                   method=method, device=device)
            
            explanation_fidelity = 0
            explanation_complexity = 0
            if global_explanation:
                explanation = logic.base.replace_names(global_explanation, concepts)
                explanation_accuracy, y_formula = logic.base.test_explanation(global_explanation, 
                                                                              target_class, 
                                                                              x=x_test, y=y_test,
                                                                              metric=accuracy_score)
                explanation_fidelity = accuracy_score(y_preds_test.cpu().detach().numpy(), y_formula)
                explanation_complexity = lens.logic.complexity(global_explanation)

        elapsed_time = time.time() - start
        
        if verbose:
            print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
            print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
            print(f'\t Elapsed time {elapsed_time}')
            print(f'\t Fidelity: "{explanation_fidelity:.4f}" - Complexity: "{explanation_complexity}"')

        methods.append(method)
        splits.append(split)
        explanations.append(explanation)
        model_accuracies.append(black_box_accuracy)
        explanation_accuracies.append(explanation_accuracy)
        elapsed_times.append(elapsed_time)
        explanation_fidelities.append(explanation_fidelity)
        explanation_complexities.append(explanation_complexity)
    
    explanation_consistency = lens.logic.formula_consistency(explanations)
    print(f'Consistency of explanations: {explanation_consistency:.4f}')
    
    results = pd.DataFrame({
        'method': methods,
        'split': splits,
        'explanation': explanations,
        'model_accuracy': model_accuracies,
        'explanation_accuracy': explanation_accuracies,
        'explanation_fidelity': explanation_fidelities,
        'explanation_complexity': explanation_complexities,
        'explanation_consistency': explanation_consistency,
        'elapsed_time': elapsed_times,
    })
    results.to_csv(os.path.join(results_dir, f'results_{method}.csv'))
    
    return results

# General pruning

In [6]:
method = 'pruning'
need_pruning = True
relu = False
results_pruning = c_to_y(method, need_pruning, relu=relu, l1=0.000001, lr=0.0005, verbose=True)
results_pruning

Split [1/10]
	 Black box accuracy: 0.8652
	 Epoch 0: train accuracy: 0.1120
	 Epoch 500: train accuracy: 0.9700
	 Epoch 1000: train accuracy: 0.9887
	 Epoch 1500: train accuracy: 0.9887
	 Epoch 2000: train accuracy: 0.9887
Pruned 85/90 features
	 Epoch 2500: train accuracy: 0.9218
	 Epoch 3000: train accuracy: 0.9218
	 Epoch 3500: train accuracy: 0.9218
	 Epoch 4000: train accuracy: 0.9218
	 Epoch 4500: train accuracy: 0.9218
	 Epoch 5000: train accuracy: 0.9218
	 Model's accuracy: 0.9551
	 Model's accuracy: 0.9551
	 Class 1 - Global explanation: "afib_flg & stroke_flg & ~abg_count_NORMAL & ~sapsi_first_HIGH & ~sodium_first_LOW" - Accuracy: 0.8427
	 Elapsed time 15.418745040893555
	 Fidelity: "0.9551" - Complexity: "5"
Split [2/10]
	 Black box accuracy: 0.8652
	 Epoch 0: train accuracy: 0.1176
	 Epoch 500: train accuracy: 0.9844
	 Epoch 1000: train accuracy: 0.9919
	 Epoch 1500: train accuracy: 0.9919
	 Epoch 2000: train accuracy: 0.9919
Pruned 86/90 features
	 Epoch 2500: train accura

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,pruning,0,afib_flg & stroke_flg & ~abg_count_NORMAL & ~s...,0.865169,0.842697,0.955056,5,0.411765,15.418745
1,pruning,1,afib_flg & stroke_flg & ~sapsi_first_HIGH & ~w...,0.865169,0.853933,0.921348,4,0.411765,14.452854
2,pruning,2,stroke_flg & age_HIGH & ~hgb_first_LOW & ~map_...,0.876404,0.876404,0.932584,7,0.411765,15.885037
3,pruning,3,stroke_flg & age_HIGH & ~hgb_first_LOW & ~saps...,0.859551,0.859551,0.955056,7,0.411765,14.851997
4,pruning,4,afib_flg & stroke_flg & age_HIGH & ~hgb_first_...,0.825843,0.848315,0.910112,7,0.411765,15.033449
5,pruning,5,stroke_flg & age_HIGH & ~afib_flg & ~hgb_first...,0.859551,0.88764,0.91573,7,0.411765,14.964742
6,pruning,6,stroke_flg & age_HIGH & ~afib_flg & ~hgb_first...,0.80791,0.836158,0.881356,6,0.411765,14.962678
7,pruning,7,stroke_flg & age_HIGH & ~afib_flg & ~abg_count...,0.870056,0.847458,0.954802,12,0.411765,15.800543
8,pruning,8,stroke_flg & age_HIGH & ~afib_flg & ~age_LOW &...,0.830508,0.853107,0.875706,6,0.411765,15.153544
9,pruning,9,stroke_flg & age_HIGH & ~afib_flg & ~bun_first...,0.853107,0.870056,0.926554,9,0.411765,16.495365


# ReLUs

In [7]:
method = 'weights'
need_pruning = False
relu = True
results_weights = c_to_y(method='weights', need_pruning=False, l1=0.0000001, lr=0.0005, relu=True, verbose=True)
results_weights

Split [1/10]
	 Black box accuracy: 0.8652
	 Epoch 0: train accuracy: 0.1120
	 Epoch 500: train accuracy: 0.9931
	 Epoch 1000: train accuracy: 0.9931
	 Epoch 1500: train accuracy: 0.9931
	 Epoch 2000: train accuracy: 0.9931
	 Epoch 2500: train accuracy: 0.9931
	 Epoch 3000: train accuracy: 0.9931
	 Epoch 3500: train accuracy: 0.9931
	 Epoch 4000: train accuracy: 0.9931
	 Epoch 4500: train accuracy: 0.9931
	 Epoch 5000: train accuracy: 0.9931
	 Model's accuracy: 0.9775
	 Model's accuracy: 0.9775
	 Class 1 - Global explanation: "stroke_flg & ~afib_flg & ~abg_count_NORMAL & ~hgb_first_LOW & ~sapsi_first_HIGH & ~sodium_first_LOW & ~weight_first_NORMAL" - Accuracy: 0.8371
	 Elapsed time 15.642074584960938
	 Fidelity: "0.8933" - Complexity: "7"
Split [2/10]
	 Black box accuracy: 0.8652
	 Epoch 0: train accuracy: 0.1176
	 Epoch 500: train accuracy: 0.9931
	 Epoch 1000: train accuracy: 0.9937
	 Epoch 1500: train accuracy: 0.9937
	 Epoch 2000: train accuracy: 0.9937
	 Epoch 2500: train accuracy:

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,weights,0,stroke_flg & ~afib_flg & ~abg_count_NORMAL & ~...,0.865169,0.837079,0.893258,7,0.355556,15.642075
1,weights,1,stroke_flg & ~afib_flg & ~weight_first_NORMAL,0.865169,0.825843,0.882022,3,0.355556,15.179163
2,weights,2,stroke_flg & ~sapsi_first_LOW & ~sapsi_first_H...,0.876404,0.865169,0.932584,4,0.355556,19.137606
3,weights,3,stroke_flg & ~abg_count_NORMAL & ~age_LOW & ~b...,0.859551,0.859551,0.94382,8,0.355556,16.15541
4,weights,4,stroke_flg & age_HIGH & ~abg_count_NORMAL & ~c...,0.825843,0.859551,0.94382,7,0.355556,21.770095
5,weights,5,stroke_flg & ~creatinine_first_NORMAL & ~hgb_f...,0.859551,0.825843,0.898876,7,0.355556,17.686428
6,weights,6,stroke_flg & ~resp_flg & ~creatinine_first_NOR...,0.80791,0.830508,0.864407,9,0.355556,16.702914
7,weights,7,afib_flg & age_HIGH & ~stroke_flg & ~abg_count...,0.870056,0.841808,0.903955,12,0.355556,23.076098
8,weights,8,stroke_flg & ~sodium_first_LOW,0.830508,0.847458,0.881356,2,0.355556,15.923318
9,weights,9,~stroke_flg & ~hgb_first_LOW & ~sapsi_first_LO...,0.853107,0.514124,0.514124,5,0.355556,23.292561


# Psi network

In [8]:
method = 'psi'
need_pruning = True
relu = False
results_psi = c_to_y(method, need_pruning, l1=0.0001, lr=0.0001, relu=False, verbose=True)
results_psi

Split [1/10]
	 Black box accuracy: 0.8652
	 Epoch 0: train accuracy: 0.1120
	 Epoch 500: train accuracy: 0.9743
	 Epoch 1000: train accuracy: 0.9775
	 Epoch 1500: train accuracy: 0.9756
	 Epoch 2000: train accuracy: 0.9750
	 Epoch 2500: train accuracy: 0.9080
	 Epoch 3000: train accuracy: 0.9074
	 Epoch 3500: train accuracy: 0.9074
	 Epoch 4000: train accuracy: 0.9074
	 Epoch 4500: train accuracy: 0.9074
	 Epoch 5000: train accuracy: 0.9074
	 Model's accuracy: 0.9213
(feature0000000009)
	 Model's accuracy: 0.9213
	 Class 1 - Global explanation: "(stroke_flg)" - Accuracy: 0.8427
	 Elapsed time 10.046723365783691
	 Fidelity: "0.9213" - Complexity: "1"
Split [2/10]
	 Black box accuracy: 0.8652
	 Epoch 0: train accuracy: 0.1176
	 Epoch 500: train accuracy: 0.9637
	 Epoch 1000: train accuracy: 0.9700
	 Epoch 1500: train accuracy: 0.9700
	 Epoch 2000: train accuracy: 0.9768
	 Epoch 2500: train accuracy: 0.9043
	 Epoch 3000: train accuracy: 0.9036
	 Epoch 3500: train accuracy: 0.9036
	 Epoch 

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,psi,0,(stroke_flg),0.865169,0.842697,0.921348,1,0.433333,10.046723
1,psi,1,(stroke_flg),0.865169,0.825843,0.882022,1,0.433333,10.293958
2,psi,2,(stroke_flg & age_HIGH),0.876404,0.876404,0.932584,2,0.433333,10.179928
3,psi,3,(stroke_flg & ~sapsi_first_LOW),0.859551,0.859551,0.94382,2,0.433333,10.093949
4,psi,4,(age_HIGH),0.825843,0.741573,0.814607,1,0.433333,10.063174
5,psi,5,(stroke_flg & ~sapsi_first_LOW),0.859551,0.837079,0.910112,2,0.433333,10.146026
6,psi,6,(stroke_flg),0.80791,0.819209,0.875706,1,0.433333,10.174666
7,psi,7,(stroke_flg),0.870056,0.858757,0.943503,1,0.433333,10.292563
8,psi,8,(stroke_flg),0.830508,0.847458,0.881356,1,0.433333,10.045091
9,psi,9,(stroke_flg),0.853107,0.881356,0.926554,1,0.433333,10.057718


# Decision tree

In [9]:
method = 'tree'
need_pruning = False
relu = False
results_tree = c_to_y(method, need_pruning, relu, verbose=True)
results_tree

Split [1/10]
	 Black box accuracy: 0.8652
	 Model's accuracy: 0.8596
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & afib_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & tco2_first_HIGH <= 0.50 & abg_count_HIGH <= 0.50 & hr_1st_HIGH <= 0.50 & chloride_first_LOW > 0.50 & hr_1st_LOW > 0.50 & pco2_first_LOW <= 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & tco2_first_HIGH <= 0.50 & abg_count_HIGH <= 0.50 & hr_1st_HIGH > 0.50 & platelet_first_HIGH <= 0.50 & sofa_first_HIGH <= 0.50 & age_HIGH <= 0.50 & chloride_first_LOW > 0.50 & mal_flg > 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & tco2_first_HIGH <= 0.50 & abg_count_HIGH <= 0.50 & hr_1st_HIGH > 0.50 & platelet_first_HIGH <= 0.50 & sofa_first_HIGH <= 0.50 & age_HIGH > 0.50 & tco2_first_LOW > 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & tco2_fi

	 Black box accuracy: 0.8652
	 Model's accuracy: 0.8596
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & afib_flg <= 0.50 & age_HIGH <= 0.50 & hr_1st_HIGH > 0.50 & iv_day_1_NORMAL <= 0.50 & sofa_first_HIGH > 0.50 & po2_first_NORMAL > 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & age_HIGH <= 0.50 & hr_1st_HIGH > 0.50 & iv_day_1_NORMAL > 0.50 & potassium_first_LOW <= 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & age_HIGH > 0.50 & sofa_first_HIGH <= 0.50 & tco2_first_HIGH <= 0.50 & chloride_first_HIGH <= 0.50 & hr_1st_HIGH <= 0.50 & abg_count_LOW <= 0.50 & potassium_first_NORMAL <= 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & age_HIGH > 0.50 & sofa_first_HIGH <= 0.50 & tco2_first_HIGH <= 0.50 & chloride_first_HIGH <= 0.50 & hr_1st_HIGH <= 0.50 & abg_count_LOW > 0.50 & chf_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL > 0.50 & chloride_first_LOW > 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & age_HIGH > 0.50 & sofa_first_HIGH <= 0.50 & tco2_first_HIGH <= 0.50

	 Black box accuracy: 0.8764
	 Model's accuracy: 0.8596
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL > 0.50 & mal_flg <= 0.50 & afib_flg > 0.50 & platelet_first_NORMAL > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL > 0.50 & mal_flg > 0.50 & renal_flg <= 0.50 & sapsi_first_NORMAL > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH > 0.50 & hr_1st_HIGH > 0.50 & creatinine_first_HIGH <= 0.50) | (stroke_flg <= 0.50 & age_HIGH > 0.50 & resp_flg <= 0.50 & chloride_first_HIGH <= 0.50 & hr_1st_HIGH <= 0.50 & afib_flg <= 0.50 & abg_count_LOW <= 0.50) | (stroke_flg <= 0.50 & age_HIGH > 0.50 & resp_flg <= 0.50 & chloride_first_HIGH <= 0.50 & hr_1st_HIGH <= 0.50 & afib_flg <= 0.50 & abg_count_LOW > 0.50 & bun_first_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & chf_flg > 0.50 & mal_flg > 0.50 & day_icu_intime_num_LOW > 0.50) | (stroke_flg <=

	 Black box accuracy: 0.8596
	 Model's accuracy: 0.8652
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH > 0.50 & afib_flg <= 0.50 & icu_los_day_NORMAL > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH > 0.50 & afib_flg > 0.50) | (stroke_flg <= 0.50 & age_HIGH > 0.50 & afib_flg <= 0.50 & sofa_first_HIGH <= 0.50 & tco2_first_HIGH <= 0.50 & hr_1st_HIGH <= 0.50 & bun_first_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & resp_flg <= 0.50 & chloride_first_HIGH <= 0.50 & hr_1st_LOW > 0.50 & hour_icu_intime_LOW > 0.50 & chf_flg <= 0.50 & tco2_first_NORMAL <= 0.50 & map_1st_LOW <= 0.50) | (stroke_flg <= 0.50 & age_HIGH > 0.50 & afib_flg <= 0.50 & sofa_first_HIGH <= 0.50 & tco2_first_HIGH <= 0.50 & hr_1st_HIGH <= 0.50 & bun_first_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & resp_flg <= 0.50 & chloride_first_HIGH <= 0.50 & hr_1st_LOW > 0.50 & hour_icu_intime_LOW > 0.50 & chf_flg > 0.50 & day_icu_intime_num_HIGH <= 0.50) | (stroke_flg <= 0.50 & a

	 Black box accuracy: 0.8258
	 Model's accuracy: 0.8596
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & age_HIGH <= 0.50 & hr_1st_HIGH <= 0.50 & afib_flg <= 0.50 & mal_flg > 0.50 & hour_icu_intime_NORMAL > 0.50 & aline_flg <= 0.50 & hr_1st_LOW > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & hr_1st_HIGH <= 0.50 & afib_flg > 0.50 & resp_flg > 0.50 & map_1st_LOW <= 0.50 & service_num_LOW <= 0.50 & hour_icu_intime_NORMAL > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & hr_1st_HIGH <= 0.50 & afib_flg > 0.50 & resp_flg > 0.50 & map_1st_LOW > 0.50 & day_icu_intime_num_HIGH <= 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & hr_1st_HIGH > 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & map_1st_LOW <= 0.50 & mal_flg > 0.50 & tco2_first_LOW > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & hr_1st_HIGH > 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & map_1st_LOW > 0.50 & chloride_first_NORMAL <= 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & hr_1st_

	 Black box accuracy: 0.8596
	 Model's accuracy: 0.8596
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_NORMAL <= 0.50 & iv_day_1_LOW <= 0.50 & hr_1st_NORMAL <= 0.50 & hour_icu_intime_NORMAL > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_NORMAL <= 0.50 & iv_day_1_LOW > 0.50 & afib_flg > 0.50 & day_icu_intime_num_LOW > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_NORMAL > 0.50 & sofa_first_HIGH <= 0.50 & afib_flg <= 0.50 & hr_1st_HIGH > 0.50 & mal_flg > 0.50 & tco2_first_LOW > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & chloride_first_NORMAL > 0.50 & sofa_first_HIGH <= 0.50 & afib_flg > 0.50 & hour_icu_intime_NORMAL > 0.50 & tco2_first_NORMAL <= 0.50 & copd_flg <= 0.50) | (stroke_

	 Black box accuracy: 0.8079
	 Model's accuracy: 0.8079
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg <= 0.50 & sofa_first_HIGH <= 0.50 & bun_first_NORMAL <= 0.50 & chloride_first_LOW > 0.50 & map_1st_NORMAL <= 0.50 & tco2_first_NORMAL <= 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg <= 0.50 & sofa_first_HIGH <= 0.50 & bun_first_NORMAL > 0.50 & cad_flg > 0.50 & sodium_first_HIGH > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg <= 0.50 & sofa_first_HIGH > 0.50 & hour_icu_intime_NORMAL > 0.50 & resp_flg > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg > 0.50 & resp_flg > 0.50 & chloride_first_NORMAL <= 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg > 0.50 & resp_flg > 0.50 & chloride_first_NORMAL > 0.50 & day_icu_intime_num_NORMAL > 0.50 & gender_num > 0.50 & age_LOW <= 0.50

	 Black box accuracy: 0.8701
	 Model's accuracy: 0.8814
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & afib_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_LOW <= 0.50 & age_NORMAL <= 0.50 & sofa_first_LOW <= 0.50 & creatinine_first_NORMAL <= 0.50 & sapsi_first_LOW <= 0.50 & chloride_first_HIGH <= 0.50 & resp_flg <= 0.50 & po2_first_NORMAL <= 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_LOW <= 0.50 & age_NORMAL <= 0.50 & sofa_first_LOW <= 0.50 & creatinine_first_NORMAL <= 0.50 & sapsi_first_LOW <= 0.50 & chloride_first_HIGH <= 0.50 & resp_flg > 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bun_first_LOW > 0.50 & tco2_first_HIGH <= 0.50 & platelet_first_HIGH <= 0.50 & sofa_first_HIGH <= 0.50 & age_HIGH > 0.50 & hr_1st_HIGH <= 0.50 & mal_flg > 0.50 & resp_flg > 0.50 & tco2_first_LOW <= 0.50 & hr_1st_NORMAL <= 0.50 & gender_num <= 0.50) | (stroke_flg <= 0.50 & afib_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & bu

	 Black box accuracy: 0.8305
	 Model's accuracy: 0.8588
	 Class 1 - Global explanation: "(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg <= 0.50 & sofa_first_HIGH <= 0.50 & mal_flg > 0.50 & hour_icu_intime_NORMAL > 0.50 & hr_1st_LOW > 0.50 & aline_flg <= 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg <= 0.50 & sofa_first_HIGH > 0.50 & day_icu_intime_num_HIGH > 0.50 & map_1st_NORMAL <= 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg > 0.50 & bun_first_LOW <= 0.50 & day_icu_intime_num_NORMAL > 0.50 & sodium_first_NORMAL > 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg > 0.50 & bun_first_LOW > 0.50 & day_icu_intime_num_LOW > 0.50 & chloride_first_NORMAL <= 0.50) | (stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi_first_HIGH > 0.50 & hr_1st_HIGH > 0.50 & hour_icu_intime_LOW <= 0.50) | (stroke_flg <= 0.50 & age_HIGH > 0.50 & afib_flg <= 0.50 & sofa_first

	 Black box accuracy: 0.8531
	 Model's accuracy: 0.8418
	 Class 1 - Global explanation: "(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg <= 0.50 & iv_day_1_NORMAL <= 0.50 & sofa_first_HIGH > 0.50 & map_1st_LOW > 0.50 & hr_1st_HIGH > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg <= 0.50 & iv_day_1_NORMAL > 0.50 & hr_1st_LOW <= 0.50 & hr_1st_HIGH > 0.50 & potassium_first_NORMAL > 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg <= 0.50 & iv_day_1_NORMAL > 0.50 & hr_1st_LOW > 0.50 & platelet_first_NORMAL <= 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg > 0.50 & resp_flg > 0.50 & chloride_first_NORMAL <= 0.50 & day_icu_intime_num_HIGH <= 0.50) | (age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi_first_HIGH <= 0.50 & afib_flg > 0.50 & resp_flg > 0.50 & chloride_first_NORMAL > 0.50 & day_icu_intime_num_NORMAL > 0.50 & bun_first_LOW <= 0.50) | (age_HIGH <=

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,tree,0,(stroke_flg <= 0.50 & afib_flg <= 0.50 & sapsi...,0.865169,0.859551,0.91573,432,0.603597,0.009974
1,tree,1,(stroke_flg <= 0.50 & afib_flg <= 0.50 & age_H...,0.865169,0.859551,0.926966,432,0.603597,0.007979
2,tree,2,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.876404,0.859551,0.91573,361,0.603597,0.008977
3,tree,3,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.859551,0.865169,0.949438,418,0.603597,0.006981
4,tree,4,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & hr_1s...,0.825843,0.859551,0.932584,513,0.603597,0.008976
5,tree,5,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.859551,0.859551,0.94382,505,0.603597,0.010001
6,tree,6,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.80791,0.80791,0.920904,477,0.603597,0.008976
7,tree,7,(stroke_flg <= 0.50 & afib_flg <= 0.50 & sapsi...,0.870056,0.881356,0.943503,370,0.603597,0.008985
8,tree,8,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & sapsi...,0.830508,0.858757,0.903955,413,0.603597,0.008977
9,tree,9,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.853107,0.841808,0.954802,400,0.603597,0.008976


In [7]:
method = 'tree-3'
need_pruning = False
relu = False
results_tree = c_to_y(method, need_pruning, relu, verbose=True)
results_tree

Split [1/10]
	 Black box accuracy: 0.8652
	 Model's accuracy: 0.8596
	 Class 1 - Global explanation: "(stroke_flg > 0.50 & age_HIGH > 0.50 & sapsi_first_LOW <= 0.50)" - Accuracy: 0.8596
	 Elapsed time 0.004986763000488281
	 Fidelity: "0.9494" - Complexity: "3"
Split [2/10]
	 Black box accuracy: 0.8652
	 Model's accuracy: 0.8764
	 Class 1 - Global explanation: "(stroke_flg > 0.50 & age_HIGH > 0.50 & weight_first_LOW > 0.50)" - Accuracy: 0.8764
	 Elapsed time 0.00594639778137207
	 Fidelity: "0.9326" - Complexity: "3"
Split [3/10]
	 Black box accuracy: 0.8764
	 Model's accuracy: 0.8764
	 Class 1 - Global explanation: "(stroke_flg > 0.50 & age_HIGH <= 0.50 & bun_first_LOW <= 0.50) | (stroke_flg > 0.50 & age_HIGH > 0.50 & sapsi_first_LOW <= 0.50)" - Accuracy: 0.8764
	 Elapsed time 0.007995367050170898
	 Fidelity: "0.9326" - Complexity: "6"
Split [4/10]
	 Black box accuracy: 0.8596
	 Model's accuracy: 0.8652
	 Class 1 - Global explanation: "(stroke_flg > 0.50 & age_HIGH > 0.50 & sapsi_first_

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,tree-3,0,(stroke_flg > 0.50 & age_HIGH > 0.50 & sapsi_f...,0.865169,0.859551,0.949438,3,0.285714,0.004987
1,tree-3,1,(stroke_flg > 0.50 & age_HIGH > 0.50 & weight_...,0.865169,0.876404,0.932584,3,0.285714,0.005946
2,tree-3,2,(stroke_flg > 0.50 & age_HIGH <= 0.50 & bun_fi...,0.876404,0.876404,0.932584,6,0.285714,0.007995
3,tree-3,3,(stroke_flg > 0.50 & age_HIGH > 0.50 & sapsi_f...,0.859551,0.865169,0.960674,3,0.285714,0.004992
4,tree-3,4,(stroke_flg > 0.50 & age_HIGH > 0.50 & po2_fir...,0.825843,0.859551,0.94382,6,0.285714,0.006982
5,tree-3,5,(stroke_flg <= 0.50 & age_HIGH > 0.50 & sodium...,0.859551,0.865169,0.949438,6,0.285714,0.005984
6,tree-3,6,(stroke_flg > 0.50 & age_HIGH > 0.50 & sapsi_f...,0.80791,0.841808,0.920904,3,0.285714,0.006984
7,tree-3,7,(stroke_flg > 0.50 & age_HIGH <= 0.50 & liver_...,0.870056,0.870056,0.977401,9,0.285714,0.005995
8,tree-3,8,(stroke_flg > 0.50 & age_HIGH > 0.50 & sapsi_f...,0.830508,0.887006,0.920904,3,0.285714,0.005983
9,tree-3,9,(age_HIGH > 0.50 & stroke_flg <= 0.50 & afib_f...,0.853107,0.870056,0.960452,6,0.285714,0.005975


# BRL

In [12]:
method = 'brl'
need_pruning = False
relu = False
results_brl = c_to_y(method, need_pruning, relu, verbose=True)
results_brl

Split [1/10]
	 Black box accuracy: 0.8652
Labels (1598, 1)
Discretized features
Completed model 1/1!
Train_acc: 0.9, Val_acc: 0.0
	 Model's accuracy: 0.8596
	 Class 1 - Global explanation: "(stroke_flg & age_HIGH & sodium_first_HIGH) | (stroke_flg & age_HIGH & ~sapsi_first_LOW) | (stroke_flg & sodium_first_HIGH & ~hgb_first_HIGH) | (stroke_flg & ~hgb_first_HIGH & ~sapsi_first_LOW)" - Accuracy: 0.8652
	 Elapsed time 551.2550559043884
	 Fidelity: "0.9326" - Complexity: "12"
Split [2/10]
	 Black box accuracy: 0.8652
Labels (1598, 1)
Discretized features
Completed model 1/1!
Train_acc: 0.9, Val_acc: 0.0
	 Model's accuracy: 0.8652
	 Class 1 - Global explanation: "(stroke_flg & age_HIGH & wbc_first_HIGH & ~sapsi_first_LOW) | (stroke_flg & age_HIGH & weight_first_LOW & ~sapsi_first_LOW) | (stroke_flg & wbc_first_HIGH & ~hgb_first_HIGH & ~sapsi_first_LOW) | (stroke_flg & weight_first_LOW & ~hgb_first_HIGH & ~sapsi_first_LOW)" - Accuracy: 0.8652
	 Elapsed time 550.8910205364227
	 Fidelity: "0.9

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,brl,0,(stroke_flg & age_HIGH & sodium_first_HIGH) | ...,0.865169,0.865169,0.932584,12,0.222222,551.255056
1,brl,1,(stroke_flg & age_HIGH & wbc_first_HIGH & ~sap...,0.865169,0.865169,0.932584,16,0.222222,550.891021
2,brl,2,(renal_flg & stroke_flg & ~abg_count_LOW) | (r...,0.876404,0.876404,0.932584,28,0.222222,545.325771
3,brl,3,(stroke_flg & age_HIGH & weight_first_LOW & ~s...,0.859551,0.865169,0.960674,8,0.222222,526.027836
4,brl,4,(stroke_flg & hgb_first_NORMAL & ~sapsi_first_...,0.825843,0.853933,0.949438,15,0.222222,540.576839
5,brl,5,(chf_flg & stroke_flg & age_HIGH & ~sodium_fir...,0.859551,0.865169,0.949438,16,0.222222,559.686333
6,brl,6,(afib_flg & resp_flg & ~(sapsi_first_LOW & ~tc...,0.80791,0.824859,0.937853,39,0.222222,536.183713
7,brl,7,(stroke_flg & age_HIGH & ~sodium_first_LOW) | ...,0.870056,0.864407,0.983051,13,0.222222,548.4927
8,brl,8,(stroke_flg & ~sapsi_first_LOW & ~(hr_1st_NORM...,0.830508,0.847458,0.949153,41,0.222222,527.914999
9,brl,9,(creatinine_first_LOW & sapsi_first_HIGH & ~(a...,0.853107,0.841808,0.954802,41,0.222222,547.931929


# Summary

In [8]:
import glob

cols = ['model_accuracy', 'explanation_accuracy', 'explanation_fidelity', 
        'explanation_complexity', 'elapsed_time', 'explanation_consistency']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

df_list = []
files = glob.glob(f'{results_dir}/**/results**.csv', recursive = True)
for file in files:
    df = pd.read_csv(file, index_col=0)
    method = df['method'][0]
    df_mean = df[cols].mean()
    df_sem = df[cols].sem()
    df_mean.columns = mean_cols
    df_sem.columns = sem_cols
    summary = pd.concat([df_mean, df_sem])
    summary.name = method
    df_list.append(summary)


summary = pd.concat(df_list, axis=1).T
summary.columns = mean_cols + sem_cols
summary

Unnamed: 0,model_accuracy_mean,explanation_accuracy_mean,explanation_fidelity_mean,explanation_complexity_mean,elapsed_time_mean,explanation_consistency_mean,model_accuracy_sem,explanation_accuracy_sem,explanation_fidelity_sem,explanation_complexity_sem,elapsed_time_sem,explanation_consistency_sem
brl,0.851327,0.856954,0.948216,22.9,543.42862,0.222222,0.007047,0.004787,0.005007,4.132392,3.390933,0.0
pruning,0.851327,0.857532,0.922831,7.0,15.301895,0.411765,0.007047,0.005077,0.009045,0.699206,0.190695,0.0
psi,0.851327,0.838993,0.903161,1.3,10.13938,0.433333,0.007047,0.012539,0.012733,0.152753,0.030241,0.0
tree-3,0.851327,0.867117,0.94482,4.8,0.006182,0.285714,0.007047,0.00387,0.005816,0.663325,0.000291,0.0
tree,0.851327,0.855275,0.930743,432.1,0.00888,0.603597,0.007047,0.006061,0.005318,16.459006,0.000277,0.0
weights,0.851327,0.810693,0.865822,6.4,18.456567,0.355556,0.007047,0.03326,0.040017,0.945163,1.001188,0.0


In [9]:
summary.to_csv(os.path.join(results_dir, 'summary.csv'))