In [1]:
import os
import copy
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from transformers import AutoTokenizer, AutoModelWithLMHead
import shap
from shap.utils import GenerateLogits
import scipy as sp
import nlp
import torch

In [2]:
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6")
model =  AutoModelWithLMHead.from_pretrained("sshleifer/distilbart-xsum-12-6").cuda()

In [3]:
dataset = nlp.load_dataset('xsum',split='train')

Using custom data configuration default


### Explanation using model decoder to generate logits scores

In [4]:
# Instantiate object of GenerateLogits with model and tokenizer
logit_generator_model = GenerateLogits(model=model,tokenizer=tokenizer)

In [5]:
def gen_kwargs(x):
    model.eval()
    input_ids = torch.tensor([tokenizer.encode(x)]).cuda()
    with torch.no_grad():
        # generate input ids for output translation which we aim to explain
        out=model.generate(input_ids)
    # slice sentence by removing bos and eos token ids
    target_sentence_ids = out[:,1:-1].cuda()
    output_names = logit_generator_model.get_output_names(target_sentence_ids)
    del out, input_ids
    return {'target_sentence_ids':target_sentence_ids,'output_names':output_names}

In [6]:
# This function defines the additional arguments passed onto the model function required inorder to get conditional logits corresponding to the original input sentence translation
def f_kwargs(x):
    kwargs = gen_kwargs(x)
    return kwargs

In [7]:
def f(x_batch,**kwargs):
    output_batch=[]
    # Extract target sentence for which we want to generate conditional logits
    target_sentence_ids=kwargs['target_sentence_ids']
    for i,source_sentence in enumerate(x_batch):
        conditional_logits = logit_generator_model.generate_logits(source_sentence,target_sentence_ids)
        output_batch.append(conditional_logits)
    return np.array(output_batch)

In [8]:
explainer = shap.Explainer(f,tokenizer)
explainer.masker.mask_token_id=None
explainer.masker.mask_token=""

In [9]:
shap_values = explainer(dataset['document'][0:1],model_kwargs=f_kwargs)



#### Visualize explanation using text plot

In [None]:
shap.plots.text(shap_values[0])

### Explanation by approximating logit scores using a language model (distilgpt2)

In [10]:
lm_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
lm_model =  AutoModelWithLMHead.from_pretrained("distilgpt2")

In [11]:
logit_generator_model_lm = GenerateLogits(model=lm_model,tokenizer=lm_tokenizer,device='cpu')

In [12]:
# return model prediction
def f_lm_predict(x):
    model.eval()
    input_ids = torch.tensor([tokenizer.encode(x)]).cuda()
    with torch.no_grad():
        out=model.generate(input_ids)
    sentence = [tokenizer.decode(g, skip_special_tokens=True) for g in out][0]
    del input_ids, out
    return sentence

In [13]:
def gen_kwargs_lm(x):
    target_sentence = f_lm_predict(x)
    output_names = logit_generator_model_lm.get_output_names(target_sentence)
    return {'target_sentence':target_sentence,'output_names':output_names}

In [14]:
def f_kwargs_lm(x):
    kwargs = gen_kwargs_lm(x)
    return kwargs

In [15]:
def f_lm(x_batch,**kwargs):
    output_batch=[]
    # Extract target sentence for which we want to generate conditional logits
    target_sentence=kwargs['target_sentence']
    for i,x in enumerate(x_batch):
        source_sentence = f_lm_predict(x)
        conditional_logits = logit_generator_model_lm.generate_logits(source_sentence,target_sentence)
        output_batch.append(conditional_logits)
    return np.array(output_batch)

In [16]:
explainer_lm = shap.Explainer(f_lm,tokenizer)
explainer_lm.masker.mask_token_id=None
explainer_lm.masker.mask_token=""

In [17]:
shap_values_lm = explainer_lm(dataset['document'][0:1],model_kwargs=f_kwargs_lm)

Partition explainer: 2it [01:03, 31.67s/it]               


#### Visualize explanation using text plot

In [None]:
shap.plots.text(shap_values_lm[0])

In [None]:
class TeacherForcingLogits(Model):
    def __init__(self, model, tokenizer, generation_function = None, text_similarity_model = None, text_similarity_tokenizer = None):