In [1]:
%reload_ext autoreload
%autoreload 2

%matplotlib inline

# To prevent automatic figure display when execution of the cell ends
%config InlineBackend.close_figures=False 

In [1]:
import os
import pandas as pd
import numpy as np
import seaborn as sns
from scipy import stats

import torch.optim as optim
import torch
import torch.nn as nn

from sklearn.preprocessing import LabelEncoder

from models.mlp import BlackBoxModel

import matplotlib.pyplot as plt
from ipywidgets import widgets
from IPython.display import display,clear_output

import warnings
warnings.filterwarnings("ignore")

In [2]:
from dataset import GermanCreditDataset
from experiment import GermanCreditExperiment

In [3]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report

german_credit = GermanCreditDataset()

df = german_credit.get_dataframe()

df_X, df_y = german_credit.get_Xy()

X_train, X_test, y_train, y_test = german_credit.get_standardized_train_test_split(random_state=32)

german_credit_experiment = GermanCreditExperiment()
german_credit_experiment.train_and_evaluate_models(X_train, X_test, y_train, y_test)

for model_name in german_credit_experiment.model_names:
    print(
        model_name + ':', 
        german_credit_experiment.model_reports[
            german_credit_experiment.model_names[0]
        ]['accuracy']
    )

RandomForestClassifier: 0.69


In [4]:
import shap
from explainers import pshap
from utils.benchmarking import *
import ot

In [5]:
sample_num = 100
max_round = 50
ce_max_iter = 50
ci_factor = 1.96  # 95% confidence interval factor
n_proj = 10
delta = 0.05

In [6]:
target_name = german_credit.target_name
model = german_credit_experiment.models[0]

X_test_ext = X_test.copy()
X_test_ext[target_name] = model.predict_proba(X_test.values)[:,0]

df_baseline = X_test[X_test_ext[target_name] > 0.9]
df_explain = X_test

max_len = min(df_baseline.shape[0], df_explain.shape[0], sample_num)

df_baseline = df_baseline.sample(max_len)
df_explain = df_explain.sample(int(max_len))

X_baseline = df_baseline.values
y_baseline = model.predict_proba(X_baseline)[:,0]
X_explain = df_explain.values
y_explain = model.predict_proba(X_explain)[:,0]

ot_cost = ot.dist(X_explain, X_baseline)
matrix_mu = ot.emd(
    np.ones(X_explain.shape[0])/X_explain.shape[0], 
    np.ones(X_baseline.shape[0])/X_baseline.shape[0], ot_cost
)

shap_explainer = shap.KernelExplainer(lambda X: model.predict_proba(X)[:,0], X_baseline)
jp_explainer = pshap.JointProbabilityExplainer(model)
# jp_explainer = shap.KernelExplainer(lambda X: model.predict_proba(X)[:,0], X_train.sample(max_len))

shap_values_baseline = shap_explainer.shap_values(X_explain)
shap_values_jp = jp_explainer.shap_values(X_explain, X_baseline, joint_probs=matrix_mu)
# shap_values_jp = jp_explainer.shap_values(X_explain)

  0%|          | 0/45 [00:00<?, ?it/s]

In [7]:
num_pairs_list = [25, 50, 75, 100, 150, 200, 300, 400, 500]

# Initialize interactive output display
plt.ioff()
out = widgets.Output()
vbox = widgets.VBox([out])
display(vbox)

# Lists to store accuracies over iterations
ot_list_bs = []
ot_list_jp = []
exp_list_bs = []
exp_list_jp = []
mmd_list_bs = []
mmd_list_jp = []

for t in range(max_round):

    ot_start, _ = WassersteinDivergence().distance(
        torch.FloatTensor(y_explain), 
        torch.FloatTensor(y_baseline),
        delta=delta,
        )
    kl_start = compute_kl_divergence(
            y_explain, 
            y_baseline,
        )
    mmd_start = compute_mmd(
            y_explain, 
            y_baseline,
        )
    # results = [{
    #     'OT_bs': ot_start, 'OT_jp': ot_start,
    #     'KL_bs': kl_start, 'KL_jp': kl_start,
    #     'MMD_bs': mmd_start, 'MMD_jp': mmd_start,
    #     }]
    results = []
    
    for num_pairs in num_pairs_list:
        result = counterfactual_ability_performance_benchmarking(
                model=model,
                df_explain=df_explain,
                df_baseline=df_baseline,
                y_baseline=y_baseline,
                shap_values_baseline=shap_values_baseline,
                shap_values_jp=shap_values_jp,
                num_pairs=num_pairs,
                delta=delta,
        )
        results.append(result)

    new_ot_bs = [result['OT_bs'] for result in results]
    new_ot_jp = [result['OT_jp'] for result in results]

    new_exp_bs = [result['EXP_bs'] for result in results]
    new_exp_jp = [result['EXP_jp'] for result in results]

    new_mmd_bs = [result['MMD_bs'] for result in results]
    new_mmd_jp = [result['MMD_jp'] for result in results]

    ot_list_bs.append(new_ot_bs)
    ot_list_jp.append(new_ot_jp)

    exp_list_bs.append(new_exp_bs)
    exp_list_jp.append(new_exp_jp)

    mmd_list_bs.append(new_mmd_bs)
    mmd_list_jp.append(new_mmd_jp)

    # Compute mean and confidence intervals for OT
    ot_means_bs = np.mean(ot_list_bs, axis=0)
    ot_means_jp = np.mean(ot_list_jp, axis=0)
    ot_std_err_bs = stats.sem(ot_means_bs, axis=0)
    ot_std_err_jp = stats.sem(ot_means_jp, axis=0)
    ot_ci_bs = ot_std_err_bs * ci_factor / np.sqrt(t+1)
    ot_ci_jp = ot_std_err_jp * ci_factor / np.sqrt(t+1)

    # Compute mean and confidence intervals for KL
    exp_means_bs = np.mean(exp_list_bs, axis=0)
    exp_means_jp = np.mean(exp_list_jp, axis=0)
    exp_std_err_bs = stats.sem(exp_means_bs, axis=0)
    exp_std_err_jp = stats.sem(exp_means_jp, axis=0)
    exp_ci_bs = exp_std_err_bs * ci_factor / np.sqrt(t+1)
    exp_ci_jp = exp_std_err_jp * ci_factor / np.sqrt(t+1)

    # Compute mean and confidence intervals for MMD
    mmd_means_bs = np.mean(mmd_list_bs, axis=0)
    mmd_means_jp = np.mean(mmd_list_jp, axis=0)
    mmd_std_err_bs = stats.sem(mmd_means_bs, axis=0)
    mmd_std_err_jp = stats.sem(mmd_means_jp, axis=0)
    mmd_ci_bs = mmd_std_err_bs * ci_factor / np.sqrt(t+1)
    mmd_ci_jp = mmd_std_err_jp * ci_factor / np.sqrt(t+1)

    fig, axes = plt.subplots(1,3,figsize=(16, 4))
    x_labels =  num_pairs_list

    # Plotting code for OT Distance
    axes[0].plot(x_labels, ot_means_bs, label='SHAP', marker='o')
    axes[0].fill_between(x_labels, ot_means_bs - ot_ci_bs, ot_means_bs + ot_ci_bs, alpha=0.2)
    axes[0].plot(x_labels, ot_means_jp, label='JP-SHAP', marker='o')
    axes[0].fill_between(x_labels, ot_means_jp - ot_ci_jp, ot_means_jp + ot_ci_jp, alpha=0.2)
    axes[0].set_xlabel('Number of Changes')
    axes[0].set_ylabel('OT Distance')
    axes[0].legend()
    axes[0].grid(True)

    # Plotting code for MMD Divergence
    axes[1].plot(x_labels, mmd_means_bs, label='SHAP', marker='o')
    axes[1].fill_between(x_labels, mmd_means_bs - mmd_ci_bs, mmd_means_bs + mmd_ci_bs, alpha=0.2)
    axes[1].plot(x_labels, mmd_means_jp, label='JP-SHAP', marker='o')
    axes[1].fill_between(x_labels, mmd_means_jp - mmd_ci_jp, mmd_means_jp + mmd_ci_jp, alpha=0.2)
    axes[1].set_xlabel('Number of Changes')
    axes[1].set_ylabel('MMD')
    axes[1].legend()
    axes[1].grid(True)

    # Plotting code for MMD Divergence
    axes[2].plot(x_labels, exp_means_bs, label='SHAP', marker='o')
    axes[2].fill_between(x_labels, exp_means_bs - exp_ci_bs, exp_means_bs + exp_ci_bs, alpha=0.2)
    axes[2].plot(x_labels, exp_means_jp, label='JP-SHAP', marker='o')
    axes[2].fill_between(x_labels, exp_means_jp - exp_ci_jp, exp_means_jp + exp_ci_jp, alpha=0.2)
    axes[2].set_xlabel('Number of Changes')
    axes[2].set_ylabel('Exp Diff')
    axes[2].legend()
    axes[2].grid(True)

    # Adjust the spacing between the plots
    fig.subplots_adjust(wspace=0.3)  # Increase the width space

    with out:
        clear_output(wait=True);
        print(f'Round:{t}')
        display(fig)

    plt.close(fig)  # Close the figure to free memory and avoid unnecessary resource use


VBox(children=(Output(),))