## Example - RobustnessTest

This notebook shows the functionality of the RobustnessTest.

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [48]:
!pip install captum
!pip install opencv-python

import torch
import torchvision
from torchvision import transforms
import numpy as np
import h5py
from tqdm import tqdm
from captum.attr import Saliency, IntegratedGradients
from pathlib import Path
import warnings

# Retrieve source code.
from drive.MyDrive.Projects.xai_quantification_toolbox import * #import xaiquantificationtoolbox

# Notebook settings.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore", category=UserWarning)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Load model, data and attributions.

In [8]:
# Load pre-trained ResNet18 model.
model = torchvision.models.resnet18(pretrained=True)
model.eval()

# Load test data and loaders.
test_set = torchvision.datasets.ImageFolder(root='/content/drive/My Drive/imagenet_images', 
                                            transform=transforms.Compose([transforms.Resize(256),
                                                                          transforms.CenterCrop((224, 224)),
                                                                          transforms.ToTensor(),
                                                                          transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
test_loader = torch.utils.data.DataLoader(test_set, shuffle=True, batch_size=64)


# Evaluate model performance.
#predictions, labels = evaluate_model(model.to(device), data=test_loader, device=device)
#print(f"\nModel test accuracy: {(100 * score_model(predictions, labels)):.2f}%")

# Load data, targets and attributions.
x_batch, y_batch = iter(test_loader).next()
a_batch = explain(model.to(device), x_batch.to(device), y_batch.to(device), explanation_func="Saliency")

In [None]:
# Plot some explanations!
import matplotlib.pyplot as plt

for i in range(20, 30): #[4140, 2091, 78, 1195]: 
    plt.imshow(denormalize_image(x_batch.cpu().data[i]).transpose(0, 1).transpose(1, 2))
    plt.show()
    plt.imshow(a_batch.cpu().data[i], cmap="seismic")
    plt.colorbar()
    plt.show()


### Option 1. Evaluate the robustness of attributions in one line of code.

In [14]:
# One-liner to measure robustness of provided attributions.
scores = RobustnessTest(**{
    "similarity_func": lipschitz_constant,
    "perturb_func": gaussian_noise,
})(model=model, 
   x_batch=x_batch.cpu().numpy(), 
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   device=device, 
   **{"explanation_func": "Saliency"})

scores

[autoreload of drive.MyDrive.Projects.xai_quantification_toolbox.xai_quantification_toolbox.helpers.explanation_func failed: Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/IPython/extensions/autoreload.py", line 247, in check
    superreload(m, reload, self.old_objects)
TypeError: Union[arg, ...]: each arg must be a type. Got <module 'torchvision.models' from '/usr/local/lib/python3.7/dist-packages/torchvision/models/__init_.
]


[36.15371720125984,
 21.459338637757334,
 43.42919743262423,
 37.08792479182698,
 13.523527100876139,
 20.94329984881238,
 44.70699158950547,
 23.67869291188523,
 29.25100910964856,
 24.647430136850215,
 50.885522424134564,
 27.256833951635276,
 35.857669598139225,
 23.510530839622227,
 38.806234502506776,
 22.85824393740396,
 25.52348039434638,
 37.083352979412076,
 23.802950806915256,
 35.60551562569577,
 12.987899146346574,
 41.11055848110009,
 32.444884071396245,
 23.838824690429323,
 22.463970294103255,
 38.40506339860404,
 9.806916670998078,
 20.191377891336156,
 18.814474579961832,
 21.044851109846004,
 31.120535577480528,
 20.508416900853167,
 28.96616358275558,
 28.041295090252827,
 32.706932361397975,
 28.4013933342942,
 23.71318363130131,
 30.811142791578604,
 33.38204749130086,
 31.117921891874843,
 39.05896825072428,
 27.431349474539743,
 37.06734211738538,
 26.39497906634696,
 41.276337324742194,
 24.21098467976126,
 32.66730895680301,
 40.45729727541629,
 26.440908447079

In [40]:

# One-liner to measure continuity of provided attributions.
scores = ContinuityTest(**{
    "similarity_func": correlation_spearman,
    "perturb_func": translation_x_direction,
    "nr_patches": 4,
    "nr_steps": 10,
})(model=model, 
   x_batch=x_batch.cpu().numpy(), 
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   device=device, 
   **{"explanation_func": "Saliency"})

scores

[0.4878787878787878,
 0.6,
 0.7999999999999999,
 0.5878787878787878,
 0.4818181818181818,
 0.6242424242424243,
 0.8272727272727272,
 0.5303030303030303,
 0.7727272727272727,
 0.8545454545454544,
 0.7030303030303029,
 0.7818181818181817,
 0.7303030303030302,
 0.5969696969696969,
 0.6909090909090908,
 0.7424242424242424,
 0.6303030303030303,
 0.9151515151515152,
 0.33030303030303027,
 0.718181818181818,
 0.5424242424242424,
 0.39393939393939387,
 0.7151515151515151,
 0.3424242424242424,
 0.903030303030303,
 0.8454545454545455,
 0.3848484848484847,
 0.5393939393939393,
 0.506060606060606,
 0.2696969696969697,
 0.4969696969696969,
 0.796969696969697,
 0.8303030303030303,
 0.5909090909090909,
 0.35454545454545455,
 0.34545454545454546,
 0.49999999999999994,
 0.5393939393939393,
 0.509090909090909,
 0.5333333333333333,
 0.7909090909090908,
 0.8303030303030302,
 0.7757575757575756,
 0.718181818181818,
 0.6757575757575757,
 0.812121212121212,
 0.7424242424242422,
 0.5878787878787878,
 0.739393

In [43]:
# One-liner to measure input independence of provided attributions.
scores = InputIndependenceRate(**{
    "similarity_func": abs_difference,
    "perturb_func": optimization_scheme,
    "std": 0.01,
})(model=model, 
   x_batch=x_batch.cpu().numpy(), 
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   device=device, 
   **{"explanation_func": "Saliency"})

scores

1.0

In [44]:
# One-liner to measure local lipschitz constant of provided attributions.
scores = EstimatedLocalLipschitzConstant(**{
    "similarity_func": lipschitz_constant,
    "perturb_func": gaussian_noise,
    "distance_numerator": distance_euclidean,
    "distance_denominator": distance_euclidean,
    "perturb_std": 0.1,
    "nr_steps": 10,
})(model=model, 
   x_batch=x_batch.cpu().numpy(), 
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   device=device, 
   **{"explanation_func": "Saliency"})

scores

[37.30127575118598,
 22.121814107683537,
 45.067826177396114,
 39.232404747828,
 14.437890142892192,
 22.202293938019196,
 45.9585545621177,
 24.456367302295472,
 30.57050645608584,
 26.44729388104142,
 54.052140853975125,
 28.8199717935829,
 37.18822060425881,
 24.189395625404025,
 40.529155887407825,
 24.487054122209425,
 26.603376709913938,
 37.90717359248646,
 25.241728448144286,
 36.94950261647965,
 14.21280285975028,
 41.24405426106984,
 34.78548224497834,
 24.558473536604215,
 22.712949325246996,
 39.56552581615689,
 10.481931323168332,
 21.126266072565624,
 19.233047753960527,
 22.323757902618798,
 33.82471535585488,
 22.653545378826532,
 31.47558044439554,
 28.46441162660213,
 36.54705731638505,
 30.017435091098957,
 24.35241438395752,
 33.228235072373764,
 34.68023558398905,
 32.4090633205909,
 42.98214852031414,
 30.940671220760954,
 38.26997863420922,
 27.307556240335924,
 42.498760145207825,
 25.67364961895835,
 35.4298964453991,
 43.074363599307304,
 29.72153336288962,
 1

In [67]:
# One-liner to measure local lipschitz constant of provided attributions.
scores = SensitivityMax(**{
    "similarity_func": difference,
    "perturb_func": uniform_sampling,
    "norm_numerator": fro_norm,
    "norm_denominator": fro_norm,
    "perturb_radius": 0.02,
    "nr_steps": 10,
})(model=model, 
   x_batch=x_batch.cpu().numpy(), 
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   device=device, 
   **{"explanation_func": "Saliency"})

scores

[0.0137402145,
 0.0064482414,
 0.014117877,
 0.016248235,
 0.0034813983,
 0.0076988307,
 0.015668634,
 0.008358887,
 0.007957364,
 0.0075780484,
 0.015087734,
 0.008030167,
 0.01583539,
 0.009883713,
 0.020468712,
 0.009848966,
 0.006580384,
 0.016781868,
 0.010596507,
 0.01325135,
 0.003844463,
 0.015161874,
 0.011751871,
 0.005740982,
 0.0074924245,
 0.017333066,
 0.002706792,
 0.007640656,
 0.006835906,
 0.010460794,
 0.009770189,
 0.012336688,
 0.011851529,
 0.009166229,
 0.011856076,
 0.011575303,
 0.011354601,
 0.010320113,
 0.011020264,
 0.009523336,
 0.014258611,
 0.009800352,
 0.015084063,
 0.010181025,
 0.013268139,
 0.009647754,
 0.022006867,
 0.01575007,
 0.010010494,
 0.004446627,
 0.009163492,
 0.011745132,
 0.009597097,
 0.012944428,
 0.007821564,
 0.009619963,
 0.0071163746,
 0.01375225,
 0.010869938,
 0.012274247,
 0.014073269,
 0.02050844,
 0.019222727,
 0.0071925684]

In [None]:

from copy import deepcopy
from inspect import signature
from typing import Any, Callable, Tuple, Union, cast

import torch
from torch import Tensor

from captum._utils.common import (
    _expand_and_update_additional_forward_args,
    _expand_and_update_baselines,
    _expand_and_update_target,
    _format_baseline,
    _format_input,
    _format_tensor_into_tuples,
)
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum.log import log_usage
from captum.metrics._utils.batching import _divide_and_aggregate_metrics


#[docs] @ log_usage()

def sensitivity_max(
        explanation_func: Callable,
        inputs: TensorOrTupleOfTensorsGeneric,
        perturb_func: Callable = default_perturb_func,
        perturb_radius: float = 0.02,
        n_perturb_samples: int = 10,
        norm_ord: str = "fro",
        max_examples_per_batch: int = None,
        **kwargs: Any,
) -> Tensor:
    r"""
    Explanation sensitivity measures the extent of explanation change when
    the input is slightly perturbed. It has been shown that the models that
    have high explanation sensitivity are prone to adversarial attacks:
    `Interpretation of Neural Networks is Fragile`
    https://www.aaai.org/ojs/index.php/AAAI/article/view/4252

    `sensitivity_max` metric measures maximum sensitivity of an explanation
    using Monte Carlo sampling-based approximation. By default in order to
    do so it samples multiple data points from a sub-space of an L-Infinity
    ball that has a `perturb_radius` radius using `default_perturb_func`
    default perturbation function. In a general case users can
    use any L_p ball or any other custom sampling technique that they
    prefer by providing a custom `perturb_func`.

    Note that max sensitivity is similar to Lipschitz Continuity metric
    however it is more robust and easier to estimate.
    Since the explanation, for instance an attribution function,
    may not always be continuous, can lead to unbounded
    Lipschitz continuity. Therefore the latter isn't always appropriate.

    More about the Lipschitz Continuity Metric can also be found here
    `On the Robustness of Interpretability Methods`
    https://arxiv.org/pdf/1806.08049.pdf
    and
    `Towards Robust Interpretability with Self-Explaining Neural Networks`
    https://papers.nips.cc/paper\
    8003-towards-robust-interpretability-
    with-self-explaining-neural-networks.pdf

    More details about sensitivity max can be found here:
    `On the (In)fidelity and Sensitivity of Explanations`
    https://arxiv.org/pdf/1901.09392.pdf

    Args:

        explanation_func (callable):
                This function can be the `attribute` method of an
                attribution algorithm or any other explanation method
                that returns the explanations.

        inputs (tensor or tuple of tensors):  Input for which
                explanations are computed. If `explanation_func` takes a
                single tensor as input, a single input tensor should
                be provided.
                If `explanation_func` takes multiple tensors as input, a tuple
                of the input tensors should be provided. It is assumed
                that for all given input tensors, dimension 0 corresponds
                to the number of examples (aka batch size), and if
                multiple input tensors are provided, the examples must
                be aligned appropriately.

        perturb_func (callable):
                The perturbation function of model inputs. This function takes
                model inputs and optionally `perturb_radius` if
                the function takes more than one argument and returns
                perturbed inputs.

                If there are more than one inputs passed to sensitivity function those
                will be passed to `perturb_func` as tuples in the same order as they
                are passed to sensitivity function.

                It is important to note that for performance reasons `perturb_func`
                isn't called for each example individually but on a batch of
                input examples that are repeated `max_examples_per_batch / batch_size`
                times within the batch.

            Default: default_perturb_func
        perturb_radius (float, optional): The epsilon radius used for sampling.
            In the `default_perturb_func` it is used as the radius of
            the L-Infinity ball. In a general case it can serve as a radius of
            any L_p nom.
            This argument is passed to `perturb_func` if it takes more than
            one argument.

            Default: 0.02
        n_perturb_samples (int, optional): The number of times input tensors
                are perturbed. Each input example in the inputs tensor is
                expanded `n_perturb_samples` times before calling
                `perturb_func` function.

                Default: 10
        norm_ord (int, float, inf, -inf, 'fro', 'nuc', optional): The type of norm
                that is used to compute the
                norm of the sensitivity matrix which is defined as the difference
                between the explanation function at its input and perturbed input.

                Default: 'fro'
        max_examples_per_batch (int, optional): The number of maximum input
                examples that are processed together. In case the number of
                examples (`input batch size * n_perturb_samples`) exceeds
                `max_examples_per_batch`, they will be sliced
                into batches of `max_examples_per_batch` examples and processed
                in a sequential order. If `max_examples_per_batch` is None, all
                examples are processed together. `max_examples_per_batch` should
                at least be equal `input batch size` and at most
                `input batch size * n_perturb_samples`.

                Default: None
         **kwargs (Any, optional): Contains a list of arguments that are passed
                to `explanation_func` explanation function which in some cases
                could be the `attribute` function of an attribution algorithm.
                Any additional arguments that need be passed to the explanation
                function should be included here.
                For instance, such arguments include:
                `additional_forward_args`, `baselines` and `target`.

    Returns:

        sensitivities (tensor): A tensor of scalar sensitivity scores per
               input example. The first dimension is equal to the
               number of examples in the input batch and the second
               dimension is one. Returned sensitivities are normalized by
               the magnitudes of the input explanations.

    Examples::
        >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
        >>> # and returns an Nx10 tensor of class probabilities.
        >>> net = ImageClassifier()
        >>> saliency = Saliency(net)
        >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
        >>> # Computes sensitivity score for saliency maps of class 3
        >>> sens = sensitivity_max(saliency.attribute, input, target = 3)

    """

    def _generate_perturbations(
            current_n_perturb_samples: int,
    ) -> TensorOrTupleOfTensorsGeneric:
        r"""
        The perturbations are generated for each example
        `current_n_perturb_samples` times.

        For perfomance reasons we are not calling `perturb_func` on each example but
        on a batch that contains `current_n_perturb_samples` repeated instances
        per example.
        """
        inputs_expanded: Union[Tensor, Tuple[Tensor, ...]] = tuple(
            torch.repeat_interleave(input, current_n_perturb_samples, dim=0)
            for input in inputs
        )
        if len(inputs_expanded) == 1:
            inputs_expanded = inputs_expanded[0]

        return (
            perturb_func(inputs_expanded, perturb_radius)
            if len(signature(perturb_func).parameters) > 1
            else perturb_func(inputs_expanded)
        )

    def max_values(input_tnsr: Tensor) -> Tensor:
        return torch.max(input_tnsr, dim=1).values  # type: ignore

    kwarg_expanded_for = None
    kwargs_copy: Any = None

    def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor:
        inputs_perturbed = _generate_perturbations(current_n_perturb_samples)

        # copy kwargs and update some of the arguments that need to be expanded
        nonlocal kwarg_expanded_for
        nonlocal kwargs_copy
        if (
                kwarg_expanded_for is None
                or kwarg_expanded_for != current_n_perturb_samples
        ):
            kwarg_expanded_for = current_n_perturb_samples
            kwargs_copy = deepcopy(kwargs)
            _expand_and_update_additional_forward_args(
                current_n_perturb_samples, kwargs_copy
            )
            _expand_and_update_target(current_n_perturb_samples, kwargs_copy)
            if "baselines" in kwargs:
                baselines = kwargs["baselines"]
                baselines = _format_baseline(
                    baselines, cast(Tuple[Tensor, ...], inputs)
                )
                if (
                        isinstance(baselines[0], Tensor)
                        and baselines[0].shape == inputs[0].shape
                ):
                    _expand_and_update_baselines(
                        cast(Tuple[Tensor, ...], inputs),
                        current_n_perturb_samples,
                        kwargs_copy,
                    )

        expl_perturbed_inputs = explanation_func(inputs_perturbed, **kwargs_copy)

        # tuplize `expl_perturbed_inputs` in case it is not
        expl_perturbed_inputs = _format_tensor_into_tuples(expl_perturbed_inputs)

        expl_inputs_expanded = tuple(
            expl_input.repeat_interleave(current_n_perturb_samples, dim=0)
            for expl_input in expl_inputs
        )

        sensitivities = torch.cat(
            [
                (expl_input - expl_perturbed).view(expl_perturbed.size(0), -1)
                for expl_perturbed, expl_input in zip(
                expl_perturbed_inputs, expl_inputs_expanded
            )
            ],
            dim=1,
        )
        # compute the norm of original input explanations
        expl_inputs_norm_expanded = torch.norm(
            torch.cat(
                [expl_input.view(expl_input.size(0), -1) for expl_input in expl_inputs],
                dim=1,
            ),
            p=norm_ord,
            dim=1,
            keepdim=True,
        ).repeat_interleave(current_n_perturb_samples, dim=0)
        expl_inputs_norm_expanded = torch.where(
            expl_inputs_norm_expanded == 0.0,
            torch.tensor(
                1.0,
                device=expl_inputs_norm_expanded.device,
                dtype=expl_inputs_norm_expanded.dtype,
            ),
            expl_inputs_norm_expanded,
        )

        # compute the norm for each input noisy example
        sensitivities_norm = (
                torch.norm(sensitivities, p=norm_ord, dim=1, keepdim=True)
                / expl_inputs_norm_expanded
        )
        return max_values(sensitivities_norm.view(bsz, -1))

    inputs = _format_input(inputs)  # type: ignore

    bsz = inputs[0].size(0)

    with torch.no_grad():
        expl_inputs = explanation_func(inputs, **kwargs)
        metrics_max = _divide_and_aggregate_metrics(
            cast(Tuple[Tensor, ...], inputs),
            n_perturb_samples,
            _next_sensitivity_max,
            max_examples_per_batch=max_examples_per_batch,
            agg_func=torch.max,
        )
    return metrics_max

### Option 2. Evaluate the robustness of provided attributions while enjoying more functionality of Quantifier and Plotting.

In [None]:
# Provide notebooks for the different use cases: compare models, XAI methods, different measures
# ...

In [None]:
# Specify the tests.
tests = [RobustnessTest(**{
    "similarity_function": similarity_fn,
    "perturbation_function": gaussian_blur,
}) for similarity_fn in [lipschitz_constant, distance_euclidean, cosine]]

# Load attributions of another explanation method.
a_batch_intgrad = IntegratedGradients(model).attribute(inputs=x_batch, targets=y_batch)

# Init the quantifier object.
quantifier = Quantifier(measures=tests, io_object=h5py.File("PATH_TO_H5PY_FILE"), checkpoints=..)

# Score the tests.
results = [quantifier.score(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch)
           for a_batch in [a_batch_saliency, a_batch_intgrad]]

# Plot Saliency vs Integrated Gradients.
Plotting(results, show=False, path_to_save="PATH_TO_SAVE_FIGURE")