In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, KBinsDiscretizer
from sklearn.impute import SimpleImputer
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
import time
from sklearn.metrics import accuracy_score
from sympy import simplify_logic
from sklearn.model_selection import train_test_split

import deep_logic as dl
from deep_logic.utils.base import validate_network, set_seed, tree_to_formula
from deep_logic.utils.layer import prune_logic_layers
from deep_logic import logic

set_seed(0)

#%%

data = pd.read_csv('data/mimic-ii/full_cohort_data.csv')
# data.drop('hgb_first')
fs = [
    'aline_flg',
    'gender_num',
    # 'hosp_exp_flg',
    # 'icu_exp_flg',
    # 'day_28_flg',
    # 'censor_flg',
    'sepsis_flg', 'chf_flg', 'afib_flg',
    'renal_flg', 'liver_flg', 'copd_flg', 'cad_flg', 'stroke_flg',
    'mal_flg', 'resp_flg',
]
features = fs
data1 = data[fs].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data1 = imp_mean.fit_transform(data1)

f2 = fs.copy()
f2.append('day_icu_intime')
f2.append('service_unit')
f2.append('day_28_flg')
f2.append('hospital_los_day')
f2.append('icu_exp_flg')
f2.append('hosp_exp_flg')
f2.append('censor_flg')
f2.append('mort_day_censored')
f2 = data.columns.difference(f2)
data2 = data[f2].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data2 = imp_mean.fit_transform(data2)
scaler = MinMaxScaler((0, 1))
data2 = scaler.fit_transform(data2)
features = features + list(f2)
est = KBinsDiscretizer(n_bins=3, encode='onehot-dense', strategy='uniform')
data2d = est.fit_transform(data2)
f2d = []
for feature in f2:
    #f2d.append(feature + '_VLOW')
    f2d.append(feature + '_LOW')
    f2d.append(feature + '_NORMAL')
    f2d.append(feature + '_HIGH')
    #f2d.append(feature + '_VHIGH')
features = fs + f2d

datax = np.hstack((data1, data2d))
datay = data['day_28_flg'].values

x = torch.FloatTensor(datax)
y = torch.FloatTensor(datay)
print(x.shape)
print(y.shape)

torch.Size([1776, 90])
torch.Size([1776])


In [2]:
results_dir = 'results_ll/mimic'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# concepts = [f'c{i:03}' for i in range(x.shape[1])]
concepts =  features
n_rep = 10
tot_epochs = 5001
prune_epochs = 2001
seed = 42

n_splits = 10
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)

In [3]:
def train_nn(x_train, y_train, seed, device, l1=0.001, lr=0.001, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        dl.nn.XLogic(x_train.size(1), 50, first=True),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(50, 30),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(30, 1),
        torch.nn.LeakyReLU(),
        dl.nn.XLogic(1, 1, top=True),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_form = torch.nn.BCELoss()
    model.train()
    need_pruning = True
    for epoch in range(tot_epochs):
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, dl.nn.XLogic):
                loss += l1 * torch.norm(module.weight, 1)
                loss += l1 * torch.norm(module.bias, 1)

        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            dl.utils.layer.prune_logic_layers(model.to(device), fan_in=5, device=device)
            need_pruning = False

        # compute accuracy
        if epoch % 100 == 0 and verbose:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [4]:
def c_to_y(method, verbose=False):
    methods = []
    splits = []
    explanations = []
    model_accuracies = []
    explanation_accuracies = []
    explanation_fidelities = []
    explanation_complexities = []
    elapsed_times = []
    for split, (trainval_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
        print(f'Split [{split+1}/{n_splits}]')
        x_trainval, x_test = torch.FloatTensor(x[trainval_index]), torch.FloatTensor(x[test_index])
        y_trainval, y_test = torch.FloatTensor(y[trainval_index]), torch.FloatTensor(y[test_index])
        x_train, x_val, y_train, y_val = train_test_split(x_trainval, y_trainval, test_size=0.3, random_state=42)
    
        explanation, explanation_inv = '', ''
        explanation_accuracy, explanation_accuracy_inv = 0, 0
        
        if method == 'tree':
            classifier = DecisionTreeClassifier(random_state=seed)
            classifier.fit(x_train.detach().numpy(), y_train.detach().numpy())
            y_preds = classifier.predict(x_test.detach().numpy())
            model_accuracy = accuracy_score(y_test.detach().numpy(), y_preds)
            explanation_accuracy = model_accuracy

            target_class = 1
            start = time.time()
            explanation = tree_to_formula(classifier, concepts, target_class)
            elapsed_time = time.time() - start
            explanation_fidelity = 1.
            explanation_complexity = dl.logic.complexity(explanation)

            target_class_inv = 0
            start = time.time()
            explanation_inv = tree_to_formula(classifier, concepts, target_class_inv)
            elapsed_time = time.time() - start
        
        else:
            model = train_nn(x_trainval, y_trainval, seed, device, verbose=False)
            y_preds = model(x_test.to(device)).cpu().detach().numpy() > 0.5
            model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds)

            # positive class
            start = time.time()
            class_explanation, class_explanations = dl.logic.explain_class(model.cpu(), 
                                                                           x_trainval.cpu(), 
                                                                           y_trainval.cpu(), 
                                                                           binary=True, target_class=1,
                                                                           topk_explanations=3)
            elapsed_time = time.time() - start
            
            if class_explanation:
                explanation = logic.base.replace_names(class_explanation, concepts)
                explanation_accuracy, y_formula = logic.base.test_explanation(class_explanation, 
                                                                              target_class=1, 
                                                                              x=x_test, y=y_test,
                                                                              metric=accuracy_score)
                explanation_fidelity = dl.logic.fidelity(y_formula, y_preds)
                explanation_complexity = dl.logic.complexity(class_explanation)
        
        if verbose:
            print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
            print(f'\t Class 1 - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
            print(f'\t Fidelity: "{explanation_fidelity:.4f}" - Complexity: "{explanation_complexity}"')
            print(f'\t Elapsed time {elapsed_time}')

        methods.append(method)
        splits.append(seed)
        explanations.append(explanation)
        model_accuracies.append(model_accuracy)
        explanation_accuracies.append(explanation_accuracy)
        explanation_fidelities.append(explanation_fidelity)
        explanation_complexities.append(explanation_complexity)
        elapsed_times.append(elapsed_time)
    
    explanation_consistency = dl.logic.formula_consistency(explanations)
    print(f'Consistency of explanations: {explanation_consistency:.4f}')
    
    results = pd.DataFrame({
        'method': methods,
        'split': splits,
        'explanation': explanations,
        'model_accuracy': model_accuracies,
        'explanation_accuracy': explanation_accuracies,
        'explanation_fidelity': explanation_fidelities,
        'explanation_complexity': explanation_complexities,
        'explanation_consistency': explanation_consistency,
        'elapsed_time': elapsed_times,
    })
    results.to_csv(os.path.join(results_dir, f'results_{method}.csv'))
    
    return results

# General pruning

In [5]:
results_pruning = c_to_y(method='logic_layer', verbose=True)
results_pruning

Split [1/10]
	 Model's accuracy: 0.8596
	 Class 1 - Global explanation: "stroke_flg & age_HIGH & ~sapsi_first_LOW" - Accuracy: 0.8596
	 Fidelity: "1.0000" - Complexity: "3"
	 Elapsed time 0.3629953861236572
Split [2/10]
	 Model's accuracy: 0.8820
	 Class 1 - Global explanation: "(stroke_flg & abg_count_LOW & age_HIGH & ~age_LOW & ~sapsi_first_LOW) | (age_HIGH & ~stroke_flg & ~abg_count_LOW & ~age_LOW & ~sapsi_first_LOW)" - Accuracy: 0.8820
	 Fidelity: "1.0000" - Complexity: "10"
	 Elapsed time 0.3061814308166504
Split [3/10]
	 Model's accuracy: 0.8764
	 Class 1 - Global explanation: "stroke_flg & age_HIGH & ~sapsi_first_LOW" - Accuracy: 0.8764
	 Fidelity: "1.0000" - Complexity: "3"
	 Elapsed time 0.26928281784057617
Split [4/10]
	 Model's accuracy: 0.8652
	 Class 1 - Global explanation: "(stroke_flg & abg_count_LOW & age_HIGH & ~age_LOW & ~sapsi_first_LOW) | (age_HIGH & ~stroke_flg & ~abg_count_LOW & ~age_LOW & ~sapsi_first_LOW)" - Accuracy: 0.8652
	 Fidelity: "1.0000" - Complexity: "1

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,logic_layer,42,stroke_flg & age_HIGH & ~sapsi_first_LOW,0.859551,0.859551,1.0,3,0.6,0.362995
1,logic_layer,42,(stroke_flg & abg_count_LOW & age_HIGH & ~age_...,0.882022,0.882022,1.0,10,0.6,0.306181
2,logic_layer,42,stroke_flg & age_HIGH & ~sapsi_first_LOW,0.876404,0.876404,1.0,3,0.6,0.269283
3,logic_layer,42,(stroke_flg & abg_count_LOW & age_HIGH & ~age_...,0.865169,0.865169,1.0,10,0.6,0.312477
4,logic_layer,42,stroke_flg & age_HIGH & ~sapsi_first_LOW,0.859551,0.859551,1.0,3,0.6,0.273816
5,logic_layer,42,stroke_flg & age_HIGH & ~sapsi_first_LOW,0.870787,0.870787,1.0,3,0.6,0.370012
6,logic_layer,42,stroke_flg & age_HIGH & ~sapsi_first_LOW,0.841808,0.841808,1.0,3,0.6,0.256406
7,logic_layer,42,(stroke_flg & age_HIGH & ~age_LOW & ~sapsi_fir...,0.870056,0.870056,1.0,8,0.6,0.337131
8,logic_layer,42,stroke_flg & age_HIGH & ~sapsi_first_LOW,0.887006,0.887006,1.0,3,0.6,0.307857
9,logic_layer,42,stroke_flg & age_HIGH & ~sapsi_first_LOW,0.870056,0.870056,1.0,3,0.6,0.285181


# Decision tree

In [6]:
results_tree = c_to_y(method='tree', verbose=False)
results_tree

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]
Consistency of explanations: 0.7510


Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,tree,42,(age_HIGH <= 0.50 & resp_flg <= 0.50 & stroke_...,0.752809,0.752809,1.0,838,0.750955,0.002992
1,tree,42,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & copd_...,0.775281,0.775281,1.0,994,0.750955,0.004987
2,tree,42,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & resp_...,0.808989,0.808989,1.0,993,0.750955,0.00299
3,tree,42,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.758427,0.758427,1.0,1002,0.750955,0.002991
4,tree,42,(age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 &...,0.735955,0.735955,1.0,927,0.750955,0.004987
5,tree,42,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & mal_f...,0.792135,0.792135,1.0,903,0.750955,0.004003
6,tree,42,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & mal_f...,0.774011,0.774011,1.0,864,0.750955,0.002991
7,tree,42,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & copd_...,0.79096,0.79096,1.0,987,0.750955,0.004986
8,tree,42,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,0.779661,0.779661,1.0,811,0.750955,0.002992
9,tree,42,(age_HIGH <= 0.50 & age_NORMAL <= 0.50 & chlor...,0.785311,0.785311,1.0,898,0.750955,0.003988


# Summary

In [7]:
cols = ['model_accuracy', 'explanation_accuracy', 'explanation_fidelity', 
        'explanation_complexity', 'elapsed_time', 'explanation_consistency']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

# pruning
df_mean = results_pruning[cols].mean()
df_sem = results_pruning[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_pruning = pd.concat([df_mean, df_sem])
summary_pruning.name = 'pruning'

# tree
df_mean = results_tree[cols].mean()
df_sem = results_tree[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_tree = pd.concat([df_mean, df_sem])
summary_tree.name = 'tree'

summary = pd.concat([summary_pruning,
                     summary_tree], axis=1).T
summary.columns = mean_cols + sem_cols
summary.to_csv(os.path.join(results_dir, 'summary.csv'))

summary

Unnamed: 0,model_accuracy_mean,explanation_accuracy_mean,explanation_fidelity_mean,explanation_complexity_mean,elapsed_time_mean,explanation_consistency_mean,model_accuracy_sem,explanation_accuracy_sem,explanation_fidelity_sem,explanation_complexity_sem,elapsed_time_sem,explanation_consistency_sem
pruning,0.868241,0.868241,1.0,4.9,0.308134,0.6,0.004056,0.004056,0.0,0.982627,0.012289,3.700743e-17
tree,0.775354,0.775354,1.0,921.7,0.003791,0.750955,0.006768,0.006768,0.0,22.24912,0.00029,0.0
