In [1]:
import os
import copy
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from transformers import AutoTokenizer, AutoModelWithLMHead
import shap
from shap.utils import GenerateLogits
import scipy as sp
import nlp
import torch

In [2]:
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es",use_fast=True)
model = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-en-es").cuda()

In [3]:
# Instantiate object of GenerateLogits with model and tokenizer
logit_generator_model = GenerateLogits(model=model,tokenizer=tokenizer)

In [4]:
def gen_kwargs(x):
    model.eval()
    input_ids = torch.tensor([tokenizer.encode(x)]).cuda()
    with torch.no_grad():
        # generate input ids for output translation which we aim to explain
        out=model.generate(input_ids)
    target_sentence_ids = out[:,1:-1].cuda()
    output_names = logit_generator_model.get_output_names(target_sentence_ids)
    del out, input_ids
    return {'target_sentence_ids':target_sentence_ids,'output_names':output_names}

In [5]:
# This function defines the additional arguments passed onto the model function required inorder to get conditional logits corresponding to the original input sentence translation
def f_kwargs(x):
    kwargs = gen_kwargs(x)
    return kwargs

In [6]:
def f(x_batch,**kwargs):
    output_batch=[]
    # Extract target sentence for which we want to generate conditional logits
    target_sentence_ids=kwargs['target_sentence_ids']
    for i,source_sentence in enumerate(x_batch):
        conditional_logits = logit_generator_model.generate_logits(source_sentence,target_sentence_ids)
        output_batch.append(conditional_logits)
    return np.array(output_batch)

In [7]:
# Example function which returns a summary ids 
def example_summarize(x,model,tokenizer):
    print(f"Input: {x}")
    inputs = tokenizer([x], max_length=512, return_tensors='pt',truncation=True)
    input_ids=inputs['input_ids'].cuda()
    summary_ids = model.generate(input_ids).detach().cpu().numpy()
    del input_ids
    summary=[tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
    print(f"summary: {summary[0]}")
    return summary_ids[0,1:-1]

In [8]:
s="In this picture, there are four persons: my father, my mother, my brother and my sister."
ids=example_summarize(s,model,tokenizer)

Input: In this picture, there are four persons: my father, my mother, my brother and my sister.
summary: En este cuadro, hay cuatro personas: mi padre, mi madre, mi hermano y mi hermana.


In [9]:
explainer = shap.Explainer(f,tokenizer,model_kwargs=f_kwargs)

In [10]:
shap_values = explainer([s])

In [None]:
shap.plots.text(shap_values[0])