In [702]:
%reload_ext autoreload
%autoreload 2

#### 0. Pick the data to be factual("Risk"=1)

In [703]:
from test_dataset.german_credit import GermanCreditDataset
dataset = GermanCreditDataset()
df = dataset.get_dataframe()
df.head(3)

Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,67,1,2,1,0,1,1169,6,5,0
1,22,0,2,1,1,2,5951,48,5,1
2,49,1,1,1,1,0,2096,12,3,0


In [704]:
df_Risk_1 = df[df['Risk'] == 1]
df_Risk_1 = df_Risk_1.sample(5)
df_Risk_1.head()

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
192,36,1,2,1,1,2,3915,27,0,1
545,43,1,2,0,1,1,1333,24,1,1
653,42,1,3,1,2,2,8086,36,1,1
918,33,1,2,1,2,1,2359,24,4,1
935,30,1,3,1,2,2,1919,30,5,1


In [705]:
df_without_target = df_Risk_1.drop(columns=['Risk']).copy()
df_without_target.head()

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose
192,36,1,2,1,1,2,3915,27,0
545,43,1,2,0,1,1,1333,24,1
653,42,1,3,1,2,2,8086,36,1
918,33,1,2,1,2,1,2359,24,4
935,30,1,3,1,2,2,1919,30,5


In [706]:
feature_names = df_without_target.columns
feature_names

Index(['Age', 'Sex', 'Job', 'Housing', 'Saving accounts', 'Checking account',
       'Credit amount', 'Duration', 'Purpose'],
      dtype='object')

#### 1. Initialize data interface

In [707]:
from xai_cola import data_interface 
data1 = data_interface.PandasData(df_without_target, target_name='Risk')

#### 2. Initialize model interface

In [708]:
import joblib
lgbmcClassifier = joblib.load('lgbm_GremanCredit.pkl')
print(f'----lgbm_GremanCredit.pkl model has been loaded----')

----lgbm_GremanCredit.pkl model has been loaded----


In [709]:
from xai_cola import ml_model_interface
ml_model1 = ml_model_interface.Model(model=lgbmcClassifier, backend="sklearn")

#### 3.Choose the CounterfactualExplanation Algorithm

In [710]:
from counterfactual_explainer import DiCE
explainer = DiCE(ml_model=ml_model1)

In [711]:
factual, counterfactual = explainer.generate_counterfactuals(data=data1,
                                                             factual_class=1,
                                                             total_cfs=1)
print(f'Factual: {factual}')
print(f'Counterfactual: {counterfactual}')

100%|██████████| 5/5 [00:00<00:00,  7.32it/s]

Factual: [[  36    1    2    1    1    2 3915   27    0]
 [  43    1    2    0    1    1 1333   24    1]
 [  42    1    3    1    2    2 8086   36    1]
 [  33    1    2    1    2    1 2359   24    4]
 [  30    1    3    1    2    2 1919   30    5]]
Counterfactual: [[  36    1    2    1    1    2 7203   27    0]
 [  43    1    2    0    1    1 5497   24    1]
 [  42    1    3    1    2    2 7928   36    1]
 [  39    1    2    1    2    1 2359   27    4]
 [  30    1    3    1    2    2 7942   30    5]]





#### 4. Choose policy and make limitation

In [712]:
from xai_cola.counterfactual_limited_actions import COLA
refiner = COLA(
    data=data1,
    ml_model=ml_model1,
    x_factual=factual,
    x_counterfactual=counterfactual,
)
refiner.set_policy(
    matcher="ot",
    attributor="pshap",
    Avalues_method="max"
    )

You choose the Policy: pshap With Optimal Transport Matching, Avalues_method is max


In [713]:
factual, ce, ace = refiner.get_refined_counterfactual(limited_actions=5)

INFO:shap:num_full_subsets = 3
INFO:shap:phi = [-0.21057749 -0.0274783   0.0594797   0.21461199  0.10654806  0.16918744]
INFO:shap:num_full_subsets = 3
INFO:shap:phi = [-0.03291602  0.2184701  -0.005328    0.45929694  0.01910381  0.01044684]
INFO:shap:num_full_subsets = 2
INFO:shap:phi = [ 0.23467718  0.31134386 -0.02635609 -0.00718227]
INFO:shap:num_full_subsets = 3
INFO:shap:phi = [ 0.07351582 -0.04075535  0.21608805  0.20190609  0.02129067  0.15928859]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [-0.03664901 -0.04346676 -0.08997685  0.00110699 -0.01324027  0.56155565
  0.01243991  0.02509277]


#### 5.Highlight the generated counterfactuals

In [714]:
factual, refine_ce, refine_ace = refiner.highlight_changes()

In [715]:
print("factual")
display(factual)
print("factaul -> corresponding counterfactual")
display(refine_ce)
print("factual -> action-limited counterfactual")
display(refine_ace)

factual


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,36,1,2,1,1,2,3915,27,0,1
1,43,1,2,0,1,1,1333,24,1,1
2,42,1,3,1,2,2,8086,36,1,1
3,33,1,2,1,2,1,2359,24,4,1
4,30,1,3,1,2,2,1919,30,5,1


factaul -> corresponding counterfactual


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,36 -> 42,1,2 -> 3,1,1 -> 2,2,3915 -> 7928,27 -> 36,0 -> 1,1 -> 0
1,43 -> 39,1,2,0 -> 1,1 -> 2,1,1333 -> 2359,24 -> 27,1 -> 4,1 -> 0
2,42 -> 30,1,3,1,2,2,8086 -> 7942,36 -> 30,1 -> 5,1 -> 0
3,33 -> 36,1,2,1,2 -> 1,1 -> 2,2359 -> 7203,24 -> 27,4 -> 0,1 -> 0
4,30 -> 43,1,3 -> 2,1 -> 0,2 -> 1,2 -> 1,1919 -> 5497,30 -> 24,5 -> 1,1 -> 0


factual -> action-limited counterfactual


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,36,1,2,1,1,2,3915,27,0,1
1,43,1,2,0 -> 1,1,1,1333 -> 2359,24,1,1 -> 0
2,42 -> 30,1,3,1,2,2,8086 -> 7942,36,1,1 -> 0
3,33,1,2,1,2,1,2359,24,4,1
4,30,1,3,1,2,2,1919 -> 5497,30,5,1 -> 0
