# Quantus + NLP
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/understandable-machine-intelligence-lab/Quantus/main?labpath=tutorials%2FTutorial_NLP_Demonstration.ipynb)


This tutorial demonstrates how to use the library for evaluation explanation of text classification models.
For this purpose, we use a pre-trained `Distilbert` model from [Huggingface](https://huggingface.co/models) hub and `GLUE/SST2` dataset [here](https://huggingface.co/datasets/sst2).

Author: Artem Sereda

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eWK9ebfMUVRG4mrOAQvXdJ452SMLfffv?usp=sharing)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
from datasets import load_dataset
import tensorflow as tf
import logging
from IPython.core.display import HTML
import random
# import matplotx
import matplotlib.pyplot as plt
import gc
from quantus.helpers.plotting import plot_model_parameter_randomisation_experiment
from quantus.nlp.helpers.utils import map_explanations, get_logits_for_labels
import quantus.nlp as qn


# plt.style.use(matplotx.styles.dracula)
logging.getLogger("absl").setLevel(logging.WARNING)
random.seed(42)
tf.config.list_physical_devices()

## 1) Preliminaries

### 1.1 Load pre-trained model and tokenizer from [huggingface](https://huggingface.co/models) hub

In [None]:
model = qn.TensorFlowHuggingFaceTextClassifier.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
)

### 1.2 Load test split of [GLUE/SST2](https://huggingface.co/datasets/sst2) dataset

In [None]:
BATCH_SIZE = 32

dataset = load_dataset("sst2")["test"]
x_batch = dataset["sentence"][:BATCH_SIZE]
random.shuffle(x_batch)

Run an example inference, and demonstrate models predictions.

In [None]:
def decode_labels(y_batch: np.ndarray):
    """A helper function to map integer labels to human-readable class names."""
    return [model.internal_model.config.id2label[i] for i in y_batch]


y_batch = model.predict(x_batch).argmax(axis=-1)

# Show the x, y data.
pd.DataFrame([x_batch[:10], decode_labels(y_batch[:10])]).T

### 1.5 Visualise the explanations.

In [None]:
labels = list(map(lambda i: "Predicted label: " + i, decode_labels(y_batch)))

In [None]:
# Visualise Integrated Gradients explanations.
a_batch_int_grad = map_explanations(
    qn.explain(model, x_batch[:4], y_batch[:4], method="IntGrad"),
    qn.normalize_sum_to_1,
)


qn.visualise_explanations_as_pyplot(a_batch_int_grad, labels=labels)

In [None]:
# Visualise SHAP explanations.
a_batch_shap = map_explanations(
    qn.explain(model, x_batch[:4], y_batch[:4], method="SHAP", call_kwargs={"max_evals": 10}),
    qn.normalize_sum_to_1,
)
qn.visualise_explanations_as_pyplot(a_batch_shap, labels=labels)

## 2) Quantitative analysis using Quantus.
To see all available metric and their category, we can run `quantus.nlp.available_metrics()`

In [None]:
qn.available_metrics()

In [None]:
# We will need it later
# fmt: off
unk_token_embedding = model.embedding_lookup([model.internal_tokenizer.unk_token_id])[0, 0]  # noqa
# fmt: on

metrics = {
    # By default, perturbation is applied to plain-text inputs.
    "Average Sensitivity": qn.AvgSensitivity(nr_samples=10, disable_warnings=True),
    "Max Sensitivity": qn.MaxSensitivity(
        nr_samples=10,
        # Perturbation type is inferred from perturb_func signature.
        perturb_func=qn.gaussian_noise,
        disable_warnings=True,
    ),
    "Relative Input Stability": qn.RelativeInputStability(
        nr_samples=10,
        disable_warnings=True,
    ),
    "Relative Output Stability": qn.RelativeOutputStability(
        nr_samples=10,
        disable_warnings=True,
    ),
    "Relative Representation Stability": qn.RelativeRepresentationStability(
        nr_samples=10,
        disable_warnings=True,
    ),
    "Model Parameter Randomisation": qn.ModelParameterRandomisation(
        seed=42,
        disable_warnings=True,
    ),
    # "Random Logit": qn.RandomLogit(num_classes=2, seed=42, disable_warnings=True),
    "Token Flipping": qn.TokenFlipping(disable_warnings=True, abs=True),
}

# By default, qn.explain is used to generate explanations.
call_kwargs = {
    # We use GradXInput as default method for all methods.
    "explain_func_kwargs": {"method": "GradXInput"},
    # We evaluate Relative Input Stability for IntGrad with different baselines.
    "Relative Input Stability": [
        {
            "explain_func_kwargs": {
                "method": "IntGrad",
            }
        },
        {
            "explain_func_kwargs": {
                "method": "IntGrad",
                "baseline_fn": tf.function(lambda x: unk_token_embedding),
            }
        },
    ],
}

# Notice, that no y_batch is required and explain_func_kwargs are passed only to metrics' __call__ method.
result = qn.evaluate(metrics, model, x_batch, call_kwargs=call_kwargs, run_gc=False)
gc.collect()

### 2.2 Results visualisation

For pruning task (default one) the tokens are removed in order of increasing relevance scores.
We expect Mean Squared Error to increase as we remove more important ones, 
which acts as a proof that higher scores are assigned to important features.

In [None]:
qn.plot_token_flipping_experiment(
    result["Token Flipping"],
    get_logits_for_labels(model.predict(x_batch), y_batch),
    task="pruning"
)
plt.show()

Model Parameter Randomisation shows us how layers' weights correlate with explanations.
0 - means randomizing layers weights changes explanation completely
1 - means randomizing layers weighs did not change explanation at all.

In [None]:
mpr = result["Model Parameter Randomisation"]
plot_model_parameter_randomisation_experiment(mpr)

Robustness metrics show us how sensitive are explanations to slight perturbations.
For this category it is impossible to give general advice. Typically,
- Higher values could mean explanations take more inputs into account, than model
- Lower values could mean explanations are highly biased against certain features from input space (tokens).

In [None]:
pd.DataFrame.from_dict(
    {
        k: v
        for k, v in result.items()
        if k
        in (
            "Average Sensitivity",
            "Max Sensitivity",
        )
    }
).boxplot()
# plt.yscale("log")
plt.title("Sensitivity metrics comparison")

In [None]:
rs = {
    "RIS_0": result["Relative Input Stability"][0],
    "RIS_{unk}": result["Relative Input Stability"][1],
    "ROS": result["Relative Output Stability"],
    "RRS": result["Relative Representation Stability"]
}

pd.DataFrame.from_dict(rs).boxplot()
plt.yscale("log")
plt.title("Relative Stability metrics comparison")