In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sympy import simplify_logic
import time
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.tree import _tree, export_text

import lens
from lens.utils.base import validate_network, set_seed, tree_to_formula
from lens.utils.layer import prune_logic_layers
from lens import logic

results_dir = 'results_ll/xor'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
concepts = ['c1', 'c2']
n_rep = 10
tot_epochs = 2001

In [2]:
# XOR problem
x_train = torch.tensor([
    [0, 0],
    [0, 1],
    [1, 0],
    [1, 1],
], dtype=torch.float)
y_train = torch.tensor([0, 1, 1, 0], dtype=torch.float)

x_test = x_train
y_test = y_train

In [3]:
def train_nn(x_train, y_train, seed, device, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        lens.nn.XLogic(2, 5, activation='identity', first=True),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(5, 5),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(5, 1),
        lens.nn.XLogic(1, 1, activation='sigmoid', top=True),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
    loss_form = torch.nn.BCELoss()
    model.train()
    need_pruning = True
    for epoch in range(tot_epochs):
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        loss = loss_form(y_pred, y_train)

        loss.backward()
        optimizer.step()

        # compute accuracy
        if epoch % 100 == 0 and verbose:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [4]:
def c_to_y(method, verbose=False):
    methods = []
    splits = []
    explanations = []
    model_accuracies = []
    explanation_accuracies = []
    explanation_fidelities = []
    explanation_complexities = []
    elapsed_times = []
    for seed in range(n_rep):
        explanation, explanation_inv = '', ''
        explanation_accuracy, explanation_accuracy_inv = 0, 0
        
        print(f'Seed [{seed+1}/{n_rep}]')
        
        if method == 'tree':
            classifier = DecisionTreeClassifier(random_state=seed)
            classifier.fit(x_train.detach().numpy(), y_train.detach().numpy())
            y_preds = classifier.predict(x_test.detach().numpy())
            model_accuracy = accuracy_score(y_test.detach().numpy(), y_preds)

            target_class = 1
            start = time.time()
            explanation = tree_to_formula(classifier, concepts, target_class)
            elapsed_time = time.time() - start

            target_class_inv = 0
            start = time.time()
            explanation_inv = tree_to_formula(classifier, concepts, target_class_inv)
            elapsed_time = time.time() - start
        
        else:
            model = train_nn(x_train, y_train, seed, device, verbose=False)
            y_preds = model(x_test.to(device)).cpu().detach().numpy() > 0.5
            model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds)

            # positive class
            start = time.time()
            class_explanation, class_explanations = lens.logic.explain_class(model.cpu(), x_train.cpu(), y_train.cpu(), 
                                                                           binary=True, target_class=1,
                                                                           topk_explanations=10)
            elapsed_time = time.time() - start
            
            if class_explanation:
                explanation = logic.base.replace_names(class_explanation, concepts)
                explanation_accuracy, y_formula = logic.base.test_explanation(class_explanation, 
                                                                              target_class=1, 
                                                                              x=x_train, y=y_train, 
                                                                              metric=accuracy_score)
                explanation_fidelity = lens.logic.fidelity(y_formula, y_preds)
                explanation_complexity = lens.logic.complexity(class_explanation)
        
        if verbose:
            print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
            print(f'\t Class 1 - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
            print(f'\t Elapsed time {elapsed_time}')

        methods.append(method)
        splits.append(seed)
        explanations.append(explanation)
        model_accuracies.append(model_accuracy)
        explanation_accuracies.append(explanation_accuracy)
        explanation_fidelities.append(explanation_fidelity)
        explanation_complexities.append(explanation_complexity)
        elapsed_times.append(elapsed_time)
    
    results = pd.DataFrame({
        'method': methods,
        'split': splits,
        'explanation': explanations,
        'model_accuracy': model_accuracies,
        'explanation_accuracy': explanation_accuracies,
        'explanation_fidelity': explanation_fidelities,
        'explanation_complexity': explanation_complexities,
        'elapsed_time': elapsed_times,
    })
    results.to_csv(os.path.join(results_dir, f'results_{method}.csv'))
    
    return results

# General pruning

In [5]:
results_pruning = c_to_y(method='logic_layer', verbose=False)
results_pruning

Seed [1/10]
Seed [2/10]
Seed [3/10]
Seed [4/10]
Seed [5/10]
Seed [6/10]
Seed [7/10]
Seed [8/10]
Seed [9/10]
Seed [10/10]


Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,explanation_fidelity,explanation_complexity,elapsed_time
0,logic_layer,0,(c1 & ~c2) | (c2 & ~c1),1.0,1.0,1.0,4,0.023936
1,logic_layer,1,(c1 & ~c2) | (c2 & ~c1),1.0,1.0,1.0,4,0.021942
2,logic_layer,2,(c1 & ~c2) | (c2 & ~c1),1.0,1.0,1.0,4,0.015993
3,logic_layer,3,(c1 & ~c2) | (c2 & ~c1),1.0,1.0,1.0,4,0.016955
4,logic_layer,4,(c1 & ~c2) | (c2 & ~c1),1.0,1.0,1.0,4,0.013962
5,logic_layer,5,(c1 & ~c2) | (c2 & ~c1),1.0,1.0,1.0,4,0.016948
6,logic_layer,6,(c1 & ~c2) | (c2 & ~c1),1.0,1.0,1.0,4,0.012965
7,logic_layer,7,(c1 & ~c2) | (c2 & ~c1),1.0,1.0,1.0,4,0.019945
8,logic_layer,8,(c1 & ~c2) | (c2 & ~c1),1.0,1.0,1.0,4,0.015952
9,logic_layer,9,(c1 & ~c2) | (c2 & ~c1),1.0,1.0,1.0,4,0.018949
