### Counterfactual explanation for text classification

An example of counterfactual explanation on sentiment analysis. The method implemented in the library is based on the Polyjuice model developed by Wu et al. Please cite the work: https://github.com/tongshuangwu/polyjuice if using this explainer.

In [1]:
import sys
import os

module_path = os.path.abspath('E:/Codes/OmniXAI/')
if module_path not in sys.path:
    sys.path.append(module_path)

# For Jupyter notebooks or interactive environments where __file__ is not defined
try:
    # Try to use __file__ if available
    directory = os.path.dirname(os.path.abspath(__file__))
except NameError:
    # If __file__ is not defined (e.g., in Jupyter), use the current working directory
    directory = os.path.abspath('')
    
sys.path.append(os.path.dirname(directory))

In [2]:
# This default renderer is used for sphinx docs only. Please delete this cell in IPython.
import plotly.io as pio
pio.renderers.default = "png"

In [3]:
import transformers
import numpy as np
from omnixai.data.text import Text
from omnixai.explainers.nlp.counterfactual.polyjuice import Polyjuice

In [5]:
# A transformer model for sentiment analysis
model = transformers.pipeline(
    'sentiment-analysis',
    model='distilbert-base-uncased-finetuned-sst-2-english',
    return_all_scores=True
)
idx2label = {"NEGATIVE": 0, "POSITIVE": 1}

Device set to use cpu
`return_all_scores` is now deprecated,  if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.


In [6]:
# Build the prediction function, e.g., the outputs of the prediction function are the class probabilities.
def _predict(x):
    scores = []
    predictions = model(x.values)
    for pred in predictions:
        score = [0.0, 0.0]
        for d in pred:
            if d['label'] == 'NEGATIVE':
                score[0] = d['score']
            else:
                score[1] = d['score']
        scores.append(score)
    return np.array(scores)

In [9]:
# Initialize the counterfactual explainer based on Polyjuice
explainer = Polyjuice(predict_function=_predict)

ModuleNotFoundError: No module named 'polyjuice_nlp'

In [8]:
x = Text([
    "What a great movie! if you have no taste.",
    "it was a fantastic performance!",
    "best film ever",
    "such a great show!",
    "it was a horrible movie",
    "i've never watched something as bad"
])
explanations = explainer.explain(x)
explanations.ipython_plot()

NameError: name 'explainer' is not defined