In [1]:
%reload_ext autoreload
%autoreload 2

%matplotlib inline

# To prevent automatic figure display when execution of the cell ends
%config InlineBackend.close_figures=False 

In [2]:
import os
import pandas as pd
import numpy as np
import seaborn as sns
from scipy import stats

import torch.optim as optim
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn

from sklearn.preprocessing import LabelEncoder

from models.mlp import BlackBoxModel

import matplotlib.pyplot as plt
from ipywidgets import widgets
from IPython.display import display,clear_output

import warnings
warnings.filterwarnings("ignore")

In [3]:
data_path = 'data/'

df_ = pd.read_csv(os.path.join(data_path, 'german_credit_data.csv'))
df = df_.copy()

In [4]:
target_name = 'Risk'
target = df[target_name].replace({'good': 0, 'bad': 1})

df[target_name] = target

In [5]:
# Initialize a label encoder
label_encoder = LabelEncoder()
label_mappings = {}


# Convert categorical columns to numerical representations using label encoding
for column in df.columns:
    if column is not target_name and df[column].dtype == 'object':
        # Handle missing values by filling with a placeholder and then encoding
        df[column] = df[column].fillna('Unknown')
        df[column] = label_encoder.fit_transform(df[column])
        label_mappings[column] = dict(zip(label_encoder.classes_, range(len(label_encoder.classes_))))


# For columns with NaN values that are numerical, we will impute them with the median of the column
for column in df.columns:
    if df[column].isna().any():
        median_val = df[column].median()
        df[column].fillna(median_val, inplace=True)

# Display the first few rows of the transformed dataframe
df.head()

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,67,1,2,1,0,1,1169,6,5,0
1,22,0,2,1,1,2,5951,48,5,1
2,49,1,1,1,1,0,2096,12,3,0
3,45,1,2,0,1,1,7882,42,4,0
4,53,1,2,0,1,1,4870,24,1,1


In [6]:
features = [
    'Age', 
    'Sex', 
    'Job', 
    'Housing', 
    'Saving accounts', 
    'Checking account',
    'Credit amount', 
    'Duration', 
    'Purpose', 
]

df_X = df[features].copy()
df_y = target

In [7]:
seed = 42

np.random.seed(seed)  # for reproducibility


# Split the dataset into training and testing sets (80% train, 20% test)
X_train, X_test, y_train, y_test = train_test_split(df_X, df_y, test_size=0.2, random_state=seed)

df_train = X_train.copy()

std = X_train.std()
mean = X_train.mean()

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

# X_train, X_test, y_train, y_test = X_train.values, X_test.values, y_train.values, y_test.values

# Convert to PyTorch tensors
X_train_tensor = torch.FloatTensor(X_train.values)
y_train_tensor = torch.FloatTensor(y_train.values).view(-1, 1)
X_test_tensor = torch.FloatTensor(X_test.values)
y_test_tensor = torch.FloatTensor(y_test.values).view(-1, 1)

# Initialize the model, loss function, and optimizer
model = BlackBoxModel(input_dim=X_train.shape[1])
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
num_epochs = 300
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_train_tensor)
    loss = criterion(outputs, y_train_tensor)
    
    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Evaluate on test set
model.eval()
with torch.no_grad():
    test_outputs = model(X_test_tensor)
    test_loss = criterion(test_outputs, y_test_tensor)

    # Convert outputs to binary using 0.5 as threshold
    y_pred_tensor = (test_outputs > 0.5).float()
    correct_predictions = (y_pred_tensor == y_test_tensor).float().sum()
    accuracy = correct_predictions / y_test_tensor.shape[0]

test_accuracy = accuracy.item()
print('Test Accuracy:', test_accuracy)

Test Accuracy: 0.6949999928474426


In [8]:
from explainers import pshap, dce
import shap

from utils.benchmarking import *

In [31]:
sample_num = 50
max_round = 50
ce_max_iter = 50
ci_factor = 1.96  # 95% confidence interval factor
n_proj = 10
delta = 0.05

explain_columns = [
    'Age', 
    'Sex', 
    'Job', 
    'Housing', 
    'Saving accounts', 
    'Checking account',
    'Credit amount', 
    'Duration', 
    'Purpose', 
]

y_target = torch.distributions.beta.Beta(0.1, 0.9).sample((sample_num,))

# Initialize interactive output display
plt.ioff()
out = widgets.Output()
vbox = widgets.VBox([out])
display(vbox)

# Lists to store accuracies over iterations
ot_list_bs = []
ot_list_jp = []
kl_list_bs = []
kl_list_jp = []
mmd_list_bs = []
mmd_list_jp = []

for t in range(max_round):
    indice = (X_test.sample(sample_num)).index
    df_explain = X_test.loc[indice]
    X_explain = df_explain.values
    y_explain = model.predict_proba(X_explain)

    explainer = dce.DistributionalCounterfactualExplainer(
        model=model, 
        df_X=df_explain, 
        explain_columns=explain_columns,
        y_target=y_target, 
        lr=1e-1, 
        n_proj=n_proj,
        delta=0.05)

    explainer.optimize(U_1=0.4, U_2=0.2, l=0.2, r=1, max_iter=ce_max_iter, tau=1e3)

    X_baseline = explainer.best_X.detach().numpy()
    # y_baseline = explainer.best_y.detach().numpy().flatten()
    y_baseline = y_target
    df_baseline = pd.DataFrame(X_baseline, columns=df_X.columns, index=indice)

    matrix_mu = get_ot_plan(explainer.swd.mu_list)

    shap_explainer = shap.KernelExplainer(model.predict_proba, X_baseline)
    jp_explainer = pshap.JointProbabilityExplainer(model)

    shap_values_baseline = shap_explainer.shap_values(X_explain)
    shap_values_jp = jp_explainer.shap_values(X_explain, X_baseline, 
                                              joint_probs=matrix_mu.detach().numpy())

    ranked_features_baseline = get_ranked_features(shap_values_baseline, columns=df_explain.columns.to_list())
    ranked_features_jp = get_ranked_features(shap_values_jp, columns=df_explain.columns.to_list())

    baseline_change_columns = []
    jp_change_columns = []

    ot_start, _ = WassersteinDivergence().distance(
        torch.FloatTensor(y_explain), 
        torch.FloatTensor(y_baseline),
        delta=delta,
        )
    kl_start = compute_kl_divergence(
            y_explain, 
            y_baseline,
        )
    mmd_start = compute_mmd(
            y_explain, 
            y_baseline,
        )
    results = [{
        'OT_bs': ot_start, 'OT_jp': ot_start,
        'KL_bs': kl_start, 'KL_jp': kl_start,
        'MMD_bs': mmd_start, 'MMD_jp': mmd_start,
        }]

    for bs_change_col, jp_change_col in zip(ranked_features_baseline, ranked_features_jp):
        baseline_change_columns.append(bs_change_col)
        jp_change_columns.append(jp_change_col)
        result = counterfactual_ability_performance_benchmarking(
            model=model, df_baseline=df_baseline, 
            df_explain=df_explain, y_baseline=y_baseline,
            baseline_change_columns=baseline_change_columns,
            jp_change_columns=jp_change_columns,
            delta=delta
        )
        results.append(result)

    new_ot_bs = [result['OT_bs'] for result in results]
    new_ot_jp = [result['OT_jp'] for result in results]

    ot_list_bs.append(new_ot_bs)
    ot_list_jp.append(new_ot_jp)


    # Compute mean and confidence intervals
    ot_means_bs = np.mean(ot_list_bs, axis=0)
    ot_means_jp = np.mean(ot_list_jp, axis=0)
    ot_std_err_bs = stats.sem(ot_means_bs, axis=0)
    ot_std_err_jp = stats.sem(ot_means_jp, axis=0)
    ot_ci_bs = ot_std_err_bs * ci_factor
    ot_ci_jp = ot_std_err_jp * ci_factor

    fig, ax = plt.subplots(figsize=(10, 6))
    x_labels = np.arange(len(ot_means_bs))
    ax.plot(x_labels, ot_means_bs, label='SHAP', marker='o')
    ax.fill_between(x_labels, ot_means_bs - ot_ci_bs, ot_means_bs + ot_ci_bs, alpha=0.2)
    ax.plot(x_labels, ot_means_jp, label='JP-SHAP', marker='o')
    ax.fill_between(x_labels, ot_means_jp - ot_ci_jp, ot_means_jp + ot_ci_jp, alpha=0.2)
    ax.set_xlabel('Number of Dropped Features')
    ax.set_ylabel('OT distance')
    ax.set_title('Comparison of Baseline and JointProb Accuracies with Confidence Intervals')
    ax.legend()
    ax.grid(True)

    with out:
        clear_output(wait=True);
        print(f'Round:{t}')
        display(ax.figure);

VBox(children=(Output(),))

INFO:root:Optimization started
INFO:root:U_1-Qu_upper=0.4, U_2-Qv_upper=-0.18726447762824755
INFO:root:eta=1, l=0.24000000000000002, r=1
INFO:root:Iter 1: Q = 0.01807350106537342, term1 = 0.09819041937589645, term2 = 0.01807350106537342
INFO:root:U_1-Qu_upper=0.31783670471164466, U_2-Qv_upper=0.10142868467267963
INFO:root:eta=0.8161407969676816, l=0.278, r=1
INFO:root:Iter 2: Q = 0.004943922162055969, term1 = 0.009668862447142601, term2 = 0.0038794935680925846
INFO:root:U_1-Qu_upper=0.3823243954704225, U_2-Qv_upper=0.16173315693365764
INFO:root:eta=0.7853695095489219, l=0.31410000000000005, r=1
INFO:root:Iter 3: Q = 0.0023495235946029425, term1 = 0.0031245320569723845, term2 = 0.002137724542990327
INFO:root:U_1-Qu_upper=0.4, U_2-Qv_upper=0.14291972126926664
INFO:root:eta=0.8194417462135776, l=0.34839500000000007, r=1
INFO:root:Iter 4: Q = 0.0018002556171268225, term1 = 0.0024098954163491726, term2 = 0.0016659258399158716
INFO:root:U_1-Qu_upper=0.4, U_2-Qv_upper=0.14787352865582387
INFO

  0%|          | 0/50 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.0155534 , -0.04264358, -0.00688771, -0.01754064, -0.12466388,
        0.05630807, -0.00805143,  0.05494967,  0.00785664])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.14950333, -0.02445972, -0.01392978,  0.12094955,  0.07690125,
        0.24774702,  0.0095204 ,  0.23956535,  0.04101431])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.0029787 ,  0.00739639,  0.00265684, -0.01339063,  0.00675156,
       -0.11293758,  0.01756987,  0.00977224, -0.02280264])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.04008007, -0.01993001,  0.03837514,  0.02135263,  0.25260716,
        0.19539749, -0.02210714,  0.08047508,  0.06130954])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.03094118, -0.00404881,  0.02735839,  0.00686645,  0.01877464,
        0.01035666, -0.00214344, -0.10087721, -0.02713082])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.05827351,  0.01382481, -0.05927857, -0.0066519 ,

INFO:root:Optimization started
INFO:root:U_1-Qu_upper=0.4, U_2-Qv_upper=-0.1203595455441609
INFO:root:eta=1, l=0.24000000000000002, r=1
INFO:root:Iter 1: Q = 0.040830474346876144, term1 = 0.32375338673591614, term2 = 0.040830474346876144
INFO:root:U_1-Qu_upper=0.21761589480864535, U_2-Qv_upper=0.05653851795479056
INFO:root:eta=0.843266160801437, l=0.278, r=1
INFO:root:Iter 2: Q = 0.006259383168071508, term1 = 0.01318289153277874, term2 = 0.004972543567419052
INFO:root:U_1-Qu_upper=0.3472518096847668, U_2-Qv_upper=0.1087896030404263
INFO:root:eta=0.8277654370777089, l=0.31410000000000005, r=1
INFO:root:Iter 3: Q = 0.0034744669683277607, term1 = 0.00576006667688489, term2 = 0.002998898271471262
INFO:root:U_1-Qu_upper=0.39337901702765354, U_2-Qv_upper=0.12617784262011683
INFO:root:eta=0.8334246182182814, l=0.34839500000000007, r=1
INFO:root:Iter 4: Q = 0.0028837560676038265, term1 = 0.003384585026651621, term2 = 0.002783656120300293
INFO:root:U_1-Qu_upper=0.4, U_2-Qv_upper=0.1322060004665

  0%|          | 0/50 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.06298339,  0.05924195,  0.0984426 ,  0.03990148,  0.1252955 ,
        0.268075  ,  0.08593306,  0.18282354, -0.04389162])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.01266337,  0.00728919, -0.02269363, -0.00431856, -0.01577322,
       -0.0778098 , -0.01145301, -0.02046999,  0.01343406])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.01885007, -0.01613427, -0.00938137,  0.02455362, -0.15566527,
        0.04235607, -0.07454474, -0.00171315,  0.0969142 ])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.00726286,  0.0213298 ,  0.05740182,  0.01845273,  0.0644073 ,
        0.19022683, -0.08776943,  0.08331497,  0.01225951])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.07642117,  0.03508926,  0.05033981, -0.02027953,  0.03763568,
        0.18429165, -0.02411269,  0.08994434, -0.0202395 ])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.07500293, -0.03035976,  0.07594482,  0.0894959 ,

INFO:root:Optimization started
INFO:root:U_1-Qu_upper=0.4, U_2-Qv_upper=-0.2168752410261009
INFO:root:eta=1, l=0.24000000000000002, r=1
INFO:root:Iter 1: Q = 0.019287439063191414, term1 = 0.4289889335632324, term2 = 0.019287439063191414
INFO:root:U_1-Qu_upper=0.1197647294560168, U_2-Qv_upper=0.12386597724113016
INFO:root:eta=0.6136031291807548, l=0.24000000000000002, r=0.962
INFO:root:Iter 2: Q = 0.042575255036354065, term1 = 0.06196283549070358, term2 = 0.030366551131010056
INFO:root:U_1-Qu_upper=0.21580228728458645, U_2-Qv_upper=-0.0007252797225084928
INFO:root:eta=0.962, l=0.2761, r=0.962
INFO:root:Iter 3: Q = 0.03238922357559204, term1 = 0.13888181746006012, term2 = 0.028182655572891235
INFO:root:U_1-Qu_upper=0.21114357240089823, U_2-Qv_upper=0.12858719281586384
INFO:root:eta=0.7023886707283424, l=0.31039500000000003, r=0.962
INFO:root:Iter 4: Q = 0.006497687194496393, term1 = 0.016541732475161552, term2 = 0.002241892972961068
INFO:root:U_1-Qu_upper=0.34395422734105197, U_2-Qv_uppe

  0%|          | 0/50 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.08042192,  0.02647359,  0.06665741, -0.04341948,  0.05942158,
        0.18737797, -0.00180852,  0.0457666 , -0.0634067 ])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.0308348 , -0.04372452,  0.06297384, -0.0166562 ,  0.04277172,
       -0.02175192, -0.03243512, -0.01826338, -0.05899047])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.00220941, -0.03457524,  0.01559429,  0.06200069, -0.01501298,
       -0.20853153,  0.01772508,  0.01657456, -0.01991128])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.09853447, -0.05659717,  0.02532778, -0.01738471, -0.0095426 ,
       -0.25338904, -0.02805142,  0.15271365,  0.1177883 ])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.0788628 , -0.01597217,  0.07858764,  0.07600598,  0.09574621,
        0.20681101,  0.1235088 , -0.1697134 ,  0.12854921])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.04817426,  0.00540589,  0.0005941 ,  0.00564683,

INFO:root:Optimization started
INFO:root:U_1-Qu_upper=0.4, U_2-Qv_upper=-0.26855354463979325
INFO:root:eta=1, l=0.24000000000000002, r=1
INFO:root:Iter 1: Q = 0.10203004628419876, term1 = 0.2889133095741272, term2 = 0.10203004628419876
INFO:root:U_1-Qu_upper=0.09565100603419319, U_2-Qv_upper=-0.19644897565991992
INFO:root:eta=1, l=0.278, r=1
INFO:root:Iter 2: Q = 0.0337105356156826, term1 = 0.20086240768432617, term2 = 0.0337105356156826
INFO:root:U_1-Qu_upper=0.05171587675497574, U_2-Qv_upper=-0.03856225875750349
INFO:root:eta=1, l=0.31410000000000005, r=1
INFO:root:Iter 3: Q = 0.02320035919547081, term1 = 0.06690683960914612, term2 = 0.02320035919547081
INFO:root:U_1-Qu_upper=0.0479048005651217, U_2-Qv_upper=-0.020820769565614655
INFO:root:eta=1, l=0.34839500000000007, r=1
INFO:root:Iter 4: Q = 0.012274065054953098, term1 = 0.06800410151481628, term2 = 0.012274065054953098
INFO:root:U_1-Qu_upper=0.04320633375080268, U_2-Qv_upper=0.055239160313672736
INFO:root:eta=0.6343752103817275, 

  0%|          | 0/50 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.0041333 ,  0.00517816, -0.05227712, -0.01068676,  0.02360754,
        0.09213271, -0.0167586 ,  0.09840986, -0.13270673])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.09865934, -0.02675687,  0.05552819,  0.01443412,  0.0645497 ,
        0.10455961,  0.00164496, -0.04853508,  0.15458878])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.05944746, -0.00431042, -0.03523294, -0.01935776,  0.02957516,
        0.07422844,  0.02704786, -0.13105594, -0.01731098])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.07653934,  0.00360907,  0.02297942,  0.01557833,  0.01670107,
       -0.10063253, -0.00554473, -0.15234426, -0.01099603])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.24319524,  0.01222667, -0.02486972,  0.01355588,  0.05429512,
        0.11852158, -0.01237505,  0.06397077, -0.1179083 ])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.08593372, -0.02695537,  0.08024406,  0.09245538,

INFO:root:Optimization started
INFO:root:U_1-Qu_upper=0.4, U_2-Qv_upper=-0.08788670811518345
INFO:root:eta=1, l=0.24000000000000002, r=1
INFO:root:Iter 1: Q = 0.023305829614400864, term1 = 0.049347326159477234, term2 = 0.023305829614400864
INFO:root:U_1-Qu_upper=0.3232331808073071, U_2-Qv_upper=0.02948550805986791
INFO:root:eta=0.9364678231327337, l=0.278, r=1
INFO:root:Iter 2: Q = 0.011633451096713543, term1 = 0.009077822789549828, term2 = 0.011806830763816833
INFO:root:U_1-Qu_upper=0.32671124733859513, U_2-Qv_upper=0.04654838260528041
INFO:root:eta=0.9099609774406467, l=0.31410000000000005, r=1
INFO:root:Iter 3: Q = 0.0099403141066432, term1 = 0.0049720648676157, term2 = 0.01043191272765398
INFO:root:U_1-Qu_upper=0.35285696485673335, U_2-Qv_upper=0.06295964091062778
INFO:root:eta=0.8961464811610723, l=0.34839500000000007, r=1
INFO:root:Iter 4: Q = 0.005159604363143444, term1 = 0.023079432547092438, term2 = 0.0030828937888145447
INFO:root:U_1-Qu_upper=0.3575858753082255, U_2-Qv_upper=

  0%|          | 0/50 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.01358615, -0.00011659, -0.00167038, -0.04171755,  0.00896324,
       -0.07783062,  0.00837148, -0.02845514,  0.02073683])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.01624986, -0.03300212, -0.00230335, -0.02904053, -0.16201416,
        0.05598692, -0.03863981,  0.01875894,  0.05399462])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.04213625, -0.0061974 , -0.06528808, -0.02969178,  0.02575366,
        0.04504609,  0.04335983, -0.08248778, -0.01382123])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.0178697 ,  0.21286938,  0.08344063,  0.11659529,  0.0612786 ,
        0.18734066, -0.01551891, -0.00411744,  0.21186374])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.06744827,  0.00088125, -0.00017774, -0.00021865,  0.029784  ,
       -0.10487014,  0.00647345, -0.11871924, -0.0043104 ])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.04216534, -0.03101845,  0.00369076,  0.00663235,

INFO:root:Optimization started
INFO:root:U_1-Qu_upper=0.4, U_2-Qv_upper=-0.1731089655148123
INFO:root:eta=1, l=0.24000000000000002, r=1
INFO:root:Iter 1: Q = 0.049087196588516235, term1 = 0.18910488486289978, term2 = 0.049087196588516235
INFO:root:U_1-Qu_upper=0.22658652268724708, U_2-Qv_upper=-0.02608050518894081
INFO:root:eta=1, l=0.278, r=1
INFO:root:Iter 2: Q = 0.03003225103020668, term1 = 0.07191146910190582, term2 = 0.03003225103020668
INFO:root:U_1-Qu_upper=0.17530614499145547, U_2-Qv_upper=0.014290161171928994
INFO:root:eta=0.9455817648829, l=0.31410000000000005, r=1
INFO:root:Iter 3: Q = 0.009681876748800278, term1 = 0.03897411748766899, term2 = 0.007996108382940292
INFO:root:U_1-Qu_upper=0.213636219485413, U_2-Qv_upper=0.07659716700611362
INFO:root:eta=0.8189801749392219, l=0.34839500000000007, r=1
INFO:root:Iter 4: Q = 0.006722039543092251, term1 = 0.009399661794304848, term2 = 0.006130202207714319
INFO:root:U_1-Qu_upper=0.34450930235436705, U_2-Qv_upper=0.12250302625149213


  0%|          | 0/50 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.00696459, -0.01038862,  0.00801603, -0.02563479, -0.02375769,
       -0.08468649, -0.01537477,  0.01733016,  0.04361643])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.04261944, -0.00563869,  0.04678883,  0.02678284,  0.04045016,
        0.14858402, -0.02183209, -0.09339789, -0.00323927])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.06672161,  0.02730948, -0.04648178,  0.01143666,  0.0604917 ,
       -0.13708417,  0.01232961, -0.01894015,  0.0960841 ])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.01129118, -0.00231044,  0.02721468, -0.02070216,  0.03045318,
       -0.12051045, -0.01572668, -0.0120345 , -0.00428798])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([0.04608702, 0.06090929, 0.03917563, 0.03232773, 0.13892427,
       0.21839362, 0.07347132, 0.01482475, 0.15852662])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.00484854, -0.00987415, -0.0040356 ,  0.15843381,  0.02161

INFO:root:Optimization started
INFO:root:U_1-Qu_upper=0.4, U_2-Qv_upper=-0.2436933056031318
INFO:root:eta=1, l=0.24000000000000002, r=1
INFO:root:Iter 1: Q = 0.06376748532056808, term1 = 0.36194437742233276, term2 = 0.06376748532056808
INFO:root:U_1-Qu_upper=0.1300503324537403, U_2-Qv_upper=-0.06428467465295401
INFO:root:eta=1, l=0.278, r=1
INFO:root:Iter 2: Q = 0.027485191822052002, term1 = 0.05276740714907646, term2 = 0.027485191822052002
INFO:root:U_1-Qu_upper=0.11999978637231146, U_2-Qv_upper=-0.03253861268504804
INFO:root:eta=1, l=0.31410000000000005, r=1
INFO:root:Iter 3: Q = 0.019097616896033287, term1 = 0.08224105834960938, term2 = 0.019097616896033287
INFO:root:U_1-Qu_upper=0.08796111889989833, U_2-Qv_upper=0.03734778397443131
INFO:root:eta=0.795570430827623, l=0.34839500000000007, r=1
INFO:root:Iter 4: Q = 0.015444628894329071, term1 = 0.019732195883989334, term2 = 0.014342897571623325
INFO:root:U_1-Qu_upper=0.29766932052026474, U_2-Qv_upper=0.02977655824238082
INFO:root:eta=

KeyboardInterrupt: 

In [50]:
shap_values_baseline

array([[ 6.96459464e-03, -1.03886152e-02,  8.01602907e-03,
        -2.56347877e-02, -2.37576940e-02, -8.46864922e-02,
        -1.53747743e-02,  1.73301597e-02,  4.36164318e-02],
       [ 4.26194391e-02, -5.63869162e-03,  4.67888274e-02,
         2.67828396e-02,  4.04501599e-02,  1.48584018e-01,
        -2.18320944e-02, -9.33978864e-02, -3.23927132e-03],
       [ 6.67216057e-02,  2.73094819e-02, -4.64817752e-02,
         1.14366596e-02,  6.04917029e-02, -1.37084167e-01,
         1.23296060e-02, -1.89401525e-02,  9.60841017e-02],
       [ 1.12911840e-02, -2.31044319e-03,  2.72146781e-02,
        -2.07021629e-02,  3.04531785e-02, -1.20510452e-01,
        -1.57266790e-02, -1.20344965e-02, -4.28797583e-03],
       [ 4.60870238e-02,  6.09092900e-02,  3.91756321e-02,
         3.23277345e-02,  1.38924266e-01,  2.18393623e-01,
         7.34713217e-02,  1.48247531e-02,  1.58526619e-01],
       [ 4.84853800e-03, -9.87414624e-03, -4.03559622e-03,
         1.58433811e-01,  2.16162522e-02, -1.064179

In [64]:
P_jp = np.abs(shap_values_jp) / np.abs(shap_values_jp).sum()

num_pairs = 5

flat_indices = np.random.choice(a=P_jp.size, size=num_pairs, p=P_jp.flatten(), replace=True)
flat_indices = np.unique(flat_indices)

i_indices, j_indices = np.unravel_index(flat_indices, P_jp.shape)

In [65]:
i_indices

array([ 1,  4, 17, 18, 25], dtype=int64)

In [66]:
j_indices

array([5, 0, 0, 1, 6], dtype=int64)

In [62]:
flat_indices

array([296, 388, 419, 277, 257])

In [53]:
X_explain.shape

(50, 9)

In [39]:
x_index = 32
x = X_explain[x_index]
weights = matrix_mu[x_index]
weights = weights / (weights.sum())

In [40]:
weights

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [48]:
np.unique(np.random.choice(X_baseline.shape[0], size=5, p=weights, replace=True))

array([14])

In [None]:
jp_explainer.weighted_explainer.explain_instance(x, X_baseline, weights, num_samples)
for x, weights in zip(X, joint_probs)

  0%|          | 0/50 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.01237254, -0.02367189,  0.0093609 , -0.00038501,  0.12230476,
        0.19996587, -0.00462628,  0.04052322, -0.00833499])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.02146724, -0.0098594 ,  0.00459034, -0.02622162,  0.027334  ,
        0.07008412, -0.02732191, -0.0777325 , -0.00662901])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.04136242, -0.01608276,  0.00981311,  0.01307485,  0.08944378,
        0.29564102, -0.0083814 ,  0.03021186,  0.07595169])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.03608801,  0.02836259, -0.05274677, -0.0144207 ,  0.04569356,
        0.20379475,  0.08062586, -0.2238143 ,  0.08847952])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.00010158, -0.01967849, -0.00641504,  0.00149295,  0.02102222,
       -0.08784012,  0.00503834, -0.02471052, -0.00762063])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.0001621 , -0.02186968,  0.00620607, -0.00627852,

In [12]:
y_explain.round(3)

array([0.199, 0.   , 0.713, 0.   , 0.   , 0.   , 0.001, 0.975, 0.006,
       0.   , 0.   , 0.34 , 0.003, 0.378, 0.   , 0.55 , 0.556, 0.001,
       0.001, 0.389, 0.   , 0.101, 0.812, 0.   , 0.   , 0.002, 1.   ,
       0.003, 0.606, 0.023, 0.91 , 0.   , 0.003, 0.   , 0.452, 0.999,
       0.208, 0.   , 0.   , 0.139, 0.044, 0.038, 0.   , 0.001, 0.618,
       0.   , 0.505, 0.87 , 0.998, 0.072], dtype=float32)

In [22]:
y_baseline.round(3)

array([0.019, 0.705, 0.006, 0.103, 0.024, 0.   , 0.   , 0.008, 0.   ,
       0.001, 0.001, 0.012, 0.   , 0.002, 0.   , 0.843, 0.993, 0.   ,
       0.023, 0.411, 0.083, 0.005, 0.019, 0.999, 0.   , 0.509, 0.096,
       0.   , 0.016, 0.   , 0.509, 0.03 , 0.327, 0.   , 0.516, 0.   ,
       0.   , 0.015, 0.048, 0.003, 0.   , 0.   , 0.   , 0.   , 0.   ,
       0.004, 0.028, 0.   , 0.   , 1.   ], dtype=float32)

In [41]:
from scipy.stats import gaussian_kde, entropy

In [None]:
# Set up the matplotlib figure and axes
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))

# Plot the first SHAP summary plot on the first axis
plt.sca(axs[0])  # Set the current axis to the first subplot
shap.summary_plot(shap_values_baseline, X_explain, feature_names=df_explain.columns, plot_type='violin', show=False)
axs[0].set_title('Baseline Model')

# Plot the second SHAP summary plot on the second axis
plt.sca(axs[1])  # Set the current axis to the second subplot
shap.summary_plot(shap_values_jp, X_explain, feature_names=df_explain.columns, plot_type='violin', show=False)
axs[1].set_title('JP Model')

# Adjust layout
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
factual_X = df[df_X.columns].loc[indice].copy()
counterfactual_X = pd.DataFrame(explainer.best_X.detach().numpy() * std[df_X.columns].values + mean[df_X.columns].values, columns=df_X.columns, index=indice)

# dtype_dict = df.dtypes.apply(lambda x: x.name).to_dict()
# for k, v in dtype_dict.items():
#     if k in counterfactual_X.columns:
#         if v[:3] == 'int':
#             counterfactual_X[k] = counterfactual_X[k].round().astype(v)
#         else:
#             counterfactual_X[k] = counterfactual_X[k].astype(v)

factual_y = pd.DataFrame(y.detach().numpy(),columns=[target_name], index=factual_X.index)
counterfactual_y = pd.DataFrame(explainer.y.detach().numpy(),columns=[target_name], index=factual_X.index)

In [None]:
# Now, reverse the label encoding using the label_mappings
for dft in [factual_X, counterfactual_X]:
    for column, mapping in label_mappings.items():
        if column in dft.columns:
            # Invert the label mapping dictionary
            inv_mapping = {v: k for k, v in mapping.items()}
            # Map the encoded labels back to the original strings
            dft[column] = dft[column].map(inv_mapping)

In [None]:
factual = factual_X
counterfactual = counterfactual_X

factual[target_name] = factual_y
counterfactual[target_name] = counterfactual_y

In [None]:
factual.head(5)

In [None]:
counterfactual.head(5)

In [None]:
import matplotlib.pyplot as plt

matrix_nu = explainer.wd.nu.detach().numpy()

mu_avg = torch.zeros_like(explainer.swd.mu_list[0])
for mu in explainer.swd.mu_list:
    mu_avg += mu

total_sum = mu_avg.sum()

matrix_mu = mu_avg / total_sum

# Determine the global minimum and maximum values across both matrices
vmin = min(matrix_mu.min(), matrix_nu.min())
vmax = max(matrix_mu.max(), matrix_nu.max())

# Create a figure and a set of subplots
fig, axs = plt.subplots(1, 2, figsize=(20, 8))  # 1 row, 2 columns

# First subplot for matrix_mu
im_mu = axs[0].imshow(matrix_mu, cmap='viridis', vmin=vmin, vmax=vmax)
axs[0].set_title("Heatmap of $\mu$")

# Second subplot for matrix_nu
im_nu = axs[1].imshow(matrix_nu, cmap='viridis', vmin=vmin, vmax=vmax)
axs[1].set_title("Heatmap of $\\nu$")

# Create a colorbar for the whole figure
fig.colorbar(im_mu, ax=axs, orientation='vertical')

# Display the plots
plt.show()