In [1]:
import os
import copy
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from transformers import AutoTokenizer, AutoModelWithLMHead
import shap
from shap.utils import GenerateLogits
import scipy as sp
import nlp
import torch

In [3]:
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es",use_fast=True)
model = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-en-es").cuda()

In [4]:
# Instantiate object of GenerateLogits with model and tokenizer
logit_generator_model = GenerateLogits(model=model,tokenizer=tokenizer)

In [5]:
def gen_kwargs(x):
    model.eval()
    input_ids = torch.tensor([tokenizer.encode(x)]).cuda()
    with torch.no_grad():
        # generate input ids for output translation which we aim to explain
        out=model.generate(input_ids)
    target_sentence_ids = out[:,1:-1].cuda()
    output_names = logit_generator_model.get_output_names(target_sentence_ids)
    del out, input_ids
    return {'target_sentence_ids':target_sentence_ids,'output_names':output_names}

In [6]:
# This function defines the additional arguments passed onto the model function required inorder to get conditional logits corresponding to the original input sentence translation
def f_kwargs(x):
    kwargs = gen_kwargs(x)
    return kwargs

In [7]:
def f(x_batch,**kwargs):
    output_batch=[]
    # Extract target sentence for which we want to generate conditional logits
    target_sentence_ids=kwargs['target_sentence_ids']
    for i,source_sentence in enumerate(x_batch):
        conditional_logits = logit_generator_model.generate_logits(source_sentence,target_sentence_ids)
        output_batch.append(conditional_logits)
    return np.array(output_batch)

In [8]:
# Example function which returns a summary ids 
def example_summarize(x,model,tokenizer):
    print(f"Input: {x}")
    inputs = tokenizer([x], max_length=512, return_tensors='pt',truncation=True)
    input_ids=inputs['input_ids'].cuda()
    summary_ids = model.generate(input_ids).detach().cpu().numpy()
    del input_ids
    summary=[tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
    print(f"summary: {summary[0]}")
    return summary_ids[0,1:-1]

In [9]:
s="In this picture, there are four persons: my father, my mother, my brother and my sister."
ids=example_summarize(s,model,tokenizer)

Input: In this picture, there are four persons: my father, my mother, my brother and my sister.
summary: En este cuadro, hay cuatro personas: mi padre, mi madre, mi hermano y mi hermana.


In [10]:
explainer = shap.Explainer(f,tokenizer,model_kwargs=f_kwargs)

In [11]:
shap_values = explainer([s])

In [12]:
shap.plots.text(shap_values[0])

Unnamed: 0_level_0,In,this,picture,",",there,are,four,persons,:,my,father,",",my,mother,",",my,brother,and,my,sister,.,Unnamed: 22_level_0
En,4.781,1.04,0.963,0.193,0.008,-0.04,0.13,-0.02,-0.005,0.054,0.078,0.032,-0.029,-0.024,-0.064,-0.055,-0.062,-0.068,0.037,0.039,0.217,-0.039
este,-1.083,1.734,1.036,0.147,-0.024,0.088,0.022,0.486,0.605,-0.028,0.035,0.097,-0.08,-0.083,0.041,-0.081,0.044,0.095,-0.106,-0.044,0.136,-0.008
cuadro,0.765,0.862,5.77,-0.219,0.01,0.158,0.069,0.063,0.28,-0.062,0.016,-0.006,-0.073,-0.067,-0.054,-0.061,-0.025,-0.029,-0.105,-0.064,-0.103,-0.058
",",-0.424,-0.102,0.796,2.35,-0.554,-0.735,0.589,0.029,-0.485,-0.063,-0.034,-0.247,0.03,0.041,-0.051,0.004,-0.045,-0.086,-0.052,-0.013,-0.029,0.013
hay,-0.29,-0.147,0.401,-0.051,3.004,2.337,0.749,1.189,-0.2,-0.297,-0.261,-0.355,0.022,-0.096,-0.104,-0.054,-0.108,-0.175,-0.059,-0.09,0.099,-0.023
cuatro,-0.181,-0.186,-0.218,-0.203,-0.639,-0.788,11.116,-1.16,-0.113,-0.181,-0.202,-0.217,-0.003,-0.023,-0.027,-0.025,-0.013,-0.01,-0.023,-0.039,0.059,-0.017
personas,-0.041,0.081,0.331,0.125,-0.123,-0.286,-1.112,6.789,0.126,0.044,-0.199,0.005,0.029,0.025,-0.014,0.035,0.004,-0.027,0.01,0.003,0.096,0.022
:,0.161,0.13,0.159,0.07,0.109,0.07,0.274,0.33,3.444,-0.259,-0.348,-0.016,-0.027,0.019,-0.013,0.022,0.02,-0.004,0.004,0.029,-0.019,-0.014
mi,0.441,0.47,0.633,0.459,0.486,0.462,0.644,0.555,-1.197,2.639,-0.069,0.515,0.461,0.395,0.49,0.534,0.518,0.63,0.598,0.553,0.648,0.353
padre,-0.094,-0.188,-0.157,-0.148,-0.206,-0.188,-0.053,-0.078,-0.326,-0.188,8.686,-0.102,-0.195,-0.147,-0.132,-0.163,-0.134,-0.133,-0.145,-0.124,-0.102,-0.142
",",0.122,0.099,0.117,0.098,0.148,0.147,0.165,0.154,-0.246,-0.256,-0.085,1.723,0.124,0.152,0.164,0.144,0.172,0.183,0.171,0.163,0.19,0.078
mi,-0.252,-0.194,-0.182,-0.239,-0.241,-0.208,-0.188,-0.194,-0.301,-0.199,-0.027,-0.19,1.412,-0.05,0.563,0.656,0.69,0.67,0.666,0.677,0.649,-0.221
madre,-0.257,-0.191,-0.188,-0.215,-0.515,-0.501,-0.496,-0.451,0.631,0.69,1.012,-0.099,-0.511,9.801,-0.884,-0.866,-0.716,-0.809,-0.788,-0.796,-0.861,-0.141
",",0.028,0.005,0.006,-0.002,-0.056,-0.064,-0.031,-0.054,0.148,0.16,0.206,0.121,0.622,1.029,1.36,0.307,0.394,-0.89,-0.172,-0.184,-0.152,0.003
mi,-0.178,-0.073,-0.088,-0.144,-0.143,-0.14,-0.121,-0.135,-0.076,-0.133,-0.128,-0.092,-0.133,-0.169,-0.16,1.762,1.352,-1.214,0.808,0.607,0.674,-0.12
hermano,-0.071,-0.014,-0.027,-0.011,-0.036,-0.043,0.009,-0.015,-0.055,-0.001,0.503,0.108,0.114,0.682,-0.005,-0.437,7.66,-0.517,-0.636,-0.457,-0.545,0.01
y,0.025,0.0,0.061,-0.041,-0.037,-0.067,0.13,-0.032,-0.082,-0.135,-0.065,0.465,-0.178,-0.009,0.341,1.219,1.793,3.463,-0.827,-0.735,-1.044,-0.057
mi,-0.041,0.037,0.004,-0.026,-0.04,0.001,0.027,-0.033,-0.001,0.052,-0.077,-0.018,-0.074,-0.041,-0.043,-0.589,1.978,-0.388,1.536,1.844,-1.481,-0.005
hermana,-0.02,0.044,0.009,-0.007,-0.045,-0.061,-0.005,-0.084,-0.009,0.008,0.341,-0.03,-0.061,0.459,0.039,0.503,2.47,-0.644,-2.051,6.215,-0.795,0.015
.,0.123,0.059,0.109,0.025,-0.078,-0.089,-0.05,-0.077,-0.131,-0.12,-0.105,-0.057,-0.028,0.01,0.004,-0.064,-0.023,-0.016,-0.326,-0.107,2.653,-0.063
