In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sympy import simplify_logic
import time
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.tree import _tree, export_text

from deep_logic.utils.base import validate_network, set_seed, tree_to_formula
from deep_logic.utils.relunn import get_reduced_model, prune_features
from deep_logic.utils.sigmoidnn import prune_equal_fanin
from deep_logic import logic

results_dir = 'results/xor'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
concepts = ['c1', 'c2', 'c3', 'c4']
n_rep = 10
tot_epochs = 4001
prune_epochs = 2000

In [2]:
# XOR problem
x_train = torch.tensor([
    [0, 0, 0, 1],
    [0, 1, 0, 1],
    [1, 0, 0, 1],
    [1, 1, 0, 1],
], dtype=torch.float)
y_train = torch.tensor([0, 1, 1, 0], dtype=torch.long)


x_test = torch.tensor([
    [0, 0.95, 0, 1],
    [0, 0.9, 0.2, 0.7],
    [0.05, 1, 0.5, 1],
    [0.1, 0.8, 0.1, 0.8],
    [0.45, 1, 0.95, 0.9],
    [0, 0.35, 0.95, 0.9],
    [0.95, 0.9, 0.75, 0.2],
    [0.75, 0.2, 0.75, 0.15],
    [0.75, 0.15, 0.75, 0.2],
], dtype=torch.float)
y_test = torch.tensor([1, 1, 1, 1, 1, 0, 0, 1, 1], dtype=torch.long)

In [3]:
def train_nn(x_train, y_train, need_pruning, seed, device, relu=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        torch.nn.Linear(x_train.size(1), 10),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(10, 4),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(4, 2),
        torch.nn.Softmax(dim=1),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_form = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train)
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.004 * torch.norm(module.weight, 1)
                loss += 0.004 * torch.norm(module.bias, 1)
                break

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            prune_features(model, n_classes=1, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0:
            y_pred_d = torch.argmax(y_pred, dim=1)
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

# General pruning

In [4]:
need_pruning = True
method = 'pruning'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []
for seed in range(n_rep):
    explanation, explanation_inv = '', ''
    explanation_accuracy, explanation_accuracy_inv = 0, 0
    print(f'Seed [{seed+1}/{n_rep}]')
    
    model = train_nn(x_train, y_train, need_pruning, seed, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    # positive class
    target_class = 1
    start = time.time()
    global_explanation, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                       x_train.to(device), y_train.to(device), 
                                                                       target_class=target_class,
                                                                       method=method, device=device)
    elapsed_time = time.time() - start
    if global_explanation:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    start = time.time()
    global_explanation_inv, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                           x_train.to(device), y_train.to(device), 
                                                                           target_class=target_class,
                                                                           method=method, device=device)
    elapsed_time_inv = time.time() - start
    if global_explanation_inv:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(seed)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Seed [1/10]
	 Epoch 0: train accuracy: 0.5000
	 Epoch 500: train accuracy: 1.0000
	 Epoch 1000: train accuracy: 1.0000
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 1.0000
	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(c1 & ~c2) | (c2 & ~c1)" - Accuracy: 1.0000
	 Elapsed time 0.013961553573608398
	 Class 0 - Global explanation: "(c1 & c2) | (~c1 & ~c2)" - Accuracy: 1.0000
	 Elapsed time 0.012965917587280273
Seed [2/10]
	 Epoch 0: train accuracy: 0.5000
	 Epoch 500: train accuracy: 1.0000
	 Epoch 1000: train accuracy: 1.0000
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 1.0000
	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(c1 &

In [5]:
results_pruning = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_pruning.to_csv(os.path.join(results_dir, 'results_pruning.csv'))
results_pruning

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,pruning,0,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.013962,0.012966
1,pruning,1,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.016954,0.018946
2,pruning,2,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.013963,0.017978
3,pruning,3,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.020945,0.031909
4,pruning,4,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.013964,0.014958
5,pruning,5,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.012966,0.013962
6,pruning,6,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.013963,0.013962
7,pruning,7,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.016955,0.016955
8,pruning,8,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.016955,0.015957
9,pruning,9,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.013967,0.017956


# LIME

In [6]:
need_pruning = False
method = 'lime'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []
for seed in range(n_rep):
    print(f'Seed [{seed+1}/{n_rep}]')
    
    model = train_nn(x_train, y_train, need_pruning, seed, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    # positive class
    target_class = 1
    start = time.time()
    global_explanation, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                       x_train.to(device), y_train.to(device), 
                                                                       target_class=target_class,
                                                                       method=method, device=device)
    elapsed_time = time.time() - start
    if global_explanation:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    start = time.time()
    global_explanation_inv, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                           x_train.to(device), y_train.to(device), 
                                                                           target_class=target_class,
                                                                           method=method, device=device)
    elapsed_time_inv = time.time() - start
    if global_explanation_inv:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(seed)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Seed [1/10]
	 Epoch 0: train accuracy: 0.5000
	 Epoch 500: train accuracy: 1.0000
	 Epoch 1000: train accuracy: 1.0000
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 1.0000
	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(c1 & ~c2) | (c2 & ~c1)" - Accuracy: 1.0000
	 Elapsed time 0.6768643856048584
	 Class 0 - Global explanation: "(c1 & c2) | (~c1 & ~c2)" - Accuracy: 1.0000
	 Elapsed time 0.7114787101745605
Seed [2/10]
	 Epoch 0: train accuracy: 0.5000
	 Epoch 500: train accuracy: 1.0000
	 Epoch 1000: train accuracy: 1.0000
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 1.0000
	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(c1 & ~c2

In [7]:
results_lime = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_lime.to_csv(os.path.join(results_dir, 'results_lime.csv'))
results_lime

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,lime,0,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.676864,0.711479
1,lime,1,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.747026,0.695234
2,lime,2,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.69614,0.717908
3,lime,3,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.690565,0.711605
4,lime,4,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.66516,0.666279
5,lime,5,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.651258,0.673329
6,lime,6,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.730987,0.721074
7,lime,7,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.667218,0.646298
8,lime,8,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.725228,0.681281
9,lime,9,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.71409,0.684914


# Weights

In [8]:
need_pruning = False
method = 'weights'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []
for seed in range(n_rep):
    print(f'Seed [{seed+1}/{n_rep}]')
    
    model = train_nn(x_train, y_train, need_pruning, seed, device, relu=True)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    # positive class
    target_class = 1
    start = time.time()
    global_explanation, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                       x_train.to(device), y_train.to(device), 
                                                                       target_class=target_class,
                                                                       method=method, device=device)
    elapsed_time = time.time() - start
    if global_explanation:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    start = time.time()
    global_explanation_inv, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                           x_train.to(device), y_train.to(device), 
                                                                           target_class=target_class,
                                                                           method=method, device=device)
    elapsed_time_inv = time.time() - start
    if global_explanation_inv:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(seed)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Seed [1/10]
	 Epoch 0: train accuracy: 0.5000
	 Epoch 500: train accuracy: 1.0000
	 Epoch 1000: train accuracy: 1.0000
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 1.0000
	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Model's accuracy: 0.8889
	 Class 1 - Global explanation: "c1 | c2" - Accuracy: 0.8889
	 Elapsed time 0.013963699340820312
	 Class 0 - Global explanation: "(c1 & c2) | (~c1 & ~c2)" - Accuracy: 1.0000
	 Elapsed time 0.01695418357849121
Seed [2/10]
	 Epoch 0: train accuracy: 0.5000
	 Epoch 500: train accuracy: 1.0000
	 Epoch 1000: train accuracy: 1.0000
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 1.0000
	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(c1 & ~c2) | (c2 & ~c1

In [9]:
results_weights = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_weights.to_csv(os.path.join(results_dir, 'results_weights.csv'))
results_weights

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,weights,0,c1 | c2,(c1 & c2) | (~c1 & ~c2),0.888889,0.888889,1.0,0.013964,0.016954
1,weights,1,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.015959,0.015957
2,weights,2,~c1 | ~c2,c1 | ~c2,1.0,0.888889,0.777778,0.014962,0.014959
3,weights,3,~c1 | ~c2,(c1 & c2) | (~c1 & ~c2),1.0,0.888889,1.0,0.009974,0.01496
4,weights,4,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.011968,0.012964
5,weights,5,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.013964,0.013963
6,weights,6,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.011968,0.012965
7,weights,7,~c1 | ~c2,c2 | ~c1,1.0,0.888889,0.444444,0.01297,0.011024
8,weights,8,(c1 & ~c2) | (c2 & ~c1),(c1 & c2) | (~c1 & ~c2),1.0,1.0,1.0,0.015958,0.015957
9,weights,9,c1 | c2,~c2,0.888889,0.888889,0.666667,0.015958,0.013963


# Psi network

In [10]:
def train_psi_nn(x_train, y_train, need_pruning, seed, device):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device).to(torch.float)
    layers = [
        torch.nn.Linear(x_train.size(1), 10),
        torch.nn.Sigmoid(),
        torch.nn.Linear(10, 4),
        torch.nn.Sigmoid(),
        torch.nn.Linear(4, 1),
        torch.nn.Sigmoid(),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.00001 * torch.norm(module.weight, 1)

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            model = prune_equal_fanin(model, 2, validate=True, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [11]:
need_pruning = True
method = 'psi'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []
for seed in range(n_rep):
    print(f'Seed [{seed+1}/{n_rep}]')
    
    # positive class
    target_class = 1
    model = train_psi_nn(x_train, y_train, need_pruning, seed, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds > 0.5)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    start = time.time()
    global_explanation = logic.generate_fol_explanations(model, device)[0]
    elapsed_time = time.time() - start
    explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
    explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    model = train_psi_nn(x_train, y_train.eq(target_class), need_pruning, seed, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.eq(target_class).cpu().detach().numpy(), y_preds > 0.5)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    start = time.time()
    global_explanation_inv = logic.generate_fol_explanations(model, device)[0]
    elapsed_time_inv = time.time() - start
    explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, 
                                                              target_class, x_test, y_test)
    explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(seed)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Seed [1/10]
	 Epoch 0: train accuracy: 0.5000
	 Epoch 500: train accuracy: 0.7500
	 Epoch 1000: train accuracy: 0.7500
	 Epoch 1500: train accuracy: 0.7500
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 0.5000
	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "((c1 & ~c2) | (c2 & ~c1))" - Accuracy: 1.0000
	 Elapsed time 0.024934053421020508
	 Epoch 0: train accuracy: 0.5000
	 Epoch 500: train accuracy: 0.5000
	 Epoch 1000: train accuracy: 0.5000
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 0.7500
	 Epoch 3000: train accuracy: 0.7500
	 Epoch 3500: train accuracy: 0.7500
	 Epoch 4000: train accuracy: 0.7500
	 Model's accuracy: 0.8889
	 Class 0 - Global explanation: "(~c1 & ~c2)" - Accuracy: 0.8889
	 Elapsed time 0.03690052032470703
Seed [2/10]
	 Epoch 0: train accuracy: 0.5000
	 Epoch 500: tr

	 Epoch 2500: train accuracy: 0.7500
	 Epoch 3000: train accuracy: 0.7500
	 Epoch 3500: train accuracy: 0.7500
	 Epoch 4000: train accuracy: 0.7500
	 Model's accuracy: 0.7778
	 Class 0 - Global explanation: "(c1 & c2)" - Accuracy: 0.8889
	 Elapsed time 0.027927875518798828
Seed [10/10]
	 Epoch 0: train accuracy: 0.5000
	 Epoch 500: train accuracy: 0.7500
	 Epoch 1000: train accuracy: 0.7500
	 Epoch 1500: train accuracy: 0.7500
	 Epoch 2000: train accuracy: 0.7500
	 Epoch 2500: train accuracy: 0.7500
	 Epoch 3000: train accuracy: 0.7500
	 Epoch 3500: train accuracy: 0.7500
	 Epoch 4000: train accuracy: 0.7500
	 Model's accuracy: 0.8889
	 Class 1 - Global explanation: "(c1 | c2)" - Accuracy: 0.8889
	 Elapsed time 0.036901235580444336
	 Epoch 0: train accuracy: 0.5000
	 Epoch 500: train accuracy: 0.5000
	 Epoch 1000: train accuracy: 0.7500
	 Epoch 1500: train accuracy: 0.7500
	 Epoch 2000: train accuracy: 0.7500
	 Epoch 2500: train accuracy: 0.7500
	 Epoch 3000: train accuracy: 0.7500
	 E

In [12]:
results_psi = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_psi.to_csv(os.path.join(results_dir, 'results_psi.csv'))
results_psi

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,psi,0,((c1 & ~c2) | (c2 & ~c1)),(~c1 & ~c2),0.888889,1.0,0.888889,0.024934,0.036901
1,psi,1,((c1 & ~c2) | (c2 & ~c1)),((c1 & c2) | (~c1 & ~c2)),0.777778,1.0,1.0,0.02593,0.038995
2,psi,2,((c1 & ~c2) | (c2 & ~c1)),((c1 & c2) | (~c1 & ~c2)),1.0,1.0,1.0,0.0369,0.04089
3,psi,3,((c1 & ~c2) | (c2 & ~c1)),(c2 | ~c1),0.444444,1.0,0.444444,0.033176,0.038895
4,psi,4,((c1 & ~c2) | (c2 & ~c1)),((c1 & c2) | (~c1 & ~c2)),0.777778,1.0,1.0,0.02992,0.044882
5,psi,5,(c2 & ~c1),((c1 & c2) | (~c1 & ~c2)),1.0,0.777778,1.0,0.03391,0.027925
6,psi,6,(c2 & ~c1),(~c1 & ~c2),0.444444,0.777778,0.888889,0.034907,0.02992
7,psi,7,(c1 & ~c2),((c1 & c2) | (~c1 & ~c2)),1.0,0.444444,1.0,0.030917,0.029919
8,psi,8,((c1 & ~c2) | (c2 & ~c1)),(c1 & c2),0.777778,1.0,0.888889,0.033419,0.027928
9,psi,9,(c1 | c2),(c1 | ~c2),0.777778,0.888889,0.777778,0.036901,0.034907


# Decision tree

In [13]:
need_pruning = False
method = 'decision_tree'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []
for seed in range(n_rep):
    print(f'Seed [{seed+1}/{n_rep}]')
    
    classifier = DecisionTreeClassifier(random_state=seed)
    classifier.fit(x_train.detach().numpy(), y_train.detach().numpy())
    y_preds = classifier.predict(x_test.detach().numpy())
    model_accuracy = accuracy_score(y_test.detach().numpy(), y_preds)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    target_class = 1
    start = time.time()
    explanation = tree_to_formula(classifier, concepts, target_class)
    elapsed_time = time.time() - start
    print(f'\t Class {target_class} - Global explanation: {explanation}')
    print(f'\t Elapsed time {elapsed_time}')
    
    target_class = 0
    start = time.time()
    explanation_inv = tree_to_formula(classifier, concepts, target_class)
    elapsed_time = time.time() - start
    print(f'\t Class {target_class} - Global explanation: {explanation_inv}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(seed)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(model_accuracy)
    explanation_accuracies_inv.append(model_accuracy)
    elapsed_times.append(0)
    elapsed_times_inv.append(0)

Seed [1/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: (c2 <= 0.50 & c1 > 0.50) | (c2 > 0.50 & c1 <= 0.50)
	 Elapsed time 0.0
	 Class 0 - Global explanation: (c2 <= 0.50 & c1 <= 0.50) | (c2 > 0.50 & c1 > 0.50)
	 Elapsed time 0.03490710258483887
Seed [2/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: (c1 <= 0.50 & c2 > 0.50) | (c1 > 0.50 & c2 <= 0.50)
	 Elapsed time 0.0
	 Class 0 - Global explanation: (c1 <= 0.50 & c2 <= 0.50) | (c1 > 0.50 & c2 > 0.50)
	 Elapsed time 0.03490710258483887
Seed [3/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: (c1 <= 0.50 & c2 > 0.50) | (c1 > 0.50 & c2 <= 0.50)
	 Elapsed time 0.0
	 Class 0 - Global explanation: (c1 <= 0.50 & c2 <= 0.50) | (c1 > 0.50 & c2 > 0.50)
	 Elapsed time 0.03490710258483887
Seed [4/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: (c1 <= 0.50 & c2 > 0.50) | (c1 > 0.50 & c2 <= 0.50)
	 Elapsed time 0.0
	 Class 0 - Global explanation: (c1 <= 0.50 & c2 <= 0.50) | (c1 > 0.50 & c2

In [14]:
results_tree = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_tree.to_csv(os.path.join(results_dir, 'results_tree.csv'))
results_tree

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,decision_tree,0,(c2 <= 0.50 & c1 > 0.50) | (c2 > 0.50 & c1 <= ...,(c2 <= 0.50 & c1 <= 0.50) | (c2 > 0.50 & c1 > ...,1.0,1.0,1.0,0,0
1,decision_tree,1,(c1 <= 0.50 & c2 > 0.50) | (c1 > 0.50 & c2 <= ...,(c1 <= 0.50 & c2 <= 0.50) | (c1 > 0.50 & c2 > ...,1.0,1.0,1.0,0,0
2,decision_tree,2,(c1 <= 0.50 & c2 > 0.50) | (c1 > 0.50 & c2 <= ...,(c1 <= 0.50 & c2 <= 0.50) | (c1 > 0.50 & c2 > ...,1.0,1.0,1.0,0,0
3,decision_tree,3,(c1 <= 0.50 & c2 > 0.50) | (c1 > 0.50 & c2 <= ...,(c1 <= 0.50 & c2 <= 0.50) | (c1 > 0.50 & c2 > ...,1.0,1.0,1.0,0,0
4,decision_tree,4,(c1 <= 0.50 & c2 > 0.50) | (c1 > 0.50 & c2 <= ...,(c1 <= 0.50 & c2 <= 0.50) | (c1 > 0.50 & c2 > ...,1.0,1.0,1.0,0,0
5,decision_tree,5,(c2 <= 0.50 & c1 > 0.50) | (c2 > 0.50 & c1 <= ...,(c2 <= 0.50 & c1 <= 0.50) | (c2 > 0.50 & c1 > ...,1.0,1.0,1.0,0,0
6,decision_tree,6,(c1 <= 0.50 & c2 > 0.50) | (c1 > 0.50 & c2 <= ...,(c1 <= 0.50 & c2 <= 0.50) | (c1 > 0.50 & c2 > ...,1.0,1.0,1.0,0,0
7,decision_tree,7,(c2 <= 0.50 & c1 > 0.50) | (c2 > 0.50 & c1 <= ...,(c2 <= 0.50 & c1 <= 0.50) | (c2 > 0.50 & c1 > ...,1.0,1.0,1.0,0,0
8,decision_tree,8,(c2 <= 0.50 & c1 > 0.50) | (c2 > 0.50 & c1 <= ...,(c2 <= 0.50 & c1 <= 0.50) | (c2 > 0.50 & c1 > ...,1.0,1.0,1.0,0,0
9,decision_tree,9,(c1 <= 0.50 & c2 > 0.50) | (c1 > 0.50 & c2 <= ...,(c1 <= 0.50 & c2 <= 0.50) | (c1 > 0.50 & c2 > ...,1.0,1.0,1.0,0,0


# Summary

In [15]:
cols = ['model_accuracy', 'explanation_accuracy', 'explanation_accuracy_inv', 'elapsed_time', 'elapsed_time_inv']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

# pruning
df_mean = results_pruning[cols].mean()
df_sem = results_pruning[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_pruning = pd.concat([df_mean, df_sem])
summary_pruning.name = 'pruning'

# lime
df_mean = results_lime[cols].mean()
df_sem = results_lime[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_lime = pd.concat([df_mean, df_sem])
summary_lime.name = 'lime'

# weights
df_mean = results_weights[cols].mean()
df_sem = results_weights[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_weights = pd.concat([df_mean, df_sem])
summary_weights.name = 'weights'

# psi
df_mean = results_psi[cols].mean()
df_sem = results_psi[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_psi = pd.concat([df_mean, df_sem])
summary_psi.name = 'psi'

# tree
df_mean = results_tree[cols].mean()
df_sem = results_tree[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_tree = pd.concat([df_mean, df_sem])
summary_tree.name = 'tree'

summary = pd.concat([summary_pruning, summary_lime, summary_weights, summary_psi, summary_tree], axis=1).T
summary.columns = mean_cols + sem_cols
summary

Unnamed: 0,model_accuracy_mean,explanation_accuracy_mean,explanation_accuracy_inv_mean,elapsed_time_mean,elapsed_time_inv_mean,model_accuracy_sem,explanation_accuracy_sem,explanation_accuracy_inv_sem,elapsed_time_sem,elapsed_time_inv_sem
pruning,1.0,1.0,1.0,0.015459,0.017555,0.0,0.0,0.0,0.000776,0.001715
lime,1.0,1.0,1.0,0.696454,0.69094,0.0,0.0,0.0,0.010102,0.007848
weights,0.977778,0.944444,0.888889,0.013764,0.014367,0.014815,0.018519,0.061975,0.000645,0.000557
psi,0.788889,0.888889,0.888889,0.032091,0.035116,0.065105,0.057378,0.054935,0.001315,0.001882
tree,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [16]:
summary.to_csv(os.path.join(results_dir, 'summary.csv'))