# Text to Multiclass Explanation: Language Modelling Example

This notebook demostrates how to get explanations for the top-k next words generated by a language model. In this demo, we use the pretrained gpt2 model provided by hugging face (https://huggingface.co/gpt2) to predict the top-k next words. By looking at the top-k next words, we treat them as k separate classes and then learn the explanations for each of this k words. We thereby are able to explain the contribution of words in the input that are responsible for the liklihood of the top-k next words to be predicted. 

In [1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import shap
import torch

### Load model and tokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=True)
model = AutoModelForCausalLM.from_pretrained('gpt2').cuda()

We next wrap the model with the GenerateTopKLM model which extracts the log odds of the top-k next words and also create a Text masker by initializing it with the mask_token = "..." and set collapse_mask_token = True, which is used for infilling text during perturbation of the inputs.

In [3]:
wrapped_model = shap.models.GenerateTopKLM(model, tokenizer, k=100)
masker = shap.maskers.Text(tokenizer, mask_token = "...", collapse_mask_token=True)

### Define data

Here we set the initial text for which we want the gpt2 model to predict the next word

In [4]:
s = ["In a shocking finding, scientists discovered a herd of unicorns living in a"]

### Create explainer object

In [5]:
explainer = shap.Explainer(wrapped_model,masker)

explainers.Partition is still in an alpha state, so use with caution...


### Compute SHAP values

In [6]:
shap_values = explainer(s)

### Visualize the SHAP values across the input sentence for the top-k next words

We can now see the top-k next words predicted by gpt2 under "Output Text" in the viz plot below and hover over each of the token to understand which words in the input sentence are driving the generation of the particular output word to be predicted

In [7]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,In a,shocking,finding,",",scientists discovered,a,herd,of unicorns living,in,a
cave,0.263,-0.035,-0.086,-0.628,0.538,0.227,0.022,1.809,2.772,4.778
forest,0.407,-0.262,-0.234,-0.124,-0.13,0.363,0.456,1.309,2.311,4.953
small,0.143,-0.004,0.134,-0.387,-0.45,0.301,0.126,0.216,0.927,4.738
desert,0.002,-0.053,0.121,-0.449,0.429,0.373,0.957,1.048,2.298,3.785
tiny,0.088,0.16,0.178,-0.338,0.302,0.134,-0.015,0.784,0.558,4.962
"""",-0.432,0.148,0.126,0.286,0.421,-0.071,-0.26,0.309,0.19,0.111
remote,0.163,-0.222,-0.158,-0.079,0.597,0.564,0.137,0.305,1.968,4.032
zoo,0.565,0.47,0.369,-0.311,0.245,0.317,0.612,1.552,1.037,3.856
tree,0.203,-0.423,-0.372,-0.238,0.2,0.224,-0.065,2.012,1.078,3.81
field,0.641,-0.5,-0.405,-0.066,-0.355,0.455,0.372,-0.373,1.954,3.841
house,0.01,-0.102,-0.124,-0.605,-0.547,-0.135,-0.226,1.233,1.783,3.981
nest,-0.132,-0.03,0.195,-0.365,0.074,0.454,0.873,2.602,0.494,3.707
tropical,0.262,0.125,0.075,-0.587,0.685,0.3,-0.075,1.545,1.638,3.863
lake,-0.184,0.082,0.297,-0.724,0.609,0.302,0.437,0.359,2.336,4.003
large,0.257,0.034,0.268,-0.226,-0.225,0.29,0.021,0.04,0.51,3.648
mountain,0.038,-0.094,-0.026,-0.658,-0.069,0.447,0.719,1.33,1.213,3.781
farm,0.225,0.164,0.055,-0.516,-0.316,0.446,1.057,1.038,1.231,3.007
group,0.459,-0.256,-0.189,-0.347,-0.23,0.256,0.328,-0.557,-0.239,5.012
wild,0.128,0.06,0.027,-0.485,-0.111,0.347,1.102,1.442,0.107,3.105
very,-0.195,0.308,0.257,-0.434,0.138,0.002,-0.367,-0.339,0.099,3.123
single,0.313,0.221,-0.017,-0.284,-0.459,-0.169,-0.307,-0.032,0.227,4.746
barn,-0.113,0.043,-0.072,-0.644,-0.349,0.288,0.923,1.287,2.082,4.011
jungle,-0.302,-0.17,-0.105,-0.584,0.197,0.241,0.621,2.052,1.525,3.782
new,0.1,0.098,0.044,-0.431,0.649,0.02,-0.282,-1.297,-0.03,3.393
valley,-0.123,-0.325,-0.134,-0.825,-0.59,0.084,0.611,0.476,2.347,6.045
world,-0.04,0.008,-0.28,-0.616,-0.09,-0.458,-0.537,0.356,1.041,3.817
garden,0.132,-0.078,0.019,-0.696,-0.165,-0.072,-0.007,1.789,1.781,3.794
herd,0.245,0.013,0.187,-0.319,0.111,0.544,2.731,0.817,-0.509,4.462
grass,0.284,0.059,0.169,-0.455,-0.032,0.318,0.795,0.989,1.51,3.291
natural,0.593,-0.102,-0.104,-0.329,0.816,0.044,0.047,0.265,0.868,2.614
park,0.252,0.259,0.173,-0.607,-0.22,0.126,0.256,0.4,1.97,3.267
swamp,-0.311,-0.06,0.016,-0.537,0.098,0.13,0.393,1.28,2.361,4.392
laboratory,0.765,0.139,0.386,-0.926,0.887,0.055,-0.629,-0.43,1.954,4.939
nearby,-0.423,-0.353,-0.006,-0.289,-0.033,0.276,0.416,1.21,1.586,2.94
well,0.146,0.508,0.372,0.397,1.076,0.21,-0.11,0.485,0.44,-1.26
rural,0.309,0.223,0.155,-0.712,-0.541,0.387,0.369,-0.022,2.303,3.756
pond,-0.607,-0.096,-0.021,-0.124,0.367,0.452,0.801,1.19,2.428,2.817
dark,-0.239,0.303,0.311,-0.773,0.1,-0.201,-0.374,0.339,1.448,4.151
wood,0.058,-0.189,-0.1,-0.104,-0.068,0.008,-0.025,0.821,1.958,2.791
subter,0.072,0.1,-0.029,-0.603,0.77,0.34,-0.198,2.017,1.517,4.326
room,-0.268,-0.227,-0.131,-0.864,-0.297,-0.383,-0.483,0.795,1.977,3.892
lab,0.605,0.007,0.338,-0.753,0.418,0.115,-0.122,-0.125,1.57,3.953
cage,0.072,0.055,-0.11,-0.713,-0.315,0.127,0.069,1.47,2.121,4.416
huge,-0.349,0.24,0.307,-0.482,0.159,0.145,0.024,-0.103,0.238,4.319
New,0.393,0.22,0.104,-0.387,0.212,0.086,-0.433,-0.222,0.619,1.533
water,0.136,0.149,0.329,-0.411,0.711,-0.243,-0.001,-0.489,1.633,1.895
colony,0.422,-0.352,-0.107,-0.232,0.34,0.255,0.328,1.193,0.332,4.908
massive,-0.194,0.291,0.25,-0.628,0.356,0.208,0.052,-0.387,0.307,4.737
common,0.353,-0.094,0.044,-0.17,0.059,-0.079,-0.321,0.503,0.29,3.454
state,1.193,-0.046,-0.281,-0.172,-0.548,0.047,-0.548,-1.41,1.669,2.883
deep,-0.183,0.167,0.301,-0.65,0.942,0.053,-0.284,0.802,0.747,2.158
home,0.147,0.072,-0.057,-0.523,-0.455,0.025,-0.278,0.757,1.071,2.175
man,-0.016,0.084,0.15,-0.547,-0.07,0.181,-0.133,0.088,0.025,3.175
mine,-0.731,0.068,0.392,-0.281,0.078,0.0,0.223,0.785,1.218,2.802
human,0.666,0.248,0.391,-0.601,0.393,-0.28,-0.29,0.552,-0.021,2.461
rock,-0.264,-0.176,-0.077,-0.578,0.712,0.14,0.12,0.819,1.106,2.647
region,0.158,-0.562,-0.496,-0.252,0.071,0.121,0.098,-0.688,1.912,4.417
box,-0.041,-0.335,-0.334,-0.195,-0.468,-0.215,-0.479,0.196,1.822,4.1
river,-0.202,-0.194,-0.024,-0.244,-0.088,0.185,0.31,0.437,1.548,3.849
part,0.287,0.151,0.016,-0.276,0.311,0.239,0.012,-0.049,0.082,1.369
hollow,-0.129,-0.096,0.143,-0.081,0.313,0.121,-0.055,1.193,1.983,3.356
c,0.137,-0.437,-0.514,0.5,-0.45,0.188,0.383,0.972,0.345,1.417
hole,-0.461,-0.356,-0.232,-0.667,0.435,0.188,-0.314,0.747,1.873,4.319
vast,0.221,0.088,0.276,-0.285,0.496,0.061,-0.171,-0.033,0.562,4.333
village,0.266,-0.161,-0.196,-0.611,-1.021,0.485,0.468,0.05,1.632,4.827
different,0.322,-0.133,-0.077,-0.289,0.021,0.086,-0.236,-0.149,1.074,1.965
virtual,0.696,-0.251,-0.262,-0.079,0.453,-0.167,-0.492,0.898,0.916,3.037
city,0.621,0.165,0.01,-0.505,-0.501,-0.07,-0.596,-0.109,1.435,3.015
strange,-0.528,0.162,0.187,-0.402,0.752,0.14,-0.065,0.676,0.739,3.158
greenhouse,0.81,0.045,0.2,-0.545,0.772,-0.258,-0.337,0.474,2.316,3.747
frozen,-0.034,0.103,0.383,-0.34,0.88,0.177,0.199,0.982,1.174,2.703
shallow,-0.125,-0.265,-0.124,-0.443,0.28,0.327,-0.1,0.174,2.292,4.472
semi,0.423,0.1,-0.086,-0.39,0.019,0.206,0.063,0.521,0.942,3.75
flat,-0.104,-0.125,-0.082,-0.436,-0.301,0.028,-0.31,1.321,0.976,3.158
patch,-0.303,-0.471,-0.708,0.059,0.794,0.289,0.023,0.202,1.494,3.722
mysterious,-0.5,0.461,0.517,-0.436,0.929,0.236,-0.251,0.425,0.308,4.01
local,0.498,-0.011,0.051,-0.35,-0.402,0.107,-0.146,-0.283,0.856,2.727
giant,0.007,0.268,0.274,-0.567,0.464,0.06,0.161,0.252,0.155,3.5
sub,0.37,-0.374,-0.414,0.092,-0.091,0.006,-0.004,0.098,0.669,2.95
barren,0.069,0.102,0.299,-0.136,0.272,0.368,0.586,0.997,1.755,2.808
special,0.193,0.094,0.018,-0.557,0.062,0.087,-0.324,-0.013,0.626,2.897
mountainous,-0.193,-0.059,-0.113,-0.414,0.029,0.627,0.309,0.612,2.402,4.261
mud,-0.467,-0.034,-0.003,-0.447,0.425,0.226,0.578,0.854,2.154,3.079
cemetery,0.166,0.325,0.3,-0.797,-0.38,0.218,-0.172,0.793,2.028,4.853
pod,0.284,-0.257,-0.377,0.112,0.078,0.228,0.306,1.43,0.559,4.059
hive,-0.152,-0.396,-0.304,-0.187,0.199,-0.054,0.209,1.783,0.741,5.058
newly,0.333,0.424,0.482,-0.248,0.4,0.256,-0.025,-0.456,0.482,3.192
closed,-0.057,-0.206,-0.269,-0.004,-0.051,0.178,0.103,0.056,2.153,2.264
community,0.346,-0.203,-0.135,-0.278,-0.684,-0.11,-0.099,-0.065,0.98,3.359
California,0.739,0.556,0.522,-0.74,0.107,0.172,-0.173,-0.463,1.443,1.589
place,0.091,-0.388,-0.419,-0.126,-0.169,-0.096,-0.315,-0.394,1.029,3.661
flooded,-0.535,0.015,0.195,-0.326,0.163,0.554,0.754,0.586,2.224,3.415
prehistoric,0.124,0.149,0.223,-0.614,1.263,0.448,0.37,1.217,0.485,3.582
sw,-0.09,-0.075,-0.149,-0.105,-0.573,0.39,0.93,0.27,0.688,2.742
high,0.277,0.316,0.303,-0.599,-0.331,-0.126,-0.486,-0.232,0.522,2.757
z,-0.084,-0.296,-0.378,0.357,-0.008,0.014,0.55,0.579,0.302,2.076
hot,-0.02,0.238,0.213,-0.566,0.585,-0.053,-0.227,-0.044,1.021,2.562
far,0.138,0.375,0.399,-0.574,0.221,0.141,0.172,0.155,-0.168,1.802
1,0.075,-0.044,-0.11,0.365,0.273,-0.277,-0.236,-0.28,0.23,-0.147
pasture,0.142,-0.192,-0.232,-0.145,-0.163,0.69,1.804,0.995,2.136,2.995
