# Text to Text  Explanation: Abstractive Summarization Example

This notebook demonstrates use of generating model explanations for a text to text scenario on a pretrained transformer model. Below we demonstrate the process of generating explanations for a pretrained model distilbart on the Extreme Summarization (XSum) Dataset provided by hugging face (https://huggingface.co/sshleifer/distilbart-xsum-12-6). 

The first example only needs the model and tokenizer and we use the model decoder to generate log odds of the output tokens to be explained. In the second example, we demonstrate the use of how to generate expplanations for model in the form of an api/fucntion (input->text and output->text). In this case we need to approximate the log odds by using a text similarity model. The underlying explainer used to compute the shap values is the partition explainer.

In [1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import nlp
import shap

Failure while loading azureml_run_type_providers. Failed to load entrypoint hyperdrive = azureml.train.hyperdrive:HyperDriveRun._from_run_dto with exception (pyOpenSSL 20.0.1 (c:\users\v-maxtell\miniconda3\envs\interpret_cpu\lib\site-packages), Requirement.parse('pyopenssl<20.0.0'), {'azureml-core'}).
Failure while loading azureml_run_type_providers. Failed to load entrypoint automl = azureml.train.automl.run:AutoMLRun._from_run_dto with exception (pyOpenSSL 20.0.1 (c:\users\v-maxtell\miniconda3\envs\interpret_cpu\lib\site-packages), Requirement.parse('pyopenssl<20.0.0'), {'azureml-core'}).
Failure while loading azureml_run_type_providers. Failed to load entrypoint azureml.PipelineRun = azureml.pipeline.core.run:PipelineRun._from_dto with exception (pyOpenSSL 20.0.1 (c:\users\v-maxtell\miniconda3\envs\interpret_cpu\lib\site-packages), Requirement.parse('pyopenssl<20.0.0'), {'azureml-core'}).
Failure while loading azureml_run_type_providers. Failed to load entrypoint azureml.ReusedStepR

### Load model and tokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6")
model =  AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-12-6")

### Load data

In [3]:
dataset = nlp.load_dataset('xsum',split='train')

Using custom data configuration default


In [7]:
print(dataset)

Dataset(features: {'document': Value(dtype='string', id=None), 'summary': Value(dtype='string', id=None)}, num_rows: 204017)


In [4]:
# slice inputs from dataset to run model inference on
s = dataset['document'][0:1]

### Create an explainer object

In [5]:
explainer = shap.Explainer(model,tokenizer)

In [6]:
print(s)

['The problem is affecting people using the older versions of the PlayStation 3, called the "Fat" model.The problem isn\'t affecting the newer PS3 Slim systems that have been on sale since September last year.Sony have also said they are aiming to have the problem fixed shortly but is advising some users to avoid using their console for the time being."We hope to resolve this problem within the next 24 hours," a statement reads. "In the meantime, if you have a model other than the new slim PS3, we advise that you do not use your PS3 system, as doing so may result in errors in some functionality, such as recording obtained trophies, and not being able to restore certain data."We believe we have identified that this problem is being caused by a bug in the clock functionality incorporated in the system."The PlayStation Network is used by millions of people around the world.It allows users to play their friends at games like Fifa over the internet and also do things like download software 

### Compute shap values

In [6]:
shap_values = explainer(s)

Partition explainer: 2it [00:20, 20.83s/it]               


### Visualize shap explanations

In [7]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,"The problem is affecting people using the older versions of the PlayStation 3,","called the ""Fat"" model.",The problem isn't affecting the newer PS,3 Slim systems that have been on sale since September last year.,Sony have also said they are aiming to have the problem fixed shortly but is advising,"some users to avoid using their console for the time being.""We hope to resolve this problem within the next 24 hours,"" a statement reads.","""In the meantime,","if you have a model other than the new slim PS3,","we advise that you do not use your PS3 system,","as doing so may result in errors in some functionality, such as recording obtained trophies,","and not being able to restore certain data.""We believe we have identified that this","problem is being caused by a bug in the clock functionality incorporated in the system.""The PlayStation Network is used by millions of people around the world.",It allows users to play their friends at games like Fifa over the internet and also do things like download software or visit online stores.
Sony,1.026,0.069,0.526,0.43,1.786,1.413,0.053,0.637,0.42,0.115,0.412,1.495,0.459
has,0.387,0.084,0.239,0.403,0.169,0.404,0.16,0.475,0.31,0.306,0.181,0.211,0.475
said,0.332,-0.178,0.175,-0.054,0.571,0.573,0.247,0.253,-0.117,-0.148,0.527,0.568,-0.408
that,0.352,0.19,0.196,0.343,-0.325,-0.232,0.003,-0.035,0.16,0.234,0.258,0.292,0.01
a,0.345,0.131,0.166,-0.091,0.303,0.691,0.004,0.087,0.317,-0.161,0.346,1.093,-0.097
bug,1.163,-0.133,0.628,0.012,0.141,0.377,0.034,-0.343,0.211,0.712,1.87,3.104,-0.279
in,0.098,0.01,0.171,0.126,0.108,0.141,0.151,0.294,0.196,0.59,0.35,1.029,0.122
its,-0.344,-0.352,-0.204,-0.21,-0.202,-0.03,0.031,-0.165,-0.188,-0.038,0.175,0.767,0.09
PlayStation,1.446,-0.09,0.733,0.268,1.283,1.372,0.13,0.94,0.943,0.115,0.302,3.129,0.524
3,0.983,0.42,0.662,0.758,0.719,0.24,0.115,1.176,1.764,0.435,-0.21,-1.162,-0.565
console,0.163,0.063,0.169,0.009,0.189,1.849,0.025,0.2,0.324,-0.021,-0.205,0.081,-0.354
is,0.275,0.115,0.083,-0.018,0.156,0.08,0.045,-0.098,-0.065,-0.177,0.193,0.791,-0.075
causing,0.31,0.05,0.296,-0.035,0.376,0.57,-0.059,-0.077,0.325,0.284,0.702,1.192,-0.195
some,0.231,0.041,0.168,0.196,0.452,0.385,-0.005,0.217,0.129,0.392,0.552,0.334,0.006
users,0.455,0.048,-0.128,-0.058,0.162,0.683,-0.132,-0.246,-0.273,-0.081,0.679,0.601,0.876
to,0.085,0.091,0.002,0.024,0.01,0.193,0.03,0.044,-0.088,0.015,-0.096,0.028,0.075
lose,0.167,0.041,0.29,0.223,0.75,-0.139,-0.122,-0.025,0.113,-0.367,0.919,0.635,0.122
access,0.646,-0.18,0.343,0.163,0.243,0.937,0.195,0.032,0.308,0.164,1.139,0.875,0.355
to,0.006,-0.001,0.021,0.013,0.063,0.015,0.003,0.005,-0.016,-0.011,-0.053,0.026,-0.021
the,-0.449,-0.233,-0.153,-0.123,-0.007,0.392,0.099,0.151,-0.015,0.038,-0.914,0.412,0.539
PlayStation,-0.05,-0.109,-0.13,-0.266,0.068,0.237,0.124,0.567,0.285,0.236,0.079,3.452,0.89
Network,0.34,-0.007,0.041,0.178,-0.47,-0.026,-0.011,0.142,0.101,0.29,1.087,4.923,1.5
.,0.03,0.047,-0.008,0.125,0.151,0.188,-0.135,-0.059,-0.041,-0.188,0.368,0.711,0.128


### API

Below we demonstrate generating explanations for a model which is an api/function. Since this is a model agnostic case, we use a text similarity model to approximate log odds of generating output text which is used for computing shap explanations.

In [8]:
# Define function
def f(x):
    input_ids = torch.tensor([tokenizer.encode(x)])
    with torch.no_grad():
        out=model.generate(input_ids)
    sentence = [tokenizer.decode(g, skip_special_tokens=True) for g in out][0]
    return sentence

For a model agnostic case, we wrap the model to be explained with the shal.models.TeacherForcingLogits class and define the text similarity model and tokenizer. The TeacherForcingLogits class uses the similarity model to approximate the log odds of generating the output text from the model(function->f)

We also have to create a Text masker and define mask_token="..." and pass collapse_mask_token=True, which then cues the algorithm to use text infilling while masking

In [9]:
# wrap model with TeacherForcingLogits class
wrapped_model = shap.models.TeacherForcingLogits(f, similarity_model=model, similarity_tokenizer=tokenizer)
# create a Text masker
masker = shap.maskers.Text(tokenizer, mask_token = "...", collapse_mask_token=True)

### Create an explainer object using wrapped model and Text masker

In [10]:
explainer_model_agnostic = shap.Explainer(wrapped_model,masker)

### Compute shap values

In [11]:
shap_values_model_agnostic = explainer_model_agnostic(s)

Partition explainer: 2it [01:43, 103.91s/it]              


### Visualize shap explanations

In [12]:
shap.plots.text(shap_values_model_agnostic)

Unnamed: 0_level_0,"The problem is affecting people using the older versions of the PlayStation 3,","called the ""Fat"" model.",The problem isn't affecting the newer PS,3 Slim systems that have been on sale since September last year.,Sony have also said they are aiming to,have the problem fixed shortly but is advising,"some users to avoid using their console for the time being.""We hope to resolve this problem within the next 24 hours,"" a statement reads.","""In the meantime, if you have a model other than the new slim PS3,","we advise that you do not use your PS3 system,","as doing so may result in errors in some functionality, such as recording obtained trophies,","and not being able to restore certain data.""We believe we have identified that this problem is being caused by a bug in the clock functionality incorporated in the system.""The PlayStation Network is used by millions of people around the world.",It allows users to play their friends at games like Fifa over the internet and also do things like download software or visit online stores.
Sony,1.016,0.434,0.488,0.256,1.017,0.144,-0.512,0.965,0.665,0.006,1.541,0.636
has,0.203,0.06,0.343,-0.036,0.306,0.169,0.262,0.297,0.334,0.036,0.672,0.173
said,0.391,-0.088,0.195,0.044,0.19,0.079,0.357,0.377,0.005,-0.062,0.603,-0.121
that,0.525,0.331,0.18,0.171,0.392,0.582,0.009,0.156,0.134,0.098,0.553,0.051
a,0.245,-0.01,0.147,0.016,0.011,0.08,-0.038,0.423,-0.002,-0.097,0.759,0.187
bug,0.893,0.025,0.517,-0.072,0.266,0.779,0.114,0.579,0.163,0.625,4.494,-0.025
in,0.449,0.095,0.034,0.056,0.772,0.438,0.249,0.37,-0.088,0.018,1.119,-0.045
its,-0.525,-0.175,-0.223,-0.133,-0.588,-0.763,-0.176,0.311,-0.072,0.095,0.592,-0.254
PlayStation,1.131,0.243,0.91,1.008,0.262,-0.171,0.426,1.03,0.969,0.544,2.538,-0.048
3,0.564,0.457,0.389,0.821,-0.216,-0.186,-0.149,0.568,0.924,0.568,1.583,-1.148
console,-0.094,-0.06,0.233,-0.181,0.058,0.164,1.421,0.728,0.385,0.201,0.05,-0.284
is,0.135,0.101,0.053,-0.016,0.204,0.264,0.382,-0.264,-0.157,-0.255,0.249,-0.186
causing,0.524,0.083,0.275,0.071,0.366,0.487,0.163,0.116,-0.114,-0.04,0.888,-0.236
some,0.3,0.146,0.117,0.019,0.365,0.39,0.425,-0.03,-0.068,-0.068,1.396,0.037
users,0.414,0.189,0.055,0.179,0.002,0.142,0.649,0.162,-0.388,-0.127,1.274,1.176
to,-0.069,-0.026,0.023,0.045,0.154,0.057,0.221,-0.021,-0.043,-0.069,-0.064,0.18
lose,0.494,-0.041,0.039,-0.06,1.074,1.183,-0.203,0.006,0.012,-0.101,1.334,0.283
access,0.445,-0.029,0.582,0.478,0.016,-0.206,-0.075,0.582,0.268,0.809,3.008,0.677
to,-0.015,-0.016,0.015,0.001,-0.024,-0.065,0.052,0.008,-0.004,-0.0,0.037,-0.026
the,0.219,0.324,-0.088,0.281,0.073,0.038,0.46,0.405,0.065,0.017,-1.029,0.3
PlayStation,0.182,0.182,-0.005,0.281,-0.161,-0.252,-0.38,0.442,0.454,0.428,-0.104,0.282
Network,0.376,0.293,0.319,0.475,0.577,0.537,0.828,0.695,0.283,0.591,1.842,0.422
.,-0.054,0.165,0.191,0.193,0.228,0.122,0.64,0.401,0.131,-0.08,-0.483,0.265
