# Quantus + NLP

This tutorial demonstrates how to use the library for robustness evaluation
explanation of text classification models.
For this purpose, we use a pre-trained `Distilbert` model from [Huggingface](https://huggingface.co/models) and `GLUE/SST2` dataset [here](https://huggingface.co/datasets/sst2).

This is not a working example yet, and is meant only for demonstration purposes 
so far. For this demo, we use a (yet) unreleased version of Quantus.

Author: Artem Sereda

In [None]:
# Use an unreleased version of Quantus.
!pip install 'quantus @ git+https://github.com/aaarrti/Quantus.git@nlp-domain' --no-deps
!pip install transformers datasets nlpaug tf_explain

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting quantus@ git+https://github.com/aaarrti/Quantus.git@nlp-domain
  Cloning https://github.com/aaarrti/Quantus.git (to revision nlp-domain) to /tmp/pip-install-9wfq6g9y/quantus_790c297bd50b44deba2935555becf6d5
  Running command git clone -q https://github.com/aaarrti/Quantus.git /tmp/pip-install-9wfq6g9y/quantus_790c297bd50b44deba2935555becf6d5
  Running command git checkout -b nlp-domain --track origin/nlp-domain
  Switched to a new branch 'nlp-domain'
  Branch 'nlp-domain' set up to track remote branch 'nlp-domain' from 'origin'.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Using cached transformers-4.23.1-py3-none-any.whl (5.3 MB)
Collecting datasets
  Using cached datasets-2.6.1-py3-none-any.whl (441 kB)
Collecting nlpaug
  Using cached nlpaug-1.1.11-py3-none-any.whl (410 kB)
Collecting tf_explain


In [None]:
import numpy as np
import pandas as pd
from datasets import load_dataset
import tensorflow as tf
from functools import partial
import logging
from typing import NamedTuple, Any
from transformers import AutoTokenizer, TFDistilBertForSequenceClassification
from IPython.display import HTML
import quantus.nlp as qn

# Suppress debug logs.
logging.getLogger('absl').setLevel(logging.WARNING)
tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

## 1) Preliminaries

### 1.1 Load pre-trained model and tokenizer from [huggingface](https://huggingface.co/models) hub

In [None]:
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = TFDistilBertForSequenceClassification.from_pretrained(MODEL_NAME)

Downloading:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/629 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFDistilBertForSequenceClassification.

All the layers of TFDistilBertForSequenceClassification were initialized from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.


### 1.2 Load test split of [GLUE/SST2](https://huggingface.co/datasets/sst2) dataset

In [None]:
BATCH_SIZE = 8
dataset = load_dataset("sst2")['test']
x_batch = dataset['sentence'][:BATCH_SIZE]

Downloading builder script:   0%|          | 0.00/3.77k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.85k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/4.61k [00:00<?, ?B/s]

Downloading and preparing dataset sst2/default (download: 7.09 MiB, generated: 4.78 MiB, post-processed: Unknown size, total: 11.88 MiB) to /root/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5...


Downloading data:   0%|          | 0.00/7.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

Dataset sst2 downloaded and prepared to /root/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Run an example inference, and demonstrate models predictions.

In [None]:
CLASS_NAMES = ['negative', 'positive']

def decode_labels(y_batch: np.ndarray, CLASS_NAMES: list):
    """A helper function to map integer labels to human-readable class names."""
    return [CLASS_NAMES[i] for i in y_batch]
    
# Run tokenizer.
tokens = tokenizer(x_batch, padding='longest', return_tensors='tf')
logits = model(**tokens).logits
y_batch = tf.argmax(tf.nn.softmax(logits), axis=1).numpy()

# Show the x, y data.
pd.DataFrame([x_batch, decode_labels(y_batch)]).T

Unnamed: 0,0,1
0,uneasy mishmash of styles and genres .,negative
1,this film 's relationship to actual tension is...,negative
2,"by the end of no such thing the audience , lik...",positive
3,director rob marshall went out gunning to make...,positive
4,lathan and diggs have considerable personal ch...,positive
5,a well-made and often lovely depiction of the ...,positive
6,none of this violates the letter of behan 's b...,negative
7,although it bangs a very cliched drum at times...,positive


### 1.3 Helper functions: visualise explanations

There are not many XAI libraries for NLP out there, so here we fully relly on our own implementations of explanation methods. This section write functions to visualise our explanations. 

In [None]:
def create_div(explanation: qn.TokenSalience, predicted_label):
    """
    Created a div with background CSS property values based on relevance score.

    Parameters
    ----------
    explanations: quantus.nlp.types.TokenSalience.
    predicted_label: a str with predicted class name.

    Returns
    -------
    div: str

    The highest saliency score gets red, namely RGB(255, 0, 0).
    The lowest score gets white, namely RGB(255, 255, 255).
    Everything in between is linearly interpolated in [RED, WHITE] interval.

    """
    # Create a container, which inherits root styles.
  
    div_template = (
        """
        <div class="container">
            <p>
                Predicted: {{predicted_label}} <br>
                {{saliency_map}}
            </p>
        </div>
        """
    )
    
    # For each token, create a separate highlight span with different background color.
    token_span_template = (
        """
        <span class="highlight-container" style="background:{{color}};">
            <span class="highlight"> {{token}} </span>
        </span>
        """
    )
    tokens = explanation.tokens
    grads = explanation.salience
    body = ""
    max_grad = np.max(grads)

    for t, g in zip(tokens, grads):
        # Calculate color based on relevance score in explanation.
        green = 255.0 - 255.0 * (g / max_grad)
        blue = 255.0 - 255.0 * (g / max_grad)
        token_span = (
            token_span_template
            .replace("{{color}}", f"rgb(255,{green},{blue})")
            .replace("{{token}}", t)
        )
        body += token_span + " "

    return (
        div_template
        .replace("{{predicted_label}}", predicted_label)
        .replace("{{saliency_map}}", body)
      )


def create_textual_heatmap(explanations, predicted_labels):
    """
    Builds an HTML element to visualise textual heatmap, later
    output can be rendered in jupyter.

    Parameters
    ----------
    explanations: a List of qunatus.nlp.types.TokenSalience.
        The explanations.
    predicted_labels: 1D np.ndarray.
        The predicted labels.

    Returns
    -------
    heatmap: str
    """
    # Define top-level styles
    heatmap_template = (
        """
        <style>

            .container {
                line-height: 1.4;
                text-align: center;
                margin: 10px 10px 10px 10px;
                color: black;
                background: white;
            }

            p {
                font-size: 16px;
            }

            .highlight-container, .highlight {
                position: relative;
                border-radius: 10% 10% 10% 10%;
            }

            .highlight-container {
                display: inline-block;
            }

            .highlight-container:before {
                content: " ";
                display: block;
                height: 90%;
                width: 100%;
                margin-left: -3px;
                margin-right: -3px;
                position: absolute;
                top: -1px;
                left: -1px;
                padding: 10px 3px 3px 10px;
            }

        </style>
        
        {{body}}
        """
    )

    spans = ""
    # For each token, create a separate div holding whole input sequence on 1 line.
    for i, j in zip(explanations, predicted_labels):        
        name = CLASS_NAMES[j]
        div = create_div(i, name)
        spans += div
    return heatmap_template.replace("{{body}}", spans)

### 1.4 Helper functions: generate explanations

Write out functions to generate explanations using baseline methods: Gradient Norm and Integrated Gradients

In [None]:
def explain_gradient_norm(model, input, target, tokenizer):
    """
    Generate Gradient Norm explanation for a single sequence.
    
    Parameters
    ----------
    model: tf.keras.Model) 
        Model used to predict label.
    input: str
        A single input sequence.
    target: int 
        The predicted label.
    tokenizer: tokenizer
        A tokenizer object to encode inputs.

    Returns
    -------
    explanation: a named tuple of tokens and corresponding relevance scores.
    """

    # Convert input to tokens.
    token_ids = tokenizer([input], return_tensors='tf')["input_ids"]
    
    # Convert tokens to embeddings.
    embeddings = model.distilbert.get_input_embeddings()(input_ids=token_ids)
    with tf.GradientTape() as tape:
        tape.watch(embeddings)
        logits = model(None, inputs_embeds=embeddings).logits
        logits_for_label = tf.gather(logits, axis=1, indices=target)

    # Compute gradients of logits with respect to embeddings.    
    grads = tape.gradient(logits_for_label, embeddings)
    
    # Compute L2 norm of gradients.
    grad_norm = tf.linalg.norm(grads, axis=-1)
    
    return qn.TokenSalience(
        tokenizer.convert_ids_to_tokens(token_ids[0]), grad_norm.numpy()[0]
    )


def explain_gradient_norm_batch(model: tf.keras.Model, inputs, targets, tokenizer):
    """A wrapper around explain_gradient_norm which allows calling it on batch"""
    return [
        explain_gradient_norm(model, x, y, tokenizer)
        for x, y in zip(inputs, targets)
    ]


def get_interpolated_inputs(baseline: np.ndarray, target: np.ndarray, num_steps=10) -> np.ndarray:
    """
    Gets num_step linearly interpolated inputs from baseline to target.
    Reference: https://github.com/PAIR-code/lit/blob/main/lit_nlp/components/gradient_maps.py#L238

    Returns
    -------
    interpolated_inputs: <float32>[num_steps, num_tokens, emb_size]
    """
    if num_steps <= 0:
        return np.array([])
    if num_steps == 1:
        return np.array([baseline, target])

    delta = target - baseline  # <float32>[num_tokens, emb_size]
    
    # Creates scale values array of shape [num_steps, num_tokens, emb_dim],
    # where the values in scales[i] are the ith step from np.linspace. <float32>[num_steps, 1, 1]
    scales = np.linspace(0, 1, num_steps + 1, dtype=np.float32)[:, np.newaxis, np.newaxis]
    shape = (num_steps + 1,) + delta.shape
    
    # <float32>[num_steps, num_tokens, emb_size]
    deltas = scales * np.broadcast_to(delta, shape)
    interpolated_inputs = baseline + deltas
    
    return interpolated_inputs  


def estimate_integral(path_gradients: np.ndarray) -> np.ndarray:
    """
    Estimates the integral of the path_gradients using trapezoid rule.
    Reference: https://github.com/PAIR-code/lit/blob/main/lit_nlp/components/gradient_maps.py#L257
    """
    
    path_gradients = (path_gradients[:-1] + path_gradients[1:]) / 2
    
    return tf.reduce_mean(path_gradients, axis=(0, 2))


def explain_int_grad(tokenizer, model: tf.keras.Model, input, target):
    """
    Generate Integrated Gradients explaination for a single sequence.
    Adjusted code from: https://github.com/PAIR-code/lit/blob/main/lit_nlp/components/gradient_maps.py#L181

    Parameters
    ----------
    model: tf.keras.Model) 
        Model used to predict label.
    input: str
        A single input sequence.
    target: int 
        The predicted label.
    tokenizer: tokenizer
        A tokenizer object to encode inputs.

    Returns
    -------
    explanation: a named tuple of tokens and corresponding relevance scores
    
    """
    # Convert input to tokens.
    token_ids = tokenizer([input], return_tensors='tf')["input_ids"]
    
    # Convert tokens to embeddings.
    embeddings = model.distilbert.get_input_embeddings()(input_ids=token_ids)[0]
    
    baseline = np.zeros_like(embeddings)
    
    # Generate interpolation from 0 to embeddings.
    interpolated_embeddings = get_interpolated_inputs(baseline, embeddings)
    interpolated_embeddings = tf.convert_to_tensor(interpolated_embeddings)

    with tf.GradientTape() as tape:
        tape.watch(interpolated_embeddings)
        logits = model(None, inputs_embeds=interpolated_embeddings).logits
        logits_for_label = tf.gather(logits, axis=1, indices=target)

    # Compute gradients of logits with respect to interpolations. 
    grads = tape.gradient(logits_for_label, interpolated_embeddings)
    
    # Integrate gradients.
    int_grad = tf.abs(estimate_integral(grads))
    
    return qn.TokenSalience(
        tokenizer.convert_ids_to_tokens(token_ids[0]),
        int_grad.numpy()
    )


def explain_int_grad_batch(model, inputs, targets, tokenizer):
    """A wrapper around explain_int_grad which allows calling it on batch"""
    return [
        explain_int_grad(
            tokenizer=tokenizer, model=model, input=x, target=y
        ) for x, y in zip(inputs, targets)
    ]


# Create functions which match the signature required by Quantus.
explain_gradient_norm_func = partial(explain_gradient_norm_batch, tokenizer=tokenizer)
explain_int_grad_func = partial(explain_int_grad_batch, tokenizer=tokenizer)

### 1.5 Visualise the explanations.

In [None]:
# Visualise GradNorm.
a_batch_grad_norm = explain_gradient_norm_func(model, x_batch, y_batch)
html = create_textual_heatmap(a_batch_grad_norm[:3], y_batch[:3])
HTML(html)

In [None]:
# Visualise IntegratedGradients explanations.
a_batch_int_grad = explain_int_grad_func(model, x_batch, y_batch)
html = create_textual_heatmap(a_batch_int_grad[:3], y_batch[:3])
HTML(html)

## 2) Quantitative analysis using Quantus
For this example, we compute [Sensitivity](https://arxiv.org/abs/1901.09392) metric

In [None]:
# This is only a workaround to account for hardcoded attribute access in lib.
class ModelTuple(NamedTuple):
    model: Any
    tokenizer: Any

# This is also only a workaround to account for hardcoded attribute access in lib.
model_stub = ModelTuple(model, tokenizer)
model_stub.model.bert = model.distilbert
model_stub.model.bert.embeddings.word_embeddings = model.distilbert.embeddings.weight

Average Sensitivity captures the average change in explanations under slight perturbation

In [None]:
# Instantiate metric.
avg_sensitivity = qn.AvgSensitivity()

# Evaluate avg sensitivity for Gradient Norm.
avg_sensitivity_grad_norm = avg_sensitivity(
    model=model_stub,
    x_batch=x_batch,
    y_batch=y_batch,
    perturb_func=qn.change_spelling,
    explain_func=explain_gradient_norm_func,
).mean()

# Evaluate avg sensitivity for Integrated Gradients.
avg_sensitivity_int_grad = avg_sensitivity(
    model=model_stub,
    x_batch=x_batch,
    y_batch=y_batch,
    perturb_func=qn.change_spelling,
    explain_func=explain_int_grad_func
).mean()

Collecting perturbations:   0%|          | 0/10 [00:00<?, ?it/s]

Collecting explanations:   0%|          | 0/9 [00:00<?, ?it/s]

Collecting perturbations:   0%|          | 0/10 [00:00<?, ?it/s]

Collecting explanations:   0%|          | 0/9 [00:00<?, ?it/s]

Maximum Sensitivity captures the maximal change in explanations under slight perturbation

In [None]:
# Instantiate metric.
max_sensitivity = qn.MaxSensitivity()

# Evaluate max sensitivity metric for Gradient Norm.
max_sensitivity_grad_norm = max_sensitivity(
    model=model_stub,
    x_batch=x_batch,
    y_batch=y_batch,
    perturb_func=qn.change_spelling,
    explain_func=explain_gradient_norm_func,
).mean()

# Evaluate max sensitivity metric for Integrated Gradients.
max_sensitivity_int_grad = max_sensitivity(
    model=model_stub,
    x_batch=x_batch,
    y_batch=y_batch,
    perturb_func=qn.change_spelling,
    explain_func=explain_int_grad_func
).mean()

Collecting perturbations:   0%|          | 0/10 [00:00<?, ?it/s]

Collecting explanations:   0%|          | 0/9 [00:00<?, ?it/s]

Collecting perturbations:   0%|          | 0/10 [00:00<?, ?it/s]

Collecting explanations:   0%|          | 0/9 [00:00<?, ?it/s]

Display results in tabular form

In [None]:
# Reformat the results.
all_results = np.asarray([
    [
        avg_sensitivity_grad_norm,
        avg_sensitivity_int_grad
    ],
    [
        max_sensitivity_grad_norm,
        max_sensitivity_int_grad
    ]
])

# Print out the evaluation outcome!
pd.DataFrame(
    all_results,
    columns=['Gradient Norm', 'Integrated Gradients'],
    index=['Average Sensitivity', 'Max Sensitivity']
)

Unnamed: 0,Gradient Norm,Integrated Gradients
Average Sensitivity,0.148746,9.245282e-11
Max Sensitivity,0.192074,1.34273e-10
