In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from xai_cola.ce_sparsifier import COLA
from xai_cola.ce_sparsifier.data import COLAData
from xai_cola.ce_sparsifier.models import Model
from xai_cola.ce_generator import DiCE, DisCount

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


#### 0. Pick the data to be factual("Risk"=1)

In [31]:
from datasets.german_credit import GermanCreditDataset
dataset = GermanCreditDataset()
X_train, y_train, X_test, y_test = dataset.get_original_train_test_split()
df = dataset.get_dataframe()
print(df.head(3))

   Age  Sex  Job  Housing  Saving accounts  Checking account  Credit amount  \
0   67    1    2        1                0                 1           1169   
1   22    0    2        1                1                 2           5951   
2   49    1    1        1                1                 0           2096   

   Duration  Purpose  Risk  
0         6        5     0  
1        48        5     1  
2        12        3     0  


Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`


In [32]:
# pick 4 samples with Risk = 1
df_Risk_1 = df[df['Risk'] == 1]
df_Risk_1 = df_Risk_1.sample(10)

# Keep the target column for now
# We'll pass the full dataframe to COLAData
feature_names = df_Risk_1.drop(columns=['Risk']).columns
print(df_Risk_1.head())

     Age  Sex  Job  Housing  Saving accounts  Checking account  Credit amount  \
677   24    1    2        1                2                 2           5595   
528   31    1    2        2                1                 1           2302   
240   29    0    2        1                0                 1            915   
398   46    1    2        2                1                 2           1223   
611   48    0    1        0                2                 3           1240   

     Duration  Purpose  Risk  
677        72        5     1  
528        36        5     1  
240        24        1     1  
398        12        1     1  
611        10        1     1  


#### 1. Initialize data interface

In [49]:
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from lightgbm import LGBMClassifier
from xai_cola.ce_sparsifier.utils import create_pipeline_with_column_names, ensure_column_order

# 定义特征列表
numerical_features = ['Age', 'Housing', 'Credit amount', 'Duration']
categorical_features = ['Sex', 'Job', 'Saving accounts', 'Checking account', 'Purpose']

# # 步骤 1: 创建预处理器
# preprocessor = ColumnTransformer(
#     transformers=[
#         ('num', StandardScaler(), numerical_features),  # 数值特征标准化
#         ('cat', OneHotEncoder(drop='first', handle_unknown='ignore'), categorical_features)  # 类别特征独热编码
#     ],
#     remainder='passthrough'  # 保留其他列
# )

# # 步骤 2: 创建 LGBM 分类器（可以自定义参数）
# lgbm_clf = LGBMClassifier(
#     n_estimators=100,
#     learning_rate=0.1,
#     max_depth=5,
#     random_state=42,
#     verbose=-1  # 关闭训练时的输出
# )

# # 步骤 3: 创建 Pipeline
# pipe = Pipeline([
#     ('preprocessor', preprocessor),
#     ('classifier', lgbm_clf)
# ])

pipe, feature_order = create_pipeline_with_column_names(
    numerical_features=numerical_features,
    categorical_features=categorical_features,
    classifier=LGBMClassifier(random_state=42, verbose=-1),
    reference_dataframe=X_train
)
# 步骤 4: 训练 Pipeline（使用原始数据）
pipe.fit(X_train, y_train)

In [50]:
numerical_features = ['Age', 'Housing', 'Credit amount', 'Duration']
categorical_features = ['Sex', 'Job', 'Saving accounts', 'Checking account', 'Purpose']

In [51]:
data = COLAData(
    factual_data=df_Risk_1, 
    label_column='Risk',
    numerical_features=numerical_features
)
print(data)

COLAData(factual: 10 rows, features: 9, label: Risk, no counterfactual)


In [52]:
data.get_factual_all()

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
677,24,1,2,1,2,2,5595,72,5,1
528,31,1,2,2,1,1,2302,36,5,1
240,29,0,2,1,0,1,915,24,1,1
398,46,1,2,2,1,2,1223,12,1,1
611,48,0,1,0,2,3,1240,10,1,1
780,25,1,2,1,1,2,4933,39,5,1
302,37,1,1,1,0,3,1344,24,1,1
435,25,1,2,1,0,2,1484,12,5,1
273,28,1,2,1,1,2,3060,48,5,1
727,25,0,2,2,1,1,1882,18,5,1


#### 2. Initialize model interface

In [53]:
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
import joblib
# lgbmcClassifier = joblib.load('lgbm_GremanCredit.pkl')

# # ct = ColumnTransformer([
# #     ('cat', OneHotEncoder(drop='first'), categorical_features),
# #     ('num', StandardScaler(), numerical_features)
# # ])

# # pipe = Pipeline([
# #     ("pre", ct),                    # 必须有这一步！
# #     ("clf", LGBMClassifier())       # 分类器
# # ])

# # pipe.fit(X_train, y_train)

# pipe = Pipeline([
#     ("clf", lgbmcClassifier) 
# ])

# ml_model = Model(model=lgbmcClassifier, backend="sklearn")

ml_model = Model(model=pipe, backend="sklearn")

#### 3.Choose the CounterfactualExplanation Algorithm

In [54]:
from xai_cola.ce_generator import DiCE, DisCount, ARecourseS, AlibiCounterfactualInstances
explainer = DiCE(ml_model=ml_model)
factual, counterfactual = explainer.generate_counterfactuals(
    data=data,
    factual_class=1,
    total_cfs=1,
    features_to_keep=['Age','Sex']
)

100%|██████████| 10/10 [00:00<00:00, 11.05it/s]


In [55]:
factual

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
677,24,1,2,1,2,2,5595,72,5,1
528,31,1,2,2,1,1,2302,36,5,1
240,29,0,2,1,0,1,915,24,1,1
398,46,1,2,2,1,2,1223,12,1,1
611,48,0,1,0,2,3,1240,10,1,1
780,25,1,2,1,1,2,4933,39,5,1
302,37,1,1,1,0,3,1344,24,1,1
435,25,1,2,1,0,2,1484,12,5,1
273,28,1,2,1,1,2,3060,48,5,1
727,25,0,2,2,1,1,1882,18,5,1


In [56]:
counterfactual

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,24,1,2,1,0,2,3754,72,5,0
1,31,1,2,2,1,1,2302,10,5,0
2,29,0,2,1,0,3,2059,24,1,0
3,46,1,2,2,1,2,2977,37,1,0
4,48,0,1,0,2,3,5301,10,1,0
5,25,1,2,1,0,3,4933,39,5,0
6,37,1,1,0,0,3,1838,24,1,0
7,25,1,2,1,0,2,3775,12,5,0
8,28,1,2,1,1,2,5455,20,5,0
9,25,0,2,2,1,1,3741,18,3,0


In [57]:
data.summary()

{'factual_samples': 10,
 'feature_count': 9,
 'label_column': 'Risk',
 'all_columns': ['Age',
  'Sex',
  'Job',
  'Housing',
  'Saving accounts',
  'Checking account',
  'Credit amount',
  'Duration',
  'Purpose',
  'Risk'],
 'has_counterfactual': False}

#### 4. Choose policy and make limitation

In [58]:
# Add counterfactual data to COLAData object first
data.add_counterfactuals(counterfactual, with_target_column=True)

In [59]:
data

COLAData(factual: 10 rows, features: 9, label: Risk, counterfactual: 10 rows)

In [60]:
data.summary()

{'factual_samples': 10,
 'feature_count': 9,
 'label_column': 'Risk',
 'all_columns': ['Age',
  'Sex',
  'Job',
  'Housing',
  'Saving accounts',
  'Checking account',
  'Credit amount',
  'Duration',
  'Purpose',
  'Risk'],
 'has_counterfactual': True,
 'counterfactual_samples': 10}

In [61]:
# Initialize COLA - it will automatically extract factual and counterfactual from data
sparsifier = COLA(
    data=data,
    ml_model=ml_model
)

sparsifier.set_policy(
    matcher="ot",
    attributor="pshap",
    Avalues_method="max",
    random_state=25
)

Policy set: pshap with Optimal Transport Matching, Avalues_method: max


In [63]:
# 1. 检查数据维度
print("=" * 50)
print("Data Dimensions Check")
print("=" * 50)
print(f"Original features: {len(numerical_features + categorical_features)}")
print(f"Factual data shape (from COLAData): {data.get_factual_features().shape}")
print(f"Counterfactual data shape (from COLAData): {data.get_counterfactual_features().shape}")

# 2. 检查 Pipeline 转换后的维度
sample_data = data.get_factual_features().head(1)
transformed = pipe.named_steps['preprocessor'].transform(sample_data)
print(f"\nTransformed data shape: {transformed.shape}")
print(f"Features after preprocessing: {transformed.shape[1]}")

# 3. 检查 varphi 的维度
try:
    varphi = sparsifier._get_attributor()
    print(f"\nvarphi shape: {varphi.shape}")
    print(f"varphi ndim: {varphi.ndim}")
except Exception as e:
    print(f"\nError getting varphi: {e}")

# 4. 检查 Model 是否正确识别为 Pipeline
print(f"\nIs pipeline: {ml_model.is_pipeline}")

X does not have valid feature names, but StandardScaler was fitted with feature names
X does not have valid feature names, but OneHotEncoder was fitted with feature names
X does not have valid feature names, but StandardScaler was fitted with feature names
X does not have valid feature names, but OneHotEncoder was fitted with feature names
INFO:shap:num_full_subsets = 2
X does not have valid feature names, but StandardScaler was fitted with feature names
X does not have valid feature names, but OneHotEncoder was fitted with feature names
INFO:shap:phi = [-6.21055141e-02  6.74526701e-02  4.55364912e-18 -6.87386406e-01]
INFO:shap:phi = [ 6.21055141e-02 -6.74526701e-02  6.50521303e-19  6.87386406e-01]
X does not have valid feature names, but StandardScaler was fitted with feature names
X does not have valid feature names, but OneHotEncoder was fitted with feature names
X does not have valid feature names, but StandardScaler was fitted with feature names
X does not have valid feature names

Data Dimensions Check
Original features: 9
Factual data shape (from COLAData): (10, 9)
Counterfactual data shape (from COLAData): (10, 9)

Transformed data shape: (1, 22)
Features after preprocessing: 22

varphi shape: (10, 2, 9)
varphi ndim: 3

Is pipeline: True


In [62]:
sparsifier.query_minimum_actions()

X does not have valid feature names, but StandardScaler was fitted with feature names
X does not have valid feature names, but OneHotEncoder was fitted with feature names
X does not have valid feature names, but StandardScaler was fitted with feature names
X does not have valid feature names, but OneHotEncoder was fitted with feature names
INFO:shap:num_full_subsets = 2
X does not have valid feature names, but StandardScaler was fitted with feature names
X does not have valid feature names, but OneHotEncoder was fitted with feature names
INFO:shap:phi = [-6.21055141e-02  6.74526701e-02  4.55364912e-18 -6.87386406e-01]
INFO:shap:phi = [ 6.21055141e-02 -6.74526701e-02  6.50521303e-19  6.87386406e-01]
X does not have valid feature names, but StandardScaler was fitted with feature names
X does not have valid feature names, but OneHotEncoder was fitted with feature names
X does not have valid feature names, but StandardScaler was fitted with feature names
X does not have valid feature names

ValueError: too many values to unpack (expected 2)

In [49]:
""" Here! control the limited actions """
# Set random seed for reproducibilit

refined_cf = sparsifier.get_refined_counterfactual(limited_actions=52)
refined_cf

INFO:shap:num_full_subsets = 3
INFO:shap:phi = [-0.03707051 -0.13978683  0.07460411  0.19879756 -0.0754925   0.11187736]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.05091269  0.04994483  0.10267253  0.1444564  -0.19084643  0.53886165
 -0.03983242  0.01731757]
INFO:shap:num_full_subsets = 3
INFO:shap:phi = [ 0.08041051 -0.04309788 -0.08829535  0.01530427 -0.23562114  0.38311789
  0.03888616]
INFO:shap:num_full_subsets = 1
INFO:shap:phi = [0.20631759 0.48375949]
INFO:shap:num_full_subsets = 3
INFO:shap:phi = [-0.16479924 -0.0422279  -0.03383711 -0.17210329  0.32182552  0.10074208
 -0.13131242]
INFO:shap:num_full_subsets = 3
INFO:shap:phi = [-0.09787612  0.03926393  0.20765632  0.14388528  0.39831383 -0.14938047
  0.04482932]
INFO:shap:num_full_subsets = 3
INFO:shap:phi = [-0.02215658 -0.07842826  0.03718402  0.12196376  0.34182635  0.31532637
 -0.13496945]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.18919298  0.02669583  0.02945507 -0.16654265  0.24992408  0.07299589
  0.11

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
601,30,0,2,2,1,1,2103,9,5,0
602,34,0,2,1,0,1,3657,42,5,0
169,31,0,1,1,1,2,5163,24,0,0
332,24,0,1,1,2,2,7408,12,1,0
76,29,0,0,2,1,1,5481,24,1,0
727,31,1,2,1,1,2,3871,24,0,0
15,32,0,2,1,1,1,3154,9,4,0
557,34,0,2,0,1,2,7217,21,3,0
972,20,0,2,1,1,1,3095,12,5,0
188,29,0,2,1,0,0,674,12,1,0


In [50]:
factual_df, counterfactual_df, refined_cf_df = sparsifier.get_all_results(
    limited_actions=52,
    # features_to_vary=['Saving accounts','Checking account','Credit amount','Duration','Purpose']
    )

INFO:shap:num_full_subsets = 3
INFO:shap:phi = [-0.03707051 -0.13978683  0.07460411  0.19879756 -0.0754925   0.11187736]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.05091269  0.04994483  0.10267253  0.1444564  -0.19084643  0.53886165
 -0.03983242  0.01731757]
INFO:shap:num_full_subsets = 3
INFO:shap:phi = [ 0.08041051 -0.04309788 -0.08829535  0.01530427 -0.23562114  0.38311789
  0.03888616]
INFO:shap:num_full_subsets = 1
INFO:shap:phi = [0.20631759 0.48375949]
INFO:shap:num_full_subsets = 3
INFO:shap:phi = [-0.16479924 -0.0422279  -0.03383711 -0.17210329  0.32182552  0.10074208
 -0.13131242]
INFO:shap:num_full_subsets = 3
INFO:shap:phi = [-0.09787612  0.03926393  0.20765632  0.14388528  0.39831383 -0.14938047
  0.04482932]
INFO:shap:num_full_subsets = 3
INFO:shap:phi = [-0.02215658 -0.07842826  0.03718402  0.12196376  0.34182635  0.31532637
 -0.13496945]
INFO:shap:num_full_subsets = 4
INFO:shap:phi = [ 0.18919298  0.02669583  0.02945507 -0.16654265  0.24992408  0.07299589
  0.11

#### 5.Highlight the generated counterfactuals

In [51]:
refine_factual, refine_ce, refine_ace = sparsifier.highlight_changes_comparison()
display(refine_ce,refine_ace)  # 显示 "1553 -> 1103"
# refine_ce.to_html('comparison.html')

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
601,30 -> 25,0,2,1 -> 2,1,2 -> 1,918 -> 2103,9 -> 12,4 -> 5,1 -> 0
602,34,0 -> 1,1 -> 2,0 -> 1,1 -> 0,2 -> 1,1837 -> 3657,24 -> 42,3 -> 5,1 -> 0
169,31 -> 32,1 -> 0,2 -> 1,1,1 -> 2,2 -> 1,1935 -> 5163,24,0 -> 5,1 -> 0
332,24,0,3 -> 1,1,2,2,7408,60 -> 12,1,1 -> 0
76,34 -> 29,1 -> 0,2 -> 0,1 -> 2,1,1,3965 -> 5481,42 -> 24,5 -> 1,1 -> 0
727,25 -> 31,0 -> 1,2,2 -> 1,1,1 -> 2,1882 -> 3871,18 -> 24,5 -> 0,1 -> 0
15,32 -> 30,0,1 -> 2,1,2 -> 1,1 -> 2,1282 -> 3154,24 -> 9,5 -> 4,1 -> 0
557,29 -> 34,0,2 -> 1,1 -> 0,0 -> 1,0 -> 2,5003 -> 7217,21 -> 24,1 -> 3,1 -> 0
972,29 -> 20,0 -> 1,0 -> 2,2 -> 1,1 -> 2,1,1193 -> 3095,24 -> 12,1 -> 5,1 -> 0
188,20 -> 29,1 -> 0,2,1,2 -> 0,1 -> 0,674 -> 677,12 -> 21,5 -> 1,1 -> 0


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
601,30,0,2,1 -> 2,1,2 -> 1,918 -> 2103,9,4 -> 5,1 -> 0
602,34,0,1 -> 2,0 -> 1,1 -> 0,2 -> 1,1837 -> 3657,24 -> 42,3 -> 5,1 -> 0
169,31,1 -> 0,2 -> 1,1,1,2,1935 -> 5163,24,0,1 -> 0
332,24,0,3 -> 1,1,2,2,7408,60 -> 12,1,1 -> 0
76,34 -> 29,1 -> 0,2 -> 0,1 -> 2,1,1,3965 -> 5481,42 -> 24,5 -> 1,1 -> 0
727,25 -> 31,0 -> 1,2,2 -> 1,1,1 -> 2,1882 -> 3871,18 -> 24,5 -> 0,1 -> 0
15,32,0,1 -> 2,1,2 -> 1,1,1282 -> 3154,24 -> 9,5 -> 4,1 -> 0
557,29 -> 34,0,2,1 -> 0,0 -> 1,0 -> 2,5003 -> 7217,21,1 -> 3,1 -> 0
972,29 -> 20,0,0 -> 2,2 -> 1,1,1,1193 -> 3095,24 -> 12,1 -> 5,1 -> 0
188,20 -> 29,1 -> 0,2,1,2 -> 0,1 -> 0,674,12,5 -> 1,1 -> 0


In [52]:
factual_df, ce_style, ace_style = sparsifier.highlight_changes_final()
display(ce_style,ace_style)  # 只显示 "1103"
# ce_style.to_html('final.html')

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
601,25,0,2,2,1,1,2103,12,5,0
602,34,1,2,1,0,1,3657,42,5,0
169,32,0,1,1,2,1,5163,24,5,0
332,24,0,1,1,2,2,7408,12,1,0
76,29,0,0,2,1,1,5481,24,1,0
727,31,1,2,1,1,2,3871,24,0,0
15,30,0,2,1,1,2,3154,9,4,0
557,34,0,1,0,1,2,7217,24,3,0
972,20,1,2,1,2,1,3095,12,5,0
188,29,0,2,1,0,0,677,21,1,0


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
601,30,0,2,2,1,1,2103,9,5,0
602,34,0,2,1,0,1,3657,42,5,0
169,31,0,1,1,1,2,5163,24,0,0
332,24,0,1,1,2,2,7408,12,1,0
76,29,0,0,2,1,1,5481,24,1,0
727,31,1,2,1,1,2,3871,24,0,0
15,32,0,2,1,1,1,3154,9,4,0
557,34,0,2,0,1,2,7217,21,3,0
972,20,0,2,1,1,1,3095,12,5,0
188,29,0,2,1,0,0,674,12,1,0


In [None]:
# 二进制热图（显示是否改变）
sparsifier.heatmap_binary(save_path='./results', save_mode='combined',show_axis_labels=False)

# 方向性热图（显示变化方向）
sparsifier.heatmap_direction(save_path='./results', save_mode='combined',show_axis_labels=False)

In [None]:
fig = sparsifier.stacked_bar_chart(save_path='./results')

In [None]:
factual_df, diversity_styles = sparsifier.diversity()
for i, style in enumerate(diversity_styles):
    print(f"Instance {i+1} diversity:")
    display(style)

In [None]:
# 测试 diversity 功能（基于 refined counterfactual）
factual_df, diversity_styles = sparsifier.diversity()

# 显示第一个实例的 diversity 分析
print("Instance 1 diversity (based on refined counterfactual):")
display(diversity_styles[0])