# Getting started

This notebook shows how to get started with Quantus on tensorflow

In [1]:
# Imports general.
from __future__ import annotations
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tf_explain

import quantus
from quantus.metrics import *

## 1. Preliminaries

### 1.1 Load datasets

We will then load a batch of input, output pairs that we generate explanations for, then to evaluate.

In [3]:
# Load datasets
ds_test = tfds.load(
    'imagenet_v2',
    split=['test'],
    as_supervised=True,
    batch_size=8,
    try_gcs=True
)
ds_test = ds_test[0]
ds_test = ds_test.map(lambda x,y: (tf.image.resize(x, (224, 224)), y))

In [None]:
x_batch, y_batch = ds_test.take(1).as_numpy_iterator().next()

In [None]:
# Plot some inputs!
i = 0
nr_images = 5
fig, axes = plt.subplots(nrows=1, ncols=nr_images, figsize=(nr_images*3, int(nr_images*2/3)))
for x_batch, y_batch in ds_train.take(nr_images):
    axes[i].imshow((np.reshape(x_batch.numpy(), (28, 28)) * 255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i].title.set_text(f"MNIST class - {y_batch}")
    axes[i].axis("off")
    i += 1

plt.show()

In [None]:
# Build a training pipeline

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.expand_dims(tf.cast(image, tf.float32), axis = 3) / 255., label

def to_rgb(image, label):
    return grayscale_to_rgb(image), label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

In [None]:
# Build an evaluation pipeline
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

In [None]:
x_batch, y_batch = next(iter(ds_test))
x_batch = x_batch.numpy() [:,:,:,0]
y_batch = y_batch.numpy() 

### 1.2 Train a model

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

print(f"\n Model architecture: {model.summary()}\n")

### 1.3 Generate explanations

There exist multiple ways to generate explanations for neural network models e.g., using `captum` or `innvestigate` libraries. In this example, we rely on the `quantus.explain` functionality (a simple wrapper around `captum`) however use whatever approach or library you'd like to create your explanations.

**Requirements.**

* **Data type.** Similar to the x-y pairs, the attributions should also be of type `np.ndarray`
* **Shape.** Sharing all the same dimensions as the input (expect for nr_channels which for explanations is equal to 1). For example, if x_batch is of size (128, 3, 224, 224) then the attributions should be of size (128, 1, 224, 224).

In [None]:
# Generate Integrated Gradients attributions of the first batch of the test set.
a_batch_intgrad = (
            np.array(
                list(
                    map(
                        lambda x, y: tf_explain.core.integrated_gradients.IntegratedGradients().explain(
                            ([x], None), model, y, n_steps=10
                        ),
                        x_batch,
                        y_batch,
                    )
                ),
                dtype=float,
            )
            / 255
        )

Visualise attributions given model and pairs of input-output.

In [None]:
# Plot explanations!
nr_images = 3
fig, axes = plt.subplots(nrows=nr_images, ncols=2, figsize=(nr_images*2.5, int(nr_images*3)))
for i in range(nr_images):
    axes[i, 0].imshow((np.reshape(x_batch[i], (28, 28)) * 255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i, 0].title.set_text(f"MNIST digit {y_batch[i].item()}")
    axes[i, 0].axis("off")
    a = axes[i, 1].imshow(a_batch_intgrad[i], cmap="seismic")
    axes[i, 1].title.set_text(f"Integrated Gradients")
    axes[i, 1].axis("off")
plt.tight_layout()
plt.show()

## 2. Quantative evaluation using Quantus

We can evaluate our explanations on a variety of quantuative criteria but as a motivating example we test the Max-Sensitivity (Yeh at el., 2019) of the explanations. This metric tests how the explanations maximally change while subject to slight perturbations.

In [None]:
# Define metric for evaluation.
metric_init = quantus.MaxSensitivity(nr_samples=10,
    lower_bound=0.1,
    norm_numerator=quantus.fro_norm,
    norm_denominator=quantus.fro_norm,
    perturb_func=quantus.uniform_noise,
    similarity_func=quantus.difference,
    disable_warnings=True,
    normalise=True,
    abs=True,)

In [None]:
# Return Max-Sensitivity scores in an one-liner - by calling the metric instance.
scores_intgrad_maxs = metric_init(model=model, 
                                  x_batch=x_batch,
                                  y_batch=y_batch,
                                  a_batch=a_batch_intgrad,
                                  explain_func=quantus.explain,
                                  explain_func_kwargs={"method": "IntegratedGradients"})

In [None]:
metrics = {"max-Sensitivity": quantus.MaxSensitivity(**params_eval_maxs)}

xai_methods = {"IntegratedGradients": a_batch_intgrad}

results = quantus.evaluate(metrics=metrics,
                           xai_methods=xai_methods,
                           model=model,
                           x_batch=x_batch,
                           y_batch=y_batch,
                           agg_func=np.mean,
                           explain_func=quantus.explain,
                           explain_func_kwargs={"method": "IntegratedGradients"})

df = pd.DataFrame(results)
df

In [None]:
# Calculate Selectivity (Montavon et al., 2018)
metric_init_select = quantus.Selectivity(perturb_func=quantus.baseline_replacement_by_patch,
    disable_warnings=True,
    normalise=True,
    abs=True,
    perturb_baseline="uniform",
    patch_size=4,
 )

# Return Selectivity scores in an one-liner - by calling the metric instance.
scores_intgrad_maxs = metric_init_select(model=model, 
                                  x_batch=x_batch,
                                  y_batch=y_batch,
                                  a_batch=a_batch_intgrad,
                                  explain_func=quantus.explain,
                                  explain_func_kwargs={"method": "IntegratedGradients"})