In [1]:
import os
import copy
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from transformers import AutoTokenizer, AutoModelWithLMHead
import shap
from shap.utils import GenerateLogits
import scipy as sp
import nlp
import torch

In [19]:
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es",use_fast=True)
model = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-en-es")

The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.


In [20]:
model

MarianMTModel(
  (model): BartModel(
    (shared): Embedding(65001, 512, padding_idx=65000)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(65001, 512, padding_idx=65000)
      (embed_positions): SinusoidalPositionalEmbedding(512, 512)
      (layers): ModuleList(
        (0): EncoderLayer(
          (self_attn): Attention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (1): EncoderLayer

In [3]:
# Instantiate object of GenerateLogits with model and tokenizer
logit_generator_model = GenerateLogits(model=model,tokenizer=tokenizer)

In [4]:
def gen_kwargs(x):
    model.eval()
    input_ids = torch.tensor([tokenizer.encode(x)]).cuda()
    with torch.no_grad():
        # generate input ids for output translation which we aim to explain
        out=model.generate(input_ids)
    target_sentence_ids = out[:,1:-1].cuda()
    output_names = logit_generator_model.get_output_names(target_sentence_ids)
    del out, input_ids
    return {'target_sentence_ids':target_sentence_ids,'output_names':output_names}

In [5]:
# This function defines the additional arguments passed onto the model function required inorder to get conditional logits corresponding to the original input sentence translation
def f_kwargs(x):
    kwargs = gen_kwargs(x)
    return kwargs

In [6]:
def f(x_batch,**kwargs):
    output_batch=[]
    # Extract target sentence for which we want to generate conditional logits
    target_sentence_ids=kwargs['target_sentence_ids']
    for i,source_sentence in enumerate(x_batch):
        conditional_logits = logit_generator_model.generate_logits(source_sentence,target_sentence_ids)
        output_batch.append(conditional_logits)
    return np.array(output_batch)

In [7]:
# Example function which returns a summary ids 
def example_summarize(x,model,tokenizer):
    print(f"Input: {x}")
    inputs = tokenizer([x], max_length=512, return_tensors='pt',truncation=True)
    input_ids=inputs['input_ids'].cuda()
    summary_ids = model.generate(input_ids).detach().cpu().numpy()
    del input_ids
    summary=[tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
    print(f"summary: {summary[0]}")
    return summary_ids[0,1:-1]

In [12]:
s="The doctor cared for its patient."
ids=example_summarize(s,model,tokenizer)

Input: The doctor cared for its patient.
summary: El médico cuidó a su paciente.


In [13]:
explainer = shap.Explainer(f,tokenizer)

explainers.Partition is still in an alpha state, so use with caution...


In [14]:
shap_values = explainer([s],model_kwargs=f_kwargs)

[ -6.10614284  -6.26744355 -10.34296496  -2.82822171   1.7401661
  -1.19378975  -2.04853493  -5.72463461  -0.3384375 ] (9,)
[ 1.31737219  0.55470154 -1.41519825  3.02896184  2.35388625 -0.32803499
  1.89735776  2.36388507  2.1269851 ] (9,)
[0 1 2 3 4 5 6 7 8] (9,)


In [15]:
shap.plots.text(shap_values[0])

invalid value encountered in double_scalars


Unnamed: 0_level_0,The,doctor,cared,for,its,patient,.,Unnamed: 8_level_0
El,2.483,3.542,0.417,0.059,0.326,0.597,-0.001,0.0
médico,-0.299,5.482,0.796,0.48,0.001,0.675,-0.313,0.0
cu,-0.654,0.078,9.54,-0.868,0.376,0.548,-0.092,0.0
id,-0.315,-0.357,5.637,0.214,-0.03,0.425,0.283,0.0
ó,0.089,0.041,0.668,-0.29,-0.021,0.203,-0.075,0.0
a,-0.349,-0.558,0.192,-0.202,-0.64,3.467,-1.045,0.0
su,0.008,0.001,0.415,0.116,2.527,1.223,-0.344,0.0
paciente,0.194,0.044,0.621,0.461,-0.423,7.036,0.156,0.0
.,0.103,0.081,0.187,-0.001,-0.313,-0.042,2.451,0.0


In [16]:
s_="The nurse cared for its patient."
explainer_ = shap.Explainer(f,tokenizer)
shap_values_ = explainer_([s_],model_kwargs=f_kwargs)

explainers.Partition is still in an alpha state, so use with caution...


[-5.7834518  -7.54599042 -9.88510135 -3.81695614  2.4835077  -1.29051606
 -1.57667796 -6.31239202 -0.1745286 ] (9,)
[ 0.80590728  2.71151152 -1.07809089  3.332745    2.34999685 -0.31970506
  1.91842388  2.49270899  2.11823395] (9,)
[0 1 2 3 4 5 6 7 8] (9,)


In [17]:
shap.plots.text(shap_values_[0])

invalid value encountered in double_scalars


Unnamed: 0_level_0,The,nurse,cared,for,its,patient,.,Unnamed: 8_level_0
La,2.038,3.448,0.344,0.078,0.24,0.131,0.31,0.0
enfermera,-0.152,9.077,0.513,0.514,-0.013,0.115,0.204,0.0
cu,-0.516,0.709,9.444,-1.041,0.262,0.091,-0.143,0.0
id,0.59,0.011,5.691,0.753,-0.184,0.095,0.193,0.0
ó,-0.12,-0.016,0.628,-0.306,-0.13,0.049,-0.238,0.0
a,-0.22,-0.312,-0.047,-0.551,-0.643,3.847,-1.104,0.0
su,-0.026,0.191,0.3,-0.013,2.175,1.197,-0.329,0.0
paciente,0.296,0.431,0.406,0.241,-0.244,7.589,0.085,0.0
.,0.09,0.103,0.159,0.006,-0.329,-0.074,2.339,0.0
