In [1]:
%reload_ext autoreload
%autoreload 2

In [21]:
from xai_cola import data_interface 
from xai_cola import ml_model_interface
from counterfactual_explainer import DiCE,DisCount
from xai_cola.cola_policy.matching import CounterfactualExactMatchingPolicy
from xai_cola.counterfactual_limited_actions import COLA

#### 0. Pick the data to be factual("Risk"=1)

In [None]:
from datasets.german_credit import GermanCreditDataset
dataset = GermanCreditDataset()
df = dataset.get_dataframe()
df.head(3)

In [None]:
# pick 4 samples with Risk = 1
df_Risk_1 = df[df['Risk'] == 1]
df_Risk_1 = df_Risk_1.sample(5)

# drop the target column.
# Normally, the input data doesn't contain the target column
df_without_target = df_Risk_1.drop(columns=['Risk']).copy()
feature_names = df_without_target.columns
df_without_target.head()

#### 1. Initialize data interface

In [29]:
data = data_interface.PandasData(df_without_target, target_name='Risk')

#### 2. Initialize model interface

In [30]:
import joblib
lgbmcClassifier = joblib.load('lgbm_GremanCredit.pkl')
ml_model = ml_model_interface.Model(model=lgbmcClassifier, backend="sklearn")

#### 3.Choose the CounterfactualExplanation Algorithm

In [None]:
from counterfactual_explainer import DiCE,DisCount,ARecourseS,AlibiCounterfactualInstances
explainer = DiCE(ml_model=ml_model)
factual, counterfactual = explainer.generate_counterfactuals(data=data,
                                                             factual_class=1,
                                                             total_cfs=1,
                                                             features_to_keep=['Age','Sex'])

In [None]:
factual, counterfactual = explainer.generate_counterfactuals(data=data,
                                                             factual_class=1,
                                                             total_cfs=1,
                                                             features_to_keep=['Age','Sex'])

In [None]:
import numpy as np

# Use numpy's array2string for consistent formatting
print('factual')
print(np.array2string(factual, separator=' ', suppress_small=True))
print()  # Add a blank line for spacing
print('counterfactual')
print(np.array2string(counterfactual, separator='   ', suppress_small=True))


#### 4. Choose policy and make limitation

In [None]:
from xai_cola.counterfactual_limited_actions import COLA
refiner = COLA(
            data=data,
            ml_model=ml_model,
            x_factual=factual,
            x_counterfactual=counterfactual,
            )
refiner.set_policy(
            matcher="ect",
            attributor="pshap",
            Avalues_method="max"
            )

In [None]:
""" Here! control the limited actions """
factual, ce, ace = refiner.get_refined_counterfactual(limited_actions=4)

#### 5.Highlight the generated counterfactuals

In [36]:
refine_factual, refine_ce, refine_ace = refiner.highlight_changes()

In [None]:
actions = refiner.query_minimum_actions()

In [None]:
print("factual")
display(factual)
print("factaul -> corresponding counterfactual")
display(refine_ce)
print("factual -> action-limited counterfactual")
display(refine_ace)

In [None]:
refine_ace

In [None]:
refiner.heatmap()