### L2X (learning to explain) for text classification

This is an example of the L2X explainer on text classification. Different from gradient-based methods, L2X trains a separate explanation model. The advantage of L2X is that it generates explanations fast after the explanation model is trained. The disadvantage is that the quality of the explanations highly depend on the trained explanation model, which can be affected by multiple factors, e.g., the network structure of the explanation model, the training hyperparameters.

For text classification, we implement the default CNN-based explanation model in `omnixai.explainers.nlp.agnostic.l2x`. One may implement other models by following the same interface (please refer to the docs for more details). If using this explainer, please cite the original work: "Learning to Explain: An Information-Theoretic Perspective on Model Interpretation, Jianbo Chen, Le Song, Martin J. Wainwright, Michael I. Jordan, https://arxiv.org/abs/1802.07814".

In [1]:
import numpy as np
import sklearn.ensemble
from sklearn.datasets import fetch_20newsgroups

from omnixai.data.text import Text
from omnixai.preprocessing.text import Tfidf
from omnixai.explainers.nlp.agnostic.l2x import L2XText

We use a `Text` object to represent a batch of texts/sentences. The package `omnixai.preprocessing.text` provides some transforms related to text data such as `Tfidf`.

In [2]:
# Load the training and text datasets
categories = ['alt.atheism', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)

x_train = Text(newsgroups_train.data)
y_train = newsgroups_train.target
x_test = Text(newsgroups_test.data)
y_test = newsgroups_test.target
class_names = ['atheism', 'christian']
# A TFDIF transform
transform = Tfidf().fit(x_train)

For this classification task, we train a random forest classifier with TF-IDF feature vectors.

In [3]:
train_vectors = transform.transform(x_train)
test_vectors = transform.transform(x_test)
model = sklearn.ensemble.RandomForestClassifier(n_estimators=500)
model.fit(train_vectors, y_train)
predict_function = lambda x: model.predict_proba(transform.transform(x))

predictions = model.predict(test_vectors)
print('Test accuracy: {}'.format(
    sklearn.metrics.f1_score(y_test, predictions, average='binary')))

Test accuracy: 0.925233644859813


To initialize `L2XText`, we need to set the following parameters:

  - `training_data`: The data used to train the explainer. `training_data` should be the training dataset for training the machine learning model.
  - `predict_function`: The prediction function corresponding to the model to explain. When the model is for classification, the outputs of the `predict_function` are the class probabilities. When the model is for regression, the outputs of the `predict_function` are the estimated values.
  - `mode`: The task type, e.g., `classification` or `regression`.
  - `selection_model`: A pytorch model class for estimating P(S|X) in L2X. If `selection_model = None`, a default model `DefaultSelectionModel` will be used.
  - `prediction_model`: A pytorch model class for estimating Q(X_S) in L2X. If `prediction_model = None`, a default model `DefaultPredictionModel` will be used.

In [4]:
idx = 83
explainer = L2XText(
    training_data=x_train,
    predict_function=predict_function
)
explanations = explainer.explain(x_test[idx:idx+9])
explanations.ipython_plot(class_names=class_names)

 |████████████████████████████████████████| 100.0% Complete, Loss 0.0039
L2X prediction model accuracy: 0.8674698795180723
