# Positive vs. Negative Sentiment Classification

Here we demonstrate how to explain a sentiment classification model for movie reviews. positive vs. negative sentim

In [3]:
import datasets
import numpy as np
import transformers
import time
import shap

## Load the IMDB movie review dataset

In [5]:
dataset = datasets.load_dataset("imdb", split="test")

# shorten the strings to fit into the pipeline model
short_data = [v[:500] for v in dataset["text"][:20]]

## Load and run a sentiment analysis pipeline

In [7]:
classifier = transformers.pipeline("sentiment-analysis", return_all_scores=True)
classifier(short_data[:2])

No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.


[[{'label': 'NEGATIVE', 'score': 0.07582081109285355},
  {'label': 'POSITIVE', 'score': 0.924179196357727}],
 [{'label': 'NEGATIVE', 'score': 0.018342547118663788},
  {'label': 'POSITIVE', 'score': 0.9816573858261108}]]

## Explain the sentiment analysis pipeline

In [9]:
# define the explainer
explainer = shap.Explainer(classifier)

In [10]:
# explain the predictions of the pipeline on the first two samples
start_time = time.time()
shap_values = explainer(short_data[:2])
explanation_time = time.time() - start_time

PartitionExplainer explainer: 3it [13:22, 401.36s/it]              


In [11]:
explanation_time

802.7239351272583

In [12]:
shap_values

.values =
array([array([ 0.00000000e+00,  0.00000000e+00, -2.87491275e-09,  0.00000000e+00,
               0.00000000e+00,  0.00000000e+00,  0.00000000e+00, -6.28688213e-11,
              -3.71869646e-10,  0.00000000e+00, -3.71778697e-09, -1.77976744e-09,
               0.00000000e+00,  0.00000000e+00,  4.23483471e-10,  0.00000000e+00,
              -4.09750101e-09,  0.00000000e+00,  0.00000000e+00,  1.49179868e-09,
               5.64455149e-10,  0.00000000e+00,  0.00000000e+00, -4.99596808e-10,
               0.00000000e+00, -1.32672540e-09,  3.13275450e-09,  0.00000000e+00,
               2.99147966e-10,  0.00000000e+00,  2.93152880e-09,  0.00000000e+00,
              -3.15640136e-09,  0.00000000e+00,  0.00000000e+00,  5.66114977e-09,
               0.00000000e+00,  4.13820089e-09, -1.09139364e-09,  0.00000000e+00,
               0.00000000e+00,  1.93631422e-09, -7.05767889e-10,  0.00000000e+00,
              -3.11501935e-10,  5.18298293e-10,  0.00000000e+00,  0.00000000e+00,
      

In [13]:
54.178290367126465

.values =
array([array([ 0.00000000e+00,  0.00000000e+00, -2.87491275e-09,  0.00000000e+00,
               0.00000000e+00,  0.00000000e+00,  0.00000000e+00, -6.28688213e-11,
              -3.71869646e-10,  0.00000000e+00, -3.71778697e-09, -1.77976744e-09,
               0.00000000e+00,  0.00000000e+00,  4.23483471e-10,  0.00000000e+00,
              -4.09750101e-09,  0.00000000e+00,  0.00000000e+00,  1.49179868e-09,
               5.64455149e-10,  0.00000000e+00,  0.00000000e+00, -4.99596808e-10,
               0.00000000e+00, -1.32672540e-09,  3.13275450e-09,  0.00000000e+00,
               2.99147966e-10,  0.00000000e+00,  2.93152880e-09,  0.00000000e+00,
              -3.15640136e-09,  0.00000000e+00,  0.00000000e+00,  5.66114977e-09,
               0.00000000e+00,  4.13820089e-09, -1.09139364e-09,  0.00000000e+00,
               0.00000000e+00,  1.93631422e-09, -7.05767889e-10,  0.00000000e+00,
              -3.11501935e-10,  5.18298293e-10,  0.00000000e+00,  0.00000000e+00,
               0.00000000e+00, -5.13940298e-10,  0.00000000e+00,  0.00000000e+00,
               0.00000000e+00,  9.52582013e-10,  0.00000000e+00,  0.00000000e+00,
               0.00000000e+00,  2.35169344e-10, -1.00760644e-09,  1.13686838e-09,
               0.00000000e+00,  0.00000000e+00,  7.56017471e-10, -1.27988642e-09,
               0.00000000e+00,  8.73933459e-09, -6.95308700e-10,  3.10183168e-09,
               7.73297870e-10,  1.88038030e-10,  0.00000000e+00,  4.51564119e-10,
              -5.22049959e-10,  3.64252628e-10,  0.00000000e+00,  2.71938916e-10,
               2.94676283e-10, -1.58252078e-10,  0.00000000e+00, -5.87760951e-10,
               7.56926966e-10, -6.58474164e-10,  1.04819264e-09,  0.00000000e+00,
              -1.60844138e-09, -2.05182005e-09, -6.73026079e-10, -1.06410880e-09,
               2.72848394e-12, -6.01175998e-10,  8.47876436e-10, -2.41698217e-10,
               0.00000000e+00,  1.52272150e-09,  1.46474122e-09,  3.17004378e-09,
               1.18461685e-09, -1.80307325e-09, -2.38742359e-10, -3.79532139e-10,
               0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
               0.00000000e+00, -3.64570951e-09,  4.16952793e-10,  6.42424258e-10,
              -2.04426425e-10, -7.84811246e-10, -3.13593773e-09,  4.72209649e-09,
               0.00000000e+00, -2.66315207e-09,  2.89583113e-09,  8.61473382e-09,
               2.95403879e-09,  7.13043846e-10, -1.31694833e-09,  1.54977897e-09,
               1.14960130e-09,  1.42212872e-09, -2.98314262e-10, -5.17199321e-10,
               8.58135001e-11, -4.48684053e-10, -3.44597437e-11,  1.73496942e-10,
               0.00000000e+00])                                                  ,
       array([ 0.00000000e+00,  1.15251169e-08,  0.00000000e+00, -5.05679054e-09,
               3.42879503e-10,  0.00000000e+00,  0.00000000e+00,  5.98538463e-09,
               0.00000000e+00, -6.57928467e-09,  0.00000000e+00,  0.00000000e+00,
              -1.67347025e-09, -3.89627530e-09,  1.74259185e-09,  0.00000000e+00,
              -3.33784556e-10, -1.58888724e-09,  1.75896275e-09,  0.00000000e+00,
              -1.76623871e-09,  1.23691279e-10, -8.34916136e-10,  0.00000000e+00,
              -1.77124093e-09,  0.00000000e+00,  0.00000000e+00, -1.34241418e-09,
              -2.91402102e-09, -1.15051080e-09, -3.06772563e-09,  3.17413651e-10,
               0.00000000e+00,  0.00000000e+00, -7.14317139e-09,  4.99676389e-09,
              -2.73075784e-10,  0.00000000e+00,  1.56069291e-09,  1.09139364e-11,
              -8.52196536e-10, -4.51018423e-09,  0.00000000e+00,  0.00000000e+00,
              -7.26686267e-10,  1.19416654e-09, -2.10548023e-09, -1.19234755e-09,
               0.00000000e+00,  5.04587661e-09,  0.00000000e+00,  4.82714313e-09,
               0.00000000e+00, -5.30235411e-09,  4.36496824e-09,  4.82941687e-10,
               2.66572897e-09,  4.35920811e-09,  0.00000000e+00,  1.37333700e-09,
               0.00000000e+00, -6.13636075e-09, -4.22642188e-09, -2.77577783e-09,
               1.88265403e-09,  0.00000000e+00, -3.62888386e-09, -3.77743466e-10,
              -1.08723595e-09, -3.29600880e-09, -2.72029865e-09,  1.79261406e-09,
               1.47338142e-10,  3.58886609e-09, -1.98633643e-09,  0.00000000e+00,
               0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
              -4.86534191e-09,  1.71121428e-09, -4.48153514e-09, -1.51749191e-09,
              -1.40062184e-10, -1.22872734e-09,  0.00000000e+00,  2.09911377e-09,
              -8.35825631e-10, -7.97626853e-10,  0.00000000e+00,  3.22961569e-09,
               0.00000000e+00, -1.70954687e-09,  9.27684596e-11, -3.81583555e-10,
               9.49967216e-10, -1.34984172e-10,  4.96584107e-10, -2.25069622e-09,
              -1.44564183e-09,  2.71529643e-09,  1.31071179e-09,  3.45426088e-09,
               1.87583282e-10, -9.48299810e-11, -8.85847840e-10,  0.00000000e+00])],
      dtype=object)

.base_values =
array([[0.56289107, 0.4371089 ],
       [0.53882831, 0.46117169]])

.data =
(array(['', 'I ', 'love ', 'sci', '-', 'fi ', 'and ', 'am ', 'willing ',
       'to ', 'put ', 'up ', 'with ', 'a ', 'lot', '. ', 'Sci', '-',
       'fi ', 'movies', '/', 'TV ', 'are ', 'usually ', 'under', 'fu',
       'nded', ', ', 'under', '-', 'appreciated ', 'and ',
       'misunderstood', '. ', 'I ', 'tried ', 'to ', 'like ', 'this',
       ', ', 'I ', 'really ', 'did', ', ', 'but ', 'it ', 'is ', 'to ',
       'good ', 'TV ', 'sci', '-', 'fi ', 'as ', 'Babylon ', '5 ', 'is ',
       'to ', 'Star ', 'Trek ', '(', 'the ', 'original', ')', '. ',
       'Silly ', 'pro', 'st', 'hetic', 's', ', ', 'cheap ', 'cardboard ',
       'sets', ', ', 'stil', 'ted ', 'dialogues', ', ', 'C', 'G ',
       'that ', 'doesn', "'", 't ', 'match ', 'the ', 'background', ', ',
       'and ', 'painfully ', 'one', '-', 'dimensional ', 'characters ',
       'cannot ', 'be ', 'overcome ', 'with ', 'a ', "'", 'sci', '-',
       'fi', "' ", 'setting', '. ', '(', 'I', "'", 'm ', 'sure ',
       'there ', 'are ', 'those ', 'of ', 'you ', 'out ', 'there ',
       'who ', 'think ', 'Babylon ', '5 ', 'is ', 'good ', 'sci', '-',
       'fi', ''], dtype=object), array(['', 'Worth ', 'the ', 'entertainment ', 'value ', 'of ', 'a ',
       'rental', ', ', 'especially ', 'if ', 'you ', 'like ', 'action ',
       'movies', '. ', 'This ', 'one ', 'features ', 'the ', 'usual ',
       'car ', 'chases', ', ', 'fights ', 'with ', 'the ', 'great ',
       'Van ', 'Dam', 'me ', 'kick ', 'style', ', ', 'shooting ',
       'battles ', 'with ', 'the ', '40 ', 'shell ', 'load ', 'shotgun',
       ', ', 'and ', 'even ', 'terrorist ', 'style ', 'bombs', '. ',
       'All ', 'of ', 'this ', 'is ', 'entertaining ', 'and ',
       'competent', 'ly ', 'handled ', 'but ', 'there ', 'is ',
       'nothing ', 'that ', 'really ', 'blows ', 'you ', 'away ', 'if ',
       'you', "'", 've ', 'seen ', 'your ', 'share ', 'before', '.', '<',
       'br ', '/', '>', '<', 'br ', '/', '>', 'The ', 'plot ', 'is ',
       'made ', 'interesting ', 'by ', 'the ', 'inclusion ', 'of ', 'a ',
       'rabbit', ', ', 'which ', 'is ', 'clever ', 'but ', 'hardly ',
       'profound', '. ', 'Many ', 'of ', 'the ', 'c', ''], dtype=object))



SyntaxError: invalid syntax (1494382757.py, line 3)

In [None]:
shap.plots.text(shap_values[:, :, "POSITIVE"])

## Wrap the pipeline manually

SHAP requires tensor outputs from the classifier, and explanations works best in additive spaces so we transform the probabilities into logit values (information values instead of probabilites).

### Create a TransformersPipeline wrapper

In [None]:
pmodel = shap.models.TransformersPipeline(classifier, rescale_to_logits=False)

In [None]:
pmodel(short_data[:2])

In [None]:
pmodel = shap.models.TransformersPipeline(classifier, rescale_to_logits=True)
pmodel(short_data[:2])

In [None]:
explainer2 = shap.Explainer(pmodel)
shap_values2 = explainer2(short_data[:2])
shap.plots.text(shap_values2[:, :, 1])

### Pass a tokenizer as the masker object

In [None]:
explainer2 = shap.Explainer(pmodel, classifier.tokenizer)
shap_values2 = explainer2(short_data[:2])
shap.plots.text(shap_values2[:, :, 1])

### Build a Text masker explicitly

In [None]:
masker = shap.maskers.Text(classifier.tokenizer)
explainer2 = shap.Explainer(pmodel, masker)
shap_values2 = explainer2(short_data[:2])
shap.plots.text(shap_values2[:, :, 1])

## Explore how the Text masker works

In [None]:
masker.shape("I like this movie.")

In [None]:
model_args = masker(np.array([True, True, True, True, True, True, True]), "I like this movie.")
model_args

In [None]:
pmodel(*model_args)

In [None]:
model_args = masker(np.array([True, True, False, False, True, True, True]), "I like this movie.")
model_args

In [None]:
pmodel(*model_args)

In [None]:
masker2 = shap.maskers.Text(classifier.tokenizer, mask_token="...", collapse_mask_token=True)

In [None]:
model_args2 = masker2(np.array([True, True, False, False, True, True, True]), "I like this movie.")
model_args2

In [None]:
pmodel(*model_args2)

## Plot summary statistics and bar charts

In [None]:
# explain the predictions of the pipeline on the first two samples
shap_values = explainer(short_data[:20])

In [None]:
shap.plots.bar(shap_values[0, :, "POSITIVE"])

In [None]:
shap.plots.bar(shap_values[:, :, "POSITIVE"].mean(0))

In [None]:
shap.plots.bar(shap_values[:, :, "POSITIVE"].mean(0), order=shap.Explanation.argsort)