# Experiments and Evaluation of Shap
- Evaluation protocol for Deep Shap following evaluation procedure and skill score calculation described in **["Finding the right XAI Method --- A Guide for the Evaluation and Ranking of Explainable AI Methods in Climate Science](https://arxiv.org/abs/2303.00652)**  by Bommer et. al.
- **Note that** the calculations have been seperated into a Colab python notebook due to version conflicts with innvestigate v.1.0.9
- For execution via Colab:
    - 1.) create colab account
    - 2.) sync to colab drive (use colab app) and create shortcut in google drive (right click -> organise -> shortcut)
    - 3.) adapt paths in 'Preliminaries'

In [None]:
# Install packages.
!pip install scipy==1.10.1
!pip install matplotlib==3.5.3
!pip install keras
!pip install shap

Collecting matplotlib==3.5.3
  Downloading matplotlib-3.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.9/11.9 MB[0m [31m87.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: matplotlib
  Attempting uninstall: matplotlib
    Found existing installation: matplotlib 3.7.1
    Uninstalling matplotlib-3.7.1:
      Successfully uninstalled matplotlib-3.7.1
Successfully installed matplotlib-3.5.3


Collecting shap
  Downloading shap-0.42.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.1/547.1 kB[0m [31m40.5 MB/s[0m eta [36m0:00:00[0m
Collecting slicer==0.0.7 (from shap)
  Downloading slicer-0.0.7-py3-none-any.whl (14 kB)
Installing collected packages: slicer, shap
Successfully installed shap-0.42.0 slicer-0.0.7


In [None]:
# Import python packages.
import keras
import numpy as np
import matplotlib.pyplot as plt
import shap
import json
import os
# Mount google drive
from google.colab import drive
drive.mount('/content/drive/', force_remount = True)



Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


Mounted at /content/drive/


In [None]:
%%capture
# Install a local package.
!pip install -e /content/drive/MyDrive/Climate_X_Quantus/QuantusClimate/. --user

In [None]:
%%capture
!pip freeze

In [None]:
%%capture
# Import local package.
import sys
sys.path.insert(0,'/content/drive/My Drive/Climate_X_Quantus/')
import QuantusClimate as quantus

## Functions

In [None]:
%%capture
from typing import Dict, Any, Tuple

def generate_tf_innvestigation(
    model, inputs, targets, **kwargs
) -> np.ndarray:
    """
    Generate explanation for a tf model with innvestigate and NoiseGrad and FusionGrad
    tensorflow implementation
    :param model: trained model instance keras.model
    :param inputs: input sample
    :param targets: vectore of according true class labels
    :param kwargs: 'num_classes' - number of classes
                   'nr_samples' - number of iterations of the evaluation metric
                   'explanation' - true explanation
    :return:
    """

    method = kwargs.get("method", "random")
    og_shape = inputs.shape

    inputs = inputs.reshape(-1, *model.input_shape[1:])

    explanation = np.zeros_like(inputs)

    if not method:
        raise KeyError(
            "Specify a XAI method that already has been implemented."
        )

    elif "DeepSHAP" in method[2]:

        base = method[1]["base"]
        if 'MLP' in method[1]["net"]:
          base = base.reshape((len(base),method[1]["lat"]*method[1]["lon"]))
          inputs = inputs.reshape((len(inputs),method[1]["lat"]*method[1]["lon"]))
        else:
          base = base.reshape((len(base),method[1]["lat"],method[1]["lon"],1))
          inputs = inputs.reshape((len(inputs),method[1]["lat"],method[1]["lon"],1))

        exp = shap.DeepExplainer(model, base)

        if kwargs.get('num_classes', 0) > 0:
            expl = []
            s_values = exp.shap_values(inputs, check_additivity=False)


            for itrs in range(0, targets.shape[0]):
                if targets[itrs] == 0:
                    pred = np.argmax(model.predict(inputs[np.newaxis, itrs, ...]))
                    idx = np.random.choice(
                        [y_ for y_ in list(np.arange(0, kwargs.get('num_classes', 0))) if y_ != pred]
                    )

                    expl.append(np.array(s_values[idx][itrs,...]))
                else:
                    expl.append(np.array(s_values[targets[itrs]][itrs,...]))
            explanation = np.array(expl).reshape(inputs.shape)
        else:
            s_values = exp.shap_values(inputs,  ranked_outputs=1, check_additivity=False)

            shap_values, indexes = exp.shap_values(inputs,  ranked_outputs=1, check_additivity=False)
            explanation = np.array(shap_values[0])

    if np.isnan(explanation).sum()>0:
        print("<<< Error: All explanations are nans >>>")

    return explanation.reshape(og_shape)


In [None]:
%%capture

def run_quantus(args: Dict,
                explanations: Dict,
                metrics: Dict,
                xai_methods: Any,
                **params,
                ):
    """
    Function running pre-defined evaluation metrics in quantus on different explanation techniques
    :param args: model - keras.Model (trained model instance)
                x_batch - input batch
                y_batch - output batch
                s_batch - explanation batch
                n_samp - number of iterations for evaluation procedure
                num_cl -  number of classes
    :param explanations: same as s_batch
    :param metrics: Dict of metric function and hyperparameter settings
    :param xai_methods: dict of explanation name and hyperparameters settings
    :param params: dirout - output directory for back-up files in
                   csvfile - filename for back-up
    :return: dict of {metric: explanation: scores (float or array)}
    """

    results = {metric: {} for metric, metric_func in metrics.items()}
    dirout = params['dirout']
    csv_file = params['csvfile']
    for metric, metric_func in metrics.items():
        if metric is "RandomLogit":
            num_cl = args["num_cl"]
        else:
            num_cl = 0


        for method in xai_methods:
            print(metric, ":", method[2])
            if metric == "ROAD":
                scores = []
                for i in range(args["n_iter"]):
                    score = metric_func(model=args['model'],
                                         x_batch=args['x_batch'],
                                         y_batch=args['y_batch'],
                                         a_batch=explanations[method[2]],
                                         s_batch=args['s_batch'],
                                         explain_func=generate_tf_innvestigation,
                                         explain_func_kwargs={"method": method,
                                                              "explanation": explanations[method[2]],
                                                              'nr_samples':  args["n_smps"],
                                                              "num_classes": num_cl})
                    scores.append(score)
            elif metric in ["Robustness", "LocalLipschitzEstimate", "AvgSensitivity"]:
                scores = []
                if method[0] == "Control Var. Random Uniform":
                    as_list = list(method)
                    as_list[1] = {'fix': 0}
                    method = tuple(as_list)
                if params['net'] == 'CNN':
                    for i in range(args["n_iter"]):
                        score = metric_func(model=args['model'],
                                             x_batch=args['x_batch'][i*args["n_sms"]:(i+1)*args["n_sms"],...],
                                             y_batch=args['y_batch'][i*args["n_sms"]:(i+1)*args["n_sms"]],
                                             a_batch=explanations[method[2]][i*args["n_sms"]:(i+1)*args["n_sms"],...],
                                             s_batch=args['s_batch'][i*args["n_sms"]:(i+1)*args["n_sms"],...],
                                             explain_func=generate_tf_innvestigation,
                                             explain_func_kwargs={"method": method,
                                                                  "explanation": explanations[method[2]][i*args["n_sms"]:(i+1)*args["n_sms"],...],
                                                                  'nr_samples': args["n_sms"],
                                                                  "num_classes": num_cl})
                        scores.append(score)
                else:
                    scores = metric_func(model=args['model'],
                                         x_batch=args['x_batch'],
                                         y_batch=args['y_batch'],
                                         a_batch=explanations[method[2]],
                                         s_batch=args['s_batch'],
                                         explain_func=generate_tf_innvestigation,
                                         explain_func_kwargs={"method": method,
                                                              "explanation": explanations[method[2]],
                                                              'nr_samples': args["n_smps"],
                                                              "num_classes": num_cl})
            else:
                scores = metric_func(model=args['model'],
                            x_batch=args['x_batch'],
                            y_batch=args['y_batch'],
                            a_batch=explanations[method[2]],
                            s_batch=args['s_batch'],
                            explain_func=generate_tf_innvestigation,
                            explain_func_kwargs={"method": method,
                                                 "explanation": explanations[method[2]],
                                                 'nr_samples': args["n_smps"],
                                                 "num_classes": num_cl})

            results[metric][method[2]] = scores

    return results


In [None]:
%%capture
import sklearn.metrics as metrix

def area_score(results: Any,
                     ** kwargs):
    """
    Implements an area under the curve metric for ROAD graph
    """
    y = np.zeros((len(results.values()),))
    x = np.zeros((len(results.values()),))
    i = 0
    for keys, vals in results.items():

        y[i] = vals
        x[i] = float(keys)

        i+=1
    score =  metrix.auc(x,y)
    return score

def bss_mean_var(metricx: Dict,
                         methods: Dict,
                         results: Dict,
                         base: Dict,
                         **params):
    """
    Calculates breier skill score satistics including mean BSS and SEM across samples in scores[metric][method]
    :param metrics: dict of quantus metrics
    :param methods: dict of explanation methods
    :param scores:  dict of scores for each metric and each XAI method
    :param params:  kwargs with number of XAI methods and names of the properties (network comparison see defaults)
                or metrics that underlie normalization according to Eq.
    :return:
    """
    # Set params.
    num_xai = params.get('num_xai', 8)
    string_list = params.get('min_norm', ["Randomisation", "Robustness"])

    #Initialize result dicts.
    means = {}
    var = {}
    i = 0
    # Aggregate mean and SEM.
    for metric, metric_func in metricx.items():
        means[metric] = {}
        var[metric] = {}
        unnormed_scores = []
        for j, methoddict in enumerate(methods):
            method = methoddict[2]
            if metric is "ROAD":
                u_sc = []
                for r in range(len(results[metric][method])):
                    agg_score = area_score(results[metric][method][r])
                    u_sc.append(agg_score)
                unnormed_scores.append(np.array(u_sc))
            elif type(results[metric][method]) is dict:
                u_scores = []
                for vals in results[metric][method].values():
                    u_scores.append(vals)

                unnormed_scores.append(np.array(u_scores).flatten())
            else:
                unnormed_scores.append(np.array(results[metric][method]).flatten())

        unnormed_scores = np.array(unnormed_scores)

        if metric is "ROAD":
              b_sc = []
              for r in range(len(base[metric])):
                  agg_score = area_score(base[metric][r])
                  b_sc.append(agg_score)
              base_score = np.abs(np.array(b_sc)).reshape(unnormed_scores[0].shape)
        elif type(base[metric]) is dict:
                b_scores = []
                for vals in base[metric].values():
                    b_scores.append(vals)

                base_score = np.array(b_scores).flatten().reshape(unnormed_scores[0].shape)
        else:
              base_score = base[metric].reshape(unnormed_scores[0].shape)
        base_scores = np.repeat(base_score[np.newaxis, :], num_xai, axis=0)

        if metric in string_list:
            unnormed_scores = np.abs(unnormed_scores)
            scores = np.ones(base_scores.shape) - (unnormed_scores/base_scores)
        else:
            scores = (unnormed_scores - base_scores)/(np.ones(base_scores.shape)-base_scores)

        for i, methoddict in enumerate(methods):
            meth = methoddict[2]
            if np.isinf(np.abs(np.mean(scores[i, :]))):
                means[metric][meth] = np.nanmean(np.ma.masked_invalid(scores[i, :]))
            else:
                means[metric][meth] = np.nanmean(scores[i, :])
            if np.isinf(np.std(scores[i, :])):
                var[metric][meth] = np.nanstd(np.ma.masked_invalid(scores[i, :]))/ np.sqrt(scores.shape[1])
            else:
                var[metric][meth] = np.nanstd(scores[i, :]) / np.sqrt(scores.shape[1])
        i += 1
    return means, var


## Preliminaries

- Set raw_path to raw data path
- Set save_path to DeepShap result path
- Set score_path to general result path (holding data from python script evaluation)
- Set net = 'MLP' for MLP-based evaluation of DeepShap
- Set net = 'CNN' for CNN-based evaluation of DeepShap

- Set properties = '#_0' with # = {Robustness, Faithfulness, Complexity}
  - Robustness runs Robustness evaluation (incl. Lipschitz Estimate & Max. Sensitivity)
  - Faithfulness runs Faithfulness evaluation (incl. Faithfulness Correlation & ROAD)
  - Complexity runs Complexity, Localization and Randomization evaluation (incl. Complexity, Sparseness, Relevance Rank Accuracy, Top-k, Model Parameter Randomization Test and Random Logit Test)

In [None]:
# Set experiment paths.
import yaml

exp_path = '/content/drive/MyDrive/Climate_X_Quantus/Experiment/'
raw_path = '/content/drive/MyDrive/Climate_X_Quantus/Data/'
save_path = '/content/drive/MyDrive/Climate_X_Quantus/Data/Quantus/Baseline/Shap/'
score_path = '/content/drive/MyDrive/Climate_X_Quantus/Data/Quantus/Baseline/'

if not os.path.isdir(score_path):
    print("path does not exist")
    os.mkdir(score_path)
    os.mkdir(save_path)

In [None]:
# Set experiment settings.

config = yaml.load(open(exp_path + 'plot_config.yaml'), Loader=yaml.FullLoader)
post_settings = yaml.load(open(exp_path + 'Post_config.yaml'), Loader=yaml.FullLoader)

# Experiment variables.
properties = 'Complexity_0'
net = 'CNN'
params = config['params']
params['net'] = net




### Load Data


In [None]:
# Load the full data object.
all = np.load(raw_path + 'Quantus/%s' + '/0/' + 'Postprocessed_data_ALL.npz', allow_pickle=True)
background= all["Input"].reshape(all["Input"].shape[0], 1, len(all["wh"][0]), len(all["wh"][1]))

# select a set of background examples to take an expectation over.
background = background[np.random.choice(background.shape[0], 100, replace=False)]

# Longitude and latitudes.
lat = all['wh'][0]
lon = all['wh'][1]

del all

#Load experiment data, including random baseline samples`.
data = np.load(raw_path + 'Quantus/Baseline/' + 'data_%s_%s.npz' % (properties, net), allow_pickle=True)

# Input images.
if 'MLP' in net:
  x_batch = data['x_batch'].swapaxes(1,3).swapaxes(1,2)
else:
  x_batch = data['x_batch']#.swapaxes(1,3).swapaxes(1,2)


# Classification labels.
y_batch = data['y_batch']

# mask data.
s_batch = data['s_batch']#.swapaxes(1,3).swapaxes(1,2)

# Years of the input images.
y_out= data['y_out']

# Experiment settings.
n_smps=data['n_smps']
n_sms= data['n_sms']
n_iter= data['n_iter']
num_cl= y_out.shape[1]

# Reference scores.
ref = data['reference']

### Load model
- load trained model

In [None]:
from keras.models import load_model
import keras

model = load_model(raw_path + '/Network/' + 'lens_%s_0_T2M_1.tf' % net, compile=False)

# Run the model on a test sample, requiring a compilation.
model.compile(optimizer=keras.optimizers.SGD(lr=0.001, momentum=0.9, nesterov=True),
              loss='binary_crossentropy',
              metrics=[keras.metrics.categorical_accuracy],)
print(model)

<keras.engine.functional.Functional object at 0x7ad4d7f75c60>


The `lr` argument is deprecated, use `learning_rate` instead.


### Create explanations SHAP
- Generation of explanation such to pass to the evaluation metrics as samples

In [None]:
# Reshape Data
if 'MLP' in net:
  backg = background.reshape((len(background),len(lat)*len(lon)))
  x_b= x_batch.reshape((len(x_batch), len(lat)*len(lon)))

else:
  backg = background.reshape((len(background),len(lat),len(lon),1))
  x_b = x_batch.reshape((len(x_batch), len(lat),len(lon),1))

In [None]:
import shap

# Explanation variables.
xai_methods =[("DeepSHAP", {"base": backg, "lat":len(lat),"lon":len(lon), "net":net}, "DeepSHAP")]
explanations = {}

# explain predictions of the model on three images
exp = shap.DeepExplainer(model, backg)
shapley_values = exp.shap_values(x_b, ranked_outputs=1,check_additivity=False)

if 'MLP' in net:
  explanations[xai_methods[0][0]] = shapley_values[0][0].reshape((len(x_batch), 1, len(lat),len(lon)))
else:
    explanations[xai_methods[0][0]] = shapley_values[0][0].reshape(x_batch.shape)

keras is no longer supported, please use tf.keras instead.
Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.


## Quantus Experiment

### Set up configuration
- create metrics dict and set metrics hyperparameters for the evaluation metrics
- build baseline (reference) dict from loaded previous experiment results (This guarantees that all skill scores are calculated based on the same skill score)

In [None]:
%%capture
arguments = {'model': model,
                'x_batch': x_batch,
                'y_batch': y_batch,
                's_batch': s_batch,
                'net': net,
                'y_out': y_out,
                'n_smps': n_smps,
                'n_sms' : n_sms,
                'n_iter': n_iter,
                "num_cl": num_cl,
}

if "Robustness" in properties:

    if 'MLP' in net:
      metrics = {"AvgSensitivity": quantus.AvgSensitivity(nr_samples= config['n_sms'],
                                                      lower_bound=0.1,
                                                      norm_numerator=quantus.fro_norm,
                                                      norm_denominator=quantus.fro_norm,
                                                      perturb_func=quantus.gaussian_noise,
                                                      similarity_func= quantus.difference,
                                                      disable_warnings= True,normalise=True)}
    else:
      metrics = dict()


    metrics["LocalLipschitzEstimate"] = quantus.LocalLipschitzEstimate(
                                                          nr_samples = config['n_sms'],
                                                          perturb_std =0.1,
                                                          perturb_mean= 0,
                                                          norm_numerator= quantus.distance_euclidean,
                                                          norm_denominator= quantus.distance_euclidean,
                                                          perturb_func = quantus.gaussian_noise,
                                                          similarity_func = quantus.lipschitz_constant,normalise=True)




    params['min_norm'] = list(metrics.keys())
    config['property'] = "Robustness_0"

elif "Faithfulness"in properties:
    metrics = {
        "FaithfulnessCorrelation": quantus.FaithfulnessCorrelation(
            nr_runs=n_smps,
            subset_size=40,
            perturb_baseline="uniform",
            perturb_func=quantus.baseline_replacement_by_indices,
            similarity_func=quantus.correlation_pearson,
            return_aggregate=False,
            normalise=True, ),
        "ROAD": quantus.ROAD(noise=0.01,
                             normalise=True,
                             perturb_baseline="uniform",
                             perturb_func=quantus.noisy_linear_imputation,
                             percentages=np.linspace(1, 50, n_smps).tolist()), }
    params['min_norm'] = ["ROAD"]
    config['property'] = "Faithfulness_0"
else:
    metrics = {
        "Complexity:Complexity": quantus.Complexity(
            normalise=True,
            disable_warnings=True),
        "Complexity:Sparseness": quantus.Sparseness(
            normalise=True,
            disable_warnings=True),
        "Localisation:TopK": quantus.TopKIntersection(
            normalise=True,
            disable_warnings=True,
            k=(int(0.01 * int(lat.shape[0]) * int(lon.shape[0])))),
        "Localisation:RRA": quantus.RelevanceRankAccuracy(
            normalise=True,
            disable_warnings=True),
        "Randomisation": quantus.ModelParameterRandomisation(layer_order="bottom_up",
                                                             similarity_func=quantus.correlation_spearman,
                                                             normalise=True),}
    if 'MLP' in net:
        metrics["RandomLogit"] = quantus.RandomLogit(
            normalise=True,
            num_classes=num_cl,
            similarity_func=quantus.correlation_spearman, )

    params['min_norm'] = ["Complexity:Complexity", "Randomisation", "RandomLogit"]
    config['property'] = "Complexity_0"

In [None]:
# Set up reference dict
reference = {}
for i in range(len(ref)):
  if type(ref[i]) is dict:
    if 'Randomisation' in list(metrics.keys())[i]:
       for j in range(len(list(ref[i].keys()))):
        key = list(ref[i].keys())[j]
        ref[i][key] = np.ones((len(ref[i][key]),))
       reference[list(metrics.keys())[i]] = ref[i]
    else:
        reference[list(metrics.keys())[i]] = ref[i]
  else:
    reference[list(metrics.keys())[i]] = np.asarray(ref[i])

### Run experiment

In [None]:
# Intiate intermediate results save.
csv_files = 'inter_results_%s_xai_%s_%s.csv' % (properties,len(xai_methods), config['net'])
params['dirout'] = save_path
params['csvfile'] = csv_files
print('>>>>> Run %s analysis and baseline test <<<<<' % properties)
# Run Quantus.
results = run_quantus(arguments,explanations,metrics,xai_methods, **params)

>>>>> Run Complexity_0 analysis and baseline test <<<<<
Complexity:Complexity : DeepSHAP
Complexity:Sparseness : DeepSHAP
Localisation:TopK : DeepSHAP
Localisation:RRA : DeepSHAP
Randomisation : DeepSHAP


### Calculate scores
- calculate and save skill scores in a numpy '.npz'-file

In [None]:
import pandas as pd

dfs = pd.DataFrame.from_dict(results)

# Set aggregation params.
params['num_xai'] = len(xai_methods)

# Statistics: brier skill score
bss_mean, bss_sem = bss_mean_var(metrics,xai_methods, results, reference, **params)

bss2 = pd.DataFrame.from_dict(bss_sem)
bss = pd.DataFrame.from_dict(bss_mean)

# Save SEM BSS.
np.savez(save_path + 'bss_%s_SEM_scores_xai_%s_%s.npz'% (properties,len(xai_methods), net), sem = bss2.values, xai = xai_methods[0][0], properties = bss2.columns.values)

# Save mean BSS.
np.savez(save_path + 'bss_%s_abs_agg_scores_xai_%s_%s.npz' % (properties,len(xai_methods), net), mean = bss.values, xai = xai_methods[0][0], properties = bss.columns.values)



Complexity:Complexity 

scores: [[8.79207698 8.684202   9.20592117 8.95304717 9.20413404 8.68367985
  9.12495122 8.99593703 8.94811223 8.83899008 9.05229144 9.01394969
  9.04962166 8.74869978 9.10583063 8.8485133  9.04652193 8.95927355
  8.6940937  8.99092313 9.10368441 9.12926649 8.92240323 8.84576479
  9.16139004 9.17886412 8.75139114 9.07302641 9.03731666 9.13541147
  9.07468621 8.78691811 9.0169081  9.11549435 9.1857293  8.88237359
  8.71953023 9.06770443 8.79576085 8.73677593 9.06847041 8.72058222
  9.18817063 9.07696617 8.98938953 8.86273159 9.14942472 8.77978836
  8.81477131 9.15427981]]
ref: [[9.17664792 9.1501766  9.13662945 9.14530576 9.13735445 9.17285714
  9.13719639 9.17425328 9.17570316 9.17406199 9.14095628 9.16999619
  9.16516854 9.1740146  9.15737853 9.17382023 9.16753486 9.13975896
  9.17525423 9.17680795 9.15287837 9.13772086 9.14063365 9.17143283
  9.13831851 9.14523687 9.1509302  9.17577137 9.16371386 9.15155839
  9.1710098  9.13389663 9.17261187 9.15373951 9.13557