# Text to Text Explanation: Open Ended Text Generation Using GPT2

This notebook demonstrates use of generating model explanations for open ended text generation using gpt2. In this demo, we use the pretrained gpt2 model provided by hugging face (https://huggingface.co/gpt2) to explain the model used to generate text based on passing custom model generation configurations on an intial provided text.

In [1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import shap
import torch

### Load model and tokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model =  AutoModelForCausalLM.from_pretrained("gpt2").cuda()

Below, we set certain model configurations. We need to define if the model is a decoder or encoder-decoder.
This can be set through the 'is_decoder' or 'is_encoder_decoder' param in model's config file.
We can also set custom model generation parameters which will be used during the output text generation decoding process.

In [3]:
# set model decoder to true and generation params
model.config.is_decoder=True
model.config.text_generation_params = {
    "do_sample": True,
    "max_length": 50,
    "temperature": 0.7,
    "top_k": 0
  }

### Define initial text

In [4]:
s=["I enjoy walking with my cute dog"]

### Create an explainer object

In [5]:
explainer = shap.Explainer(model,tokenizer)

explainers.Partition is still in an alpha state, so use with caution...


### Compute shap values

In [6]:
shap_values = explainer(s)

Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence


### Visualize shap explanations

In [7]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,I,enjoy,walking,with,my,cute,dog
and,-0.22,0.341,0.243,0.018,-0.384,1.097,1.073
watching,0.079,0.723,3.184,0.372,0.503,0.334,-0.201
the,0.228,-0.284,0.241,-0.085,0.055,-0.423,0.194
Heroes,0.534,-0.023,-0.16,0.085,-0.197,-0.358,-0.631
of,-0.023,0.385,-0.028,0.117,0.089,-0.071,0.1
the,0.187,0.108,0.124,0.05,-0.01,-0.02,0.064
Storm,0.004,-0.017,-0.092,0.078,-0.099,0.444,-0.049
stream,0.188,-0.187,0.348,-0.055,0.12,0.139,0.748
after,-0.242,-0.312,-0.111,0.128,0.106,0.155,-0.027
work,0.967,0.463,0.41,0.068,0.246,0.198,0.619
.,-0.115,0.182,0.147,0.235,-0.033,0.113,0.135
Ċ,-0.539,0.062,-0.079,0.12,-0.363,0.13,0.044
Ċ,-0.02,0.241,-0.012,0.147,-0.579,-0.458,0.253
We,0.11,-0.42,0.025,0.225,-0.559,-0.076,-0.018
share,0.354,0.573,0.252,0.209,0.686,0.251,0.201
a,0.335,-0.45,0.275,0.176,0.383,0.019,0.096
lot,-0.077,0.357,-0.369,-0.037,-0.648,-0.218,-0.427
of,0.116,0.124,0.012,-0.028,-0.151,-0.028,0.035
the,0.006,-0.246,0.002,-0.064,-0.176,-0.177,-0.122
same,0.587,-0.456,0.387,0.174,0.731,0.097,0.101
interests,0.414,0.168,0.232,0.113,0.542,0.396,0.309
.,0.237,-0.119,0.09,-0.085,0.234,-0.008,-0.005
We,-0.154,0.088,-0.033,0.001,-0.314,-0.058,-0.06
spend,-0.025,0.119,0.171,0.026,0.184,0.278,0.088
a,0.026,-0.056,-0.001,-0.017,-0.099,-0.102,-0.078
lot,0.166,-0.055,-0.017,0.059,0.024,-0.029,-0.08
of,0.03,-0.064,0.146,-0.002,0.007,0.031,0.121
time,0.158,0.068,0.103,0.071,0.077,0.028,-0.039
together,0.136,0.105,0.179,0.101,0.492,0.331,0.175
.,0.283,-0.108,0.104,-0.104,0.158,-0.051,-0.033
We,-0.216,-0.014,-0.06,0.016,-0.141,-0.031,-0.009
like,0.078,0.199,-0.01,-0.022,0.041,0.14,0.081
to,-0.005,0.041,-0.008,0.024,-0.002,-0.028,0.057
be,-0.034,-0.053,0.028,0.014,-0.067,-0.032,0.022
together,-0.075,0.04,-0.052,0.009,0.075,0.096,-0.012
.,0.007,-0.093,0.042,-0.07,0.025,-0.078,-0.062
And,0.016,-0.145,-0.001,0.015,-0.093,-0.064,-0.22
we,-0.19,0.098,-0.034,-0.005,-0.193,-0.017,-0.05
enjoy,0.028,0.345,-0.047,-0.001,0.018,0.114,-0.002
each,-0.202,-0.583,0.054,-0.034,0.046,0.088,0.029
other,0.03,0.005,0.074,0.023,-0.057,-0.035,0.029
's,0.044,-0.38,0.166,0.044,0.076,-0.027,0.038
company,0.17,-0.727,0.374,0.053,0.086,-0.067,0.265


### Another example...

In [8]:
s=['Scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth']

In [9]:
explainer = shap.Explainer(model,tokenizer)

explainers.Partition is still in an alpha state, so use with caution...


In [10]:
shap_values = explainer(s)

Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence


In [11]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,Scientists confirmed the worst,possible outcome,: the,massive,asteroid,will,collide,with,Earth
in,0.478,-0.183,-0.648,0.063,0.251,0.135,0.979,0.489,1.436
a,-0.178,0.295,-0.174,-0.132,-0.121,0.106,0.885,-0.168,0.31
massive,0.767,0.119,-0.191,0.655,0.688,0.222,2.157,0.214,0.009
collision,0.896,0.192,0.078,-0.053,1.062,0.443,3.414,0.001,0.413
that,0.693,-0.016,-0.039,-0.047,-0.09,-0.285,1.08,0.409,-0.912
will,-1.035,0.446,-0.263,0.197,0.338,1.81,0.622,0.068,0.662
cause,-0.247,0.087,0.106,-0.055,0.044,0.305,0.531,0.147,0.179
widespread,1.294,0.507,-0.143,-0.061,0.12,-0.146,0.009,-0.079,-0.016
devastation,0.151,-0.087,-0.219,0.365,0.59,-0.013,0.038,-0.048,0.251
.,-0.725,0.237,0.382,-0.097,-0.183,0.052,0.086,0.264,0.312
Ċ,0.73,0.353,-0.034,0.085,0.045,-0.1,0.069,0.133,-0.164
Ċ,1.374,0.596,-0.577,-0.125,-0.023,-0.378,0.183,0.095,0.015
The,0.576,0.11,-0.055,-0.007,-0.0,-0.091,0.09,0.04,-0.04
news,0.157,0.117,0.186,0.192,0.415,0.089,-0.019,0.085,0.121
comes,0.228,-0.078,-0.256,-0.059,-0.003,0.023,0.167,0.098,-0.045
after,-0.028,-0.023,0.037,-0.038,-0.057,-0.001,-0.011,0.055,-0.068
a,0.164,-0.003,0.005,0.016,0.029,-0.0,0.069,-0.053,0.036
team,0.732,0.213,0.091,0.155,0.261,-0.08,0.08,-0.029,0.177
of,-0.24,-0.075,0.035,0.029,0.017,-0.001,0.009,0.008,-0.138
international,-0.001,0.172,0.018,0.104,0.157,0.025,-0.064,-0.058,-0.029
scientists,0.791,0.058,0.065,0.005,0.126,0.036,-0.004,-0.052,0.777
published,-0.089,0.2,0.19,0.048,0.106,0.047,0.043,-0.033,0.383
a,-0.296,0.017,0.047,0.022,0.028,0.03,0.02,-0.0,-0.007
preliminary,0.53,0.12,-0.047,0.028,0.083,-0.002,0.042,-0.049,-0.063
report,0.103,0.031,-0.009,-0.053,-0.049,0.018,-0.022,-0.008,-0.064
on,0.059,-0.046,0.014,-0.009,-0.01,0.004,0.0,0.017,-0.013
Wednesday,0.345,0.092,-0.004,-0.139,-0.176,-0.028,-0.072,-0.01,-0.246
that,0.013,-0.004,0.042,0.022,0.037,-0.019,0.009,-0.017,0.023
described,0.207,0.08,-0.022,0.08,0.128,0.006,0.098,-0.02,-0.05
a,-0.044,-0.049,-0.074,0.009,0.018,0.038,0.056,-0.059,-0.044
possible,-0.535,0.515,0.175,0.08,0.279,0.29,-0.011,0.095,0.268
',-0.147,-0.177,-0.026,-0.002,-0.03,-0.089,-0.125,0.005,0.074
super,-0.06,-0.182,-0.067,0.082,0.102,0.009,0.132,-0.021,0.184
asteroid,-0.208,0.071,-0.012,1.207,2.025,0.196,0.023,0.099,0.241
impact,0.19,0.241,-0.032,-0.187,-0.3,0.003,0.125,0.041,-0.091
',0.057,-0.059,-0.011,-0.071,-0.06,-0.075,-0.076,-0.032,-0.016


##  Custom text generation

Below we demonstrate how to explain the liklihood of generating a particular output sentence given an input sentence using the model. For this, we define an input-output sentence pair and a custom text generation function which returns the target sentence ids of the output sentence of choice to be generated by the model

In [12]:
sentence_pairs = {
    "I know many people who are Russian." : "They love their vodka!"
}

In [13]:
def generate_target_sentence_ids_for_output(x):
    # we get the target sentence by a dictionary lookup from the previously definned sentence pairs
    target_sentence = sentence_pairs[x]
    target_sentence_ids = torch.tensor([tokenizer.encode(target_sentence)])
    return target_sentence_ids

In [14]:
wrapped_model = shap.models.PTTeacherForcingLogits(model, tokenizer, generation_function_for_target_sentence_ids=generate_target_sentence_ids_for_output)

In [15]:
explainer = shap.Explainer(wrapped_model,tokenizer)

explainers.Partition is still in an alpha state, so use with caution...


In [16]:
s = list(sentence_pairs.keys())

In [17]:
shap_values = explainer(s)

In [18]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,I,know,many,people,who,are,Russian,.
They,-1.954,-1.215,-1.509,-1.876,-0.569,-1.412,1.81,4.452
love,0.64,0.145,0.173,0.079,-0.569,0.202,0.252,0.104
their,0.014,-0.11,0.226,0.187,0.05,0.197,0.067,-0.307
vodka,-0.096,-0.118,-0.153,-0.034,0.028,0.108,2.477,-0.389
!,-0.099,-0.126,-0.164,-0.467,-0.298,0.06,0.032,-0.145
