In [288]:
from data_module.data_module import Data
from dataset.german_credit import GermanCreditDataset
import pandas as pd
import joblib
from model_module.pretrained_model import PreTrainedModel

import numpy as np
import dice_ml
from sklearn.model_selection import train_test_split
import ot

In [289]:
%reload_ext autoreload
%autoreload 2

### 1. Data (x_factual)

In [290]:
dataset = GermanCreditDataset()
df = dataset.get_dataframe()  
df.head(3)         ## dataset with target name

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,67,1,2,1,0,1,1169,6,5,0
1,22,0,2,1,1,2,5951,48,5,1
2,49,1,1,1,1,0,2096,12,3,0


In [291]:
data = Data(dataset=df, target_name='Risk')
x_factual = data.get_x()
y_factual = data.get_y()
x_labels = data.get_x_labels()

### 2. Model (pre-trained model)

In [292]:
# model_filename = 'trained_LGBMClassifier.pkl'
# loaded_model = joblib.load(model_filename)
# print(f"-----Model has already been downloaded-----")
# #### normalized data

In [293]:
model_filename2 = 'LGBMClassifier.pkl'
LGBMClassifier = joblib.load(model_filename2)
print(f"-----Model has already been downloaded-----")
#### non-normalized data

-----Model has already been downloaded-----


In [294]:
# X_train, X_test, y_train, y_test = dataset.get_standardized_train_test_split(random_state=None)
x_train, x_test, y_train, y_test = train_test_split(x_factual, y_factual, test_size=0.2, random_state=None)

In [295]:
lgbmclassifier = PreTrainedModel(model_path=model_filename2, backend='sklearn')

In [314]:
loaded_model = lgbmclassifier.load_model()

----LGBMClassifier.pkl model has been loaded----


In [297]:
y_factual = lgbmclassifier.predict(x_factual=x_test)

---- predictions have been made----


In [298]:
# n=0
# for i in range(len(y_factual_model_2)):
#     if y_factual_model_2['Prediction'][i] == y_factual[i]:
#         n=n+1
#     else:
#         n=n
# print(n)
# ###  200个数据中只有124个预测对了

### 3. CE model

In [308]:
from ce_module.ce_models import DiCE

In [316]:
dice= DiCE(ml_model=loaded_model, x_factual=x_test, target_name='Risk', sample_num=4)

In [317]:
x_factual, y_factual, x_counterfactual, y_counterfactual = dice.generate_x_counterfactuals()

100%|██████████| 4/4 [00:00<00:00, 10.39it/s]

---- x_counterfactual has already been generated ----
---- y_counterfactual has already been generated ----





In [318]:
print(f'---The X_factual is ---')
print(x_factual)
print(f'---The Y_factual is ---')
print(y_factual)
print(f'---The X_counterfactual is --- ')
print(x_counterfactual)
print(f'---The Y_counterfactual is --- ')
print(y_counterfactual)

---The X_factual is ---
[[  40    1    3    1    0    1 1358   24    7]
 [  22    1    2    1    1    2 3832   30    4]
 [  23    1    2    2    1    2 1534   12    5]
 [  39    0    2    1    1    2 1188   21    0]]
---The Y_factual is ---
[1 0 1 1]
---The X_counterfactual is --- 
[[  48    1    3    1    3    1 1358   24    7]
 [  65    1    2    2    1    2 1534    8    5]
 [  39    0    2    0    1    0 1188   21    0]
 [  22    1    2    1    1    2 5817   27    4]]
---The Y_counterfactual is --- 
[0 0 0 0]


### 4. Probability of joint distribution (Policy)

In [319]:
from policy_module import policy

In [321]:
policy = policy.compute_intervention_policy(
                            model=lgbmclassifier,   #没有使用我们自己模块生成的model
                            X_train=x_train,
                            X_factual=x_factual,
                            X_counterfactual=x_counterfactual,
                            shapley_method="CF_OTMatch",
                            Avalues_method='max', # 'avg'
                        )


AttributeError: 'dict' object has no attribute 'compute_intervention_policy'

In [None]:
varphi = policy['varphi']
p = policy['p']
q = policy['q']

### 6. get the q, use the shapley value to change the specific position And get the Z.

In [None]:
action = 5

In [None]:
# 1. 找到 varphi 矩阵中概率最高的前 action 个数及其位置
flat_indices = np.argpartition(varphi.flatten(), -action)[-action:]
row_indices, col_indices = np.unravel_index(flat_indices, varphi.shape)

# 2. 在 q 中找到这些位置对应的值
q_values = q[row_indices, col_indices]

# 3. 将 x_factual 中对应位置的值替换为 q 中找到的相应值
x_action_constrained = x_factual.copy()

for row_idx, col_idx, q_val in zip(row_indices, col_indices, q_values):
    x_action_constrained[row_idx, col_idx] = q_val

# 打印结果
print("原始 x_factual:")
print(x_factual)
print("\nDiCE处理后的x_counterfactual")
print(x_counterfactual)
print("\naction constrained  x_action_constrained:")
print(x_action_constrained)

原始 x_factual:
[[  25    1    0    1    1    1 2473   18    4]
 [  21    1    1    2    1    1 1987   24    5]
 [  26    0    2    1    1    2 9960   48    4]
 [  35    1    2    1    1    1  691   12    1]]

DiCE处理后的x_counterfactual
[[  26    0    2    2    0    2 9960   48    4]
 [  35    1    2    1    2    1 5557   12    1]
 [  65    1    0    1    2    1 2473   18    4]
 [  21    1    1    2    1    1  476   24    5]]

action constrained  x_action_constrained:
[[  25    1    0    1    1    1 5557   18    4]
 [  65    1    1    2    1    1 1987   24    5]
 [  26    0    2    1    0    2 9960   48    4]
 [  35    1    2    1    1    1  476   12    5]]


In [None]:
x_factual= pd.DataFrame(x_factual)
x_factual.columns = x_labels
x_factual['Risk'] = y_factual
x_factual

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,25,1,0,1,1,1,2473,18,4,1
1,21,1,1,2,1,1,1987,24,5,1
2,26,0,2,1,1,2,9960,48,4,1
3,35,1,2,1,1,1,691,12,1,1


In [None]:
x_action_constrained_ce = pd.DataFrame(x_action_constrained)
y_action_constrained_ce = lgbmclassifier.predict(x_action_constrained_ce)
x_action_constrained_ce.columns = x_labels
x_action_constrained_ce['Risk'] = y_action_constrained_ce
x_action_constrained_ce

----LGBMClassifier.pkl model has been loaded----
---- predictions have been made----


Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,25,1,0,1,1,1,5557,18,4,0
1,65,1,1,2,1,1,1987,24,5,0
2,26,0,2,1,0,2,9960,48,4,0
3,35,1,2,1,1,1,476,12,5,1
