In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, KBinsDiscretizer
from sklearn.impute import SimpleImputer
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
import time
from sklearn.metrics import accuracy_score
from sympy import simplify_logic
from sklearn.model_selection import train_test_split

import deep_logic as dl
from deep_logic.utils.base import validate_network, set_seed, tree_to_formula
from deep_logic.utils.layer import prune_logic_layers
from deep_logic import logic

set_seed(0)

#%%

data = pd.read_csv('data/mimic-ii/full_cohort_data.csv')
# data.drop('hgb_first')
fs = [
    'aline_flg',
    'gender_num',
    # 'hosp_exp_flg',
    # 'icu_exp_flg',
    # 'day_28_flg',
    # 'censor_flg',
    'sepsis_flg', 'chf_flg', 'afib_flg',
    'renal_flg', 'liver_flg', 'copd_flg', 'cad_flg', 'stroke_flg',
    'mal_flg', 'resp_flg',
]
features = fs
data1 = data[fs].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data1 = imp_mean.fit_transform(data1)

f2 = fs.copy()
f2.append('day_icu_intime')
f2.append('service_unit')
f2.append('day_28_flg')
f2.append('hospital_los_day')
f2.append('icu_exp_flg')
f2.append('hosp_exp_flg')
f2.append('censor_flg')
f2.append('mort_day_censored')
f2 = data.columns.difference(f2)
data2 = data[f2].values
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
data2 = imp_mean.fit_transform(data2)
scaler = MinMaxScaler((0, 1))
data2 = scaler.fit_transform(data2)
features = features + list(f2)
est = KBinsDiscretizer(n_bins=3, encode='onehot-dense', strategy='uniform')
data2d = est.fit_transform(data2)
f2d = []
for feature in f2:
    #f2d.append(feature + '_VLOW')
    f2d.append(feature + '_LOW')
    f2d.append(feature + '_NORMAL')
    f2d.append(feature + '_HIGH')
    #f2d.append(feature + '_VHIGH')
features = fs + f2d

datax = np.hstack((data1, data2d))
datay = data['day_28_flg'].values

x = torch.FloatTensor(datax)
y = torch.FloatTensor(datay)
print(x.shape)
print(y.shape)

torch.Size([1776, 90])
torch.Size([1776])


In [2]:
results_dir = 'results_ll/mimic'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# concepts = [f'c{i:03}' for i in range(x.shape[1])]
concepts =  features
n_rep = 10
tot_epochs = 5001
prune_epochs = 2001
seed = 42

n_splits = 10
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)

In [3]:
def train_nn(x_train, y_train, seed, device, l1=0.001, lr=0.001, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        dl.nn.XLogic(x_train.size(1), activation='identity', first=True),
        torch.nn.Linear(x_train.size(1), 50),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(50, 30),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(30, 1),
        torch.nn.LeakyReLU(),
        dl.nn.XLogic(1, activation='sigmoid', top=True),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_form = torch.nn.BCELoss()
    model.train()
    need_pruning = True
    for epoch in range(tot_epochs):
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, dl.nn.XLogic):
                loss += l1 * torch.norm(module.weight, 1)
                loss += l1 * torch.norm(module.bias, 1)

        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            dl.utils.layer.prune_logic_layers(model.to(device), fan_in=5, device=device)
            need_pruning = False

        # compute accuracy
        if epoch % 100 == 0 and verbose:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [4]:
def c_to_y(method, verbose=False):
    methods = []
    splits = []
    explanations = []
    model_accuracies = []
    explanation_accuracies = []
    explanation_fidelities = []
    explanation_complexities = []
    elapsed_times = []
    for split, (trainval_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
        print(f'Split [{split+1}/{n_splits}]')
        x_trainval, x_test = torch.FloatTensor(x[trainval_index]), torch.FloatTensor(x[test_index])
        y_trainval, y_test = torch.FloatTensor(y[trainval_index]), torch.FloatTensor(y[test_index])
        x_train, x_val, y_train, y_val = train_test_split(x_trainval, y_trainval, test_size=0.3, random_state=42)
    
        explanation, explanation_inv = '', ''
        explanation_accuracy, explanation_accuracy_inv = 0, 0
        
        if method == 'tree':
            classifier = DecisionTreeClassifier(random_state=seed)
            classifier.fit(x_train.detach().numpy(), y_train.detach().numpy())
            y_preds = classifier.predict(x_test.detach().numpy())
            model_accuracy = accuracy_score(y_test.detach().numpy(), y_preds)
            explanation_accuracy = model_accuracy

            target_class = 1
            start = time.time()
            explanation = tree_to_formula(classifier, concepts, target_class)
            elapsed_time = time.time() - start
            explanation_fidelity = 1.
            explanation_complexity = dl.logic.complexity(explanation)

            target_class_inv = 0
            start = time.time()
            explanation_inv = tree_to_formula(classifier, concepts, target_class_inv)
            elapsed_time = time.time() - start
        
        else:
            model = train_nn(x_trainval, y_trainval, seed, device, verbose=verbose)#=False)
            y_preds = model(x_test.to(device)).cpu().detach().numpy() > 0.5
            model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds)

            # positive class
            start = time.time()
            class_explanation, class_explanations = dl.logic.explain_class(model.cpu(), 
                                                                           x_trainval.cpu(), 
                                                                           y_trainval.cpu(), 
                                                                           binary=True, target_class=0,
                                                                           topk_explanations=4)
            elapsed_time = time.time() - start
            
            if class_explanation:
                explanation = logic.base.replace_names(class_explanation, concepts)
                explanation_accuracy, y_formula = logic.base.test_explanation(class_explanation, 
                                                                              target_class=0, 
                                                                              x=x_test, y=y_test,
                                                                              metric=accuracy_score)
                explanation_fidelity = dl.logic.fidelity(y_formula, 1-y_preds)
                explanation_complexity = dl.logic.complexity(class_explanation)
        
        if verbose:
            print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
            print(f'\t Class 1 - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
            print(f'\t Fidelity: "{explanation_fidelity:.4f}" - Complexity: "{explanation_complexity}"')
            print(f'\t Elapsed time {elapsed_time}')

        methods.append(method)
        splits.append(seed)
        explanations.append(explanation)
        model_accuracies.append(model_accuracy)
        explanation_accuracies.append(explanation_accuracy)
        explanation_fidelities.append(explanation_fidelity)
        explanation_complexities.append(explanation_complexity)
        elapsed_times.append(elapsed_time)
    
    explanation_consistency = dl.logic.formula_consistency(explanations)
    print(f'Consistency of explanations: {explanation_consistency:.4f}')
    
    results = pd.DataFrame({
        'method': methods,
        'split': splits,
        'explanation': explanations,
        'model_accuracy': model_accuracies,
        'explanation_accuracy': explanation_accuracies,
        'explanation_fidelity': explanation_fidelities,
        'explanation_complexity': explanation_complexities,
        'explanation_consistency': explanation_consistency,
        'elapsed_time': elapsed_times,
    })
    results.to_csv(os.path.join(results_dir, f'results_{method}.csv'))
    
    return results

# General pruning

In [5]:
results_pruning = c_to_y(method='logic_layer', verbose=True)
results_pruning

Split [1/10]
Epoch 0: train accuracy: 0.1589
Epoch 100: train accuracy: 0.8411
Epoch 200: train accuracy: 0.8411
Epoch 300: train accuracy: 0.8411
Epoch 400: train accuracy: 0.8467
Epoch 500: train accuracy: 0.8561
Epoch 600: train accuracy: 0.8586
Epoch 700: train accuracy: 0.8648
Epoch 800: train accuracy: 0.8730
Epoch 900: train accuracy: 0.8748
Epoch 1000: train accuracy: 0.8780
Epoch 1100: train accuracy: 0.8767
Epoch 1200: train accuracy: 0.8748
Epoch 1300: train accuracy: 0.8742
Epoch 1400: train accuracy: 0.8736
Epoch 1500: train accuracy: 0.8748
Epoch 1600: train accuracy: 0.8730
Epoch 1700: train accuracy: 0.8761
Epoch 1800: train accuracy: 0.8736
Epoch 1900: train accuracy: 0.8748
Epoch 2000: train accuracy: 0.8817
Epoch 2100: train accuracy: 0.8911
Epoch 2200: train accuracy: 0.8936
Epoch 2300: train accuracy: 0.8974
Epoch 2400: train accuracy: 0.8974
Epoch 2500: train accuracy: 0.8967
Epoch 2600: train accuracy: 0.8974
Epoch 2700: train accuracy: 0.8974
Epoch 2800: train a

Epoch 100: train accuracy: 0.8404
Epoch 200: train accuracy: 0.8404
Epoch 300: train accuracy: 0.8404
Epoch 400: train accuracy: 0.8461
Epoch 500: train accuracy: 0.8611
Epoch 600: train accuracy: 0.8686
Epoch 700: train accuracy: 0.8748
Epoch 800: train accuracy: 0.8817
Epoch 900: train accuracy: 0.8798
Epoch 1000: train accuracy: 0.8811
Epoch 1100: train accuracy: 0.8842
Epoch 1200: train accuracy: 0.8849
Epoch 1300: train accuracy: 0.8842
Epoch 1400: train accuracy: 0.8849
Epoch 1500: train accuracy: 0.8867
Epoch 1600: train accuracy: 0.8867
Epoch 1700: train accuracy: 0.8867
Epoch 1800: train accuracy: 0.8905
Epoch 1900: train accuracy: 0.8911
Epoch 2000: train accuracy: 0.8930
Epoch 2100: train accuracy: 0.8949
Epoch 2200: train accuracy: 0.8961
Epoch 2300: train accuracy: 0.8986
Epoch 2400: train accuracy: 0.9043
Epoch 2500: train accuracy: 0.9086
Epoch 2600: train accuracy: 0.9155
Epoch 2700: train accuracy: 0.9199
Epoch 2800: train accuracy: 0.9205
Epoch 2900: train accuracy: 0

Epoch 100: train accuracy: 0.8405
Epoch 200: train accuracy: 0.8405
Epoch 300: train accuracy: 0.8405
Epoch 400: train accuracy: 0.8443
Epoch 500: train accuracy: 0.8562
Epoch 600: train accuracy: 0.8618
Epoch 700: train accuracy: 0.8730
Epoch 800: train accuracy: 0.8762
Epoch 900: train accuracy: 0.8774
Epoch 1000: train accuracy: 0.8787
Epoch 1100: train accuracy: 0.8831
Epoch 1200: train accuracy: 0.8856
Epoch 1300: train accuracy: 0.8868
Epoch 1400: train accuracy: 0.8893
Epoch 1500: train accuracy: 0.8887
Epoch 1600: train accuracy: 0.8924
Epoch 1700: train accuracy: 0.8906
Epoch 1800: train accuracy: 0.8924
Epoch 1900: train accuracy: 0.8943
Epoch 2000: train accuracy: 0.8949
Epoch 2100: train accuracy: 0.8962
Epoch 2200: train accuracy: 0.8974
Epoch 2300: train accuracy: 0.8993
Epoch 2400: train accuracy: 0.9006
Epoch 2500: train accuracy: 0.8999
Epoch 2600: train accuracy: 0.9012
Epoch 2700: train accuracy: 0.9018
Epoch 2800: train accuracy: 0.9043
Epoch 2900: train accuracy: 0

Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,logic_layer,42,(bun_first_LOW & hr_1st_NORMAL & ~chloride_fir...,0.865169,0.808989,0.898876,8,1.0,2.818158
1,logic_layer,42,(bun_first_LOW & hr_1st_NORMAL & ~chloride_fir...,0.831461,0.820225,0.831461,8,1.0,2.718221
2,logic_layer,42,(bun_first_LOW & hr_1st_NORMAL & ~chloride_fir...,0.842697,0.752809,0.842697,8,1.0,2.979811
3,logic_layer,42,(bun_first_LOW & hr_1st_NORMAL & ~chloride_fir...,0.865169,0.808989,0.831461,8,1.0,2.742536
4,logic_layer,42,(bun_first_LOW & hr_1st_NORMAL & ~chloride_fir...,0.825843,0.808989,0.814607,8,1.0,2.926845
5,logic_layer,42,(bun_first_LOW & hr_1st_NORMAL & ~chloride_fir...,0.859551,0.775281,0.769663,8,1.0,2.97202
6,logic_layer,42,(bun_first_LOW & hr_1st_NORMAL & ~chloride_fir...,0.819209,0.785311,0.79661,8,1.0,2.585062
7,logic_layer,42,(bun_first_LOW & hr_1st_NORMAL & ~chloride_fir...,0.847458,0.813559,0.774011,8,1.0,3.086829
8,logic_layer,42,(bun_first_LOW & hr_1st_NORMAL & ~chloride_fir...,0.836158,0.774011,0.870056,8,1.0,2.752565
9,logic_layer,42,(bun_first_LOW & hr_1st_NORMAL & ~chloride_fir...,0.80791,0.80226,0.824859,8,1.0,2.730641


In [6]:
results_pruning = c_to_y(method='tree', verbose=False)
results_pruning

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]
Consistency of explanations: 0.7510


Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,explanation_consistency,elapsed_time
0,tree,42,(age_HIGH <= 0.50 & resp_flg <= 0.50 & stroke_...,0.752809,0.752809,1.0,838,0.750955,0.002992
1,tree,42,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & copd_...,0.775281,0.775281,1.0,994,0.750955,0.002952
2,tree,42,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & resp_...,0.808989,0.808989,1.0,993,0.750955,0.00297
3,tree,42,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & sapsi...,0.758427,0.758427,1.0,1002,0.750955,0.001994
4,tree,42,(age_HIGH <= 0.50 & sapsi_first_HIGH <= 0.50 &...,0.735955,0.735955,1.0,927,0.750955,0.002992
5,tree,42,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & mal_f...,0.792135,0.792135,1.0,903,0.750955,0.00399
6,tree,42,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & mal_f...,0.774011,0.774011,1.0,864,0.750955,0.003989
7,tree,42,(stroke_flg <= 0.50 & age_HIGH <= 0.50 & copd_...,0.79096,0.79096,1.0,987,0.750955,0.002917
8,tree,42,(age_HIGH <= 0.50 & stroke_flg <= 0.50 & resp_...,0.779661,0.779661,1.0,811,0.750955,0.001996
9,tree,42,(age_HIGH <= 0.50 & age_NORMAL <= 0.50 & chlor...,0.785311,0.785311,1.0,898,0.750955,0.001995
