In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from xai_cola.ce_sparsifier import COLA
from xai_cola.ce_sparsifier.data import COLAData
from xai_cola.ce_sparsifier.models import Model
from xai_cola.ce_generator import DiCE, DisCount

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


#### 0. Pick the data to be factual("Risk"=1)

In [3]:
from datasets.german_credit import GermanCreditDataset
dataset = GermanCreditDataset()
X_train, y_train, X_test, y_test = dataset.get_original_train_test_split()
df = dataset.get_dataframe()
print(df.head(3))

   Age  Sex  Job  Housing  Saving accounts  Checking account  Credit amount  \
0   67    1    2        1                0                 1           1169   
1   22    0    2        1                1                 2           5951   
2   49    1    1        1                1                 0           2096   

   Duration  Purpose  Risk  
0         6        5     0  
1        48        5     1  
2        12        3     0  


Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`


In [4]:
# pick 4 samples with Risk = 1
df_Risk_1 = df[df['Risk'] == 1]
df_Risk_1 = df_Risk_1.sample(5)

# Keep the target column for now
# We'll pass the full dataframe to COLAData
feature_names = df_Risk_1.drop(columns=['Risk']).columns
print(df_Risk_1.head())

     Age  Sex  Job  Housing  Saving accounts  Checking account  Credit amount  \
842   23    0    2        1                1                 0           1943   
747   37    0    1        1                1                 1           1274   
972   29    0    0        2                1                 1           1193   
761   24    0    2        2                1                 1           2124   
4     53    1    2        0                1                 1           4870   

     Duration  Purpose  Risk  
842        18        6     1  
747        12        1     1  
972        24        1     1  
761        18        4     1  
4          24        1     1  


#### 1. Initialize data interface

In [5]:
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from lightgbm import LGBMClassifier
from xai_cola.ce_sparsifier.utils import create_pipeline_with_column_names, ensure_column_order

# 定义特征列表
numerical_features = ['Age', 'Housing', 'Credit amount', 'Duration']
categorical_features = ['Sex', 'Job', 'Saving accounts', 'Checking account', 'Purpose']

# # 步骤 1: 创建预处理器
preprocessor = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), numerical_features),  # 数值特征标准化
        ('cat', OneHotEncoder(drop='first', handle_unknown='ignore'), categorical_features)  # 类别特征独热编码
    ],
    remainder='passthrough'  # 保留其他列
)

# 步骤 2: 创建 LGBM 分类器（可以自定义参数）
lgbm_clf = LGBMClassifier(
    n_estimators=100,
    learning_rate=0.1,
    max_depth=5,
    random_state=42,
    verbose=-1  # 关闭训练时的输出
)

# 步骤 3: 创建 Pipeline
pipe = Pipeline([
    ('preprocessor', preprocessor),
    ('classifier', lgbm_clf)
])

# pipe, feature_order = create_pipeline_with_column_names(
#     numerical_features=numerical_features,
#     categorical_features=categorical_features,
#     classifier=LGBMClassifier(random_state=42, verbose=-1),
#     reference_dataframe=X_train
# )
# 步骤 4: 训练 Pipeline（使用原始数据）
# 在训练之前确保类别特征为字符串，避免训练/推理阶段 dtype 不一致
for c in categorical_features:
    X_train[c] = X_train[c].astype(str)
    X_test[c] = X_test[c].astype(str)
pipe.fit(X_train, y_train)

Could not find the number of physical cores for the following reason:
[WinError 2] 系统找不到指定的文件。
  File "c:\Users\ZhuLi\miniconda3\envs\cola\lib\site-packages\joblib\externals\loky\backend\context.py", line 257, in _count_physical_cores
    cpu_info = subprocess.run(
  File "c:\Users\ZhuLi\miniconda3\envs\cola\lib\subprocess.py", line 503, in run
    with Popen(*popenargs, **kwargs) as process:
  File "c:\Users\ZhuLi\miniconda3\envs\cola\lib\subprocess.py", line 971, in __init__
    self._execute_child(args, executable, preexec_fn, close_fds,
  File "c:\Users\ZhuLi\miniconda3\envs\cola\lib\subprocess.py", line 1456, in _execute_child
    hp, ht, pid, tid = _winapi.CreateProcess(executable, args,


In [6]:
numerical_features = ['Age', 'Housing', 'Credit amount', 'Duration']
categorical_features = ['Sex', 'Job', 'Saving accounts', 'Checking account', 'Purpose']

In [7]:
data = COLAData(
    factual_data=df_Risk_1, 
    label_column='Risk',
    numerical_features=numerical_features
)
print(data)

COLAData(factual: 5 rows, features: 9, label: Risk, no counterfactual)


In [8]:
data.get_factual_all()

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
842,23,0,2,1,1,0,1943,18,6,1
747,37,0,1,1,1,1,1274,12,1,1
972,29,0,0,2,1,1,1193,24,1,1
761,24,0,2,2,1,1,2124,18,4,1
4,53,1,2,0,1,1,4870,24,1,1


#### 2. Initialize model interface

In [9]:
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
import joblib
# lgbmcClassifier = joblib.load('lgbm_GremanCredit.pkl')

# # ct = ColumnTransformer([
# #     ('cat', OneHotEncoder(drop='first'), categorical_features),
# #     ('num', StandardScaler(), numerical_features)
# # ])

# # pipe = Pipeline([
# #     ("pre", ct),                    # 必须有这一步！
# #     ("clf", LGBMClassifier())       # 分类器
# # ])

# # pipe.fit(X_train, y_train)

# pipe = Pipeline([
#     ("clf", lgbmcClassifier) 
# ])

# ml_model = Model(model=lgbmcClassifier, backend="sklearn")

ml_model = Model(model=pipe, backend="sklearn")

#### 3.Choose the CounterfactualExplanation Algorithm

In [10]:
from xai_cola.ce_generator import DiCE, DisCount, ARecourseS, AlibiCounterfactualInstances
explainer = DiCE(ml_model=ml_model)
factual, counterfactual = explainer.generate_counterfactuals(
    data=data,
    factual_class=1,
    total_cfs=1,
    features_to_keep=['Age','Sex'],
    continuous_features=numerical_features
)

100%|██████████| 5/5 [00:00<00:00, 12.94it/s]


In [11]:
import pandas as pd
ohe = pipe.named_steps['preprocessor'].named_transformers_['cat']
for name, cats in zip(categorical_features, ohe.categories_):
    print(name, type(cats), cats)
    print([type(v) for v in cats[:20]])
any(pd.isna(v) for v in cats)

Sex <class 'numpy.ndarray'> ['0' '1']
[<class 'str'>, <class 'str'>]
Job <class 'numpy.ndarray'> ['0' '1' '2' '3']
[<class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>]
Saving accounts <class 'numpy.ndarray'> ['0' '1' '2' '3' '4']
[<class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>]
Checking account <class 'numpy.ndarray'> ['0' '1' '2' '3']
[<class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>]
Purpose <class 'numpy.ndarray'> ['0' '1' '2' '3' '4' '5' '6' '7']
[<class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>, <class 'str'>]


False

In [12]:
factual

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
842,23,0,2,1,1,0,1943,18,6,1
747,37,0,1,1,1,1,1274,12,1,1
972,29,0,0,2,1,1,1193,24,1,1
761,24,0,2,2,1,1,2124,18,4,1
4,53,1,2,0,1,1,4870,24,1,1


In [13]:
counterfactual

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,23,0,2,1,1,0,4845,18,6,0
1,37,0,0,1,1,1,1274,24,1,0
2,29,0,0,2,1,1,3539,24,1,0
3,24,0,2,2,1,0,2124,14,4,0
4,53,1,2,0,1,1,2631,24,1,0


In [14]:
data.summary()

{'factual_samples': 5,
 'feature_count': 9,
 'label_column': 'Risk',
 'all_columns': ['Age',
  'Sex',
  'Job',
  'Housing',
  'Saving accounts',
  'Checking account',
  'Credit amount',
  'Duration',
  'Purpose',
  'Risk'],
 'has_counterfactual': False}

#### 4. Choose policy and make limitation

In [15]:
# Add counterfactual data to COLAData object first
data.add_counterfactuals(counterfactual, with_target_column=True)

In [16]:
data

COLAData(factual: 5 rows, features: 9, label: Risk, counterfactual: 5 rows)

In [17]:
data.summary()

{'factual_samples': 5,
 'feature_count': 9,
 'label_column': 'Risk',
 'all_columns': ['Age',
  'Sex',
  'Job',
  'Housing',
  'Saving accounts',
  'Checking account',
  'Credit amount',
  'Duration',
  'Purpose',
  'Risk'],
 'has_counterfactual': True,
 'counterfactual_samples': 5}

In [27]:
# Initialize COLA - it will automatically extract factual and counterfactual from data
sparsifier = COLA(
    data=data,
    ml_model=ml_model
)

sparsifier.set_policy(
    matcher="ect",
    attributor="pshap",
    Avalues_method="max",
    random_state=1
)

Policy set: pshap with Exact Matching, Avalues_method: max


In [28]:
# 1. 检查数据维度
print("=" * 50)
print("Data Dimensions Check")
print("=" * 50)
print(f"Original features: {len(numerical_features + categorical_features)}")
print(f"Factual data shape (from COLAData): {data.get_factual_features().shape}")
print(f"Counterfactual data shape (from COLAData): {data.get_counterfactual_features().shape}")

# 2. 检查 Pipeline 转换后的维度
sample_data = data.get_factual_features().head(1)
transformed = pipe.named_steps['preprocessor'].transform(sample_data)
print(f"\nTransformed data shape: {transformed.shape}")
print(f"Features after preprocessing: {transformed.shape[1]}")

# 3. 检查 varphi 的维度
try:
    varphi = sparsifier._get_attributor()
    print(f"\nvarphi shape: {varphi.shape}")
    print(f"varphi ndim: {varphi.ndim}")
except Exception as e:
    print(f"\nError getting varphi: {e}")

# 4. 检查 Model 是否正确识别为 Pipeline
print(f"\nIs pipeline: {ml_model.is_pipeline}")

Data Dimensions Check
Original features: 9
Factual data shape (from COLAData): (5, 9)
Counterfactual data shape (from COLAData): (5, 9)

Transformed data shape: (1, 22)
Features after preprocessing: 22


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 1
INFO:shap:phi = [-0.00390544  0.28015951]
INFO:shap:phi = [ 0.00390544 -0.28015951]


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 1
INFO:shap:phi = [-0.19210347 -0.35355248]
INFO:shap:phi = [0.19210347 0.35355248]


  0%|          | 0/1 [00:00<?, ?it/s]


varphi shape: (5, 1, 9)
varphi ndim: 3

Is pipeline: True


In [29]:
sparsifier.query_minimum_actions()

The minimum number of actions is 20


20

In [37]:
sparsifier._get_attributor()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 1
INFO:shap:phi = [-0.00390544  0.28015951]
INFO:shap:phi = [ 0.00390544 -0.28015951]


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 1
INFO:shap:phi = [-0.19210347 -0.35355248]
INFO:shap:phi = [0.19210347 0.35355248]


  0%|          | 0/1 [00:00<?, ?it/s]

array([[[1.00000000e-20, 1.00000000e-20, 1.00000000e-20, 1.00000000e-20,
         1.00000000e-20, 1.00000000e-20, 4.56272312e-02, 1.00000000e-20,
         1.00000000e-20]],

       [[1.00000000e-20, 1.00000000e-20, 2.38109453e-03, 1.00000000e-20,
         1.00000000e-20, 1.00000000e-20, 1.00000000e-20, 1.70809713e-01,
         1.00000000e-20]],

       [[1.00000000e-20, 1.00000000e-20, 1.00000000e-20, 1.00000000e-20,
         1.00000000e-20, 1.00000000e-20, 3.21321539e-01, 1.00000000e-20,
         1.00000000e-20]],

       [[1.00000000e-20, 1.00000000e-20, 1.00000000e-20, 1.00000000e-20,
         1.00000000e-20, 1.17123058e-01, 1.00000000e-20, 2.15556481e-01,
         1.00000000e-20]],

       [[1.00000000e-20, 1.00000000e-20, 1.00000000e-20, 1.00000000e-20,
         1.00000000e-20, 1.00000000e-20, 1.27180883e-01, 1.00000000e-20,
         1.00000000e-20]]])

In [36]:
""" Here! control the limited actions """
# Set random seed for reproducibilit

refined_cf = sparsifier.get_refined_counterfactual(limited_actions=20)
refined_cf

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 1
INFO:shap:phi = [-0.00390544  0.28015951]
INFO:shap:phi = [ 0.00390544 -0.28015951]


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 1
INFO:shap:phi = [-0.19210347 -0.35355248]
INFO:shap:phi = [0.19210347 0.35355248]


  0%|          | 0/1 [00:00<?, ?it/s]

self.x_counterfactual: [[23 '0' '2' 1 '1' '0' 4845 18 '6']
 [37 '0' '0' 1 '1' '1' 1274 24 '1']
 [29 '0' '0' 2 '1' '1' 3539 24 '1']
 [24 '0' '2' 2 '1' '0' 2124 14 '4']
 [53 '1' '2' 0 '1' '1' 2631 24 '1']]
data composer q: [[2.300e+01 0.000e+00 2.000e+00 1.000e+00 1.000e+00 0.000e+00 4.845e+03
  1.800e+01 6.000e+00]
 [3.700e+01 0.000e+00 0.000e+00 1.000e+00 1.000e+00 1.000e+00 1.274e+03
  2.400e+01 1.000e+00]
 [2.900e+01 0.000e+00 0.000e+00 2.000e+00 1.000e+00 1.000e+00 3.539e+03
  2.400e+01 1.000e+00]
 [2.400e+01 0.000e+00 2.000e+00 2.000e+00 1.000e+00 0.000e+00 2.124e+03
  1.400e+01 4.000e+00]
 [5.300e+01 1.000e+00 2.000e+00 0.000e+00 1.000e+00 1.000e+00 2.631e+03
  2.400e+01 1.000e+00]]
corresponding_counterfactual: [[2.300e+01 0.000e+00 2.000e+00 1.000e+00 1.000e+00 0.000e+00 4.845e+03
  1.800e+01 6.000e+00]
 [3.700e+01 0.000e+00 0.000e+00 1.000e+00 1.000e+00 1.000e+00 1.274e+03
  2.400e+01 1.000e+00]
 [2.900e+01 0.000e+00 0.000e+00 2.000e+00 1.000e+00 1.000e+00 3.539e+03
  2.400e+01

Found unknown categories in columns [0, 1, 2, 3, 4] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in c

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
842,23.0,0.0,2.0,1.0,1.0,0.0,4845.0,18.0,6,0
747,37.0,0.0,0.0,1.0,1.0,1.0,1274.0,24.0,1,0
972,29.0,0.0,0.0,2.0,1.0,1.0,3539.0,24.0,1,0
761,24.0,0.0,2.0,2.0,1.0,0.0,2124.0,14.0,4,0
4,53.0,1.0,2.0,0.0,1.0,1.0,2631.0,24.0,1,0


In [23]:
factual_df, counterfactual_df, refined_cf_df = sparsifier.get_all_results(
    limited_actions=42,
    # features_to_vary=['Saving accounts','Checking account','Credit amount','Duration','Purpose']
    )

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 1
INFO:shap:phi = [-0.00390544  0.28015951]
INFO:shap:phi = [ 0.00390544 -0.28015951]


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 1
INFO:shap:phi = [-0.19210347 -0.35355248]
INFO:shap:phi = [0.19210347 0.35355248]


  0%|          | 0/1 [00:00<?, ?it/s]

Found unknown categories in columns [0, 1, 2, 3, 4] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in columns [0, 1, 2, 3] during transform. These unknown categories will be encoded as all zeros
Found unknown categories in c

#### 5.Highlight the generated counterfactuals

In [24]:
counterfactual

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,23,0,2,1,1,0,4845,18,6,0
1,37,0,0,1,1,1,1274,24,1,0
2,29,0,0,2,1,1,3539,24,1,0
3,24,0,2,2,1,0,2124,14,4,0
4,53,1,2,0,1,1,2631,24,1,0


In [26]:
factual, ce, refine_ace = sparsifier.highlight_changes_comparison()
display(factual,ce,refine_ace)  # 显示 "1553 -> 1103"
# refine_ce.to_html('comparison.html')

      Age  Sex  Job Housing Saving accounts Checking account Credit amount  \
842  23.0  0.0  2.0     1.0             1.0              0.0        4845.0   
747  37.0  0.0  0.0     1.0             1.0              1.0        1274.0   
972  29.0  0.0  0.0     2.0             1.0              1.0        3539.0   
761  24.0  0.0  2.0     2.0             1.0              0.0        2124.0   
4    53.0  1.0  2.0     0.0             1.0              1.0        2631.0   

    Duration Purpose  Risk  
842     18.0     6.0     0  
747     24.0     1.0     0  
972     24.0     1.0     0  
761     14.0     4.0     0  
4       24.0     1.0     0  


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
842,23,0,2,1,1,0,1943,18,6,1
747,37,0,1,1,1,1,1274,12,1,1
972,29,0,0,2,1,1,1193,24,1,1
761,24,0,2,2,1,1,2124,18,4,1
4,53,1,2,0,1,1,4870,24,1,1


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
842,23.0,0 -> 0.0,2 -> 2.0,1.0,1 -> 1.0,0 -> 0.0,1943 -> 4845.0,18.000000,6 -> 6.0,1 -> 0
747,37.0,0 -> 0.0,1 -> 0.0,1.0,1 -> 1.0,1 -> 1.0,1274.000000,12 -> 24.0,1 -> 1.0,1 -> 0
972,29.0,0 -> 0.0,0 -> 0.0,2.0,1 -> 1.0,1 -> 1.0,1193 -> 3539.0,24.000000,1 -> 1.0,1 -> 0
761,24.0,0 -> 0.0,2 -> 2.0,2.0,1 -> 1.0,1 -> 0.0,2124.000000,18 -> 14.0,4 -> 4.0,1 -> 0
4,53.0,1 -> 1.0,2 -> 2.0,0.0,1 -> 1.0,1 -> 1.0,4870 -> 2631.0,24.000000,1 -> 1.0,1 -> 0


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
842,23.0,0,2,1.0,1 -> 1.0,0,1943 -> 4845.0,18,6,1 -> 0
747,37.0,0 -> 0.0,1 -> 0.0,1.0,1,1 -> 1.0,1274,12 -> 24.0,1,1 -> 0
972,29.0,0 -> 0.0,0 -> 0.0,2.0,1,1,1193 -> 3539.0,24.000000,1,1 -> 0
761,24.0,0,2 -> 2.0,2.0,1,1 -> 0.0,2124,18 -> 14.0,4,1 -> 0
4,53.0,1,2,0.0,1 -> 1.0,1,4870 -> 2631.0,24.000000,1,1 -> 0


In [None]:
factual_df, ce_style, ace_style = sparsifier.highlight_changes_final()
display(ce_style,ace_style)  # 只显示 "1103"
# ce_style.to_html('final.html')

In [None]:
# 二进制热图（显示是否改变）
sparsifier.heatmap_binary(save_path='./results', save_mode='combined',show_axis_labels=False)

# 方向性热图（显示变化方向）
sparsifier.heatmap_direction(save_path='./results', save_mode='combined',show_axis_labels=False)

In [None]:
fig = sparsifier.stacked_bar_chart(save_path='./results')

In [None]:
factual_df, diversity_styles = sparsifier.diversity()
for i, style in enumerate(diversity_styles):
    print(f"Instance {i+1} diversity:")
    display(style)

In [None]:
# 测试 diversity 功能（基于 refined counterfactual）
factual_df, diversity_styles = sparsifier.diversity()

# 显示第一个实例的 diversity 分析
print("Instance 1 diversity (based on refined counterfactual):")
display(diversity_styles[0])