# Quantus + NLP
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/understandable-machine-intelligence-lab/Quantus/main?labpath=tutorials%2FTutorial_NLP_Demonstration.ipynb)


This tutorial demonstrates how to use the library for robustness evaluation explanation of text classification models.
For this purpose, we use a pre-trained `Distilbert` model from [Huggingface](https://huggingface.co/models) and `GLUE/SST2` dataset [here](https://huggingface.co/datasets/sst2).

Author: Artem Sereda

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eWK9ebfMUVRG4mrOAQvXdJ452SMLfffv?usp=sharing)

In [1]:
import numpy as np
import pandas as pd
from datasets import load_dataset
import tensorflow as tf
import logging
from IPython.core.display import HTML
import random
import matplotx
import matplotlib.pyplot as plt
import gc
from quantus.helpers.plotting import plot_model_parameter_randomisation_experiment
import quantus.nlp as qn


plt.style.use(matplotx.styles.dracula)
logging.getLogger("absl").setLevel(logging.WARNING)
random.seed(42)
tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

## 1) Preliminaries

### 1.1 Load pre-trained model and tokenizer from [huggingface](https://huggingface.co/models) hub

In [2]:
model = qn.TFHuggingFaceTextClassifier.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
)

Metal device set to: Apple M1 Pro


2023-02-23 21:08:02.664240: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-02-23 21:08:02.664776: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
All model checkpoint layers were used when initializing TFDistilBertForSequenceClassification.

All the layers of TFDistilBertForSequenceClassification were initialized from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.


### 1.2 Load test split of [GLUE/SST2](https://huggingface.co/datasets/sst2) dataset

In [3]:
BATCH_SIZE = 32
MINI_BATCH_SIZE = 4

dataset = load_dataset("sst2")["test"]
x_batch = dataset["sentence"][:BATCH_SIZE]
random.shuffle(x_batch)

mini_x_batch = x_batch[:MINI_BATCH_SIZE]

Found cached dataset sst2 (/Users/artemsereda/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)


  0%|          | 0/3 [00:00<?, ?it/s]

Run an example inference, and demonstrate models predictions.

In [4]:
CLASS_NAMES = ["negative", "positive"]


def decode_labels(y_batch: np.ndarray, class_names: [str]) -> [str]:
    """A helper function to map integer labels to human-readable class names."""
    return [class_names[i] for i in y_batch]


mini_y_batch = model.predict(mini_x_batch).argmax(axis=-1)

# Show the x, y data.
pd.DataFrame([mini_x_batch, decode_labels(mini_y_batch, CLASS_NAMES)]).T

2023-02-23 21:08:05.995848: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


2023-02-23 21:08:07.660948: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Unnamed: 0,0,1
0,this is junk food cinema at its greasiest .,negative
1,uneasy mishmash of styles and genres .,negative
2,it 's just incredibly dull .,negative
3,a feel-good picture in the best sense of the t...,positive


### 1.5 Visualise the explanations.

In [5]:
labels = list(
    map(lambda i: "Predicted label: " + i, decode_labels(mini_y_batch, CLASS_NAMES))
)

In [6]:
# Visualise Integrated Gradients explanations.
a_batch_int_grad = qn.normalise_attributions(
    qn.explain(model, mini_x_batch, mini_y_batch, method="IntGrad", num_steps=10),
    qn.normalize_sum_to_1,
)
HTML(
    qn.visualise_explanations_as_html(
        a_batch_int_grad, labels=labels, ignore_special_tokens=True
    )
)

In [7]:
# Visualise SHAP explanations.
a_batch_shap = qn.normalise_attributions(
    qn.explain(
        model, mini_x_batch, mini_y_batch, method="SHAP", call_kwargs={"max_evals": 10}
    ),
    qn.normalize_sum_to_1,
)
HTML(qn.visualise_explanations_as_html(a_batch_shap, labels=labels))

## 2) Quantitative analysis using Quantus

In [None]:
# We will need it later
# fmt: off
unk_token_embedding = model.embedding_lookup([model.tokenizer.tokenizer.unk_token_id])[0, 0]
# fmt: on

metrics = {
    # By default, perturbation is applied to plain-text inputs.
    "Average Sensitivity": qn.AvgSensitivity(nr_samples=10, disable_warnings=True),
    # We can run evaluation using numerical perturbation function, by specifying perturbation_type.
    "Max Sensitivity": qn.MaxSensitivity(
        nr_samples=10,
        # !! It is up to user to select compatible perturbation type and perturbation function.
        perturbation_type=qn.PerturbationType.latent_space,
        perturb_func=qn.gaussian_noise,
        disable_warnings=True,
    ),
    # By default, we normalise scores, so they sum up to 1, this behaviour can be disabled
    # Additionally we can run evaluation on absolute values of explanation scores.
    "Local Lipschitz Estimate": qn.LocalLipschitzEstimate(
        nr_samples=10, normalise=False, abs=True, disable_warnings=True
    ),
    "Relative Input Stability": qn.RelativeInputStability(
        nr_samples=10,
        disable_warnings=True,
    ),
    "Relative Output Stability": qn.RelativeOutputStability(
        nr_samples=10,
        disable_warnings=True,
    ),
    "Relative Representation Stability": qn.RelativeRepresentationStability(
        nr_samples=10,
        disable_warnings=True,
    ),
    "Model Parameter Randomisation": qn.ModelParameterRandomisation(
        seed=42,
        disable_warnings=True,
    ),
    "Random Logit": qn.RandomLogit(num_classes=2, seed=42, disable_warnings=True),
    "Token Flipping": qn.TokenFlipping(
        disable_warnings=True, abs=True
    ),
}

# By default, qn.explain is used to generate explanations.
call_kwargs = {
    # We use GradXInput as default method for all methods.
    "explain_func_kwargs": {"method": "GradXInput"},
    # We evaluate Relative Input Stability for IntGrad with different baselines.
    "Relative Input Stability": [
        {"explain_func_kwargs": {"method": "IntGrad"}},
        {
            "explain_func_kwargs": {
                "method": "IntGrad",
                "baseline_fn": lambda x: unk_token_embedding,
            }
        },
    ],
}

# Notice, that no y_batch is required and explain_func_kwargs are passed only to metrics' __call__ method.
result = qn.evaluate(metrics, model, x_batch, call_kwargs=call_kwargs)
gc.collect()

Evaluation...:   0%|          | 0/9 [00:00<?, ?it/s]

2023-02-23 21:08:15.823081: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
The settings for perturbing input e.g., 'perturb_func' didn't cause change in input. Reconsider the parameter settings.


#### 2.2 Results visualisation

In [None]:
qn.plot_token_flipping_experiment(
    result["Token Flipping"], model.predict(x_batch).argmax(axis=-1)
)

In [None]:
ris = result["Relative Input Stability"]
pd.DataFrame(ris).T.plot().legend(["[0..0] as Baseline", "[UNK] as baseline"])
plt.yscale("log")
plt.title("Relative Input Stability")
plt.xticks([])

In [None]:
mpr = result["Model Parameter Randomisation"]
plot_model_parameter_randomisation_experiment(mpr)

In [None]:
pd.DataFrame.from_dict(
    {
        k: v
        for k, v in result.items()
        if k
        not in (
            "ModelParameterRandomisation",
            "Relative Input Stability",
            "Token Flipping",
        )
    }
).plot()
plt.yscale("log")
plt.title("Quantus NLP metrics comparison")