In [1]:
import os
import copy
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from transformers import AutoTokenizer, AutoModelWithLMHead
import shap
from shap.utils import GenerateLogits
import scipy as sp
import nlp
import torch

In [2]:
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6")
model =  AutoModelWithLMHead.from_pretrained("sshleifer/distilbart-xsum-12-6").cuda()

The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.


In [3]:
dataset = nlp.load_dataset('xsum',split='train')



### Explanation using model decoder to generate logits scores

In [3]:
# Instantiate object of GenerateLogits with model and tokenizer
logit_generator_model = GenerateLogits(model=model,tokenizer=tokenizer)

In [4]:
def f_predict(x):
    model.eval()
    input_ids = torch.tensor([tokenizer.encode(x)]).cuda()
    with torch.no_grad():
        out=model.generate(input_ids)
    sentence = [tokenizer.decode(g, skip_special_tokens=True) for g in out][0]
    del input_ids, out
    return sentence

In [22]:
sentence=["Canada's Minister of Defense resigned today, a day after an army official testified that top military officials had altered documents to cover up responsibility for the beating death of a Somali teen-ager at the hands of Canadian peacekeeping troops in 1992. Defense minister David Collenette insisted that his resignation had nothing to do with the Somalia's scandal. Ted Williams was the first name to come to mind, and he's the greatest living hitter. ..."]
f_predict(sentence[0])

"Think of a baseball player and you're likely to think of Ted Williams."

In [5]:
def gen_kwargs(x):
    model.eval()
    input_ids = torch.tensor([tokenizer.encode(x)]).cuda()
    with torch.no_grad():
        # generate input ids for output translation which we aim to explain
        out=model.generate(input_ids)
    # slice sentence by removing bos and eos token ids
    target_sentence_ids = out[:,1:-1].cuda()
    output_names = logit_generator_model.get_output_names(target_sentence_ids)
    print(output_names)
    del out, input_ids
    return {'target_sentence_ids':target_sentence_ids,'output_names':output_names,'fixed_context':None}

In [6]:
# This function defines the additional arguments passed onto the model function required inorder to get conditional logits corresponding to the original input sentence translation
def f_kwargs(x):
    kwargs = gen_kwargs(x)
    return kwargs

In [7]:
def f(x_batch,**kwargs):
    output_batch=[]
    # Extract target sentence for which we want to generate conditional logits
    target_sentence_ids=kwargs['target_sentence_ids']
    for i,source_sentence in enumerate(x_batch):
        conditional_logits = logit_generator_model.generate_logits(source_sentence,target_sentence_ids)
        output_batch.append(conditional_logits)
    return np.array(output_batch)

In [8]:
explainer = shap.Explainer(f,tokenizer)
explainer.masker.mask_token_id=None
explainer.masker.mask_token="<infill>"

explainers.Partition is still in an alpha state, so use with caution...


In [9]:
sentence=["Canada's Minister of Defense resigned today, a day after an army official testified that top military officials had altered documents to cover up responsibility for the beating death of a Somali teen-ager at the hands of Canadian peacekeeping troops in 1992. Defense minister David Collenette insisted that his resignation had nothing to do with the Somalia's scandal. Ted Williams was the first name to come to mind, and he's the greatest living hitter."]

In [10]:
shap_values = explainer(sentence,model_kwargs=f_kwargs)

['Think', 'Ġof', 'Ġthe', 'Ġmost', 'Ġfamous', 'Ġbaseball', 'Ġplayer', 'Ġin', 'Ġthe', 'Ġworld', '.']
[-8.70479713 -0.79089107 -1.59380761 -3.78756783 -2.94892869 -7.26049415
 -2.29238384 -3.67837013 -1.61220231  0.1103331  -4.06675762] (11,)
[-3.20295928  1.26738027 -1.82697374 -1.24433248  1.25981615 -0.04890383
  0.88660428 -1.7513906   0.88193335  1.26045096 -2.58991679] (11,)
[ 0  1  2  3  4  5  6  7  8  9 10] (11,)


In [11]:
shap.plots.text(shap_values[0])

invalid value encountered in double_scalars


Unnamed: 0_level_0,Unnamed: 1_level_0,Canada,'s,Minister,of,Defense,resigned,today,",",a,day,after,an,army,official,testified,that,top,military,officials,had,altered,documents,to,cover,up,responsibility,for,the,beating,death,of,a,Somali,teen,-,ager,at,the,hands,of,Canadian,peace,keeping,troops,in,1992,.,Defense,minister,David,Coll,en,ette,insisted,that,his,resignation,had,nothing,to,do,with,the,Somalia,'s,scandal,.,Ted,Williams,was,the,first,name,to,come,to,mind,",",and,he,'s,the,greatest,living,hitter,.,Unnamed: 88_level_0
Think,0.095,0.095,0.095,0.095,0.095,0.095,0.095,0.095,0.095,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,-0.032,-0.032,-0.032,-0.032,-0.032,-0.032,-0.032,-0.032,-0.03,-0.03,-0.03,-0.03,-0.03,-0.03,-0.03,-0.03,-0.02,-0.02,-0.02,-0.02,-0.02,-0.02,-0.02,-0.043,-0.043,-0.043,-0.043,-0.043,-0.043,-0.043,-0.043,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,0.345,0.345,0.345,0.345,0.345,0.345,0.345,0.345,0.345,0.345,0.345,0.23,0.23,0.23,0.23,0.228,0.228,0.228,0.228,0.0
of,0.012,0.012,0.012,0.012,0.012,0.012,0.012,0.012,0.012,-0.019,-0.019,-0.019,-0.019,-0.019,-0.019,-0.019,-0.019,-0.019,-0.019,-0.019,-0.019,-0.019,-0.019,-0.019,-0.019,-0.04,-0.04,-0.04,-0.04,-0.04,-0.04,-0.04,-0.04,-0.016,-0.016,-0.016,-0.016,-0.016,-0.016,-0.016,-0.016,-0.006,-0.006,-0.006,-0.006,-0.006,-0.006,-0.006,0.002,0.002,0.002,0.002,0.002,0.002,0.002,0.002,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.176,0.176,0.176,0.176,0.176,0.176,0.176,0.176,0.176,0.176,0.176,0.066,0.066,0.066,0.066,0.07,0.07,0.07,0.07,0.0
the,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,0.003,0.003,0.003,0.003,0.003,0.003,0.003,0.003,0.006,0.006,0.006,0.006,0.006,0.006,0.006,0.006,0.012,0.012,0.012,0.012,0.012,0.012,0.012,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.01,-0.01,-0.01,-0.01,-0.01,-0.01,-0.01,-0.01,-0.01,-0.01,-0.01,-0.01,0.017,0.017,0.017,0.017,0.017,0.017,0.017,0.017,0.017,0.017,0.017,-0.009,-0.009,-0.009,-0.009,-0.023,-0.023,-0.023,-0.023,0.0
most,0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.037,0.037,0.037,0.037,0.037,0.037,0.037,-0.082,-0.082,-0.082,-0.082,-0.082,-0.082,-0.082,-0.082,-0.017,-0.017,-0.017,-0.017,-0.017,-0.017,-0.017,-0.017,-0.017,-0.017,-0.017,-0.017,0.149,0.149,0.149,0.149,0.149,0.149,0.149,0.149,0.149,0.149,0.149,0.194,0.194,0.194,0.194,0.191,0.191,0.191,0.191,0.0
famous,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.051,-0.051,-0.051,-0.051,-0.051,-0.051,-0.051,-0.051,-0.051,-0.051,-0.051,-0.051,-0.051,-0.051,-0.051,-0.051,-0.027,-0.027,-0.027,-0.027,-0.027,-0.027,-0.027,-0.027,-0.048,-0.048,-0.048,-0.048,-0.048,-0.048,-0.048,-0.048,0.023,0.023,0.023,0.023,0.023,0.023,0.023,-0.061,-0.061,-0.061,-0.061,-0.061,-0.061,-0.061,-0.061,0.028,0.028,0.028,0.028,0.028,0.028,0.028,0.028,0.028,0.028,0.028,0.028,0.384,0.384,0.384,0.384,0.384,0.384,0.384,0.384,0.384,0.384,0.384,0.131,0.131,0.131,0.131,0.227,0.227,0.227,0.227,0.0
baseball,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,-0.008,-0.008,-0.008,-0.008,-0.008,-0.008,-0.008,-0.008,-0.008,-0.008,-0.008,-0.008,-0.008,-0.008,-0.008,-0.008,0.003,0.003,0.003,0.003,0.003,0.003,0.003,0.003,-0.076,-0.076,-0.076,-0.076,-0.076,-0.076,-0.076,-0.076,-0.047,-0.047,-0.047,-0.047,-0.047,-0.047,-0.047,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,0.283,0.283,0.283,0.283,0.283,0.283,0.283,0.283,0.283,0.283,0.283,0.558,0.558,0.558,0.558,0.837,0.837,0.837,0.837,0.0
player,0.037,0.037,0.037,0.037,0.037,0.037,0.037,0.037,0.037,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,0.035,0.035,0.035,0.035,0.035,0.035,0.035,0.035,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.003,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.097,0.097,0.097,0.097,0.097,0.097,0.097,0.097,0.097,0.097,0.097,0.173,0.173,0.173,0.173,0.214,0.214,0.214,0.214,0.0
in,0.007,0.007,0.007,0.007,0.007,0.007,0.007,0.007,0.007,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.01,-0.01,-0.01,-0.01,-0.01,-0.01,-0.01,-0.01,-0.025,-0.025,-0.025,-0.025,-0.025,-0.025,-0.025,0.057,0.057,0.057,0.057,0.057,0.057,0.057,0.057,0.075,0.075,0.075,0.075,0.075,0.075,0.075,0.075,0.075,0.075,0.075,0.075,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.074,0.074,0.074,0.074,0.088,0.088,0.088,0.088,0.0
the,0.064,0.064,0.064,0.064,0.064,0.064,0.064,0.064,0.064,0.027,0.027,0.027,0.027,0.027,0.027,0.027,0.027,0.027,0.027,0.027,0.027,0.027,0.027,0.027,0.027,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,0.002,0.002,0.002,0.002,0.002,0.002,0.002,0.002,-0.015,-0.015,-0.015,-0.015,-0.015,-0.015,-0.015,0.069,0.069,0.069,0.069,0.069,0.069,0.069,0.069,0.053,0.053,0.053,0.053,0.053,0.053,0.053,0.053,0.053,0.053,0.053,0.053,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.038,0.038,0.038,0.038,0.026,0.026,0.026,0.026,0.0
world,0.051,0.051,0.051,0.051,0.051,0.051,0.051,0.051,0.051,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,-0.024,-0.024,-0.024,-0.024,-0.024,-0.024,-0.024,-0.024,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.021,-0.021,-0.021,-0.021,-0.021,-0.021,-0.021,0.035,0.035,0.035,0.035,0.035,0.035,0.035,0.035,0.032,0.032,0.032,0.032,0.032,0.032,0.032,0.032,0.032,0.032,0.032,0.032,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,-0.065,0.114,0.114,0.114,0.114,0.125,0.125,0.125,0.125,0.0
.,0.04,0.04,0.04,0.04,0.04,0.04,0.04,0.04,0.04,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,-0.031,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.045,-0.045,-0.045,-0.045,-0.045,-0.045,-0.045,-0.045,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,0.077,0.077,0.077,0.077,0.077,0.077,0.077,0.077,0.077.1,0.077,0.077,0.11,0.11,0.11,0.11,0.194,0.194,0.194,0.194,0.0


In [76]:
sentence_1=["Canada's Minister of Defense resigned today, a day after an army official testified that top military officials had altered documents to cover up responsibility for the beating death of a Somali teen-ager at the hands of Canadian peacekeeping troops in 1992. Defense minister David Collenette insisted that his resignation had nothing to do with the Somalia's scandal."]

In [77]:
explainer_1 = shap.Explainer(f,tokenizer)
explainer_1.masker.mask_token_id=None
explainer_1.masker.mask_token="<infill>"

explainers.Partition is still in an alpha state, so use with caution...


In [78]:
shap_values_1 = explainer_1(sentence_1,model_kwargs=f_kwargs)

['Canada', "'s", 'Ġgovernment', 'Ġhas', 'Ġbeen', 'Ġaccused', 'Ġof', 'Ġcovering', 'Ġup', 'Ġthe', 'Ġdeath', 'Ġof', 'Ġa', 'ĠSomali', 'Ġteenager', 'Ġin', 'Ġthe', 'Ġ1990', 's', '.']
[-8.15470622 -1.15579588 -4.59085941 -0.74557675 -4.72267899 -1.18106749
  1.50782969 -6.9694069   2.53513752 -1.21564111 -2.86950386  2.13763188
 -1.93039932 -7.84776692 -4.22949403 -1.34443042 -2.33908424 -5.68284087
  2.98158524 -1.64317105] (20,)
[-1.18920808 -0.82581124 -1.61407043  0.23026044 -1.89440115 -0.59630836
  1.0806098  -2.21626642  3.02949138 -0.93755033 -0.8430304   2.40176673
  1.84053891  0.70533753  0.35256843 -1.14111113  0.35015019  0.98970668
  3.12948769  0.99192094] (20,)
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19] (20,)


In [79]:
shap_values_1.output_names

array([['Canada', "'s", 'Ġgovernment', 'Ġhas', 'Ġbeen', 'Ġaccused',
        'Ġof', 'Ġcovering', 'Ġup', 'Ġthe', 'Ġdeath', 'Ġof', 'Ġa',
        'ĠSomali', 'Ġteenager', 'Ġin', 'Ġthe', 'Ġ1990', 's', '.']],
      dtype='<U11')

In [80]:
shap.plots.text(shap_values_1[0])

invalid value encountered in double_scalars


Unnamed: 0_level_0,Unnamed: 1_level_0,Canada,'s,Minister,of,Defense,resigned,today,",",a,day,after,an,army,official,testified,that,top,military,officials,had,altered,documents,to,cover,up,responsibility,for,the,beating,death,of,a,Somali,teen,-,ager,at,the,hands,of,Canadian,peace,keeping,troops,in,1992,.,Defense,minister,David,Coll,en,ette,insisted,that,his,resignation,had,nothing,to,do,with,the,Somalia,'s,scandal,.,Unnamed: 69_level_0
Canada,0.517,0.517,0.517,0.517,0.105,0.105,0.105,0.105,0.105,0.024,0.024,0.024,0.024,0.024,0.024,0.024,0.024,0.024,0.024,0.024,0.024,0.024,0.024,0.024,0.024,0.006,0.006,0.006,0.006,0.006,0.006,0.006,0.006,0.016,0.016,0.016,0.016,0.016,0.016,0.016,0.016,0.226,0.226,0.226,0.226,0.226,0.226,0.226,0.365,0.365,0.365,0.365,0.323,0.323,0.323,0.323,-0.03,-0.03,-0.03,-0.03,-0.03,-0.03,-0.03,-0.03,-0.07,-0.07,-0.07,-0.07,0.0
's,-0.132,-0.132,-0.132,-0.132,-0.109,-0.109,-0.109,-0.109,-0.109,0.031,0.031,0.031,0.031,0.031,0.031,0.031,0.031,0.031,0.031,0.031,0.031,0.031,0.031,0.031,0.031,0.034,0.034,0.034,0.034,0.034,0.034,0.034,0.034,-0.02,-0.02,-0.02,-0.02,-0.02,-0.02,-0.02,-0.02,-0.022,-0.022,-0.022,-0.022,-0.022,-0.022,-0.022,0.033,0.033,0.033,0.033,0.02,0.02,0.02,0.02,0.08,0.08,0.08,0.08,0.08,0.08,0.08,0.08,0.024,0.024,0.024,0.024,0.0
government,0.028,0.028,0.028,0.028,0.109,0.109,0.109,0.109,0.109,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,-0.042,-0.042,-0.042,-0.042,-0.042,-0.042,-0.042,-0.042,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.023,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.227,0.227,0.227,0.227,0.16,0.16,0.16,0.16,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,0.083,0.083,0.083,0.083,0.0
has,-0.036,-0.036,-0.036,-0.036,-0.028,-0.028,-0.028,-0.028,-0.028,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.006,0.006,0.006,0.006,0.006,0.006,0.006,0.006,0.01,0.01,0.01,0.01,0.01,0.01,0.01,0.062,0.062,0.062,0.062,0.023,0.023,0.023,0.023,0.048,0.048,0.048,0.048,0.048,0.048,0.048,0.048,0.015,0.015,0.015,0.015,0.0
been,0.209,0.209,0.209,0.209,0.119,0.119,0.119,0.119,0.119,0.047,0.047,0.047,0.047,0.047,0.047,0.047,0.047,0.047,0.047,0.047,0.047,0.047,0.047,0.047,0.047,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.001,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,-0.005,0.01,0.01,0.01,0.01,0.01,0.01,0.01,0.076,0.076,0.076,0.076,0.05,0.05,0.05,0.05,-0.023,-0.023,-0.023,-0.023,-0.023,-0.023,-0.023,-0.023,0.076,0.076,0.076,0.076,0.0
accused,0.007,0.007,0.007,0.007,-0.011,-0.011,-0.011,-0.011,-0.011,0.065,0.065,0.065,0.065,0.065,0.065,0.065,0.065,0.065,0.065,0.065,0.065,0.065,0.065,0.065,0.065,0.045,0.045,0.045,0.045,0.045,0.045,0.045,0.045,-0.018,-0.018,-0.018,-0.018,-0.018,-0.018,-0.018,-0.018,-0.047,-0.047,-0.047,-0.047,-0.047,-0.047,-0.047,0.037,0.037,0.037,0.037,-0.019,-0.019,-0.019,-0.019,-0.072,-0.072,-0.072,-0.072,-0.072,-0.072,-0.072,-0.072,0.045,0.045,0.045,0.045,0.0
of,0.001,0.001,0.001,0.001,0.005,0.005,0.005,0.005,0.005,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,-0.018,-0.018,-0.018,-0.018,-0.018,-0.018,-0.018,-0.018,-0.007,-0.007,-0.007,-0.007,-0.007,-0.007,-0.007,-0.007,-0.016,-0.016,-0.016,-0.016,-0.016,-0.016,-0.016,-0.015,-0.015,-0.015,-0.015,-0.015,-0.015,-0.015,-0.015,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,0.016,0.016,0.016,0.016,0.0
covering,0.028,0.028,0.028,0.028,0.079,0.079,0.079,0.079,0.079,0.091,0.091,0.091,0.091,0.091,0.091,0.091,0.091,0.091,0.091,0.091,0.091,0.091,0.091,0.091,0.091,0.171,0.171,0.171,0.171,0.171,0.171,0.171,0.171,0.033,0.033,0.033,0.033,0.033,0.033,0.033,0.033,0.005,0.005,0.005,0.005,0.005,0.005,0.005,-0.047,-0.047,-0.047,-0.047,0.001,0.001,0.001,0.001,0.06,0.06,0.06,0.06,0.06,0.06,0.06,0.06,0.205,0.205,0.205,0.205,0.0
up,-0.005,-0.005,-0.005,-0.005,0.007,0.007,0.007,0.007,0.007,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.011,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,-0.011,-0.011,-0.011,-0.011,-0.011,-0.011,-0.011,0.022,0.022,0.022,0.022,0.004,0.004,0.004,0.004,-0.002,-0.002,-0.002,-0.002,-0.002,-0.002,-0.002,-0.002,0.039,0.039,0.039,0.039,0.0
the,0.009,0.009,0.009,0.009,0.038,0.038,0.038,0.038,0.038,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,-0.023,-0.023,-0.023,-0.023,-0.023,-0.023,-0.023,-0.023,-0.012,-0.012,-0.012,-0.012,-0.012,-0.012,-0.012,-0.012,0.006,0.006,0.006,0.006,0.006,0.006,0.006,0.01,0.01,0.01,0.01,0.023,0.023,0.023,0.023,0.041,0.041,0.041,0.041,0.041,0.041,0.041,0.041,-0.055,-0.055,-0.055,-0.055,0.0
death,-0.005,-0.005,-0.005,-0.005,-0.01,-0.01,-0.01,-0.01,-0.01,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.008,0.187,0.187,0.187,0.187,0.187,0.187,0.187,0.187,0.066,0.066,0.066,0.066,0.066,0.066,0.066,0.066,0.056,0.056,0.056,0.056,0.056,0.056,0.056,-0.043,-0.043,-0.043,-0.043,-0.014,-0.014,-0.014,-0.014,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,-0.081,-0.081,-0.081,-0.081,0.0
of,0.003,0.003,0.003,0.003,0.0,0.0,0.0,0.0,0.0,-0.039,-0.039,-0.039,-0.039,-0.039,-0.039,-0.039,-0.039,-0.039,-0.039,-0.039,-0.039,-0.039,-0.039,-0.039,-0.039,0.09,0.09,0.09,0.09,0.09,0.09,0.09,0.09,0.038,0.038,0.038,0.038,0.038,0.038,0.038,0.038,0.026,0.026,0.026,0.026,0.026,0.026,0.026,-0.046,-0.046,-0.046,-0.046,-0.026,-0.026,-0.026,-0.026,0.024,0.024,0.024,0.024,0.024,0.024,0.024,0.024,-0.058,-0.058,-0.058,-0.058,0.0
a,-0.009,-0.009,-0.009,-0.009,0.01,0.01,0.01,0.01,0.01,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.168,0.168,0.168,0.168,0.168,0.168,0.168,0.168,0.119,0.119,0.119,0.119,0.119,0.119,0.119,0.119,0.039,0.039,0.039,0.039,0.039,0.039,0.039,0.093,0.093,0.093,0.093,0.04,0.04,0.04,0.04,-0.044,-0.044,-0.044,-0.044,-0.044,-0.044,-0.044,-0.044,0.131,0.131,0.131,0.131,0.0
Somali,0.03,0.03,0.03,0.03,-0.001,-0.001,-0.001,-0.001,-0.001,0.014,0.014,0.014,0.014,0.014,0.014,0.014,0.014,0.014,0.014,0.014,0.014,0.014,0.014,0.014,0.014,0.102,0.102,0.102,0.102,0.102,0.102,0.102,0.102,0.471,0.471,0.471,0.471,0.471,0.471,0.471,0.471,0.025,0.025,0.025,0.025,0.025,0.025,0.025,-0.012,-0.012,-0.012,-0.012,-0.023,-0.023,-0.023,-0.023,0.039,0.039,0.039,0.039,0.039,0.039,0.039,0.039,0.822,0.822,0.822,0.822,0.0
teenager,0.043,0.043,0.043,0.043,0.025,0.025,0.025,0.025,0.025,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.023,0.084,0.084,0.084,0.084,0.084,0.084,0.084,0.084,0.363,0.363,0.363,0.363,0.363,0.363,0.363,0.363,0.044,0.044,0.044,0.044,0.044,0.044,0.044,-0.017,-0.017,-0.017,-0.017,0.012,0.012,0.012,0.012,-0.064,-0.064,-0.064,-0.064,-0.064,-0.064,-0.064,-0.064,0.142,0.142,0.142,0.142,0.0
in,0.015,0.015,0.015,0.015,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.014,0.014,0.014,0.014,0.014,0.014,0.014,0.014,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,-0.026,0.013,0.013,0.013,0.013,0.013,0.013,0.013,0.002,0.002,0.002,0.002,-0.002,-0.002,-0.002,-0.002,-0.014,-0.014,-0.014,-0.014,-0.014,-0.014,-0.014,-0.014,-0.003,-0.003,-0.003,-0.003,0.0
the,0.026,0.026,0.026,0.026,0.003,0.003,0.003,0.003,0.003,0.009,0.009,0.009,0.009,0.009,0.009,0.009,0.009,0.009,0.009,0.009,0.009,0.009,0.009,0.009,0.009,0.099,0.099,0.099,0.099,0.099,0.099,0.099,0.099,0.11,0.11,0.11,0.11,0.11,0.11,0.11,0.11,0.134,0.134,0.134,0.134,0.134,0.134,0.134,-0.016,-0.016,-0.016,-0.016,-0.0,-0.0,-0.0,-0.0,0.01,0.01,0.01,0.01,0.01,0.01,0.01,0.01,-0.047,-0.047,-0.047,-0.047,0.0
1990,0.003,0.003,0.003,0.003,-0.021,-0.021,-0.021,-0.021,-0.021,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.117,0.117,0.117,0.117,0.117,0.117,0.117,0.117,0.108,0.108,0.108,0.108,0.108,0.108,0.108,0.108,0.472,0.472,0.472,0.472,0.472,0.472,0.472,-0.004,-0.004,-0.004,-0.004,0.027,0.027,0.027,0.027,0.019,0.019,0.019,0.019,0.019,0.019,0.019,0.019,0.156,0.156,0.156,0.156,0.0
s,-0.008,-0.008,-0.008,-0.008,-0.005,-0.005,-0.005,-0.005,-0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,0.005,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,0.006,0.006,0.006,0.006,0.006,0.006,0.006,0.006,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,-0.004,0.003,0.003,0.003,0.003,-0.007,-0.007,-0.007,-0.007,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.028,0.028,0.028,0.028,0.0
.,0.075,0.075,0.075,0.075,0.09,0.09,0.09,0.09,0.09,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.133,0.133,0.133,0.133,0.133,0.133,0.133,0.133,0.012,0.012,0.012,0.012,0.012,0.012,0.012,0.012,0.008,0.008,0.008,0.008,0.008,0.008,0.008,-0.012,-0.012,-0.012,-0.012,0.011,0.011,0.011,0.011,0.01,0.01,0.01,0.01,0.01,0.01,0.01,0.01,0.03,0.03,0.03,0.03,0.0


#### Visualize explanation using text plot

In [21]:
shap.plots.text(shap_values[0])

Unnamed: 0_level_0,Unnamed: 1_level_0,north,k,orea,is,entering,its,fourth,winter,of,chronic,food,shortages,with,its,people,mal,n,our,ished,and,at,risk,of,dying,from,normally,cur,able,illnesses,",",senior,red,cross,officials,said,t,uesday,Unnamed: 39_level_0
The,-0.089,-0.089,-0.089,-0.089,0.031,0.031,0.031,0.031,0.095,0.095,0.095,0.095,0.125,0.125,0.125,0.125,0.064,0.064,0.064,0.064,0.05,0.05,0.05,0.05,0.066,0.066,0.066,0.066,0.066,0.066,0.066,0.279,0.329,0.19,0.311,0.145,0.145,0.145,0.145
Red,0.09,0.09,0.09,0.09,0.008,0.008,0.008,0.008,0.137,0.137,0.137,0.137,0.105,0.105,0.105,0.105,0.151,0.151,0.151,0.151,0.083,0.083,0.083,0.083,-0.084,-0.084,-0.084,-0.084,-0.084,-0.084,-0.084,1.207,2.015,1.419,0.808,-0.085,-0.085,-0.085,-0.085
Cross,0.107,0.107,0.107,0.107,-0.039,-0.039,-0.039,-0.039,0.16,0.16,0.16,0.16,0.102,0.102,0.102,0.102,0.192,0.192,0.192,0.192,0.055,0.055,0.055,0.055,0.093,0.093,0.093,0.093,0.093,0.093,0.093,0.265,0.459,2.372,1.113,-0.062,-0.062,-0.062,-0.062
has,-0.008,-0.008,-0.008,-0.008,-0.005,-0.005,-0.005,-0.005,0.032,0.032,0.032,0.032,0.037,0.037,0.037,0.037,0.032,0.032,0.032,0.032,0.022,0.022,0.022,0.022,0.029,0.029,0.029,0.029,0.029,0.029,0.029,0.124,0.155,-0.009,0.116,0.002,0.002,0.002,0.002
warned,0.148,0.148,0.148,0.148,-0.006,-0.006,-0.006,-0.006,0.175,0.175,0.175,0.175,0.167,0.167,0.167,0.167,-0.01,-0.01,-0.01,-0.01,0.015,0.015,0.015,0.015,0.164,0.164,0.164,0.164,0.164,0.164,0.164,0.071,-0.017,-0.308,0.303,0.09,0.09,0.09,0.09
that,-0.054,-0.054,-0.054,-0.054,0.017,0.017,0.017,0.017,-0.025,-0.025,-0.025,-0.025,-0.017,-0.017,-0.017,-0.017,0.023,0.023,0.023,0.023,0.039,0.039,0.039,0.039,0.031,0.031,0.031,0.031,0.031,0.031,0.031,0.043,0.055,0.226,0.143,-0.003,-0.003,-0.003,-0.003
North,1.236,1.236,1.236,1.236,0.262,0.262,0.262,0.262,0.169,0.169,0.169,0.169,0.186,0.186,0.186,0.186,0.022,0.022,0.022,0.022,0.014,0.014,0.014,0.014,-0.058,-0.058,-0.058,-0.058,-0.058,-0.058,-0.058,-0.162,-0.213,0.191,0.383,-0.06,-0.06,-0.06,-0.06
Korea,0.453,0.453,0.453,0.453,0.042,0.042,0.042,0.042,0.081,0.081,0.081,0.081,0.06,0.06,0.06,0.06,0.084,0.084,0.084,0.084,0.062,0.062,0.062,0.062,-0.012,-0.012,-0.012,-0.012,-0.012,-0.012,-0.012,-0.251,-0.294,0.145,0.432,-0.115,-0.115,-0.115,-0.115
is,0.107,0.107,0.107,0.107,0.046,0.046,0.046,0.046,-0.035,-0.035,-0.035,-0.035,-0.022,-0.022,-0.022,-0.022,0.012,0.012,0.012,0.012,0.006,0.006,0.006,0.006,0.028,0.028,0.028,0.028,0.028,0.028,0.028,-0.019,0.004,0.006,0.107,0.016,0.016,0.016,0.016
facing,-0.115,-0.115,-0.115,-0.115,0.106,0.106,0.106,0.106,0.196,0.196,0.196,0.196,0.177,0.177,0.177,0.177,0.059,0.059,0.059,0.059,0.038,0.038,0.038,0.038,0.027,0.027,0.027,0.027,0.027,0.027,0.027,0.096,0.193,0.4,0.195,-0.104,-0.104,-0.104,-0.104
its,-0.026,-0.026,-0.026,-0.026,0.12,0.12,0.12,0.12,0.202,0.202,0.202,0.202,0.177,0.177,0.177,0.177,0.042,0.042,0.042,0.042,0.01,0.01,0.01,0.01,-0.007,-0.007,-0.007,-0.007,-0.007,-0.007,-0.007,-0.096,-0.118,-0.146,-0.086,0.032,0.032,0.032,0.032
worst,0.006,0.006,0.006,0.006,-0.169,-0.169,-0.169,-0.169,0.066,0.066,0.066,0.066,0.071,0.071,0.071,0.071,0.074,0.074,0.074,0.074,0.048,0.048,0.048,0.048,0.043,0.043,0.043,0.043,0.043,0.043,0.043,0.004,0.039,0.27,0.12,0.05,0.05,0.05,0.05
winter,0.029,0.029,0.029,0.029,0.072,0.072,0.072,0.072,0.691,0.691,0.691,0.691,0.424,0.424,0.424,0.424,-0.044,-0.044,-0.044,-0.044,-0.033,-0.033,-0.033,-0.033,-0.061,-0.061,-0.061,-0.061,-0.061,-0.061,-0.061,-0.145,0.001,0.145,-0.033,-0.009,-0.009,-0.009,-0.009
on,-0.119,-0.119,-0.119,-0.119,-0.058,-0.058,-0.058,-0.058,0.179,0.179,0.179,0.179,0.133,0.133,0.133,0.133,-0.036,-0.036,-0.036,-0.036,-0.02,-0.02,-0.02,-0.02,0.012,0.012,0.012,0.012,0.012,0.012,0.012,-0.269,-0.248,-0.144,-0.208,0.049,0.049,0.049,0.049
record,-0.096,-0.096,-0.096,-0.096,0.038,0.038,0.038,0.038,0.115,0.115,0.115,0.115,0.083,0.083,0.083,0.083,0.005,0.005,0.005,0.005,-0.003,-0.003,-0.003,-0.003,0.011,0.011,0.011,0.011,0.011,0.011,0.011,-0.199,-0.164,-0.124,0.037,0.011,0.011,0.011,0.011
.,0.006,0.006,0.006,0.006,0.043,0.043,0.043,0.043,0.118,0.118,0.118,0.118,0.119,0.119,0.119,0.119,-0.005,-0.005,-0.005,-0.005,0.004,0.004,0.004,0.004,-0.021,-0.021,-0.021,-0.021,-0.021,-0.021,-0.021,-0.121,-0.173,-0.413,-0.499,-0.031,-0.031,-0.031,-0.031


### Explanation by approximating logit scores using a language model (distilgpt2)

In [10]:
lm_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
lm_model =  AutoModelWithLMHead.from_pretrained("distilgpt2")

In [11]:
logit_generator_model_lm = GenerateLogits(model=lm_model,tokenizer=lm_tokenizer,device='cpu')

In [12]:
# return model prediction
def f_lm_predict(x):
    model.eval()
    input_ids = torch.tensor([tokenizer.encode(x)]).cuda()
    with torch.no_grad():
        out=model.generate(input_ids)
    sentence = [tokenizer.decode(g, skip_special_tokens=True) for g in out][0]
    del input_ids, out
    return sentence

In [13]:
def gen_kwargs_lm(x):
    target_sentence = f_lm_predict(x)
    output_names = logit_generator_model_lm.get_output_names(target_sentence)
    return {'target_sentence':target_sentence,'output_names':output_names}

In [14]:
def f_kwargs_lm(x):
    kwargs = gen_kwargs_lm(x)
    return kwargs

In [15]:
def f_lm(x_batch,**kwargs):
    output_batch=[]
    # Extract target sentence for which we want to generate conditional logits
    target_sentence=kwargs['target_sentence']
    for i,x in enumerate(x_batch):
        source_sentence = f_lm_predict(x)
        conditional_logits = logit_generator_model_lm.generate_logits(source_sentence,target_sentence)
        output_batch.append(conditional_logits)
    return np.array(output_batch)

In [16]:
explainer_lm = shap.Explainer(f_lm,tokenizer)
explainer_lm.masker.mask_token_id=None
explainer_lm.masker.mask_token=""

In [17]:
shap_values_lm = explainer_lm(dataset['document'][0:1],model_kwargs=f_kwargs_lm)

Partition explainer: 2it [01:03, 31.67s/it]               


#### Visualize explanation using text plot

In [None]:
shap.plots.text(shap_values_lm[0])