## Example - RobustnessTest

This notebook shows the functionality of the RobustnessTest.

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [20]:
!pip install captum
!pip install opencv-python

import torch
import torchvision
from torchvision import transforms
import numpy as np
import h5py
from tqdm import tqdm
from captum.attr import Saliency, IntegratedGradients
from pathlib import Path
import warnings

# Retrieve source code.
from drive.MyDrive.Projects.xai_quantification_toolbox import * #import xaiquantificationtoolbox

# Notebook settings.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore", category=UserWarning)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

### Load model, data and attributions.

In [None]:
# Load pre-trained ResNet18 model.
model = torchvision.models.resnet18(pretrained=True)
model.eval()

# Load test data and loaders.
test_set = torchvision.datasets.ImageFolder(root='/content/drive/My Drive/imagenet_images', 
                                            transform=transforms.Compose([transforms.Resize(256),
                                                                          transforms.CenterCrop((224, 224)),
                                                                          transforms.ToTensor(),
                                                                          transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
test_loader = torch.utils.data.DataLoader(test_set, shuffle=True, batch_size=8)


# Evaluate model performance.
#predictions, labels = evaluate_model(model.to(device), data=test_loader, device=device)
#print(f"\nModel test accuracy: {(100 * score_model(predictions, labels)):.2f}%")

# Load data, targets and attributions.
x_batch, y_batch = iter(test_loader).next()
a_batch = explain(model.to(device), x_batch.to(device), y_batch.to(device), explanation_func="Saliency")

In [4]:
"""
# Plot some explanations!
import matplotlib.pyplot as plt

for i in range(20, 30): #[4140, 2091, 78, 1195]: 
    plt.imshow(denormalize_image(x_batch.cpu().data[i]).transpose(0, 1).transpose(1, 2))
    plt.show()
    plt.imshow(a_batch.cpu().data[i], cmap="seismic")
    plt.colorbar()
    plt.show()
""";

### Robustness tests

In [None]:
# One-liner to measure robustness of provided attributions.
scores = RobustnessTest(**{
    "similarity_func": lipschitz_constant,
    "perturb_func": gaussian_noise,
})(model=model, 
   x_batch=x_batch.cpu().numpy(), 
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   **{"explanation_func": "Saliency", "device": device})

scores

In [None]:
# One-liner to measure continuity of provided attributions.
scores = ContinuityTest(**{
    "similarity_func": correlation_spearman,
    "perturb_func": translation_x_direction,
    "nr_patches": 4,
    "nr_steps": 10,
})(model=model, 
   x_batch=x_batch.cpu().numpy(),
   y_batch=y_batch.cpu().numpy(),
   a_batch=a_batch.cpu().numpy(),
   **{"explanation_func": "Saliency", "device": device})

scores

In [None]:
# One-liner to measure input independence of provided attributions.
scores = InputIndependenceRate(**{
    "similarity_func": abs_difference,
    "perturb_func": optimization_scheme, # TODO.
    "perturb_std": 0.01,
    "threshold": 0.1,
})(model=model, 
   x_batch=x_batch.cpu().numpy(), 
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   **{"explanation_func": "Saliency", "device": device})

scores

In [None]:
# One-liner to measure local lipschitz constant of provided attributions.
scores = LocalLipschitzEstimate(**{
    "similarity_func": lipschitz_constant,
    "perturb_func": gaussian_noise,
    "distance_numerator": distance_euclidean,
    "distance_denominator": distance_euclidean,
    "perturb_std": 0.1,
    "nr_steps": 10, #200
})(model=model, 
   x_batch=x_batch.cpu().numpy(), 
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   **{"explanation_func": "Saliency", "device": device})

scores

In [None]:
# One-liner to measure local lipschitz constant of provided attributions.
scores = SensitivityMax(**{
    "similarity_func": difference,
    "perturb_func": uniform_sampling,
    "norm_numerator": fro_norm,
    "norm_denominator": fro_norm,
    "perturb_radius": 0.2,
    "nr_steps": 10,
})(model=model, 
   x_batch=x_batch.cpu().numpy(), 
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   **{"explanation_func": "Saliency", "device": device})
scores

### Faithfulness tests

In [None]:
# One-liner to for faithfulness base class of provided attributions.
scores = FaithfulnessTest(**{
    "perturb_func": baseline_replacement_by_indices,
    "similarity_func": correlation_spearman,
    "perturb_baseline": 0.0,  
    "pixels_in_step": 128,
})(model=model, 
   x_batch=x_batch.cpu().numpy(), 
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   **{"explanation_func": "Saliency", "device": device})

scores

In [None]:
# One-liner to measure faithfulness estimate of provided attributions.
scores = FaithfulnessEstimate(**{
    "perturb_func": replacement_by_indices,
    "similarity_func": correlation_pearson,
    "perturb_baseline": 0.0,  
    "pixels_in_step": 8,
})(model=model, 
   x_batch=x_batch.cpu().numpy(),
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   **{"explanation_func": "Saliency", "device": device})

scores

In [None]:
# One-liner to measure infidelity of provided attributions.
scores = Infidelity(**{
    "perturb_func": baseline_replacement_by_patch,
    "similarity_func": mse,
    "perturb_baseline": "black",  
    "perturb_patch_sizes": [14, 28] #list(np.arange(10,30)),
})(model=model, 
  x_batch=x_batch.cpu().numpy(), 
  y_batch=y_batch.cpu().numpy(), 
  a_batch=a_batch.cpu().numpy(), 
  **{"explanation_func": "Saliency", "device": device})

scores

In [None]:

class MonotonicityMetric(FaithfulnessTest):
    """
    Implementation of Montonicity Metric by Luss at el., 2019.

    It captures attributions' faithfulness by incrementally adding each attribute
    in order of increasing importance and evaluating the effect on model performance.
    As more features are added, the performance of the model is expected to increase
    and thus result in monotonically increasing model performance.

    References:
        Luss, Ronny, et al. "Generating contrastive explanations with monotonic attribute functions." 
        arXiv preprint arXiv:1905.12698 (2019).

    Current assumptions:
        • ...
    """

    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs
        self.similarity_func = self.kwargs.get("similarity_func", correlation_pearson)
        self.perturb_func = self.kwargs.get("perturb_func", baseline_replacement_by_indices)
        self.perturb_baseline = self.kwargs.get("perturb_baseline", 0.0)

        self.img_size = self.kwargs.get("img_size", 224)
        self.nr_channels = self.kwargs.get("nr_channels", 3)

        self.pixels_in_step = self.kwargs.get("pixels_in_step", 1)
        assert (
                           self.img_size * self.img_size) % self.pixels_in_step == 0, "Set 'pixels_in_step' so that the modulo remainder returns 0 given the image size."
        self.max_steps_per_input = self.kwargs.get("max_steps_per_input", None)

        if self.max_steps_per_input is not None:
            assert (
                               self.img_size * self.img_size) % self.max_steps_per_input == 0, "Set 'max_steps_per_input' so that the modulo remainder returns 0 given the image size."
            self.pixels_in_step = (self.img_size * self.img_size) / self.max_steps_per_input

        super(FaithfulnessTest, self).__init__()

    def __call__(
            self,
            model,
            x_batch: np.array,
            y_batch: Union[np.array, int],
            a_batch: Union[np.array, None],
            **kwargs
    ):
        assert (
                "explanation_func" in kwargs
        ), "To run RobustnessTest specify 'explanation_func' (str) e.g., 'Gradient'."
        assert (
                np.shape(x_batch)[0] == np.shape(a_batch)[0]
        ), "Inputs and attributions should include the same number of samples."

        if a_batch is None:
            explain(
                model.to(kwargs.get("device", None)),
                x_batch,
                y_batch,
                explanation_func=kwargs.get("explanation_func", "Gradient"),
                device=kwargs.get("device", None),
            )

        results = []

        for ix, (x, y, a) in enumerate(zip(x_batch, y_batch, a_batch)):

            # Get indices of sorted attributions (descending).
            a = abs(a.flatten())
            a_indices = np.argsort(a)

            # Predict on input.
            with torch.no_grad():
                y_pred = float(model(
                    torch.Tensor(x)
                        .reshape(1, self.nr_channels, self.img_size, self.img_size)
                        .to(kwargs.get("device", None)))[:, y])

            preds = np.zeros(self.img_size*self.img_size)

            for i_ix, a_ix in enumerate(a_indices[::self.pixels_in_step]):

                if i_ix == 0:
                    a_ix = a_indices[:self.pixels_in_step]
                else:
                    a_ix = a_indices[(self.pixels_in_step * i_ix):(self.pixels_in_step * (i_ix + 1))]

                x_perturbed = self.perturb_func(img=x.flatten(),
                                                **{"index": a_ix, "perturb_baseline": self.perturb_baseline})
                # Predict on perturbed input x.
                with torch.no_grad():
                    y_pred_i = float(model(
                        torch.Tensor(x_perturbed)
                            .reshape(1, self.nr_channels, self.img_size, self.img_size)
                            .to(kwargs.get("device", None)))[:, y])
                preds[i_ix] = float(y_pred_i)
            
            results.append(self.similarity_func(a=att_sum, b=pred_deltas))
            #np.all(np.diff(y_pred_i[a_indices]) >= 0)
            

        return results

In [None]:
# One-liner to measure faithfulness estimate of provided attributions.
scores = MonotonicityMetric(**{
    "perturb_func": replacement_by_indices,
    "similarity_func": correlation_pearson,
    "perturb_baseline": 0.0,  
    "pixels_in_step": 8,
})(model=model, 
   x_batch=x_batch.cpu().numpy(),
   y_batch=y_batch.cpu().numpy(), 
   a_batch=a_batch.cpu().numpy(), 
   **{"explanation_func": "Saliency", "device": device})

scores

### Option 1. Evaluate the robustness of attributions in one line of code.

### Option 2. Evaluate the robustness of provided attributions while enjoying more functionality of Quantifier and Plotting.

In [None]:
# Provide notebooks for the different use cases: compare models, XAI methods, different measures
# ...

In [None]:
# Specify the tests.
tests = [RobustnessTest(**{
    "similarity_function": similarity_fn,
    "perturbation_function": gaussian_blur,
}) for similarity_fn in [lipschitz_constant, distance_euclidean, cosine]]

# Load attributions of another explanation method.
a_batch_intgrad = IntegratedGradients(model).attribute(inputs=x_batch, targets=y_batch)

# Init the quantifier object.
quantifier = Quantifier(measures=tests, io_object=h5py.File("PATH_TO_H5PY_FILE"), checkpoints=..)

# Score the tests.
results = [quantifier.score(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch)
           for a_batch in [a_batch_saliency, a_batch_intgrad]]

# Plot Saliency vs Integrated Gradients.
Plotting(results, show=False, path_to_save="PATH_TO_SAVE_FIGURE")