In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, KBinsDiscretizer
from sklearn.impute import SimpleImputer
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
import time
from sklearn.metrics import accuracy_score
from sympy import simplify_logic
import re

from deep_logic.utils.base import validate_network, set_seed, tree_to_formula
from deep_logic.utils.relu_nn import get_reduced_model, prune_features
from deep_logic.utils.psi_nn import prune_equal_fanin
from deep_logic import logic

set_seed(0)

#%%

data = pd.read_csv('data/vdem/V-Dem-CY-Core-v10.csv')

In [2]:
data['country_name_year'] = data['country_name'] + '_' + data['year'].astype(str)
data['country_name_year']

0                   Mexico_1789
1                   Mexico_1790
2                   Mexico_1791
3                   Mexico_1792
4                   Mexico_1793
                  ...          
27008    Piedmont-Sardinia_1857
27009    Piedmont-Sardinia_1858
27010    Piedmont-Sardinia_1859
27011    Piedmont-Sardinia_1860
27012    Piedmont-Sardinia_1861
Name: country_name_year, Length: 27013, dtype: object

In [3]:
data_2000 = data[data['year']>2000].iloc[:, 12:-1].dropna(axis=1)
data_2000

Unnamed: 0,v2x_polyarchy,v2x_polyarchy_codelow,v2x_polyarchy_codehigh,v2x_polyarchy_sd,v2x_delibdem,v2x_delibdem_codelow,v2x_delibdem_codehigh,v2x_delibdem_sd,v2x_egaldem,v2x_egaldem_codelow,...,v2xcl_slave_sd,v2xel_elecparl,v2xel_elecpres,v2xex_elecleg,v2xlg_leginter,v2xme_altinf,v2xme_altinf_codelow,v2xme_altinf_codehigh,v2xme_altinf_sd,v2x_divparctrl
212,0.698,0.632,0.763,0.067,0.539,0.467,0.626,0.080,0.411,0.349,...,0.556,0,0,1.00,0.0,0.767,0.615,0.855,0.653,1.616
213,0.711,0.667,0.778,0.055,0.543,0.475,0.608,0.064,0.418,0.359,...,0.556,0,0,1.00,0.0,0.808,0.683,0.887,0.649,1.616
214,0.715,0.668,0.762,0.047,0.554,0.497,0.606,0.053,0.418,0.378,...,0.556,1,0,1.00,0.0,0.808,0.683,0.887,0.649,1.616
215,0.720,0.668,0.746,0.039,0.566,0.519,0.605,0.043,0.423,0.386,...,0.556,0,0,1.00,0.0,0.808,0.683,0.887,0.649,1.616
216,0.720,0.668,0.746,0.039,0.585,0.522,0.615,0.047,0.423,0.386,...,0.556,0,0,1.00,0.0,0.808,0.683,0.887,0.649,1.616
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25618,0.316,0.283,0.348,0.035,0.215,0.173,0.278,0.052,0.255,0.214,...,0.592,1,1,0.85,0.0,0.516,0.381,0.627,0.633,1.602
25619,0.267,0.236,0.286,0.028,0.194,0.152,0.233,0.041,0.234,0.207,...,0.592,1,1,0.85,0.0,0.516,0.381,0.627,0.633,-0.684
25620,0.260,0.236,0.282,0.025,0.168,0.137,0.208,0.034,0.213,0.188,...,0.599,0,0,0.85,0.0,0.437,0.321,0.579,0.628,-0.448
25621,0.262,0.245,0.284,0.020,0.213,0.181,0.264,0.041,0.214,0.192,...,0.608,0,0,0.85,0.0,0.502,0.374,0.620,0.625,-0.586


In [4]:
high_level_indicators = [
    'v2x_polyarchy',
    #'v2x_libdem',
    #'v2x_partipdem',
    'v2x_delibdem',
    'v2x_egaldem'
]
mid_level_indicators = [
    'v2x_api',
    'v2x_mpi',
    'v2x_freexp_altinf',
    'v2x_frassoc_thick',
    'v2x_suffr',
    'v2xel_frefair',
    'v2x_elecoff',
    #'v2x_liberal',
    'v2xcl_rol',
    #'v2x_jucon',
    #'v2xlg_legcon',
    #'v2x_partip',
    'v2x_cspart',
    #'v2xdd_dd',
    #'v2xel_locelec',
    #'v2xel_regelec',
    'v2xdl_delib',
    'v2x_egal',
    'v2xeg_eqprotec',
    'v2xeg_eqaccess',
    'v2xeg_eqdr',
]

drop_list = ['codelow', 'codehigh', 'sd', 'osp', 'nr', 'mean']
low_level_indicators = []
for f in data_2000.columns:
    if f.endswith('_ord') and f not in high_level_indicators and f not in mid_level_indicators:
        low_level_indicators.append(f)


low_level_indicators_continuous = []
for f in data_2000.columns:
    if f.endswith('_codehigh') or f.endswith('_codelow') and f not in high_level_indicators and f not in mid_level_indicators:
        low_level_indicators_continuous.append(f)

print(f'Main {len(high_level_indicators)} - Area {len(mid_level_indicators)} - Raw {len(low_level_indicators)}')

Main 3 - Area 14 - Raw 62


In [5]:
data_low_continuous = data_2000[low_level_indicators_continuous]
data_low_continuous.shape

(3385, 464)

In [6]:
data_low_raw = data_2000[low_level_indicators]
one_hots = []
for indicator in low_level_indicators:
    c = data_low_raw[indicator].values
    n_bins = int(c.max())
    kbin = KBinsDiscretizer(n_bins=n_bins, encode='onehot-dense', strategy='uniform')
    c1h = kbin.fit_transform(c.reshape(-1, 1))
    one_hots.append(c1h)

new_indicator_names = []
for clist, cname in zip(one_hots, low_level_indicators):
    if clist.shape[1] > 1:
        for i in range(clist.shape[1]):
            new_indicator_names.append(f'{cname}_{i}')
    else:
        new_indicator_names.append(f'{cname}')

data_low = pd.DataFrame(np.hstack(one_hots), columns=new_indicator_names)
data_low.describe()

Unnamed: 0,v2psbars_ord_0,v2psbars_ord_1,v2psbars_ord_2,v2psbars_ord_3,v2psparban_ord_0,v2psparban_ord_1,v2psparban_ord_2,v2psparban_ord_3,v2psorgs_ord_0,v2psorgs_ord_1,...,v2pepwrgen_ord_1,v2pepwrgen_ord_2,v2peedueq_ord_0,v2peedueq_ord_1,v2peedueq_ord_2,v2peedueq_ord_3,v2pehealth_ord_0,v2pehealth_ord_1,v2pehealth_ord_2,v2pehealth_ord_3
count,3385.0,3385.0,3385.0,3385.0,3385.0,3385.0,3385.0,3385.0,3385.0,3385.0,...,3385.0,3385.0,3385.0,3385.0,3385.0,3385.0,3385.0,3385.0,3385.0,3385.0
mean,0.056425,0.037518,0.080059,0.825997,0.053471,0.029838,0.028951,0.88774,0.044018,0.15096,...,0.230133,0.730871,0.05997,0.312851,0.132939,0.494239,0.038109,0.322304,0.145643,0.493944
std,0.230775,0.190057,0.271425,0.379168,0.225004,0.170164,0.167694,0.315733,0.205165,0.358063,...,0.42098,0.443572,0.237467,0.463723,0.339559,0.500041,0.191488,0.467428,0.3528,0.500037
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0
max,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [7]:
data_mid = data_2000[mid_level_indicators] > 0.5
data_mid.describe()

Unnamed: 0,v2x_api,v2x_mpi,v2x_freexp_altinf,v2x_frassoc_thick,v2x_suffr,v2xel_frefair,v2x_elecoff,v2xcl_rol,v2x_cspart,v2xdl_delib,v2x_egal,v2xeg_eqprotec,v2xeg_eqaccess,v2xeg_eqdr
count,3385,3385,3385,3385,3385,3385,3385,3385,3385,3385,3385,3385,3385,3385
unique,2,2,2,2,2,2,2,2,2,2,2,2,2,2
top,True,False,True,True,True,True,True,True,True,True,True,True,True,True
freq,2696,2196,2605,2622,3322,1909,2934,2607,2669,2464,2398,2527,2392,2073


In [8]:
data_high = data_2000[high_level_indicators].iloc[:, 0] > 0.5
data_high.describe()

count     3385
unique       2
top       True
freq      1799
Name: v2x_polyarchy, dtype: object

In [9]:
c0c = torch.FloatTensor(data_low_continuous.values)
c0 = torch.FloatTensor(data_low.values)
c1 = torch.FloatTensor(data_mid.values)
c2 = torch.FloatTensor(data_high.values)
print(c0c.shape)
print(c0.shape)
print(c1.shape)
print(c2.shape)

torch.Size([3385, 464])
torch.Size([3385, 241])
torch.Size([3385, 14])
torch.Size([3385])


# Low-to-mid level explanations

In [10]:
results_dir = 'results/vdem_low2mid'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
mid_concepts = list(data_mid.columns)
low_concepts = list(data_low.columns)
n_rep = 10
tot_epochs = 6001
prune_epochs = 3001

n_splits = 10
seed = 42
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)

x = c0
xh = c1
y = c2

print(low_concepts[:5])
print(mid_concepts[:5])

['v2psbars_ord_0', 'v2psbars_ord_1', 'v2psbars_ord_2', 'v2psbars_ord_3', 'v2psparban_ord_0']
['v2x_api', 'v2x_mpi', 'v2x_freexp_altinf', 'v2x_frassoc_thick', 'v2x_suffr']


## Train loops

In [11]:
def train_nn(x_train, y_train, need_pruning, seed, device, level, relu=False, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    
    if level == 'mid':
        layers = [
            torch.nn.Linear(x_train.size(1), 50),
            torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
            torch.nn.Linear(50, 30),
            torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
            torch.nn.Linear(30, y_train.size(1)),
            torch.nn.Sigmoid(),
        ]
        loss_form = torch.nn.BCELoss()
    else:
        layers = [
            torch.nn.Linear(x_train.size(1), 20),
            torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
            torch.nn.Linear(20, 10),
            torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
            torch.nn.Linear(10, 1),
            torch.nn.Sigmoid(),
        ]
        loss_form = torch.nn.BCELoss()
            
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
#                 if level == 'mid':
#                     loss += 0.0001 * torch.norm(module.weight, 1)
#                     loss += 0.0001 * torch.norm(module.bias, 1)
                if level == 'high':
                    loss += 0.001 * torch.norm(module.weight, 1)
                    loss += 0.001 * torch.norm(module.bias, 1)
                break

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning and level == 'high':
            prune_features(model, n_classes=1, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            if level == 'mid':
                y_pred_d = y_pred > 0.5
                accuracy = y_pred_d.eq(y_train).sum().item() / (y_train.size(0) * y_train.size(1))
                print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
            else:
                y_pred_d = y_pred > 0.5
                accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
                print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
#             if level == 'mid':
#                 y_pred_d = y_pred > 0.5
#                 accuracy = y_pred_d.eq(y_train).sum().item() / (y_train.size(0) * y_train.size(1))
#                 print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
#             else:
#                 y_pred_d = torch.argmax(y_pred, dim=1)
#                 accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
#                 print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [17]:
def train_psi_nn(x_train, y_train, need_pruning, seed, device, level, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device).to(torch.float)

    if level == 'mid':
        layers = [
            torch.nn.Linear(x_train.size(1), 50),
            torch.nn.Sigmoid(),
            torch.nn.Linear(50, 30),
            torch.nn.Sigmoid(),
            torch.nn.Linear(30, y_train.size(1)),
            torch.nn.Sigmoid(),
        ]
    else:
        layers = [
            torch.nn.Linear(x_train.size(1), 10),
            torch.nn.Sigmoid(),
            torch.nn.Linear(10, 4),
            torch.nn.Sigmoid(),
            torch.nn.Linear(4, 1),
            torch.nn.Sigmoid(),
        ]
            
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.00001 * torch.norm(module.weight, 1)

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            model = prune_equal_fanin(model, 2, validate=True, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            if level == 'mid':
                y_pred_d = y_pred > 0.5
                accuracy = y_pred_d.eq(y_train).sum().item() / (y_train.size(0) * y_train.size(1))
                print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
            else:
                y_pred_d = y_pred > 0.5
                accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
                print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [13]:
def c_to_y(method, need_pruning, relu, verbose=False):
    methods = []
    splits = []
    explanations = []
    explanations_inv = []
    model_accuracies_mid = []
    model_accuracies_top = []
    explanation_accuracies = []
    explanation_accuracies_inv = []
    elapsed_times = []
    elapsed_times_inv = []
    #for seed in range(n_rep):
    for split, (train_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
        print(f'Split [{split+1}/{n_splits}]')
        x_train, x_test = torch.FloatTensor(x[train_index]), torch.FloatTensor(x[test_index])
        xh_train, xh_test = torch.FloatTensor(xh[train_index]), torch.FloatTensor(xh[test_index])
        y_train, y_test = torch.FloatTensor(y[train_index]), torch.FloatTensor(y[test_index])
    
        explanation, explanation_inv = '', ''
        explanation_accuracy, explanation_accuracy_inv = 0, 0
        elapsed_time, elapsed_time_inv = 0, 0
        
        if method == 'tree':
            classifier = DecisionTreeRegressor(random_state=seed)
            classifier.fit(x_train.cpu().detach().numpy(), xh_train.cpu().detach().numpy())
            xh_train_preds = classifier.predict(x_train.cpu().detach().numpy())
            xh_test_preds = classifier.predict(x_test.cpu().detach().numpy())
            
            classifier_h = DecisionTreeClassifier(random_state=seed)
            classifier_h.fit(xh_train_preds, y_train)
            y_preds = classifier_h.predict(xh_test_preds)
            
            model_accuracy_mid = accuracy_score(xh_test.cpu().detach().numpy(), xh_test_preds>0.5)
            model_accuracy_top = accuracy_score(y_test.cpu().detach().numpy(), y_preds)

            target_class = 1
            start = time.time()
            explanation = tree_to_formula(classifier_h, mid_concepts, target_class)
            elapsed_time = time.time() - start
            explanation_accuracy = model_accuracy_top

#             target_class_inv = 0
#             start = time.time()
#             explanation_inv = tree_to_formula(classifier, concepts, target_class_inv)
#             elapsed_time_inv = time.time() - start
#             explanation_accuracy_inv = model_accuracy
        
        else:
            if method == 'psi':
                # positive class
                target_class = 1
                level = 'mid'
                model = train_psi_nn(x_train, xh_train, need_pruning, split, device, level, verbose)
                xh_train_preds = model(x_train.to(device)).cpu().detach().numpy()
                xh_test_preds = model(x_test.to(device)).cpu().detach().numpy()
                xh_train_preds = torch.FloatTensor(xh_train_preds)
                xh_test_preds = torch.FloatTensor(xh_test_preds)

                level = 'high'
                model_h = train_psi_nn(xh_train_preds, y_train, need_pruning, split, device, level, verbose)
                y_train_preds = model_h(xh_train_preds.to(device)).cpu().detach().numpy()
                y_test_preds = model_h(xh_test_preds.to(device)).cpu().detach().numpy()

                model_accuracy_mid = accuracy_score(xh_test.cpu().detach().numpy(), xh_test_preds>0.5)
                model_accuracy_top = accuracy_score(y_test.cpu().detach().numpy(), y_test_preds>0.5)

            else:
                level = 'mid'
                model = train_nn(x_train, xh_train, need_pruning, seed, device, level, relu, verbose)
                xh_train_preds = model(x_train.to(device)).cpu().detach().numpy()
                xh_test_preds = model(x_test.to(device)).cpu().detach().numpy()
                xh_train_preds = torch.FloatTensor(xh_train_preds)
                xh_test_preds = torch.FloatTensor(xh_test_preds)

                level = 'high'
                model_h = train_nn(xh_train_preds, y_train, need_pruning, seed, device, level, relu, verbose)
                y_train_preds = model_h(xh_train_preds.to(device)).cpu().detach().numpy()
                y_test_preds = model_h(xh_test_preds.to(device)).cpu().detach().numpy()

                model_accuracy_mid = accuracy_score(xh_test.cpu().detach().numpy(), xh_test_preds>0.5)
                model_accuracy_top = accuracy_score(y_test.cpu().detach().numpy(), y_test_preds>0.5)

            # positive class
            target_class = 1
            start = time.time()
            if method == 'psi':
                global_explanation = logic.generate_fol_explanations(model_h, device)[0]
            else:
                global_explanation, _, _ = logic.relu_nn.combine_local_explanations(model_h, 
                                                                                    xh_train.to(device), y_train.to(device), 
                                                                                    target_class=target_class,
                                                                                    topk_explanations=10,
                                                                                    method=method, device=device)
            elapsed_time = time.time() - start

            if global_explanation:
                explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, xh_test, y_test)
                explanation = logic.base.replace_names(global_explanation, mid_concepts)

#         # negative class
#         target_class_inv = 0
#         if method == 'psi':
#             model = train_psi_nn(x_train, y_train.eq(target_class_inv), need_pruning, seed, device, verbose)

#         start = time.time()
#         if method == 'psi':
#             global_explanation_inv = logic.generate_fol_explanations(model, device)[0]
#         else:
#             global_explanation_inv, _, _ = logic.relunn.combine_local_explanations(model, 
#                                                                                    x_train.to(device), 
#                                                                                    y_train.to(device), 
#                                                                                    target_class=target_class_inv,
#                                                                                    method=method, device=device)
#         elapsed_time_inv = time.time() - start
#         if global_explanation_inv:
#             explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, 
#                                                                       target_class_inv, x_test, y_test)
#             explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
        
        if verbose:
            print(f'\t Model\'s accuracy (mid): {model_accuracy_mid:.4f} - Model\'s accuracy (top): {model_accuracy_top:.4f}')
            print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
            print(f'\t Elapsed time {elapsed_time}')
#             print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
#             print(f'\t Elapsed time {elapsed_time_inv}')

        methods.append(method)
        splits.append(split)
        explanations.append(explanation)
        explanations_inv.append(explanation_inv)
        model_accuracies_mid.append(model_accuracy_mid)
        model_accuracies_top.append(model_accuracy_top)
        explanation_accuracies.append(explanation_accuracy)
        explanation_accuracies_inv.append(explanation_accuracy_inv)
        elapsed_times.append(elapsed_time)
        elapsed_times_inv.append(elapsed_time_inv)
    
    results = pd.DataFrame({
        'method': methods,
        'split': splits,
        'explanation': explanations,
        'explanation_inv': explanations_inv,
        'model_accuracy_mid': model_accuracies_mid,
        'model_accuracy_top': model_accuracies_top,
        'explanation_accuracy': explanation_accuracies,
        'explanation_accuracy_inv': explanation_accuracies_inv,
        'elapsed_time': elapsed_times,
        'elapsed_time_inv': elapsed_times_inv,
    })
    results.to_csv(os.path.join(results_dir, f'results_{method}.csv'))
    
    return results

## General pruning

In [14]:
method = 'pruning'
need_pruning = True
relu = False
results_pruning = c_to_y(method, need_pruning, relu, verbose=False)
results_pruning

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]


Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy_mid,model_accuracy_top,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,pruning,0,(v2x_freexp_altinf & v2x_frassoc_thick & v2x_s...,,0.755162,0.926254,94.690265,0,1.705521,0
1,pruning,1,v2x_freexp_altinf & v2xel_frefair & v2x_elecof...,,0.758112,0.917404,94.690265,0,0.101728,0
2,pruning,2,v2x_freexp_altinf & v2xel_frefair & v2x_elecof...,,0.749263,0.917404,93.510324,0,0.10073,0
3,pruning,3,(v2x_freexp_altinf & v2xel_frefair & v2x_eleco...,,0.743363,0.905605,94.100295,0,0.125663,0
4,pruning,4,v2x_freexp_altinf & v2x_frassoc_thick & v2xel_...,,0.746313,0.935103,96.165192,0,0.111265,0
5,pruning,5,v2x_freexp_altinf & v2x_frassoc_thick & v2xel_...,,0.715976,0.908284,93.491124,0,0.114686,0
6,pruning,6,(v2x_freexp_altinf & v2xel_frefair & v2x_eleco...,,0.727811,0.914201,93.786982,0,0.212631,0
7,pruning,7,v2x_freexp_altinf & v2x_frassoc_thick & v2xel_...,,0.733728,0.923077,96.153846,0,0.126662,0
8,pruning,8,v2x_mpi & v2x_freexp_altinf & v2xcl_rol,,0.695266,0.849112,80.177515,0,0.096741,0
9,pruning,9,v2x_freexp_altinf & v2xel_frefair & v2x_elecof...,,0.772189,0.920118,92.899408,0,0.10171,0


## ReLU

In [15]:
method = 'weights'
need_pruning = False
relu = True
results_weights = c_to_y(method, need_pruning, relu, False)
results_weights

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]


Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy_mid,model_accuracy_top,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,weights,0,v2xel_frefair | (v2x_freexp_altinf & ~v2x_mpi),,0.749263,0.935103,74.631268,0,0.879821,0
1,weights,1,v2xel_frefair | v2xcl_rol | (v2x_freexp_altinf...,,0.749263,0.935103,66.666667,0,1.559515,0
2,weights,2,v2xel_frefair | (v2x_freexp_altinf & ~v2x_mpi)...,,0.746313,0.932153,73.156342,0,0.332482,0
3,weights,3,v2xel_frefair | (v2x_freexp_altinf & v2x_frass...,,0.752212,0.914454,74.336283,0,4.748825,0
4,weights,4,(v2xel_frefair & v2xcl_rol) | (v2x_freexp_alti...,,0.746313,0.941003,80.825959,0,0.464865,0
5,weights,5,(v2xel_frefair & v2xcl_rol & ~v2x_mpi) | (v2xe...,,0.718935,0.893491,43.786982,0,3.08857,0
6,weights,6,(v2x_mpi & v2xel_frefair) | (v2x_suffr & v2xel...,,0.745562,0.926036,74.852071,0,0.651257,0
7,weights,7,v2xel_frefair | v2xcl_rol | (v2x_freexp_altinf...,,0.757396,0.934911,70.118343,0,802.924447,0
8,weights,8,(v2x_frassoc_thick & ~v2x_mpi) | (v2x_suffr & ...,,0.721893,0.940828,46.153846,0,12.794576,0
9,weights,9,v2xel_frefair | (v2xcl_rol & ~v2x_mpi),,0.751479,0.911243,76.035503,0,0.308771,0


## Psi network

In [20]:
method = 'psi'
need_pruning = True
relu = False
results_psi = c_to_y(method, need_pruning, relu, verbose=False)
results_psi

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]


Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy_mid,model_accuracy_top,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,psi,0,(v2x_mpi),,0.345133,0.867257,81.415929,0,0.00897,0
1,psi,1,(v2x_freexp_altinf & v2xcl_rol),,0.336283,0.849558,82.890855,0,0.037899,0
2,psi,2,(v2xeg_eqaccess & (v2xdl_delib | ~v2x_suffr)),,0.353982,0.834808,82.890855,0,0.048869,0
3,psi,3,(v2x_elecoff),,0.265487,0.764012,68.141593,0,0.026928,0
4,psi,4,(v2x_frassoc_thick & v2xeg_eqaccess),,0.168142,0.820059,87.610619,0,0.032912,0
5,psi,5,(v2x_frassoc_thick & v2xeg_eqaccess),,0.378698,0.786982,79.585799,0,0.028923,0
6,psi,6,(v2x_egal),,0.346154,0.733728,71.005917,0,0.026928,0
7,psi,7,(v2xel_frefair),,0.316568,0.781065,92.011834,0,0.024933,0
8,psi,8,(v2x_freexp_altinf & v2x_cspart & (v2xel_frefa...,,0.304734,0.849112,88.16568,0,0.037898,0
9,psi,9,(v2xeg_eqdr),,0.331361,0.636095,64.792899,0,0.007978,0


## Decision tree

In [21]:
method = 'tree'
need_pruning = False
relu = False
results_tree = c_to_y(method, need_pruning, relu, verbose=False)
results_tree

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]


Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy_mid,model_accuracy_top,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,tree,0,(v2xel_frefair <= 0.33 & v2xcl_rol > 0.90 & v2...,,0.749263,0.935103,0.935103,0,0.000997,0
1,tree,1,(v2xel_frefair <= 0.33 & v2xdl_delib > 0.58 & ...,,0.784661,0.932153,0.932153,0,0.000997,0
2,tree,2,(v2xel_frefair <= 0.58 & v2xcl_rol > 0.88 & v2...,,0.746313,0.932153,0.932153,0,0.002004,0
3,tree,3,(v2xel_frefair <= 0.88 & v2xcl_rol > 0.90 & v2...,,0.734513,0.914454,0.914454,0,0.000998,0
4,tree,4,(v2xel_frefair <= 0.44 & v2xcl_rol > 0.90 & v2...,,0.722714,0.917404,0.917404,0,0.000997,0
5,tree,5,(v2xel_frefair <= 0.46 & v2xdl_delib > 0.61 & ...,,0.745562,0.905325,0.905325,0,0.000999,0
6,tree,6,(v2xel_frefair <= 0.47 & v2xcl_rol > 0.90 & v2...,,0.760355,0.899408,0.899408,0,0.000997,0
7,tree,7,(v2xel_frefair <= 0.33 & v2xcl_rol > 0.88 & v2...,,0.769231,0.926036,0.926036,0,0.000999,0
8,tree,8,(v2xel_frefair <= 0.88 & v2xcl_rol > 0.90 & v2...,,0.754438,0.928994,0.928994,0,0.001015,0
9,tree,9,(v2xel_frefair <= 0.58 & v2xcl_rol > 0.88 & v2...,,0.789941,0.920118,0.920118,0,0.000997,0


In [22]:
cols = ['model_accuracy_mid', 'model_accuracy_top', 'explanation_accuracy', 
        'explanation_accuracy_inv', 'elapsed_time', 'elapsed_time_inv']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

# pruning
df_mean = results_pruning[cols].mean()
df_sem = results_pruning[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_pruning = pd.concat([df_mean, df_sem])
summary_pruning.name = 'pruning'

# # lime
# df_mean = results_lime[cols].mean()
# df_sem = results_lime[cols].sem()
# df_mean.columns = mean_cols
# df_sem.columns = sem_cols
# summary_lime = pd.concat([df_mean, df_sem])
# summary_lime.name = 'lime'

# weights
df_mean = results_weights[cols].mean()
df_sem = results_weights[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_weights = pd.concat([df_mean, df_sem])
summary_weights.name = 'weights'

# psi
df_mean = results_psi[cols].mean()
df_sem = results_psi[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_psi = pd.concat([df_mean, df_sem])
summary_psi.name = 'psi'

# tree
df_mean = results_tree[cols].mean()
df_sem = results_tree[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_tree = pd.concat([df_mean, df_sem])
summary_tree.name = 'tree'

summary = pd.concat([summary_pruning, 
#                      summary_lime, 
                     summary_weights, summary_psi, summary_tree], axis=1).T
summary.columns = mean_cols + sem_cols
summary

Unnamed: 0,model_accuracy_mid_mean,model_accuracy_top_mean,explanation_accuracy_mean,explanation_accuracy_inv_mean,elapsed_time_mean,elapsed_time_inv_mean,model_accuracy_mid_sem,model_accuracy_top_sem,explanation_accuracy_sem,explanation_accuracy_inv_sem,elapsed_time_sem,elapsed_time_inv_sem
pruning,0.739718,0.911656,92.966522,0.0,0.279734,0.0,0.007068,0.007451,1.462212,0.0,0.158782,0.0
weights,0.743863,0.926433,68.056326,0.0,82.775313,0.0,0.004066,0.004869,4.022843,0.0,80.025689,0.0
psi,0.314654,0.792268,79.851198,0.0,0.028224,0.0,0.018945,0.021982,2.86807,0.0,0.003987,0.0
tree,0.755699,0.921115,0.921115,0.0,0.0011,0.0,0.006656,0.003808,0.003808,0.0,0.0001,0.0


In [23]:
summary.to_csv(os.path.join(results_dir, 'summary.csv'))