In [1]:
import sys
sys.path.append('..')
import argparse
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sympy import simplify_logic
import time
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.tree import _tree, export_text
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
from sklearn.preprocessing import OneHotEncoder

from deep_logic.utils.base import validate_network, set_seed, tree_to_formula
from deep_logic.utils.relunn import get_reduced_model, prune_features
from deep_logic.utils.sigmoidnn import prune_equal_fanin
from deep_logic import logic

from dSprites.dSprites_loader import load_dsprites, concept_filters, get_shape_scale
from dSprites.dSprites_style_I2C import i2c_style

results_dir = './results/dsprites'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
n_rep = 10
tot_epochs = 10001
prune_epochs = 60001

In [2]:
def get_data(path):
    # Load dataset specified by schema
    dataset_schema = 'small_skip'
    # dataset_schema = 'big_skip'

    # Get filtered concept values
    shape_range, scale_range, rot_range, x_pos_range, y_pos_range = concept_filters(dataset_schema)

    # Define function for filtering out specified concept values only
    def c_filter_fn(concepts):
        in_shape_range = (concepts[1] in shape_range)
        in_scale_range = (concepts[2] in scale_range)
        in_rot_range = (concepts[3] in rot_range)
        in_x_range = (concepts[4] in x_pos_range)
        in_y_range = (concepts[5] in y_pos_range)

        return (in_shape_range and in_scale_range and in_rot_range and in_x_range and in_y_range)

    label_fn = get_shape_scale(shape_range, scale_range)
    
    # Load dataset
    x_train, y_train, x_val, y_val, x_test, y_test, c_train, c_val, c_test, c_names = load_dsprites(path,
                                                                               c_filter_fn=c_filter_fn,
                                                                               label_fn=label_fn,
                                                                               train_test_split_flag=True)


    return x_train, y_train, x_val, y_val, x_test, y_test, c_train, c_val, c_test, c_names

In [3]:
# Load the model, as well as input, label, and concept data
data = get_data(path='./dSprites/data/dsprites.npz')
x_train, y_train, x_val, y_val, x_test, y_test, c_train, c_val, c_test, c_names = data
x_train = torch.FloatTensor(x_train)
y_train = torch.FloatTensor(y_train)
x_val = torch.FloatTensor(x_val)
y_val = torch.FloatTensor(y_val)
x_test = torch.FloatTensor(x_test)
y_test = torch.FloatTensor(y_test)
c_train = torch.FloatTensor(c_train)
c_val = torch.FloatTensor(c_val)
c_test = torch.FloatTensor(c_test)
print("Data loaded successfully...")

x_train shape: (20993, 3, 64, 64)
c_train shape: (20993, 50)
y_train shape: (20993, 18)
Number of images in x_train 20993
Number of images in x_val 10341
Number of images in x_test 5530
Data loaded successfully...


In [4]:
args = {
    'models_dir': './models/dSprites/', 
    'model_style': 'CBM',
    'seed': 0,
    'batch_size': 128,
}

if not os.path.isdir(args['models_dir']):
    os.makedirs(args['models_dir'])

In [5]:
train_dataset = TensorDataset(x_train, c_train)
train_dataloader = DataLoader(train_dataset, batch_size=args['batch_size'], shuffle=False)
val_dataset = TensorDataset(x_val, c_val)
val_dataloader = DataLoader(val_dataset, batch_size=args['batch_size'], shuffle=False)
test_dataset = TensorDataset(x_test, c_test)
test_dataloader = DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False)

c_predictions_train, c_predictions_val, c_predictions_test = i2c_style(train_dataloader, val_dataloader, test_dataloader, device, args)    
c_predictions_train = torch.FloatTensor(c_predictions_train)
c_predictions_val = torch.FloatTensor(c_predictions_val)
c_predictions_test = torch.FloatTensor(c_predictions_test)

In [6]:
x_train = c_val
x_test = c_test
x_train.shape

torch.Size([10341, 50])

In [7]:
y_train = y_val[:, 0].to(torch.long)
y_test = y_test[:, 0].to(torch.long)
y_train.shape

torch.Size([10341])

In [8]:
colors = ['white']
shapes = ['square', 'ellipse', 'heart']
scale = ['very small', 'small', 's-medium', 'b-medium', 'big', 'very big']
rotation = ['0°', '5°', '10°', '15°', '20°', '25°', '30°', '35°']
x_pos = ['x0', 'x2', 'x4', 'x6', 'x8', 'x10', 'x12', 'x14', 'x16', 'x18', 'x20', 'x22', 'x24', 'x26', 'x28', 'x30']
y_pos = ['y0', 'y2', 'y4', 'y6', 'y8', 'y10', 'y12', 'y14', 'y16', 'y18', 'y20', 'y22', 'y24', 'y26', 'y28', 'y30']
concepts = colors + shapes + scale + rotation + x_pos + y_pos
len(concepts)

50

In [19]:
def train_nn(x_train, y_train, need_pruning, seed, device, relu=False, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        torch.nn.Linear(x_train.size(1), 30),
#         torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
#         torch.nn.Linear(100, 50),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(30, 10),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(10, 2),
        torch.nn.Softmax(dim=1),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
    loss_form = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train)
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.0007 * torch.norm(module.weight, 1)
                loss += 0.0007 * torch.norm(module.bias, 1)
                break

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            prune_features(model, n_classes=1, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = torch.argmax(y_pred, dim=1)
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [10]:
def train_psi_nn(x_train, y_train, need_pruning, seed, device, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device).to(torch.float)
    layers = [
        torch.nn.Linear(x_train.size(1), 10),
        torch.nn.Sigmoid(),
        torch.nn.Linear(10, 4),
        torch.nn.Sigmoid(),
        torch.nn.Linear(4, 1),
        torch.nn.Sigmoid(),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.00001 * torch.norm(module.weight, 1)

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            model = prune_equal_fanin(model, 2, validate=True, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [21]:
def c_to_y(method, need_pruning, relu, verbose=False):
    methods = []
    splits = []
    explanations = []
    explanations_inv = []
    model_accuracies = []
    explanation_accuracies = []
    explanation_accuracies_inv = []
    elapsed_times = []
    elapsed_times_inv = []
    for seed in range(n_rep):
        explanation, explanation_inv = '', ''
        explanation_accuracy, explanation_accuracy_inv = 0, 0
        elapsed_time, elapsed_time_inv = 0, 0
        
        print(f'Seed [{seed+1}/{n_rep}]')
        
        if method == 'tree':
            classifier = DecisionTreeClassifier(random_state=seed)
            classifier.fit(x_train.detach().numpy(), y_train.detach().numpy())
            y_preds = classifier.predict(x_test.detach().numpy())
            model_accuracy = accuracy_score(y_test.detach().numpy(), y_preds)

            target_class = 1
            start = time.time()
            explanation = tree_to_formula(classifier, concepts, target_class)
            elapsed_time = time.time() - start

            target_class_inv = 0
            start = time.time()
            explanation_inv = tree_to_formula(classifier, concepts, target_class_inv)
            elapsed_time = time.time() - start
        
        else:
            if method == 'psi':
                # positive class
                target_class = 1
                model = train_psi_nn(x_train, y_train.eq(target_class), need_pruning, seed, device, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy()
                model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds>0.5)
                
            else:
                model = train_nn(x_train, y_train, need_pruning, seed, device, relu, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy()
                model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))

            # positive class
            target_class = 1
            start = time.time()
            if method == 'psi':
                global_explanation = logic.generate_fol_explanations(model, device)[0]
            else:
                global_explanation, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                                   x_train.to(device), y_train.to(device), 
                                                                                   target_class=target_class,
                                                                                   method=method, device=device)
            elapsed_time = time.time() - start
            
            if global_explanation:
                explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
                explanation = logic.base.replace_names(global_explanation, concepts)

#             # negative class
#             target_class_inv = 0
#             if method == 'psi':
#                 model = train_psi_nn(x_train, y_train.eq(target_class_inv), need_pruning, seed, device, verbose)
            
#             start = time.time()
#             if method == 'psi':
#                 global_explanation_inv = logic.generate_fol_explanations(model, device)[0]
#             else:
#                 global_explanation_inv, _, _ = logic.relunn.combine_local_explanations(model, 
#                                                                                        x_train.to(device), 
#                                                                                        y_train.to(device), 
#                                                                                        target_class=target_class_inv,
#                                                                                        method=method, device=device)
#             elapsed_time_inv = time.time() - start
#             if global_explanation_inv:
#                 explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, 
#                                                                           target_class_inv, x_test, y_test)
#                 explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
        
        if verbose:
            print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
            print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
            print(f'\t Elapsed time {elapsed_time}')
#             print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
#             print(f'\t Elapsed time {elapsed_time_inv}')

        methods.append(method)
        splits.append(seed)
        explanations.append(explanation)
        explanations_inv.append(explanation_inv)
        model_accuracies.append(model_accuracy)
        explanation_accuracies.append(explanation_accuracy)
        explanation_accuracies_inv.append(explanation_accuracy_inv)
        elapsed_times.append(elapsed_time)
        elapsed_times_inv.append(elapsed_time_inv)
    
    results = pd.DataFrame({
        'method': methods,
        'split': splits,
        'explanation': explanations,
        'explanation_inv': explanations_inv,
        'model_accuracy': model_accuracies,
        'explanation_accuracy': explanation_accuracies,
        'explanation_accuracy_inv': explanation_accuracies_inv,
        'elapsed_time': elapsed_times,
        'elapsed_time_inv': elapsed_times_inv,
    })
    results.to_csv(os.path.join(results_dir, f'results_{method}.csv'))
    
    return results

# General pruning

In [22]:
method = 'pruning'
need_pruning = True
relu = False
results_pruning = c_to_y(method, need_pruning, relu, True)
results_pruning

Seed [1/10]
	 Epoch 0: train accuracy: 0.0542
	 Epoch 500: train accuracy: 0.9462
	 Epoch 1000: train accuracy: 0.9458
	 Epoch 1500: train accuracy: 0.9458
	 Epoch 2000: train accuracy: 0.9458
	 Epoch 2500: train accuracy: 0.9458
	 Epoch 3000: train accuracy: 0.9458
	 Epoch 3500: train accuracy: 0.9458
	 Epoch 4000: train accuracy: 0.9458
	 Epoch 4500: train accuracy: 1.0000
	 Epoch 5000: train accuracy: 1.0000
	 Epoch 5500: train accuracy: 1.0000
	 Epoch 6000: train accuracy: 1.0000
	 Epoch 6500: train accuracy: 1.0000
	 Epoch 7000: train accuracy: 1.0000
	 Epoch 7500: train accuracy: 1.0000
	 Epoch 8000: train accuracy: 1.0000
	 Epoch 8500: train accuracy: 1.0000
	 Epoch 9000: train accuracy: 1.0000
	 Epoch 9500: train accuracy: 1.0000
	 Epoch 10000: train accuracy: 1.0000
560
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(35° & x30 & y30 & ~ellipse & ~b-medium & ~big) | (x2 & y28 & ~small & ~b-medium & ~10° & ~15° & ~20° & ~25° & ~30° & ~35°)" - Accuracy: 0.9412
	 Elap

KeyboardInterrupt: 

# LIME

In [None]:
# method = 'lime'
# need_pruning = False
# relu = False
# results_lime = c_to_y(method, need_pruning, relu)
# results_lime

# ReLU

In [None]:
method = 'weights'
need_pruning = False
relu = True
results_weights = c_to_y(method, need_pruning, relu)
results_weights

# Psi network

In [None]:
method = 'psi'
need_pruning = True
relu = False
results_psi = c_to_y(method, need_pruning, relu, verbose=True)
results_psi

# Decision tree

In [17]:
method = 'tree'
need_pruning = False
relu = False
results_tree = c_to_y(method, need_pruning, relu, True)
results_tree

Seed [1/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(very small > 0.50 & square > 0.50)" - Accuracy: 0.0000
	 Elapsed time 0.0
Seed [2/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(very small > 0.50 & square > 0.50)" - Accuracy: 0.0000
	 Elapsed time 0.0
Seed [3/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(very small > 0.50 & square > 0.50)" - Accuracy: 0.0000
	 Elapsed time 0.0
Seed [4/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(very small > 0.50 & square > 0.50)" - Accuracy: 0.0000
	 Elapsed time 0.0
Seed [5/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(very small > 0.50 & square > 0.50)" - Accuracy: 0.0000
	 Elapsed time 0.0
Seed [6/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(very small > 0.50 & square > 0.50)" - Accuracy: 0.0000
	 Elapsed time 0.0009970664978027344
Seed [7/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(very small > 0.50 & square 

ValueError: arrays must all be same length

# Summary

In [None]:
cols = ['model_accuracy', 'explanation_accuracy', 'explanation_accuracy_inv', 'elapsed_time', 'elapsed_time_inv']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

# pruning
df_mean = results_pruning[cols].mean()
df_sem = results_pruning[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_pruning = pd.concat([df_mean, df_sem])
summary_pruning.name = 'pruning'

# # lime
# df_mean = results_lime[cols].mean()
# df_sem = results_lime[cols].sem()
# df_mean.columns = mean_cols
# df_sem.columns = sem_cols
# summary_lime = pd.concat([df_mean, df_sem])
# summary_lime.name = 'lime'

# weights
df_mean = results_weights[cols].mean()
df_sem = results_weights[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_weights = pd.concat([df_mean, df_sem])
summary_weights.name = 'weights'

# psi
df_mean = results_psi[cols].mean()
df_sem = results_psi[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_psi = pd.concat([df_mean, df_sem])
summary_psi.name = 'psi'

# tree
df_mean = results_tree[cols].mean()
df_sem = results_tree[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_tree = pd.concat([df_mean, df_sem])
summary_tree.name = 'tree'

summary = pd.concat([summary_pruning, 
#                      summary_lime, 
                     summary_weights, summary_psi, summary_tree], axis=1).T
summary.columns = mean_cols + sem_cols
summary

In [None]:
summary.to_csv(os.path.join(results_dir, 'summary.csv'))