In [68]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sympy import simplify_logic
import time
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.tree import _tree, export_text
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from sklearn.model_selection import StratifiedKFold

from deep_logic.utils.base import validate_network, set_seed, tree_to_formula
from deep_logic.utils.relunn import get_reduced_model, prune_features
from deep_logic.utils.sigmoidnn import prune_equal_fanin
from deep_logic import logic

results_dir = 'results/celldiff'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
n_rep = 10
tot_epochs = 15001
prune_epochs = 5000

In [90]:
gene_expression_matrix = pd.read_csv('./data/celldiff/data_matrix.csv', index_col=0)
clustering_labels = pd.read_csv('./data/celldiff/cluster_labels.csv', index_col=0)
biomarkers = pd.read_csv('./data/celldiff/markers.csv', index_col=0)

task_id = 2
labels = (clustering_labels.values==task_id)

scaler = MinMaxScaler((0, 1))
scaler.fit(gene_expression_matrix.values)
data_normalized = scaler.transform(gene_expression_matrix.values)

x = torch.FloatTensor(data_normalized)
y = torch.FloatTensor(labels).to(torch.long).squeeze()
print(x.shape)
print(y.shape)

torch.Size([1422, 1845])
torch.Size([1422])


In [3]:
concepts = list(gene_expression_matrix.columns)
concepts[:10]

['alas2',
 'si:xx-by187g17.1',
 'si:ch211-5k11.6',
 'tspan36',
 'ba1l',
 'mmp9',
 'il21r.1',
 'plekhf1',
 'slc3a2a',
 'si:dkey-222f2.7']

In [4]:
n_splits = 10
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)

In [108]:
def train_nn(x_train, y_train, need_pruning, seed, device):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        torch.nn.Linear(x_train.size(1), 200),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(200, 80),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(80, 30),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(30, 2),
        torch.nn.Softmax(dim=1),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
    loss_form = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train)
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.000001 * torch.norm(module.weight, 1)
                loss += 0.000001 * torch.norm(module.bias, 1)
                break

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning and epoch % 1000 == 1:
            prune_features(model, n_classes=1, device=device)
            need_pruning = True
            
        # compute accuracy
        if epoch % 500 == 0:
            y_pred_d = torch.argmax(y_pred, dim=1)
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

# General pruning

In [None]:
need_pruning = True
method = 'pruning'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []

for split, (train_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
    print(f'Split [{split+1}/{n_splits}]')
    x_train, x_test = torch.FloatTensor(x[train_index]), torch.FloatTensor(x[test_index])
    y_train, y_test = torch.LongTensor(y[train_index]), torch.LongTensor(y[test_index])
    
#     if split not in [5]: continue
    
    model = train_nn(x_train, y_train, need_pruning, split, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    # positive class
    target_class = 1
    start = time.time()
    global_explanation, _, counter = logic.relunn.combine_local_explanations(model, 
                                                                       x_train.to(device), y_train.to(device), 
                                                                       target_class=target_class,
                                                                       topk_explanations=1,
                                                                       method=method, device=device)
    elapsed_time = time.time() - start
    if global_explanation:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    start = time.time()
    global_explanation_inv, _, counter_inv = logic.relunn.combine_local_explanations(model, 
                                                                           x_train.to(device), y_train.to(device), 
                                                                           target_class=target_class,
                                                                           topk_explanations=1,
                                                                           method=method, device=device)
    elapsed_time_inv = time.time() - start
    if global_explanation_inv:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(split)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Split [1/10]
	 Epoch 0: train accuracy: 0.1665
	 Epoch 500: train accuracy: 0.8327
	 Epoch 1000: train accuracy: 0.9656
	 Epoch 1500: train accuracy: 0.9914
	 Epoch 2000: train accuracy: 0.9977
	 Epoch 2500: train accuracy: 0.9984
	 Epoch 3000: train accuracy: 0.9992
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Epoch 4500: train accuracy: 1.0000
	 Epoch 5000: train accuracy: 1.0000
	 Epoch 5500: train accuracy: 1.0000
	 Epoch 6000: train accuracy: 1.0000
	 Epoch 6500: train accuracy: 1.0000
	 Epoch 7000: train accuracy: 1.0000
	 Epoch 7500: train accuracy: 1.0000
	 Epoch 8000: train accuracy: 1.0000
	 Epoch 8500: train accuracy: 1.0000
	 Epoch 9000: train accuracy: 1.0000
	 Epoch 9500: train accuracy: 1.0000
	 Epoch 10000: train accuracy: 1.0000
	 Epoch 10500: train accuracy: 1.0000
	 Epoch 11000: train accuracy: 1.0000
	 Epoch 11500: train accuracy: 1.0000
	 Epoch 12000: train accuracy: 1.0000
	 Epoch 12500: train accuracy: 1.0000
	 Epoch 13000: train ac

In [None]:
results_pruning = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_pruning.to_csv(os.path.join(results_dir, 'results_pruning.csv'))
results_pruning

# LIME

In [None]:
# need_pruning = False
# method = 'lime'
# methods = []
# splits = []
# explanations = []
# explanations_inv = []
# model_accuracies = []
# explanation_accuracies = []
# explanation_accuracies_inv = []
# elapsed_times = []
# elapsed_times_inv = []
# for seed in range(n_rep):
#     print(f'Seed [{seed+1}/{n_rep}]')
    
#     model = train_nn(x_train, y_train, need_pruning, seed, device)
    
#     y_preds = model(x_test.to(device)).cpu().detach().numpy()
#     model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))
#     print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
#     # positive class
#     target_class = 1
#     start = time.time()
#     global_explanation, _, _ = logic.relunn.combine_local_explanations(model, 
#                                                                        x_train.to(device), y_train.to(device), 
#                                                                        target_class=target_class,
#                                                                        method=method, device=device)
#     elapsed_time = time.time() - start
#     if global_explanation:
#         explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
#         explanation = logic.base.replace_names(global_explanation, concepts)
#     print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
#     print(f'\t Elapsed time {elapsed_time}')
        
#     # negative class
#     target_class = 0
#     start = time.time()
#     global_explanation_inv, _, _ = logic.relunn.combine_local_explanations(model, 
#                                                                            x_train.to(device), y_train.to(device), 
#                                                                            target_class=target_class,
#                                                                            method=method, device=device)
#     elapsed_time_inv = time.time() - start
#     if global_explanation_inv:
#         explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, target_class, x_test, y_test)
#         explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
#     print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
#     print(f'\t Elapsed time {elapsed_time_inv}')
    
#     methods.append(method)
#     splits.append(seed)
#     explanations.append(explanation)
#     explanations_inv.append(explanation_inv)
#     model_accuracies.append(model_accuracy)
#     explanation_accuracies.append(explanation_accuracy)
#     explanation_accuracies_inv.append(explanation_accuracy_inv)
#     elapsed_times.append(elapsed_time)
#     elapsed_times_inv.append(elapsed_time_inv)

In [None]:
# results_lime = pd.DataFrame({
#     'method': methods,
#     'split': splits,
#     'explanation': explanations,
#     'explanation_inv': explanations_inv,
#     'model_accuracy': model_accuracies,
#     'explanation_accuracy': explanation_accuracies,
#     'explanation_accuracy_inv': explanation_accuracies_inv,
#     'elapsed_time': elapsed_times,
#     'elapsed_time_inv': elapsed_times_inv,
# })
# results_lime.to_csv(os.path.join(results_dir, 'results_lime.csv'))
# results_lime

# Weights

In [93]:
need_pruning = False
method = 'weights'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []

for split, (train_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
    print(f'Split [{split+1}/{n_splits}]')
    x_train, x_test = torch.FloatTensor(x[train_index]), torch.FloatTensor(x[test_index])
    y_train, y_test = torch.LongTensor(y[train_index]), torch.LongTensor(y[test_index])
    
    model = train_nn(x_train, y_train, need_pruning, split, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    # positive class
    target_class = 1
    start = time.time()
    global_explanation, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                       x_train.to(device), y_train.to(device),
                                                                       topk_explanations=1, 
                                                                       target_class=target_class,
                                                                       method=method, device=device)
    elapsed_time = time.time() - start
    if global_explanation:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    start = time.time()
    global_explanation_inv, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                           x_train.to(device), y_train.to(device), 
                                                                           topk_explanations=1, 
                                                                           target_class=target_class,
                                                                           method=method, device=device)
    elapsed_time_inv = time.time() - start
    if global_explanation_inv:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(split)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Split [1/10]
	 Epoch 0: train accuracy: 0.1665
	 Epoch 500: train accuracy: 0.8327
	 Epoch 1000: train accuracy: 0.9586
	 Epoch 1500: train accuracy: 0.9898
	 Epoch 2000: train accuracy: 0.9969
	 Epoch 2500: train accuracy: 0.9977
	 Epoch 3000: train accuracy: 0.9977
	 Epoch 3500: train accuracy: 0.9984
	 Epoch 4000: train accuracy: 0.9984
	 Epoch 4500: train accuracy: 0.9984
	 Epoch 5000: train accuracy: 0.9984
	 Epoch 5500: train accuracy: 0.9984
	 Epoch 6000: train accuracy: 0.9984
	 Epoch 6500: train accuracy: 0.9984
	 Epoch 7000: train accuracy: 0.9984
	 Epoch 7500: train accuracy: 0.9984
	 Epoch 8000: train accuracy: 0.9984
	 Epoch 8500: train accuracy: 0.9984
	 Epoch 9000: train accuracy: 0.9984
	 Epoch 9500: train accuracy: 0.9984
	 Epoch 10000: train accuracy: 0.9984
	 Epoch 10500: train accuracy: 0.9992
	 Epoch 11000: train accuracy: 0.9992
	 Epoch 11500: train accuracy: 0.9992
	 Epoch 12000: train accuracy: 0.9992
	 Epoch 12500: train accuracy: 0.9992
	 Epoch 13000: train ac

feature0000000567 & ~feature0000000125 & ~feature0000000230 & ~feature0000000315 & ~feature0000000710 & ~feature0000001476 & ~feature0000001740
	 Class 0 - Global explanation: "otud4.4 & ~si:ch211-103n10.5 & ~im:7151449 & ~oraov1 & ~anxa11a & ~clmn & ~kdelc2" - Accuracy: 0.4014
	 Elapsed time 10.155855894088745
Split [6/10]
	 Epoch 0: train accuracy: 0.8336
	 Epoch 500: train accuracy: 0.8336
	 Epoch 1000: train accuracy: 0.9617
	 Epoch 1500: train accuracy: 0.9930
	 Epoch 2000: train accuracy: 0.9977
	 Epoch 2500: train accuracy: 0.9984
	 Epoch 3000: train accuracy: 0.9992
	 Epoch 3500: train accuracy: 0.9992
	 Epoch 4000: train accuracy: 0.9992
	 Epoch 4500: train accuracy: 0.9992
	 Epoch 5000: train accuracy: 0.9992
	 Epoch 5500: train accuracy: 0.9992
	 Epoch 6000: train accuracy: 0.9992
	 Epoch 6500: train accuracy: 0.9992
	 Epoch 7000: train accuracy: 0.9992
	 Epoch 7500: train accuracy: 0.9992
	 Epoch 8000: train accuracy: 0.9992
	 Epoch 8500: train accuracy: 0.9992
	 Epoch 9000

	 Epoch 500: train accuracy: 0.8328
	 Epoch 1000: train accuracy: 0.9594
	 Epoch 1500: train accuracy: 0.9938
	 Epoch 2000: train accuracy: 0.9984
	 Epoch 2500: train accuracy: 0.9984
	 Epoch 3000: train accuracy: 0.9992
	 Epoch 3500: train accuracy: 0.9992
	 Epoch 4000: train accuracy: 1.0000
	 Epoch 4500: train accuracy: 1.0000
	 Epoch 5000: train accuracy: 1.0000
	 Epoch 5500: train accuracy: 1.0000
	 Epoch 6000: train accuracy: 1.0000
	 Epoch 6500: train accuracy: 1.0000
	 Epoch 7000: train accuracy: 1.0000
	 Epoch 7500: train accuracy: 1.0000
	 Epoch 8000: train accuracy: 1.0000
	 Epoch 8500: train accuracy: 1.0000
	 Epoch 9000: train accuracy: 1.0000
	 Epoch 9500: train accuracy: 1.0000
	 Epoch 10000: train accuracy: 1.0000
	 Epoch 10500: train accuracy: 1.0000
	 Epoch 11000: train accuracy: 1.0000
	 Epoch 11500: train accuracy: 1.0000
	 Epoch 12000: train accuracy: 1.0000
	 Epoch 12500: train accuracy: 1.0000
	 Epoch 13000: train accuracy: 1.0000
	 Epoch 13500: train accuracy: 1

In [94]:
results_weights = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_weights.to_csv(os.path.join(results_dir, 'results_weights.csv'))
results_weights

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,weights,0,pcxa & nr2c2 & ~si:ch211-103n10.5 & ~im:715144...,otud4.4 & ~pcxa & ~clmn & ~kdelc2,0.958042,0.776224,0.251748,2.684793,14.119261
1,weights,1,pcxa & ~si:ch211-103n10.5 & ~kdelc2,pcxa & ~si:ch211-103n10.5 & ~kdelc2,0.944056,0.832168,0.167832,2.138284,9.850016
2,weights,2,pcxa & ~si:ch211-103n10.5 & ~im:7151449 & ~otu...,~si:ch211-103n10.5 & ~im:7151449 & ~otud4.4 & ...,0.964789,0.823944,0.542254,2.224087,10.789901
3,weights,3,si:ch211-132g1.1 & pcxa & ~CD37 & ~si:ch211-10...,eif2ak1 & dpysl5b & SLC16A3 & asb12a & kat6b &...,0.971831,0.901408,0.323944,2.362684,12.218341
4,weights,4,si:ch211-132g1.1 & pcxa & nr2c2 & ~si:ch211-10...,otud4.4 & ~si:ch211-103n10.5 & ~im:7151449 & ~...,0.985915,0.887324,0.401408,2.850381,10.155856
5,weights,5,si:ch211-132g1.1 & pcxa & nr2c2 & ~si:ch211-13...,oraov1 & dpysl5b & asb12a & kat6b & ~CD37 & ~s...,0.93662,0.866197,0.323944,3.21544,14.723578
6,weights,6,pcxa & ~si:ch211-103n10.5 & ~im:7151449 & ~clm...,~si:ch211-103n10.5 & ~im:7151449 & ~clmn & ~kd...,0.971831,0.795775,0.387324,1.922905,9.374942
7,weights,7,si:ch211-132g1.1 & pcxa & nr2c2 & kat6b & ~si:...,dpysl5b & SLC16A3 & asb12a & kat6b & ~si:ch211...,0.957746,0.908451,0.65493,2.97701,15.136543
8,weights,8,si:ch211-132g1.1 & pcxa & nr2c2 & ~si:ch211-10...,smim14 & otud4.4 & pcxa & nr2c2 & ~si:ch211-13...,0.943662,0.887324,0.176056,3.250313,23.092312
9,weights,9,pcxa & ~si:ch211-103n10.5 & ~im:7151449 & ~otu...,asb12a & kat6b & kdelc2 & ~CD37 & ~si:ch211-10...,0.964789,0.830986,0.323944,2.021565,12.018875


# Psi network

In [95]:
def train_psi_nn(x_train, y_train, need_pruning, seed, device):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device).to(torch.float)
    layers = [
        torch.nn.Linear(x_train.size(1), 10),
        torch.nn.Sigmoid(),
        torch.nn.Linear(10, 4),
        torch.nn.Sigmoid(),
        torch.nn.Linear(4, 1),
        torch.nn.Sigmoid(),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.00001 * torch.norm(module.weight, 1)

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            model = prune_equal_fanin(model, 2, validate=True, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [100]:
need_pruning = True
method = 'psi'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []

for split, (train_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
    print(f'Split [{split+1}/{n_splits}]')
    x_train, x_test = torch.FloatTensor(x[train_index]), torch.FloatTensor(x[test_index])
    y_train, y_test = torch.LongTensor(y[train_index]), torch.LongTensor(y[test_index])
    
    # positive class
    target_class = 1
    model = train_psi_nn(x_train, y_train, need_pruning, split, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds > 0.5)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    start = time.time()
    global_explanation = logic.generate_fol_explanations(model, device)[0]
    elapsed_time = time.time() - start
    if global_explanation not in ['(False)', '(True)', '']:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
        print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
        print(f'\t Elapsed time {elapsed_time}')
    else:
        explanation_accuracy = 0
        explanation = global_explanation
        
    # negative class
    target_class = 0
    model = train_psi_nn(x_train, y_train.eq(target_class), need_pruning, split, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.eq(target_class).cpu().detach().numpy(), y_preds > 0.5)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    start = time.time()
    global_explanation_inv = logic.generate_fol_explanations(model, device)[0]
    elapsed_time_inv = time.time() - start
    if global_explanation_inv not in ['(False)', '(True)', '']:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, 
                                                                  target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
        print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
        print(f'\t Elapsed time {elapsed_time_inv}')
    else:
        explanation_accuracy_inv = 0
        explanation_inv = global_explanation_inv
    
    methods.append(method)
    splits.append(split)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Split [1/10]
	 Epoch 0: train accuracy: 0.8335
	 Epoch 500: train accuracy: 0.8335
	 Epoch 1000: train accuracy: 0.9992
	 Epoch 1500: train accuracy: 0.9992
	 Epoch 2000: train accuracy: 0.9992
	 Epoch 2500: train accuracy: 1.0000
	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Epoch 4500: train accuracy: 1.0000
	 Epoch 5000: train accuracy: 1.0000
	 Epoch 5500: train accuracy: 0.8335
	 Epoch 6000: train accuracy: 0.8335
	 Epoch 6500: train accuracy: 0.8335
	 Epoch 7000: train accuracy: 0.8335
	 Epoch 7500: train accuracy: 0.8335
	 Epoch 8000: train accuracy: 0.8335
	 Epoch 8500: train accuracy: 0.8335
	 Epoch 9000: train accuracy: 0.8335
	 Epoch 9500: train accuracy: 0.8335
	 Epoch 10000: train accuracy: 0.8335
	 Epoch 10500: train accuracy: 0.8335
	 Epoch 11000: train accuracy: 0.8335
	 Epoch 11500: train accuracy: 0.8335
	 Epoch 12000: train accuracy: 0.8335
	 Epoch 12500: train accuracy: 0.8335
	 Epoch 13000: train ac

	 Epoch 6000: train accuracy: 0.8438
	 Epoch 6500: train accuracy: 0.8445
	 Epoch 7000: train accuracy: 0.8445
	 Epoch 7500: train accuracy: 0.8469
	 Epoch 8000: train accuracy: 0.8477
	 Epoch 8500: train accuracy: 0.8469
	 Epoch 9000: train accuracy: 0.8477
	 Epoch 9500: train accuracy: 0.8477
	 Epoch 10000: train accuracy: 0.8469
	 Epoch 10500: train accuracy: 0.8469
	 Epoch 11000: train accuracy: 0.8469
	 Epoch 11500: train accuracy: 0.8469
	 Epoch 12000: train accuracy: 0.8469
	 Epoch 12500: train accuracy: 0.8469
	 Epoch 13000: train accuracy: 0.8469
	 Epoch 13500: train accuracy: 0.8469
	 Epoch 14000: train accuracy: 0.8469
	 Epoch 14500: train accuracy: 0.8469
	 Epoch 15000: train accuracy: 0.8469
	 Model's accuracy: 0.8310
	 Class 1 - Global explanation: "(clmn)" - Accuracy: 0.8310
	 Elapsed time 0.010969400405883789
	 Epoch 0: train accuracy: 0.8336
	 Epoch 500: train accuracy: 0.8336
	 Epoch 1000: train accuracy: 0.9922
	 Epoch 1500: train accuracy: 0.9992
	 Epoch 2000: train

	 Epoch 14500: train accuracy: 0.8484
	 Epoch 15000: train accuracy: 0.8477
	 Model's accuracy: 0.8310
	 Class 1 - Global explanation: "(clmn | (im:7151449 & ~si:ch211-103n10.5))" - Accuracy: 0.8451
	 Elapsed time 0.02296733856201172
	 Epoch 0: train accuracy: 0.1664
	 Epoch 500: train accuracy: 0.8336
	 Epoch 1000: train accuracy: 0.9820
	 Epoch 1500: train accuracy: 0.9992
	 Epoch 2000: train accuracy: 0.9992
	 Epoch 2500: train accuracy: 0.9992
	 Epoch 3000: train accuracy: 0.9992
	 Epoch 3500: train accuracy: 0.9992
	 Epoch 4000: train accuracy: 0.9992
	 Epoch 4500: train accuracy: 0.9992
	 Epoch 5000: train accuracy: 1.0000
	 Epoch 5500: train accuracy: 0.8430
	 Epoch 6000: train accuracy: 0.8438
	 Epoch 6500: train accuracy: 0.8453
	 Epoch 7000: train accuracy: 0.8477
	 Epoch 7500: train accuracy: 0.8484
	 Epoch 8000: train accuracy: 0.8484
	 Epoch 8500: train accuracy: 0.8477
	 Epoch 9000: train accuracy: 0.8477
	 Epoch 9500: train accuracy: 0.8477
	 Epoch 10000: train accuracy:

	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Epoch 4500: train accuracy: 1.0000
	 Epoch 5000: train accuracy: 1.0000
	 Epoch 5500: train accuracy: 0.8383
	 Epoch 6000: train accuracy: 0.8328
	 Epoch 6500: train accuracy: 0.8328
	 Epoch 7000: train accuracy: 0.8414
	 Epoch 7500: train accuracy: 0.8445
	 Epoch 8000: train accuracy: 0.8453
	 Epoch 8500: train accuracy: 0.8469
	 Epoch 9000: train accuracy: 0.8477
	 Epoch 9500: train accuracy: 0.8469
	 Epoch 10000: train accuracy: 0.8469
	 Epoch 10500: train accuracy: 0.8469
	 Epoch 11000: train accuracy: 0.8469
	 Epoch 11500: train accuracy: 0.8469
	 Epoch 12000: train accuracy: 0.8469
	 Epoch 12500: train accuracy: 0.8469
	 Epoch 13000: train accuracy: 0.8469
	 Epoch 13500: train accuracy: 0.8469
	 Epoch 14000: train accuracy: 0.8469
	 Epoch 14500: train accuracy: 0.8469
	 Epoch 15000: train accuracy: 0.8469
	 Model's accuracy: 0.8239
	 Class 0 - Global explanation: "(~cl

In [101]:
results_psi = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_psi.to_csv(os.path.join(results_dir, 'results_psi.csv'))
results_psi

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,psi,0,(False),(~clmn & (si:ch211-103n10.5 | oraov1)),0.839161,0.0,0.531469,0.0288,0.029887
1,psi,1,(im:7151449 | (~si:ch211-103n10.5 & ~anxa11a)),(si:ch211-103n10.5),0.825175,0.321678,0.300699,0.021939,0.015957
2,psi,2,(im:7151449 & ~si:ch211-103n10.5),(si:ch211-103n10.5 | ~im:7151449),0.816901,0.816901,0.816901,0.01396,0.016954
3,psi,3,(clmn),(True),0.830986,0.830986,0.0,0.010969,0.046916
4,psi,4,(im:7151449 | clmn | ~si:ch211-103n10.5),(si:ch211-103n10.5),0.830986,0.28169,0.309859,0.026442,0.009973
5,psi,5,(False),(si:ch211-103n10.5),0.830986,0.0,0.295775,0.030951,0.022939
6,psi,6,(clmn | (im:7151449 & ~si:ch211-103n10.5)),(~clmn),0.816901,0.84507,0.830986,0.022967,0.011966
7,psi,7,(clmn | (im:7151449 & ~si:ch211-103n10.5)),(si:ch211-103n10.5),0.84507,0.84507,0.309859,0.025897,0.019982
8,psi,8,(pcxa | clmn),(~clmn),0.84507,0.732394,0.838028,0.016988,0.017548
9,psi,9,(clmn),(~clmn),0.823944,0.823944,0.823944,0.013999,0.027928


# Decision tree

In [91]:
need_pruning = False
method = 'decision_tree'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []

for split, (train_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
    print(f'Split [{split+1}/{n_splits}]')
    x_train, x_test = x[train_index], x[test_index]
    y_train, y_test = y[train_index], y[test_index]
    
    classifier = DecisionTreeClassifier(random_state=split)
    classifier.fit(x_train.cpu().detach().numpy(), y_train.cpu().detach().numpy())
    y_preds = classifier.predict(x_test.cpu().detach().numpy())
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    target_class = 1
    start = time.time()
    explanation = tree_to_formula(classifier, concepts, target_class)
    elapsed_time = time.time() - start
    print(f'\t Class {target_class} - Global explanation: {explanation}')
    print(f'\t Elapsed time {elapsed_time}')
    
    target_class = 0
    start = time.time()
    explanation_inv = tree_to_formula(classifier, concepts, target_class)
    elapsed_time_inv = time.time() - start
    print(f'\t Class {target_class} - Global explanation: {explanation_inv}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(split)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(model_accuracy)
    explanation_accuracies_inv.append(model_accuracy)
    elapsed_times.append(0)
    elapsed_times_inv.append(0)

Split [1/10]
	 Model's accuracy: 0.8881
	 Class 1 - Global explanation: (nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0.56 & TMC8 <= 0.79 & kdm8 <= 0.90 & ypel3 <= 0.92 & amd1 <= 0.77 & TMC8 <= 0.70 & slc25a43 <= 0.80 & adcy2b <= 0.99 & fbxo44 <= 0.91 & ints5 <= 0.90 & ap3m1 <= 0.96 & si:dkey-83k24.5 <= 0.92 & si:ch73-236c18.6 <= 0.95 & yme1l1a <= 0.96 & vps26b <= 0.89 & abca2 <= 0.86 & dusp1 > 0.78 & sdr16c5a > 0.27) | (nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0.56 & TMC8 <= 0.79 & kdm8 <= 0.90 & ypel3 <= 0.92 & amd1 <= 0.77 & TMC8 <= 0.70 & slc25a43 <= 0.80 & adcy2b <= 0.99 & fbxo44 <= 0.91 & ints5 <= 0.90 & ap3m1 <= 0.96 & si:dkey-83k24.5 <= 0.92 & si:ch73-236c18.6 <= 0.95 & yme1l1a <= 0.96 & vps26b <= 0.89 & abca2 > 0.86 & ftr41 <= 0.46) | (nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0.56 & TMC8 <= 0.79 & kdm8 <= 0.90 & ypel3 <= 0.92 & amd1 <= 0.77 & TMC8 <= 0.70 & slc25a43 <= 0.80 & adcy2b <= 0.99 & fbxo44 <= 0.91 & ints5 <= 0.90 & ap3m1 <= 0.96 & si:dkey-83k24.5 <= 0.92 & si:ch73-

	 Model's accuracy: 0.9155
	 Class 1 - Global explanation: (nr2c2 <= 0.76 & zdhhc5a <= 0.53 & top2a <= 0.35 & kdm8 <= 0.90 & TMC8 <= 0.70 & slc25a43 <= 0.80 & atp1b2b <= 0.64 & si:dkey-83k24.5 <= 0.90 & ypel3 <= 0.92 & RFX7 <= 0.91 & ginm1 <= 0.93 & zgc:162396 <= 0.89 & snx5 <= 0.99 & perp <= 0.83 & gfi1ab <= 1.00 & ctsh <= 0.82 & SVIP <= 0.99 & si:ch211-14a17.10 <= 1.00 & HHEX <= 0.48 & vps26b <= 0.89 & srsf6b > 0.81 & si:dkey-56e3.3 <= 0.26) | (nr2c2 <= 0.76 & zdhhc5a <= 0.53 & top2a <= 0.35 & kdm8 <= 0.90 & TMC8 <= 0.70 & slc25a43 <= 0.80 & atp1b2b <= 0.64 & si:dkey-83k24.5 <= 0.90 & ypel3 <= 0.92 & RFX7 <= 0.91 & ginm1 <= 0.93 & zgc:162396 <= 0.89 & snx5 <= 0.99 & perp <= 0.83 & gfi1ab <= 1.00 & ctsh <= 0.82 & SVIP <= 0.99 & si:ch211-14a17.10 <= 1.00 & HHEX <= 0.48 & vps26b > 0.89 & ppp1r13bb > 0.27) | (nr2c2 <= 0.76 & zdhhc5a <= 0.53 & top2a <= 0.35 & kdm8 <= 0.90 & TMC8 <= 0.70 & slc25a43 <= 0.80 & atp1b2b <= 0.64 & si:dkey-83k24.5 <= 0.90 & ypel3 <= 0.92 & RFX7 <= 0.91 & ginm1 <

	 Model's accuracy: 0.9085
	 Class 1 - Global explanation: (nr2c2 <= 0.78 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.70 & kdm8 <= 0.84 & zgc:66475 <= 0.71 & si:dkey-83k24.5 <= 0.90 & flot2a <= 0.81 & abca2 <= 0.86 & gfi1ab <= 1.00 & RFX7 <= 0.91 & nid2a <= 0.98 & sytl3 <= 0.94 & yme1l1a <= 0.96 & ctsh <= 0.82 & cxcr4a <= 0.89 & slc25a43 <= 0.41 & vps26b > 0.86 & ppp1r3cb > 0.24) | (nr2c2 <= 0.78 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.70 & kdm8 <= 0.84 & zgc:66475 <= 0.71 & si:dkey-83k24.5 <= 0.90 & flot2a <= 0.81 & abca2 <= 0.86 & gfi1ab <= 1.00 & RFX7 <= 0.91 & nid2a <= 0.98 & sytl3 <= 0.94 & yme1l1a <= 0.96 & ctsh <= 0.82 & cxcr4a <= 0.89 & slc25a43 > 0.41 & tex15 > 0.77) | (nr2c2 <= 0.78 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.70 & kdm8 <= 0.84 & zgc:66475 <= 0.71 & si:dkey-83k24.5 <= 0.90 & flot2a <= 0.81 & abca2 <= 0.86 & gfi1ab <= 1.00 & RFX7 <= 0.91 & nid2a <= 0.98 & sytl3 <= 0.94 & yme1l1a <= 0.96 & ctsh <= 0.82 & cxcr4a > 0.89) | (nr2c2 <= 0.78 & top2a <= 0.32 

	 Model's accuracy: 0.9225
	 Class 1 - Global explanation: (nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.70 & slc25a43 <= 0.80 & kdm8 <= 0.90 & zgc:66475 <= 0.71 & si:dkey-83k24.5 <= 0.90 & abca2 <= 0.86 & fbxo44 <= 0.91 & stx5a <= 0.92 & RFX7 <= 0.90 & rps6ka3a <= 0.93 & si:dkey-109a10.2 <= 0.95 & si:ch73-248e21.5 <= 0.98 & HHEX <= 0.48 & pogza <= 0.49 & psme4b > 0.91 & si:dkey-77g12.4 > 0.33) | (nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.70 & slc25a43 <= 0.80 & kdm8 <= 0.90 & zgc:66475 <= 0.71 & si:dkey-83k24.5 <= 0.90 & abca2 <= 0.86 & fbxo44 <= 0.91 & stx5a <= 0.92 & RFX7 <= 0.90 & rps6ka3a <= 0.93 & si:dkey-109a10.2 <= 0.95 & si:ch73-248e21.5 <= 0.98 & HHEX <= 0.48 & pogza > 0.49 & ints5 <= 0.18) | (nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.70 & slc25a43 <= 0.80 & kdm8 <= 0.90 & zgc:66475 <= 0.71 & si:dkey-83k24.5 <= 0.90 & abca2 <= 0.86 & fbxo44 <= 0.91 & stx5a <= 0.92 & RFX7 <= 0.90 & rps6ka3a <= 0.93 & si:dkey-109a10.2 <= 0.95 &

	 Model's accuracy: 0.9155
	 Class 1 - Global explanation: (nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.79 & kdm8 <= 0.90 & zgc:66475 <= 0.71 & qrsl1 <= 0.77 & HHEX <= 0.32 & vps72 <= 0.85 & ypel3 <= 0.92 & asb12a <= 0.17) | (nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.79 & kdm8 <= 0.90 & zgc:66475 <= 0.71 & qrsl1 <= 0.77 & HHEX <= 0.32 & vps72 <= 0.85 & ypel3 <= 0.92 & asb12a > 0.17 & rps6ka3a <= 0.93 & pfdn2 <= 0.98 & si:ch211-265g21.1 <= 0.97 & perp <= 0.83 & aqp1a.1 <= 0.84 & illr1 <= 0.84 & pogza <= 0.49 & si:ch73-236c18.6 <= 0.94 & srsf6b <= 0.81 & TMC8 > 0.72 & brms1la > 0.33) | (nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.79 & kdm8 <= 0.90 & zgc:66475 <= 0.71 & qrsl1 <= 0.77 & HHEX <= 0.32 & vps72 <= 0.85 & ypel3 <= 0.92 & asb12a > 0.17 & rps6ka3a <= 0.93 & pfdn2 <= 0.98 & si:ch211-265g21.1 <= 0.97 & perp <= 0.83 & aqp1a.1 <= 0.84 & illr1 <= 0.84 & pogza <= 0.49 & si:ch73-236c18.6 <= 0.94 & srsf6b > 0.81 & rhoae <= 0.19) | (nr2c2 

	 Model's accuracy: 0.9296
	 Class 1 - Global explanation: (nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.70 & kdm8 <= 0.90 & ypel3 <= 0.92 & zgc:66475 <= 0.71 & HHEX <= 0.48 & map2k4a <= 0.82 & abca2 <= 0.86 & pfdn2 <= 0.98 & rin2 <= 0.96 & si:ch211-284o19.8 <= 0.90 & SVIP <= 0.99 & psme4b > 0.89 & zmp:0000001167 > 0.59) | (nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.70 & kdm8 <= 0.90 & ypel3 <= 0.92 & zgc:66475 <= 0.71 & HHEX <= 0.48 & map2k4a <= 0.82 & abca2 <= 0.86 & pfdn2 <= 0.98 & rin2 <= 0.96 & si:ch211-284o19.8 <= 0.90 & SVIP > 0.99) | (nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.70 & kdm8 <= 0.90 & ypel3 <= 0.92 & zgc:66475 <= 0.71 & HHEX <= 0.48 & map2k4a <= 0.82 & abca2 <= 0.86 & pfdn2 <= 0.98 & rin2 <= 0.96 & si:ch211-284o19.8 > 0.90) | (nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0.59 & TMC8 <= 0.70 & kdm8 <= 0.90 & ypel3 <= 0.92 & zgc:66475 <= 0.71 & HHEX <= 0.48 & map2k4a <= 0.82 & abca2 <= 0.86 & pfdn2 <= 0.98 & rin2 > 0.96) 

In [92]:
results_tree = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_tree.to_csv(os.path.join(results_dir, 'results_tree.csv'))
results_tree

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,decision_tree,0,(nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0....,(nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0....,0.888112,0.888112,0.888112,0,0
1,decision_tree,1,(nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0....,(nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0....,0.909091,0.909091,0.909091,0,0
2,decision_tree,2,(nr2c2 <= 0.76 & zdhhc5a <= 0.53 & top2a <= 0....,(nr2c2 <= 0.76 & zdhhc5a <= 0.53 & top2a <= 0....,0.915493,0.915493,0.915493,0,0
3,decision_tree,3,(nr2c2 <= 0.78 & top2a <= 0.32 & zdhhc5a <= 0....,(nr2c2 <= 0.78 & top2a <= 0.32 & zdhhc5a <= 0....,0.908451,0.908451,0.908451,0,0
4,decision_tree,4,(nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0....,(nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0....,0.943662,0.943662,0.943662,0,0
5,decision_tree,5,(nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0....,(nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0....,0.922535,0.922535,0.922535,0,0
6,decision_tree,6,(nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0....,(nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0....,0.943662,0.943662,0.943662,0,0
7,decision_tree,7,(nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0....,(nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0....,0.915493,0.915493,0.915493,0,0
8,decision_tree,8,(nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0....,(nr2c2 <= 0.76 & top2a <= 0.32 & zdhhc5a <= 0....,0.915493,0.915493,0.915493,0,0
9,decision_tree,9,(nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0....,(nr2c2 <= 0.74 & top2a <= 0.32 & zdhhc5a <= 0....,0.929577,0.929577,0.929577,0,0


# Summary

In [102]:
cols = ['model_accuracy', 'explanation_accuracy', 'explanation_accuracy_inv', 'elapsed_time', 'elapsed_time_inv']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

# pruning
df_mean = results_pruning[cols].mean()
df_sem = results_pruning[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_pruning = pd.concat([df_mean, df_sem])
summary_pruning.name = 'pruning'

# # lime
# df_mean = results_lime[cols].mean()
# df_sem = results_lime[cols].sem()
# df_mean.columns = mean_cols
# df_sem.columns = sem_cols
# summary_lime = pd.concat([df_mean, df_sem])
# summary_lime.name = 'lime'

# weights
df_mean = results_weights[cols].mean()
df_sem = results_weights[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_weights = pd.concat([df_mean, df_sem])
summary_weights.name = 'weights'

# psi
df_mean = results_psi[cols].mean()
df_sem = results_psi[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_psi = pd.concat([df_mean, df_sem])
summary_psi.name = 'psi'

# tree
df_mean = results_tree[cols].mean()
df_sem = results_tree[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_tree = pd.concat([df_mean, df_sem])
summary_tree.name = 'tree'

summary = pd.concat([summary_pruning, 
#                      summary_lime, 
                     summary_weights, 
                     summary_psi, 
                     summary_tree], axis=1).T
summary.columns = mean_cols + sem_cols
summary

Unnamed: 0,model_accuracy_mean,explanation_accuracy_mean,explanation_accuracy_inv_mean,elapsed_time_mean,elapsed_time_inv_mean,model_accuracy_sem,explanation_accuracy_sem,explanation_accuracy_inv_sem,elapsed_time_sem,elapsed_time_inv_sem
pruning,0.928957,0.872673,0.385398,1.498599,6.575522,0.005418,0.014966,0.024985,0.218056,0.949549
weights,0.959928,0.85098,0.355338,2.564746,13.147962,0.004801,0.014468,0.048169,0.156529,1.28251
psi,0.830518,0.549773,0.505752,0.021291,0.022005,0.003234,0.113615,0.096239,0.002193,0.003416
tree,0.919157,0.919157,0.919157,0.0,0.0,0.005314,0.005314,0.005314,0.0,0.0


In [104]:
summary.to_csv(os.path.join(results_dir, 'summary.csv'))