In [None]:
from pprint import pprint
import copy
import sys

In [None]:
import pandas as pd
import numpy as np
import sklearn
import shap
import xgboost as xgb

In [None]:
sys.path.append("/path_to_repo/explaining_cirrus") # specify directory 

In [None]:
from src.ml_pipeline.instantaneous.ml_preprocess import create_dataset
from src.ml_pipeline.instantaneous.experiment import evaluate_model, run_experiment


# Train & Evaluate ML model trained on instantaneous data

## Specify Experiment config

* *filters*: conditions the dataset should be filtered on 
* *predictors*: column names of predictor variables
* *predictand*: column name of target variable, must be element of ['iwc', 'icnc_5um']
* *preproc_steps*: preprocessing steps to be conducted
    * x_log_trans: logarithmic transformation of aerosol variables (recommended)
    * y_log_trans: logarithmic transformation of target variable (recommended)
    * kickout_outliers: if True, outliers (of target variable) are removed
    * oh_encoding: if True, do one-hot encoding of categorical variables
* *random_state*: specify random state for splitting trian/val/test splits and model training

In [None]:
experiment_config =    {
        "filters": ["nightday_flag ==1"],
        "predictors": ["t",
                       "w",
                       "wind_speed",
                       "DU_sup",
                       "DU_sub",
                       "SO4",
                       "dz_top_v2",
                       "cloud_thickness_v2",
                       "surface_height",
                       "season",
                       "land_water_mask",
                       "lat_region"
                       ],
        "predictand": "icnc_5um",
        "preproc_steps": {
            "x_log_trans": True,
            "y_log_trans": True,
            "kickout_outliers": False,
            "oh_encoding": True
        },
        "random_state": 53
    }

## Speciy XGBoost Hyperparameters

In [None]:
xgboost_config = {"objective": "reg:squarederror", 'subsample': 0.4, "colsample_bytree": 0.8, 'learning_rate': 0.02,
                  'max_depth': 15, 'alpha': 38, 'lambda': 7, 'n_estimators': 250, "n_jobs": 32}

## Load Dataset

In [None]:
# load instantaneous dataset 
inst_data_set_path = "/path_to_instantaneous_data/instantaneous.csv" # specify path to instantaneous dataset
df = pd.read_csv(inst_data_set_path)

## Linear Regression Baseline

In [None]:
# create dataset based on experiment config
X_train, X_val, X_test, y_train, y_val, y_test = create_dataset(df, **experiment_config)

In [None]:
lin_reg = sklearn.linear_model.LinearRegression().fit(X_train.values, y_train.values)

In [None]:
evaluate_model(lin_reg, X_test, y_test)

## Run experiment

In [None]:
model, validate_df = run_experiment(df, xgboost_config, experiment_config)

# XAI

## Helpers & setup

In [None]:
VARIABLE_UNITS = {
    "iwc":"mg m⁻³",
    "reffcli": "um",
    "icnc_5um": "cm⁻³",
    "icnc_100um": "cm⁻³",
    "SO4":"mg kg⁻¹",
    "DU":"mg kg⁻¹",
    "DU_sub":"mg kg⁻¹",
    "DU_sup":"mg kg⁻¹",
    "lev":"m",
    "cloud_thickness":"m",
    "wind_speed": "m s⁻¹",
    "w": "Pa s⁻¹",
    "t": "K",
    "surface_height": "m",
}

In [None]:
# create dataset based on experiment config
X_train, X_val, X_test, y_train, y_val, y_test = create_dataset(df, **experiment_config)

In [None]:
# discriminate caterforical and continuous features, needed for LIME and XAI evaluation metrics
feature_type = ["c" for i in range(9)] + ["d" for i in range(19)] # first 8 features are continuous, the rest are the one-hot encoded categorical features
continuous_features = np.array([i == 'c' for i in feature_type])
discrete_features = np.array([i == 'd' for i in feature_type])

## Calculate Attributions

In [None]:
# create np.array from dataframes
X_test_sample = X_test.values
y_test_sample = y_test.values

In [None]:
# use an independent masker
masker = shap.maskers.Independent(X_train,max_samples=1000)

### Random explainer

In [None]:
random_explainer = shap.explainers.other.Random(model.predict, masker)

In [None]:
random_exps = random_explainer(X_test_sample).values

In [None]:
shap.summary_plot(random_exps,X_test_sample, feature_names=X_train.columns, plot_type="bar")

### SHAP

In [None]:
# calculate shap_values
shap_explainer = shap.TreeExplainer(model)
print("created explainer")
shap_values = shap_explainer.shap_values(X_test_sample, approximate=False, check_additivity=True)
print("calculated {} shap values".format(shap_values.shape[0]))

### LIME

In [None]:
lime_explainer = shap.explainers.other.LimeTabular(model.predict, X_test, mode="regression")

In [None]:
cat_indices = list(np.where(discrete_features)[0])

In [None]:
lime_explainer.explainer.categorical_features = cat_indices

In [None]:
lime_attributions = lime_explainer.attributions(X_test_sample)

## Evaluate attribution methods

In [None]:
from src.ml_pipeline.xai_evaluation.xai_evaluation_metrics import eval_faithfulness, evaluate_stability_metric, MarginalPerturbation

In [None]:
# calculate background dataset as means per column
base_values = np.mean(X_test.values,0) * continuous_features + np.round(np.mean(X_test.values,0) ,0) * discrete_features

### Faithfulness metrics

* Estimated Faithfulness

adapted from Alvarez-Melis and Jaakkola https://doi.org/10.48550/arXiv.1806.07538

In [None]:
# shap faithfulness
print("shap faithfulness")
shap_pred_corr_faith = eval_faithfulness(X_test_sample, y_test_sample, shap_values, model, base_values)
print("\n")

# lime faithfulness
print("lime faithfulness")
lime_pred_corr_faith = eval_faithfulness(X_test_sample, y_test_sample, lime_attributions, model, base_values)
print("\n")

# lime faithfulness
print("random faithfulness")
random_pred_corr_faith = eval_faithfulness(X_test_sample, y_test_sample, random_exps, model, base_values)
print("\n")

In [None]:
faithfulness_df = pd.DataFrame(data=np.array((shap_pred_corr_faith, lime_pred_corr_faith, random_pred_corr_faith)).T, columns=["SHAP", "LIME", "RandomBaseline"])

### Stability

* Relative Input Stability
* Relative Outout Stability

Adapted from Agarwal et al. 2022 https://doi.org/10.48550/arXiv.2203.06877

In [None]:
# for each column draw from marginal distribution with mean=0 and std=(columns std / 100) to create marginal samples
col_dist_stds = []
for col in X_train:
    col_dist_stds.append(np.std(X_train)[col]/100)

perturber = MarginalPerturbation(col_dist_stds)

In [None]:
ris_shap_stability = []
ros_shap_stability = []
ris_lime_stability = []
ros_lime_stability = []
ris_rand_stability = []
ros_rand_stability = []

# eval stabilit metrics for each xai method
for sample in X_test_sample[:10]:
    # shap
    ris_shap_stability.append(evaluate_stability_metric(sample, model, shap_explainer, perturber, feature_mask=continuous_features, stability_metric="RIS"))
    ros_shap_stability.append(evaluate_stability_metric(sample, model, shap_explainer, perturber, feature_mask=continuous_features, eps=0.0001, stability_metric="ROS"))
    
    # lime
    ris_lime_stability.append(evaluate_stability_metric(sample, model, lime_explainer, perturber, feature_mask=continuous_features, stability_metric="RIS"))
    ros_lime_stability.append(evaluate_stability_metric(sample, model, lime_explainer, perturber, feature_mask=continuous_features, eps=0.0001, stability_metric="ROS"))
    
    # random
    ris_rand_stability.append(evaluate_stability_metric(sample, model, random_explainer, perturber, feature_mask=continuous_features, stability_metric="RIS"))
    ros_rand_stability.append(evaluate_stability_metric(sample, model, random_explainer, perturber, feature_mask=continuous_features, eps=0.0001, stability_metric="ROS"))
    
    

In [None]:
ris_stability_df = pd.DataFrame(data=np.array((ris_shap_stability, ris_lime_stability,ris_rand_stability)).T, columns=["SHAP", "LIME","RandomBaseline"])
ris_stability_df.insert(0,"metric","ris")

ros_stability_df = pd.DataFrame(data=np.array((ros_shap_stability,ros_lime_stability,ros_rand_stability)).T, columns=["SHAP", "LIME","RandomBaseline"])
ros_stability_df.insert(0,"metric","ros")

In [None]:
ris_stability_df

### Create XAI Evaluation Plot

In [None]:
import hvplot.pandas
import holoviews as hv

In [None]:
plt_options = {'fontsize': {'xlabel': '30px',
  'ylabel': '25px',
  'ticks': '20px'},
 'legend': False,  
 'box_color': 'Variable',
 'cmap': 'Set3',
 'outlier_alpha':0.001,              
 'box_width':0.8, 
 'width': 500}

In [None]:
ris_plt = ris_stability_df.hvplot.box(logy=True,xlabel="(A) Relative Input Stability", ylabel="Stability", yticks=(1e-2,1e-1,1e0,1e1,1e2,1e3,1e4,1e5,1e6),ylim=[1e-3,1e6],**plt_options)
ros_plt = ros_stability_df.hvplot.box(logy=True,xlabel="(B) Relative Output Stability", ylabel="Stability", yticks=(1e-2,1e-1,1e0,1e1,1e2,1e3,1e4,1e5,1e6),ylim=[1e-3,1e6], **plt_options) 
faith_plt = faithfulness_df.hvplot.box(xlabel="(C) Estimated Faithfulness", shared_axes=False, ylabel="Estimated Faithfulness", yticks=[-1,-0.5,0,0.5,1], **plt_options)

In [None]:
xai_eval_plt = ris_plt + ros_plt + faith_plt

In [None]:
xai_eval_plt

## SHAP deepdive

### Absolute SHAP values

In [None]:
# sum shap values of categorical features
season_idx = [9,12]
land_water_mask_idx = [13,21]
region_idx = [22,30]

season_shap_values = np.sum(shap_values[:,9:12],axis=1).reshape(-1,1)
lwm_shap_values = np.sum(shap_values[:,13:21],axis=1).reshape(-1,1)
region_shap_values = np.sum(shap_values[:,22:30],axis=1).reshape(-1,1)

agg_shap = np.concatenate((shap_values[:,:9],season_shap_values,lwm_shap_values,region_shap_values),axis=1)

In [None]:
fnames = list(X_train.columns[:9]) + ["season", "land_water_mask", "region"]

In [None]:
fnames = ["Temperature", "Vertical velocity", "Horizontal wind speed", "Distance from cloud top", "Cloud thickness", "Surface height", "Dust > 1 um", "Dust < 1 um", "SO4", "Season", "Land water mask", "Region"]

In [None]:
shap.summary_plot(agg_shap, feature_names=fnames, plot_type="bar", show=False)
pl.xlabel("Mean absolute SHAP value",fontsize=20)
pl.xticks(fontsize=15)
pl.yticks(fontsize=15)
#pl.savefig("../PaperPlots/ClimateInformaticsPaper/absolute_shap_icnc.pdf", format='pdf', dpi=600, bbox_inches='tight')
pl.show()

### SHAP dependence plots

In [None]:
VARIABLE_LABELS = {
    "SO4_log":"SO4 [mg m⁻³]",
    "DU":"mg kg⁻¹",
    "DU_sub_log":"DUST < 1um [mg m⁻³]",
    "DU_sup_log":"DUST > 1um [mg m⁻³]",
    "dz_top_v2": "Distance from cloud top [m]",
    "cloud_thickness_v2":"Cloud thickness [m]",
    "wind_speed": "Horizontal wind speed [m s⁻¹]",
    "w": "Vertical velocity [Pa s⁻¹]",
    "t": "Temperature [K]",
    "surface_height": "Surface height [m]",
}

In [None]:
# variables on x-axis are rounded to be plottable
round_dict = {'t':0, 'w':1, 'wind_speed':0, 'dz_top_v2':-2, 'cloud_thickness_v2':-2,
       'surface_height':-2, 'DU_sup_log':1, 'DU_sub_log':1, 'SO4_log':1}

In [None]:
# calculate shap values for icnc and iwc predictors first → we want to plot both shap dependencies into the same plot
# i.e. train models for iwc & icnc_5um by changing the predictand key in the experiment config
icnc_shap_values = shap_values
iwc_shap_values =  shap_values

In [None]:
min_count = 5000

plt_options = {'fontsize': {'xlabel': '30px',
  'ylabel': '25px',
  'ticks': '20px',
  'legend': '30px' }, 
 'cmap': 'Set3',
 'legend': False,
 'shared_axes':False}

hv_line_plt = hv.HLine(0).opts(line_width=0.5, color="grey")

for var_index, var_name in enumerate(X_train.columns[:1]):
    print(var_index, var_name)
    #var_index=6
    pds_shap= pd.DataFrame(np.array((X_test_sample[:,var_index].round(round_dict[var_name]),iwc_shap_values[:,var_index], icnc_shap_values[:,var_index])).T, columns=["variable","iwc_shap","icnc_shap"])
    mean = pds_shap.groupby("variable").mean()
    sd = pds_shap.groupby("variable").std().fillna(0)
    
    ylabel = "SHAP value"
    xlabel = VARIABLE_LABELS[var_name] 
    
    for pred in ["iwc","icnc"]:
        mean[f"{pred}_lower"] = mean[f"{pred}_shap"] - sd[f"{pred}_shap"]
        mean[f"{pred}_upper"] = mean[f"{pred}_shap"] + sd[f"{pred}_shap"]
        mean[f"{pred}_sd"] =      sd[f"{pred}_shap"]
    mean["count"] = pds_shap.groupby("variable").count().iloc[:,0]    
    mean = mean.reset_index()
    mean = mean.query(f"count>{min_count}")
    
    if "log" in var_name:
        mean["variable"] = 10**mean["variable"]
        plt_options["logx"]=True
    else:
        plt_options["logx"]=False
    
    # plot with legend
    #mean_plt = mean.rename(columns={"iwc_shap":"IWC","icnc_shap":"Nᵢ"}).hvplot.line(x="variable", y=["IWC","Nᵢ"], line_width=3, xlabel=xlabel, ylabel=ylabel, **plt_options) * mean.hvplot.area(x="variable",y="iwc_lower",y2="iwc_upper", line_alpha=0, fill_alpha=0.2,stacked=False) * mean.hvplot.area(x="variable",y="icnc_lower",y2="icnc_upper", line_alpha=0, fill_alpha=0.2,stacked=False) * hv_line_plt
    
    mean_plt = mean.hvplot.line(x="variable", y=["iwc_shap","icnc_shap"], line_width=3, xlabel=xlabel, ylabel=ylabel, **plt_options) * mean.hvplot.area(x="variable",y="iwc_lower",y2="iwc_upper", line_alpha=0, fill_alpha=0.2,stacked=False) * mean.hvplot.area(x="variable",y="icnc_lower",y2="icnc_upper", line_alpha=0, fill_alpha=0.2,stacked=False) * hv_line_plt
    # marginal distribution
    y_dist_plt = hv.Distribution(pds_shap.iwc_shap.values,kdims =["y"]).opts(width=80, xaxis=None,yaxis=None, alpha=0.5) * hv.Distribution(pds_shap.icnc_shap.values,kdims =["y"]).opts(width=80, xaxis=None,yaxis=None, alpha=0.5)
    x_dist_plt = hv.Distribution(X_test_sample[:,var_index],kdims = ["x"]).opts(height=80, xaxis=None,yaxis=None, color="grey")
    
    whole_plt = mean_plt << y_dist_plt << x_dist_plt
    display(whole_plt)