Quantus supports text-classification models out of the box.
In this tutorial we show how you can use Quantus with your custom model.

In [13]:
from typing import NamedTuple, List
import logging

import numpy as np
import tensorflow as tf
from tensorflow import keras

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers import normalizers
from tokenizers.normalizers import NFD, Lowercase, StripAccents
from tokenizers.models import WordPiece
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordPieceTrainer
from tokenizers.processors import TemplateProcessing
from keras_nlp.layers import FNetEncoder, TokenAndPositionEmbedding

from helpers.model.text_classifier import Tokenizable, TextClassifier, EmbeddingsCallable
from quantus.helpers.tf_utils import is_xla_compatible_platform
from quantus.helpers.types import Explanation

logging.getLogger().setLevel(logging.ERROR)
tf.get_logger().setLevel(logging.ERROR)
tf.config.list_logical_devices()

[LogicalDevice(name='/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/device:GPU:0', device_type='GPU')]

## 1) Preliminaries.
### 1.1 Load SST2 dataset.

In [14]:
dataset = load_dataset("sst2")

X_train = dataset["train"]["sentence"]
X_val = dataset["validation"]["sentence"]

Y_train = dataset["train"]["label"]
Y_val = dataset["validation"]["label"]



  0%|          | 0/3 [00:00<?, ?it/s]

### 1.2. Train Wordpiece tokenizer.
More about wordpiece algorithm [here](https://huggingface.co/course/chapter6/6?fw=pt).
More about tokenizers library [here](https://huggingface.co/docs/tokenizers/index).

In [15]:
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
trainer = WordPieceTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], vocab_size=10_000)

# Input pre-processing.
tokenizer.pre_tokenizer = Whitespace()
tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])

# Append classification tokens.
tokenizer.post_processor = TemplateProcessing(
    single="[CLS] $A [SEP]",
    pair="[CLS] $A [SEP] $B:1 [SEP]:1",
    special_tokens=[
        ("[CLS]", 1),
        ("[SEP]", 2),
    ],
)

# Train on all data.
tokenizer.train_from_iterator(X_train + X_val, trainer)






### 1.3 Encode text.

In [16]:
# Configure for usage.
tokenizer.enable_padding()
tokenizer.enable_truncation(max_length=30)
# Encode inputs.
X_train_encoded = [i.ids for i in tokenizer.encode_batch(X_train)]
X_val_encoded = [i.ids for i in tokenizer.encode_batch(X_val)]

### 1.4. Convert to TF dataset.

In [17]:
train_ds = tf.data.Dataset.from_tensor_slices(
    (X_train_encoded, Y_train)
).shuffle(100).batch(2048, drop_remainder=True).cache().prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices(
    (X_val_encoded, Y_val)
).shuffle(100).batch(2048, drop_remainder=True).cache().prefetch(tf.data.AUTOTUNE)

### 1.5. Define model
We use simple and lightweight FNet architecture. More about it [here](https://arxiv.org/abs/2105.03824).
The implementation is based on [keras_nlp](https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/layers/f_net_encoder.py).

In [18]:
class FNetConfig(NamedTuple):
    embedding_dim = 128
    intermediate_dim = 256
    num_encoder_blocks = 3
    max_sequence_length = 30
    vocab_size = 10_000
    num_labels = 20



class FNetClassifier(keras.Model):

    def __init__(self, config: FNetConfig):
        super().__init__()


    def call(self, inputs, training=None, mask=None):
        pass


def fnet_classifier(config: FNetConfig):
    input_ids = keras.Input(shape=(None,), dtype=tf.int64, name="input_ids")
    x = TokenAndPositionEmbedding(
        vocabulary_size=config.vocab_size,
        sequence_length=config.max_sequence_length,
        embedding_dim=config.embedding_dim,
    )(input_ids)

    for _ in range(config.num_encoder_blocks):
        x = FNetEncoder(intermediate_dim=config.intermediate_dim)(inputs=x)

    x = keras.layers.GlobalAveragePooling1D()(x)
    x = keras.layers.Dropout(0.1)(x)
    outputs = keras.layers.Dense(config.num_labels, activation="softmax")(x)
    fnet_model = keras.Model(input_ids, outputs, name="fnet_classifier")
    return fnet_model


model = fnet_classifier(FNetConfig())

### 1.6 Train model.

In [19]:
use_xla = is_xla_compatible_model(model)


model.compile(
    optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=5e-4),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
    jit_compile=use_xla,
)

model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5,
)
model.summary()

Epoch 1/5
Epoch 2/5

KeyboardInterrupt: 

### 1.7 Create adapter for Quantus.

In [None]:
from quantus.helpers.model.text_classifier import Tokenizable, R
from quantus.helpers.model.text_classifier import TextClassifier
from quantus.helpers.model.tf_model import TFModelWrapper
from quantus.helpers.model.tf_model_randomizer import TFNestedModelRandomizer

class TokenizerAdapter(Tokenizable):

    def __init__(self, tokenizer: Tokenizer):
        self.tokenizer = tokenizer
        self.tokenizer.enable_truncation(max_length=30)
        self.tokenizer.enable_padding()

    def batch_encode(self, text: List[str], **kwargs) -> R:
        return {"input_ids": np.asarray([i.id for i in self.tokenizer.encode_batch(text)])}

    def convert_ids_to_tokens(self, ids: np.ndarray) -> List[str]:
        return [self.tokenizer.id_to_token(i) for i in ids]

    def token_id(self, token: str) -> int:
        return self.tokenizer.token_to_id(token)

    def batch_decode(self, ids: np.ndarray, **kwargs) -> List[str]:
        return self.batch_decode(ids)


class FNetAdapter(
    TextClassifier,
    # Inherit from TFNestedModelRandomizer to support Model Parameter Randomisation.
    TFNestedModelRandomizer,
    TFModelWrapper
):
    # If you want to measure Relative Representation Stability, implement HiddenRepresentationsModel.
    # You can try to inherit TFHiddenRepresentationsModel, but from our experience it does not work with custom models and layers.
    # If you want to use gradient-based XAI methods, or latent-space perturbations for robustness metrics implement EmbeddingsCallable.


    def __init__(self, fnet: keras.Model, tokenizer: TokenizerAdapter):
        self._tokenizer = tokenizer
        self.model = fnet


    def predict(self, text: List[str], **kwargs) -> np.ndarray:
        ids, _ = self.tokenizer.get_input_ids(text)
        return self.model.predict(ids)


    def embedding_lookup(self, input_ids):
        return self.model.get_layer(name="token_embedding_0").token_embedding(input_ids)

    @property
    def tokenizer(self) -> Tokenizable:
        return self._tokenizer

### 1.8 Define custom explanation function.
A basic signature looks follow way:

```python

from typing import List, Protocol
from numpy.typing import ArrayLike
from quantus.helpers.model.text_classifier import TextClassifier


class ExplainFn(Protocol):

    def __call__(
    self,
    model: TextClassifier,
    x_batch: List[str],
    y_batch: ArrayLike,
    **kwargs
) -> List[Tuple[List[str], ArrayLike]]: ...
```

We will implement counterfactual explanations, using ....

In [None]:
def explain_counterfactual(model: FNetAdapter, x_batch: List[str], y_batch: np.ndarray) -> List[Explanation]:
    pass

## 4) Quantitative evaluation with Quantus.

In [None]:
# TODO

## 5) Results visualization.