In [1]:
%reload_ext autoreload
%autoreload 2

In [5]:
from xai_cola.ce_sparsifier import COLA
from xai_cola.ce_sparsifier.data import COLAData
from xai_cola.ce_sparsifier.models import Model
from xai_cola.ce_generator import DiCE, DisCount

#### 0. Pick the data to be factual("Risk"=1)

In [6]:
from datasets.german_credit import GermanCreditDataset
dataset = GermanCreditDataset()
df = dataset.get_dataframe()
print(df.head(3))

   Age  Sex  Job  Housing  Saving accounts  Checking account  Credit amount  \
0   67    1    2        1                0                 1           1169   
1   22    0    2        1                1                 2           5951   
2   49    1    1        1                1                 0           2096   

   Duration  Purpose  Risk  
0         6        5     0  
1        48        5     1  
2        12        3     0  


Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`


In [7]:
# pick 4 samples with Risk = 1
df_Risk_1 = df[df['Risk'] == 1]
df_Risk_1 = df_Risk_1.sample(5)

# Keep the target column for now
# We'll pass the full dataframe to COLAData
feature_names = df_Risk_1.drop(columns=['Risk']).columns
print(df_Risk_1.head())

     Age  Sex  Job  Housing  Saving accounts  Checking account  Credit amount  \
951   24    1    2        1                1                 1           2145   
858   29    0    2        1                1                 1           3959   
846   68    1    2        2                0                 0           6761   
355   23    1    1        1                1                 2           1246   
761   24    0    2        2                1                 1           2124   

     Duration  Purpose  Risk  
951        36        0     1  
858        15        1     1  
846        18        1     1  
355        24        1     1  
761        18        4     1  


#### 1. Initialize data interface

In [8]:
data = COLAData(
    factual_data=df_Risk_1, 
    label_column='Risk',
    transform=None,
    numerical_features=['Age', 'Housing', 'Saving accounts', 'Checking account',
       'Credit amount', 'Duration', 'Purpose', 'Risk']
)
print(data)

COLAData(factual: 5 rows, features: 9, label: Risk, no counterfactual)


In [9]:
data.get_factual_all()

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
951,24,1,2,1,1,1,2145,36,0,1
858,29,0,2,1,1,1,3959,15,1,1
846,68,1,2,2,0,0,6761,18,1,1
355,23,1,1,1,1,2,1246,24,1,1
761,24,0,2,2,1,1,2124,18,4,1


#### 2. Initialize model interface

In [10]:
import joblib
lgbmcClassifier = joblib.load('lgbm_GremanCredit.pkl')
ml_model = Model(model=lgbmcClassifier, backend="sklearn")

#### 3.Choose the CounterfactualExplanation Algorithm

In [11]:
from xai_cola.ce_generator import DiCE, DisCount, ARecourseS, AlibiCounterfactualInstances
explainer = DiCE(ml_model=ml_model)
factual, counterfactual = explainer.generate_counterfactuals(
    data=data,
    factual_class=1,
    total_cfs=1,
    features_to_keep=['Age','Sex']
)

Could not find the number of physical cores for the following reason:
[WinError 2] 系统找不到指定的文件。
  File "c:\Users\ZhuLi\Miniconda3\envs\cola\lib\site-packages\joblib\externals\loky\backend\context.py", line 257, in _count_physical_cores
    cpu_info = subprocess.run(
  File "c:\Users\ZhuLi\Miniconda3\envs\cola\lib\subprocess.py", line 503, in run
    with Popen(*popenargs, **kwargs) as process:
  File "c:\Users\ZhuLi\Miniconda3\envs\cola\lib\subprocess.py", line 971, in __init__
    self._execute_child(args, executable, preexec_fn, close_fds,
  File "c:\Users\ZhuLi\Miniconda3\envs\cola\lib\subprocess.py", line 1456, in _execute_child
    hp, ht, pid, tid = _winapi.CreateProcess(executable, args,
100%|██████████| 5/5 [00:00<00:00, 13.00it/s]


In [12]:
factual

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
951,24,1,2,1,1,1,2145,36,0,1
858,29,0,2,1,1,1,3959,15,1,1
846,68,1,2,2,0,0,6761,18,1,1
355,23,1,1,1,1,2,1246,24,1,1
761,24,0,2,2,1,1,2124,18,4,1


In [13]:
counterfactual

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,24,1,2,1,0,1,2145,27,0,0
1,29,0,2,1,1,1,5257,15,4,0
2,68,1,2,2,0,2,6761,30,1,0
3,23,1,1,1,1,2,5741,24,1,0
4,24,0,2,2,1,1,4507,18,4,0


In [14]:
data.summary()

{'factual_samples': 5,
 'feature_count': 9,
 'label_column': 'Risk',
 'all_columns': ['Age',
  'Sex',
  'Job',
  'Housing',
  'Saving accounts',
  'Checking account',
  'Credit amount',
  'Duration',
  'Purpose',
  'Risk'],
 'has_counterfactual': False}

#### 4. Choose policy and make limitation

In [15]:
# Add counterfactual data to COLAData object first
data.add_counterfactuals(counterfactual, with_target_column=True)

In [16]:
data

COLAData(factual: 5 rows, features: 9, label: Risk, counterfactual: 5 rows)

In [17]:
data.summary()

{'factual_samples': 5,
 'feature_count': 9,
 'label_column': 'Risk',
 'all_columns': ['Age',
  'Sex',
  'Job',
  'Housing',
  'Saving accounts',
  'Checking account',
  'Credit amount',
  'Duration',
  'Purpose',
  'Risk'],
 'has_counterfactual': True,
 'counterfactual_samples': 5}

In [20]:
# Initialize COLA - it will automatically extract factual and counterfactual from data
sparsifier = COLA(
    data=data,
    ml_model=ml_model
)

sparsifier.set_policy(
    matcher="cem",
    attributor="pshap",
    Avalues_method="max",
    random_state=1
)

Policy set: pshap with Coarsened Exact Matching with Optimal Transport, Avalues_method: max


In [49]:
sparsifier.query_minimum_actions(features_to_vary=['Saving accounts','Checking account','Credit amount'])

Using all available actions (10) from the specified features.
Some samples may not reach the target prediction. Consider including more features.


10

In [43]:
""" Here! control the limited actions """
# Set random seed for reproducibilit

refined_cf = sparsifier.get_refined_counterfactual(limited_actions=10)
refined_cf

invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.03864962 -0.02397926  0.00273289 -0.03637535  0.07869994  0.03143636
  0.13211411  0.32045498 -0.02995369]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.22050441  0.08125722  0.00231351 -0.04526418  0.06916251  0.07582101
 -0.01769772 -0.04156882  0.10545774]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.02325676 -0.03487256  0.00973301  0.02773512 -0.02764825  0.01529155
  0.23533257  0.14020705  0.03242187]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.04049728 -0.04470557 -0.02715674 -0.03411176  0.09687787 -0.04797131
  0.40558793  0.05047099  0.11682684]
INFO:shap:num_full_subsets = 4
INFO:shap:phi

nan


invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide


nan


invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
951,24,1,2,1,1,1,2145,27,0,1
858,29,0,2,1,1,1,3959,15,0,1
846,68,1,2,2,0,0,6761,27,1,0
355,23,1,1,1,1,2,2145,24,1,1
761,24,0,2,2,1,1,2124,18,4,1


In [47]:
factual_df, counterfactual_df, refined_cf_df = sparsifier.get_all_results(
    limited_actions=15,
    features_to_vary=['Saving accounts','Checking account','Credit amount']
    )

invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.03864962 -0.02397926  0.00273289 -0.03637535  0.07869994  0.03143636
  0.13211411  0.32045498 -0.02995369]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.22050441  0.08125722  0.00231351 -0.04526418  0.06916251  0.07582101
 -0.01769772 -0.04156882  0.10545774]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.02325676 -0.03487256  0.00973301  0.02773512 -0.02764825  0.01529155
  0.23533257  0.14020705  0.03242187]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.04049728 -0.04470557 -0.02715674 -0.03411176  0.09687787 -0.04797131
  0.40558793  0.05047099  0.11682684]
INFO:shap:num_full_subsets = 4
INFO:shap:phi

nan


invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide
invalid value encountered in divide


nan




#### 5.Highlight the generated counterfactuals

In [48]:
refine_factual, refine_ce, refine_ace = sparsifier.highlight_changes_comparison()
display(refine_ce,refine_ace)  # 显示 "1553 -> 1103"
# refine_ce.to_html('comparison.html')

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
951,24,1,2,1,1 -> 0,1,2145,36 -> 27,0,1 -> 0
858,29 -> 24,0 -> 1,2,1,1 -> 0,1,3959 -> 2145,15 -> 27,1 -> 0,1 -> 0
846,68 -> 24,1,2,2 -> 1,0,0 -> 1,6761 -> 2145,18 -> 27,1 -> 0,1 -> 0
355,23 -> 24,1,1 -> 2,1,1 -> 0,2 -> 1,1246 -> 2145,24 -> 27,1 -> 0,1 -> 0
761,24,0 -> 1,2,2 -> 1,1 -> 0,1,2124 -> 2145,18 -> 27,4 -> 0,1 -> 0


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
951,24,1,2,1,1 -> 0,1,2145,36,0,1
858,29,0,2,1,1 -> 0,1,3959 -> 2145,15,1,1 -> 0
846,68,1,2,2,0,0 -> 1,6761 -> 2145,18,1,1 -> 0
355,23,1,1,1,1 -> 0,2 -> 1,1246 -> 2145,24,1,1 -> 0
761,24,0,2,2,1 -> 0,1,2124 -> 2145,18,4,1 -> 0


In [34]:
factual_df, ce_style, ace_style = sparsifier.highlight_changes_final()
display(ce_style,ace_style)  # 只显示 "1103"
# ce_style.to_html('final.html')

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
951,24,1,2,1,0,1,2145,27,0,0
858,24,1,2,1,0,1,2145,27,0,0
846,24,1,2,1,0,1,2145,27,0,0
355,24,1,2,1,0,1,2145,27,0,0
761,24,1,2,1,0,1,2145,27,0,0


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
951,24,1,2,1,0,1,2145,27,0,0
858,29,0,2,1,0,1,2145,27,0,0
846,68,1,2,2,0,1,2145,27,0,0
355,23,1,1,1,0,1,2145,27,0,0
761,24,0,2,2,0,1,2145,27,4,0


In [None]:
# refiner.heatmap(save_path='./results/', save_mode='separate')
# refiner.heatmap(save_path='./outputs/', save_mode='separate')
# refiner.heatmap(save_path='./results/', save_mode='combined')

sparsifier.heatmap()