# Example: WeightedSHAP on the fraud dataset

- In this notebook, we introduce a generalized feature attribution method **weightedSHAP**.

In [None]:
import sys, os
import numpy as np
import pickle 
np.random.seed(2022)
import analysis_utils

sys.path.append('../')
import weightedSHAP

## Load data
- We use the fraud dataset (https://www.openml.org/search?type=data&status=active&id=42397).
- A function `weightedSHAP.load_data` will load the `train`, `val`, `est`, and `test` datasets. 
 - `train`: to train a model to explain
 - `val`: to optimize hyperparameters
 - `est`: to estimate coalition functions
 - `test`: to evaluate the quality of feature attributions 

In [None]:
# Load dataset
dir_path='./'
problem='classification' 
dataset='fraud'
ML_model='boosting' 
(X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test)=weightedSHAP.load_data(problem, dataset, dir_path)    


## Train a model to explain
 - This step is a typical routine in machine learning. Given training and validation datasets, we train a model. Our goal is to interpret this model by looking a particular prediction (i.e., a local attribution problem).

In [None]:
# train a baseline model
model_to_explain=weightedSHAP.create_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model)

## Compute attributions and evaluate its performance
- `weightedSHAP.run_attribution_core` computes conditional expectation values $\mathbb{E}[f(X) \mid X_S = x_S]$ based on various weighted versions $\phi_\mathbf{w} (x_i) := \sum_{j=1} ^{d} w_j \Delta_{j}(x_i)$ of SHAP (Equation (5) of the paper). We consider a set $\mathcal{W}$ which includes SHAP as well.
- We store them in `exp_dict` which includes
 - `cond_pred_keep_absolute` (add features with large absolute values first)
 - `cond_pred_remove_absolute` (remove features with large absolute values first)
 - `pred_masking` (add features with large absolute values first and other featuers are masked with zero).

In [4]:
if not os.path.exists(f'{dir_path}/fraud_example.pickle'):
    # Generate a conditional coalition function
    # To efficiently obtain a conditional coalition function, it internally trains a surrogate model. 
    conditional_extension=weightedSHAP.generate_coalition_function(model_to_explain, X_train, X_est, problem, ML_model)
    
    # With the conditional coalition function, we compute attributions
    exp_dict=weightedSHAP.compute_attributions(problem, ML_model,
                                                 model_to_explain, conditional_extension,
                                                 X_train, y_train,
                                                 X_val, y_val, 
                                                 X_test, y_test)

    with open(f'{dir_path}/fraud_example.pickle', 'wb') as handle:
        pickle.dump(exp_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
else:
    with open(f'{dir_path}/fraud_example.pickle', 'rb') as handle:
        exp_dict = pickle.load(handle)


## WeightedSHAP and Prediction recovery error curve 
 - We obtain weightedSHAP (Equation (6) of the paper)
 \begin{align*}
    \phi_{\mathrm{WeightedSHAP}}(\mathcal{T}, \mathcal{W}) := \phi_{\mathbf{w}^* (\mathcal{T}, \mathcal{W})},
 \end{align*} 
 where $\mathbf{w}^* (\mathcal{T}, \mathcal{W}) := \mathrm{argmax}_{\mathbf{w} \in \mathcal{W}} \mathcal{T} (\phi_{\mathbf{w}})$.
 - A default implementation for $\mathcal{T}$ includes the negative area under the prediction recovery error curve (AUP, Equation (4) of the paper) with `cond_pred_keep_absolute`. Note that `exp_dict` contains `remove_abolute` and `masking` as well.

In [5]:
exp_dict.keys()

dict_keys(['true_list', 'pred_list', 'cond_pred_keep_absolute', 'cond_pred_remove_absolute', 'pred_masking'])

In [None]:
y_test=np.array(exp_dict['true_list'])
pred_list=np.array(exp_dict['pred_list'])
cond_pred_keep_absolute=np.array(exp_dict['cond_pred_keep_absolute'])

# Find a optimal weight and construct WeightedSHAP
optimal_ind_keep_absolute_list=analysis_utils.find_optimal_list(cond_pred_keep_absolute, pred_list) 

# {13:MCI, 6:SHAP}
cond_pred_keep_absolute_short=np.array((cond_pred_keep_absolute[:,13,:],
                                        cond_pred_keep_absolute[:,6,:],
                                        cond_pred_keep_absolute[np.arange(len(optimal_ind_keep_absolute_list)),
                                        optimal_ind_keep_absolute_list,:])).transpose((1,0,2))
recovery_curve_keep_absolute=np.mean(np.abs(cond_pred_keep_absolute_short - pred_list.reshape(-1,1,1)), axis=0)


In [None]:
n_features=len(recovery_curve_keep_absolute[0])
n_display_features=int(n_features*0.6)

plt.plot(recovery_curve_keep_absolute[0][max(1,int(n_features*0.075)):n_display_features],
         label='MCI', color='blue', linewidth=2, alpha=0.6)
plt.plot(recovery_curve_keep_absolute[1][max(1,int(n_features*0.075)):n_display_features],
         label='Shapley', color='green', linewidth=2, alpha=0.6)
plt.plot(recovery_curve_keep_absolute[2][max(1,int(n_features*0.075)):n_display_features],
         label='WeightedSHAP', color='red', linewidth=2, alpha=0.6)
plt.legend(fontsize=12)
xlabel_text='Number of features added' 
plt.title(f'Prediction recovery error curve, Dataset: fraud \n the lower, the better', fontsize=15)
plt.xticks(np.arange(n_features)[max(1,int(n_features*0.075)):n_display_features][::n_display_features//6],
               np.arange(n_features)[max(1,int(n_features*0.075)):n_display_features][::n_display_features//6])
plt.xlabel(xlabel_text, fontsize=15)
plt.ylabel(r'$|f(x)-\mathbb{E}[f(X) \mid X_S = x_S]|$', fontsize=15)
    