### In this notebook, we demonstrate how robustness evaluation of explanations for text classification models could look like

In [26]:
import numpy as np
import pandas as pd
import quantus.nlp as qn
from datasets import load_dataset
import tensorflow as tf
from functools import partial
import logging
from typing import NamedTuple, Any
from transformers import AutoTokenizer, TFDistilBertForSequenceClassification
from IPython.display import HTML

MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
logging.getLogger('absl').setLevel(logging.WARNING)
tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

#### Load pre-trained model and tokenizer from [huggingface](https://huggingface.co/models) hub

In [27]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = TFDistilBertForSequenceClassification.from_pretrained(MODEL_NAME)

Some layers from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english were not used when initializing TFDistilBertForSequenceClassification: ['dropout_19']
- This IS expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFDistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english and are newly initialized: ['dropout_39']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#### Load test split of [GLUE/SST2](https://huggingface.co/datasets/sst2) dataset

In [29]:
CLASS_NAMES = ['negative', 'positive']


def decode_labels(y_batch):
    return [CLASS_NAMES[i] for i in y_batch]

In [30]:
BATCH_SIZE = 8
dataset = load_dataset("sst2")['test']
x_batch = dataset['sentence'][:BATCH_SIZE]



  0%|          | 0/3 [00:00<?, ?it/s]

#### Run an example inference

In [32]:
tokens = tokenizer(x_batch, padding='longest', return_tensors='tf')
logits = model(**tokens).logits
y_batch = tf.argmax(tf.nn.softmax(logits), axis=1).numpy()

pd.DataFrame([x_batch, decode_labels(y_batch)]).T

Unnamed: 0,0,1
0,uneasy mishmash of styles and genres .,negative
1,this film 's relationship to actual tension is...,negative
2,"by the end of no such thing the audience , lik...",positive
3,director rob marshall went out gunning to make...,positive
4,lathan and diggs have considerable personal ch...,positive
5,a well-made and often lovely depiction of the ...,positive
6,none of this violates the letter of behan 's b...,negative
7,although it bangs a very cliched drum at times...,positive


#### Let's also create a function to visualize our explanations

In [33]:
def _create_div(explanation: qn.TokenSalience, predicted_label):
    div = (
        """
        <div class="container">
            <p>
                Predicted: {{predicted_label}} <br>
                {{saliency_map}}
            </p>
        </div>
        """
    )

    TOKEN_SPAN = (
        """
        <span class="highlight-container" style="background:{{color}};">
            <span class="highlight"> {{token}} </span>
        </span>
        """
    )
    tokens = explanation.tokens
    grads = explanation.salience
    body = ""
    max_grad = np.max(grads)

    for t, g in zip(tokens, grads):
        green = 255.0 - 255.0 * (g / max_grad)
        blue = 255.0 - 255.0 * (g / max_grad)
        token_span = TOKEN_SPAN.replace(
            "{{color}}", f"rgb(255,{green},{blue})"
        ).replace("{{token}}", t)
        body += token_span + " "

    return div.replace("{{predicted_label}}", predicted_label).replace(
        "{{saliency_map}}", body
    )


def create_textual_heatmap(explanations, predicted_labels):
    style = (
        """
        <style>

            .container {
                line-height: 1.4;
                text-align: center;
                margin: 10px 10px 10px 10px;
                color: black;
                background: white;
            }

            p {
                font-size: 16px;
            }

            .highlight-container, .highlight {
                position: relative;
                border-radius: 10% 10% 10% 10%;
            }

            .highlight-container {
                display: inline-block;
            }

            .highlight-container:before {
                content: " ";
                display: block;
                height: 90%;
                width: 100%;
                margin-left: -3px;
                margin-right: -3px;
                position: absolute;
                top: -1px;
                left: -1px;
                padding: 10px 3px 3px 10px;
            }

        </style>
        """
    )

    html_heatmap = style
    for i, j in zip(explanations, predicted_labels):
        name = CLASS_NAMES[j]
        div = _create_div(i, name)
        html_heatmap += div
    return html_heatmap

#### Generate and visualize explanations using baseline methods: Gradient Norm and IntegratedGradients

In [34]:
def explain_gradient_norm(model, input, target, model_name, tokenizer):
    token_ids = tokenizer([input], return_tensors='tf')["input_ids"]
    embeddings = getattr(model, model_name).get_input_embeddings()(input_ids=token_ids)
    with tf.GradientTape() as tape:
        tape.watch(embeddings)
        logits = model(None, inputs_embeds=embeddings).logits
        logits_for_label = tf.gather(logits, axis=1, indices=target)
    grads = tape.gradient(logits_for_label, embeddings)
    grad_norm = tf.linalg.norm(grads, axis=-1)
    return qn.TokenSalience(
        tokenizer.convert_ids_to_tokens(token_ids[0]), grad_norm.numpy()[0]
    )


def explain_gradient_norm_batch(model: tf.keras.Model, inputs, targets, model_name, tokenizer):
    return [
        explain_gradient_norm(model, x, y, model_name, tokenizer)
        for x, y in zip(inputs, targets)
    ]


explain_gradient_norm_func = partial(explain_gradient_norm_batch, model_name="distilbert", tokenizer=tokenizer)
a_batch_grad_norm = explain_gradient_norm_func(model, x_batch, y_batch)

In [35]:
html = create_textual_heatmap(a_batch_grad_norm[:3], y_batch[:3])
HTML(html)

In [36]:
def get_interpolated_inputs(
        baseline: np.ndarray, target: np.ndarray, num_steps=10
) -> np.ndarray:
    """Gets num_step linearly interpolated inputs from baseline to target."""
    if num_steps <= 0:
        return np.array([])
    if num_steps == 1:
        return np.array([baseline, target])

    delta = target - baseline  # <float32>[num_tokens, emb_size]
    # Creates scale values array of shape [num_steps, num_tokens, emb_dim],
    # where the values in scales[i] are the ith step from np.linspace.
    # <float32>[num_steps, 1, 1]
    scales = np.linspace(0, 1, num_steps + 1, dtype=np.float32)[
             :, np.newaxis, np.newaxis
             ]
    shape = (num_steps + 1,) + delta.shape
    # <float32>[num_steps, num_tokens, emb_size]
    deltas = scales * np.broadcast_to(delta, shape)
    interpolated_inputs = baseline + deltas
    return interpolated_inputs  # <float32>[num_steps, num_tokens, emb_size]


def estimate_integral(path_gradients: np.ndarray) -> np.ndarray:
    """Estimates the integral of the path_gradients using trapezoid rule."""
    path_gradients = (path_gradients[:-1] + path_gradients[1:]) / 2
    return tf.reduce_mean(path_gradients, axis=(0, 2))


def explain_int_grad(tokenizer, model: tf.keras.Model, input, target, model_name):
    token_ids = tokenizer([input], return_tensors='tf')["input_ids"]
    embeddings = getattr(model, model_name).get_input_embeddings()(input_ids=token_ids)[
        0
    ]
    baseline = np.zeros_like(embeddings)
    interpolated_embeddings = get_interpolated_inputs(baseline, embeddings)
    interpolated_embeddings = tf.convert_to_tensor(interpolated_embeddings)
    with tf.GradientTape() as tape:
        tape.watch(interpolated_embeddings)
        logits = model(None, inputs_embeds=interpolated_embeddings).logits
        logits_for_label = tf.gather(logits, axis=1, indices=target)
    grads = tape.gradient(logits_for_label, interpolated_embeddings)
    int_grad = estimate_integral(grads)
    stddev = tf.math.reduce_std(int_grad)
    normalized_int_grad = tf.abs(int_grad / stddev)
    return qn.TokenSalience(
        tokenizer.convert_ids_to_tokens(token_ids[0]), normalized_int_grad.numpy()
    )


def explain_int_grad_batch(model, inputs, targets, tokenizer, model_name):
    return [
        explain_int_grad(
            tokenizer=tokenizer, model=model, input=x, target=y, model_name=model_name
        ) for x, y in zip(inputs, targets)
    ]


explain_int_grad_func = partial(explain_int_grad_batch, model_name="distilbert", tokenizer=tokenizer)

a_batch_int_grad = explain_int_grad_func(model, x_batch, y_batch)

In [37]:
html = create_textual_heatmap(a_batch_int_grad[:3], y_batch[:3])
HTML(html)

#### Now we compute [Sensitivity](https://arxiv.org/abs/2005.00631) metric

In [39]:
# This is only a workaround to account for hardcoded attribute access in lib
class ModelTuple(NamedTuple):
    model: Any
    tokenizer: Any


model_stub = ModelTuple(model, tokenizer)
model_stub.model.bert = model.distilbert
model_stub.model.bert.embeddings.word_embeddings = model.distilbert.embeddings.weight

#### Average Sensitivity captures the average change in explanations under slight perturbation

In [40]:
avg_sensitivity = qn.AvgSensitivity()

avg_sensitivity_grad_norm = avg_sensitivity(
    model=model_stub,
    x_batch=x_batch,
    y_batch=y_batch,
    perturb_func=qn.change_spelling,
    explain_func=explain_gradient_norm_func,
).mean()

avg_sensitivity_int_grad = avg_sensitivity(
    model=model_stub,
    x_batch=x_batch,
    y_batch=y_batch,
    perturb_func=qn.change_spelling,
    explain_func=explain_int_grad_func
).mean()

Collecting perturbations:   0%|          | 0/10 [00:00<?, ?it/s]

Collecting explanations:   0%|          | 0/8 [00:00<?, ?it/s]

Collecting perturbations:   0%|          | 0/10 [00:00<?, ?it/s]

Collecting explanations:   0%|          | 0/9 [00:00<?, ?it/s]

#### Maximum Sensitivity captures the maximal change in explanations under slight perturbation

In [41]:
max_sensitivity = qn.MaxSensitivity()

max_sensitivity_grad_norm = max_sensitivity(
    model=model_stub,
    x_batch=x_batch,
    y_batch=y_batch,
    perturb_func=qn.change_spelling,
    explain_func=explain_gradient_norm_func,
).mean()

max_sensitivity_int_grad = max_sensitivity(
    model=model_stub,
    x_batch=x_batch,
    y_batch=y_batch,
    perturb_func=qn.change_spelling,
    explain_func=explain_int_grad_func
).mean()

Collecting perturbations:   0%|          | 0/10 [00:00<?, ?it/s]

Collecting explanations:   0%|          | 0/7 [00:00<?, ?it/s]

Collecting perturbations:   0%|          | 0/10 [00:00<?, ?it/s]

Collecting explanations:   0%|          | 0/8 [00:00<?, ?it/s]

Display results in tabular form

In [43]:
all_results = np.asarray([
    [
        avg_sensitivity_grad_norm,
        avg_sensitivity_int_grad
    ],
    [
        max_sensitivity_grad_norm,
        max_sensitivity_int_grad
    ]
])
pd.DataFrame(
    all_results,
    columns=['GradNorm', 'IntGrad'],
    index=['Average Sensitivity', 'Max Sensitivity']
)

Unnamed: 0,GradNorm,IntGrad
Average Sensitivity,0.19266,0.607145
Max Sensetivity,0.290634,0.756526
