In [17]:
%reload_ext autoreload
%autoreload 2

#### 0. Pick the data to be factual("Risk"=1)

In [18]:
from test_dataset.german_credit import GermanCreditDataset
dataset = GermanCreditDataset()
df = dataset.get_dataframe()
df.head(3)

Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,67,1,2,1,0,1,1169,6,5,0
1,22,0,2,1,1,2,5951,48,5,1
2,49,1,1,1,1,0,2096,12,3,0


In [19]:
df_Risk_1 = df[df['Risk'] == 1]
df_Risk_1 = df_Risk_1.sample(5)
df_Risk_1.head()

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
169,31,1,2,1,1,2,1935,24,0,1
249,22,0,2,2,1,0,433,18,5,1
973,36,1,2,2,1,1,7297,60,0,1
700,29,0,1,2,3,0,1123,12,4,1
37,37,1,2,1,1,3,2100,18,5,1


In [20]:
df_without_target = df_Risk_1.drop(columns=['Risk']).copy()
df_without_target.head()

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose
169,31,1,2,1,1,2,1935,24,0
249,22,0,2,2,1,0,433,18,5
973,36,1,2,2,1,1,7297,60,0
700,29,0,1,2,3,0,1123,12,4
37,37,1,2,1,1,3,2100,18,5


In [21]:
feature_names = df_without_target.columns
feature_names

Index(['Age', 'Sex', 'Job', 'Housing', 'Saving accounts', 'Checking account',
       'Credit amount', 'Duration', 'Purpose'],
      dtype='object')

#### 1. Initialize data interface

In [22]:
from xai_cola import data_interface 
data1 = data_interface.PandasData(df_without_target, target_name='Risk')

#### 2. Initialize model interface

In [23]:
import joblib
lgbmcClassifier = joblib.load('lgbm_GremanCredit.pkl')
print(f'----lgbm_GremanCredit.pkl model has been loaded----')

----lgbm_GremanCredit.pkl model has been loaded----


In [24]:
from xai_cola import ml_model_interface
ml_model1 = ml_model_interface.Model(model=lgbmcClassifier, backend="sklearn")

#### 3.Choose the CounterfactualExplanation Algorithm

In [25]:
from counterfactual_explainer import DiCE
explainer = DiCE(ml_model=ml_model1)

In [26]:
factual, counterfactual = explainer.generate_counterfactuals(data=data1,
                                                             factual_class=1,
                                                             total_cfs=1)
print(f'Factual: {factual}')
print(f'Counterfactual: {counterfactual}')

100%|██████████| 5/5 [00:00<00:00, 11.74it/s]

Factual: [[  31    1    2    1    1    2 1935   24    0]
 [  22    0    2    2    1    0  433   18    5]
 [  36    1    2    2    1    1 7297   60    0]
 [  29    0    1    2    3    0 1123   12    4]
 [  37    1    2    1    1    3 2100   18    5]]
Counterfactual: [[  31    0    2    1    1    0 1935   24    0]
 [  37    0    2    2    3    0  433   18    5]
 [  36    1    2    2    3    1 7297   42    0]
 [  29    0    1    1    3    0 3989   12    4]
 [  29    1    2    1    1    0 2100   18    5]]





#### 4. Choose policy and make limitation

In [27]:
from xai_cola.counterfactual_limited_actions import COLA
refiner = COLA(
    data=data1,
    ml_model=ml_model1,
    x_factual=factual,
    x_counterfactual=counterfactual,
)
refiner.set_policy(
    matcher="ot",
    attributor="pshap",
    Avalues_method="max"
    )

You choose the Policy: pshap With Optimal Transport Matching, Avalues_method is max


In [28]:
factual, ce, ace = refiner.get_refined_counterfactual(limited_actions=5)

INFO:shap:num_full_subsets = 2
INFO:shap:phi = [ 0.10845291  0.32940638 -0.02482684  0.09001071 -0.01369091]
INFO:shap:num_full_subsets = 1
INFO:shap:phi = [ 0.40606964 -0.00211596]
INFO:shap:num_full_subsets = 1
INFO:shap:phi = [0.23691719 0.45963417]
INFO:shap:num_full_subsets = 3
INFO:shap:phi = [-0.0647144   0.29240912  0.04947976  0.06077738 -0.00799271  0.02456827
  0.10825887]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [-0.00231767 -0.06277694 -0.03997775  0.05313482  0.09225603  0.20100465
  0.14036763 -0.02769859]


#### 5.Highlight the generated counterfactuals

In [29]:
factual, refine_ce, refine_ace = refiner.highlight_changes()

In [32]:
print("factual")
display(factual)
print("factaul -> corresponding counterfactual")
display(refine_ce)
print("factual -> action-limited counterfactual")
display(refine_ace)

factual


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,31,1,2,1,1,2,1935,24,0,1
1,22,0,2,2,1,0,433,18,5,0
2,36,1,2,2,1,1,7297,60,0,1
3,29,0,1,2,3,0,1123,12,4,1
4,37,1,2,1,1,3,2100,18,5,1


factaul -> corresponding counterfactual


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,31 -> 29,1,2,1,1,2 -> 0,1935 -> 2100,24 -> 18,0 -> 5,1 -> 0
1,22 -> 37,0,2,2,1 -> 3,0,433,18,5,0
2,36,1,2,2,1 -> 3,1,7297,60 -> 42,0,1 -> 0
3,29 -> 31,0,1 -> 2,2 -> 1,3 -> 1,0,1123 -> 1935,12 -> 24,4 -> 0,1 -> 0
4,37 -> 29,1 -> 0,2 -> 1,1,1 -> 3,3 -> 0,2100 -> 3989,18 -> 12,5 -> 4,1 -> 0


factual -> action-limited counterfactual


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,31,1,2,1,1,2 -> 0,1935,24,0,1 -> 0
1,22 -> 37,0,2,2,1,0,433,18,5,0
2,36,1,2,2,1 -> 3,1,7297,60 -> 42,0,1 -> 0
3,29,0,1 -> 2,2,3,0,1123,12,4,1 -> 0
4,37,1,2,1,1,3,2100,18,5,1
