# Text to Text Explanation: Machine Translation Example

This notebook demonstrates use of generating model explanations for a text to text scenario on a pretrained transformer model for machine translation. In this demo, we showcase explanations on 2 different models provided by Hugging Face which inclues translation from English to Spanish (https://huggingface.co/Helsinki-NLP/opus-mt-en-es) and English to French (https://huggingface.co/Helsinki-NLP/opus-mt-en-fr).

In [1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import shap
import torch

## English to Spanish model

In [2]:
# load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-es").cuda()

# define the input sentences we want to translate
data = [
    "Transformers have rapidly become the model of choice for NLP problems, replacing older recurrent neural network models"
]

### Explain the model's predictions

In [9]:
# we build an explainer by passing the model we want to explain and
# the tokenizer we want to use to break up the input strings
explainer = shap.Explainer(model, tokenizer)

# explainers are callable, just like models
shap_values = explainer(data, fixed_context=1)

### Visualize shap explanations

In [10]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,Trans,former,s,have,rapidly,become,the,model,of,choice,for,N,LP,problems,",",replacing older,recurrent,ne,ural,network,models,Unnamed: 22_level_0
Les,2.309,-0.521,3.223,1.268,0.195,0.043,-0.291,-0.408,-0.081,-0.034,0.026,-0.046,0.081,-0.148,-0.2,-0.509,-0.276,-0.011,0.293,0.149,-0.138,-0.18
transformateurs,4.796,7.167,0.666,-0.249,0.184,-0.001,0.001,-0.158,-0.041,0.103,-0.009,0.216,0.118,-0.145,-0.088,0.217,0.096,0.084,0.077,0.123,0.17,-0.012
sont,-0.195,0.622,0.093,1.408,0.016,3.529,-0.806,-0.825,-0.011,-0.034,0.041,-0.053,-0.034,0.042,0.028,0.019,0.01,-0.072,-0.103,-0.15,-0.199,0.047
rapidement,0.619,0.792,0.458,0.584,9.143,0.63,-1.139,-1.234,0.187,0.045,0.25,0.099,0.069,0.099,0.044,0.268,0.097,0.061,0.098,0.023,-0.051,0.061
devenus,-0.288,-0.281,-0.481,1.327,0.342,6.992,0.055,0.062,0.024,0.096,-0.076,-0.072,0.039,-0.043,-0.116,-0.05,-0.039,0.004,0.238,0.034,0.044,-0.015
le,0.153,0.164,0.205,0.077,-0.315,1.594,3.844,0.652,-0.15,0.045,0.039,0.156,0.054,-0.057,0.001,0.069,0.037,-0.0,-0.072,-0.095,-0.337,0.2
modèle,0.106,0.186,0.089,-0.018,0.1,0.11,0.251,6.34,0.104,0.1,0.056,0.132,0.114,-0.045,-0.002,0.096,0.065,0.034,0.098,0.004,0.062,0.084
de,-0.379,-0.354,-0.361,-0.361,-0.057,0.047,0.052,1.618,5.402,0.202,-0.29,-0.349,-0.555,-0.513,-0.573,0.024,0.004,-0.014,-0.042,-0.017,-0.031,0.053
choix,-0.54,-0.471,-0.457,-0.516,1.763,1.349,1.218,1.729,0.321,8.713,-0.845,-0.463,-0.917,-0.747,-0.942,-0.107,-0.046,-0.054,0.104,-0.186,-0.17,0.02
pour,0.006,0.021,-0.007,0.021,0.054,0.211,0.142,0.3,-1.013,1.084,6.293,-1.449,-0.405,-0.216,-0.524,-0.014,0.03,0.012,-0.035,0.003,-0.105,0.173
les,-0.121,-0.064,-0.193,-0.184,-0.023,-0.032,-0.139,-0.014,-1.31,0.329,0.492,0.434,1.168,2.656,-0.61,0.063,0.034,-0.009,0.258,0.092,0.006,-0.023
problèmes,-0.044,-0.058,-0.115,-0.182,0.016,-0.029,-0.166,0.164,-0.618,-0.075,0.494,0.093,-1.499,11.208,-0.43,0.044,0.094,0.089,0.314,0.127,0.01,-0.021
de,-0.055,-0.124,-0.027,-0.016,0.056,-0.008,-0.149,-0.002,0.085,-0.172,-1.001,0.643,0.991,0.151,0.02,-0.144,0.063,0.106,0.304,0.11,0.397,-0.327
N,0.169,0.209,-0.007,0.11,0.122,0.146,0.154,0.17,-0.075,0.407,-0.611,7.011,-0.534,1.475,-0.018,0.222,0.113,-0.015,0.022,0.234,-0.49,0.509
LP,-0.15,-0.207,-0.148,-0.158,-0.157,-0.166,-0.152,-0.19,-1.388,-1.352,-1.519,5.141,14.464,-1.474,-1.271,-0.317,-0.109,-0.102,0.104,-0.113,-0.187,-0.07
",",-0.013,0.184,0.07,0.149,0.34,0.265,-0.045,0.262,-0.3,0.236,0.343,-0.458,-1.473,1.005,4.409,-0.071,-0.037,0.049,0.138,-0.029,-0.312,0.217
remplaçant,-0.273,-0.252,-0.247,-0.14,-0.076,-0.028,-0.274,-0.172,-0.103,0.022,0.042,-0.08,-0.646,-0.206,3.194,12.727,-0.351,-0.522,-0.353,-0.559,-0.354,-0.712
les,-0.12,-0.08,-0.059,-0.072,0.037,-0.015,-0.062,0.395,0.002,0.151,-0.067,-0.21,-0.138,0.156,0.247,1.073,-0.051,-0.127,-0.004,-0.098,2.141,-0.499
anciens,0.019,-0.049,-0.001,-0.05,-0.033,0.02,-0.005,0.27,-0.075,0.021,-0.067,-0.064,0.063,0.016,0.06,7.321,-0.022,-0.189,-0.306,0.172,0.473,-0.273
modèles,-0.077,0.169,-0.016,-0.173,0.095,0.013,-0.154,0.029,-0.097,0.114,0.168,-0.292,-0.129,0.642,0.188,-0.99,1.145,-0.165,0.109,-1.317,10.428,-1.37
de,-0.056,-0.073,-0.077,-0.049,-0.013,-0.006,-0.059,-0.189,0.029,0.003,0.078,-0.0,0.046,0.018,0.039,-0.097,-0.265,0.034,-0.895,2.908,0.745,-0.49
réseaux,0.013,-0.056,-0.064,-0.052,0.018,-0.027,0.018,-0.086,0.076,-0.16,-0.031,0.174,0.015,-0.194,0.052,-0.817,0.077,0.454,-1.276,7.679,1.272,-0.772
neuro,0.007,0.022,0.014,0.107,0.1,0.096,0.049,-0.005,-0.003,-0.012,0.029,0.045,-0.034,0.058,-0.07,-1.155,-0.298,5.612,8.263,-0.342,-0.582,-0.568
naux,-0.038,-0.045,-0.048,-0.015,-0.046,-0.106,-0.014,-0.032,-0.037,-0.051,-0.084,0.12,0.028,-0.107,-0.05,-0.698,-0.575,1.715,5.921,-1.16,-0.302,-0.613
récurrent,-0.127,-0.049,-0.077,-0.119,0.094,0.048,-0.085,0.055,-0.122,0.091,0.041,-0.181,0.078,0.008,0.102,4.513,8.702,-2.844,1.051,0.765,0.882,-1.101
s,-0.01,0.002,-0.007,-0.011,0.003,-0.009,-0.003,0.012,-0.005,-0.006,-0.005,0.001,0.017,0.002,-0.005,-0.079,-0.036,-0.005,0.105,0.164,0.205,-0.113


## English to French

In [11]:
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-fr").cuda()

In [12]:
explainer = shap.Explainer(model,tokenizer)

In [13]:
shap_values = explainer(data)

In [14]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,Trans former s have,rapidly become the model,of choice for N,LP problems,",",replacing older,recurrent ne,ural network,models
Les,6.444,-0.428,-0.511,0.289,0.054,-0.955,-0.312,0.113,0.051
transformateurs,12.165,0.195,0.831,-0.67,-0.213,0.87,0.122,-0.042,0.055
sont,1.463,2.382,0.033,-0.047,-0.0,-0.081,-0.068,-0.144,-0.165
rapidement,2.078,7.607,0.215,0.394,0.093,0.426,0.193,0.1,0.097
devenus,0.982,6.918,-0.299,0.038,0.008,0.352,-0.185,-0.023,0.007
le,0.392,5.862,0.022,-0.036,0.01,0.007,-0.042,0.028,0.02
modèle,0.125,7.002,0.389,-0.137,0.047,-0.035,-0.082,0.163,0.593
de,-0.022,0.387,3.329,-0.13,-0.063,-0.02,-0.114,0.129,0.007
choix,-0.42,3.933,6.109,-0.179,-0.117,0.093,-0.003,-0.348,-0.311
pour,0.247,0.334,3.965,-0.088,-0.065,0.029,0.056,0.141,-0.039
les,-0.105,-0.21,0.513,2.212,0.048,-0.022,0.043,0.259,0.072
problèmes,0.317,-0.29,1.43,6.483,0.871,0.059,0.37,0.229,-0.051
de,-0.316,-0.055,-0.128,0.805,0.164,-0.05,0.17,0.203,0.109
N,0.271,0.074,6.024,1.467,0.437,0.055,0.764,0.224,0.006
LP,-0.285,-0.276,1.518,9.614,0.988,-0.063,-0.587,-0.169,-0.262
",",0.311,0.705,0.686,1.201,1.939,0.044,0.025,0.016,0.003
remplaçant,-0.123,0.948,0.565,0.524,0.769,7.441,0.36,-0.032,0.187
les,-0.226,0.453,-0.04,0.224,0.126,1.174,-0.165,0.175,0.878
anciens,0.149,0.219,-0.184,0.192,0.056,6.71,0.126,-0.146,0.179
modèles,-0.05,0.133,0.14,0.093,0.19,0.551,0.968,0.614,5.679
de,-0.305,-0.245,0.152,0.128,0.043,-0.053,-0.092,1.213,0.792
réseaux,-0.038,-0.02,-0.21,0.22,-0.043,-0.335,0.324,4.518,1.894
neuro,-0.048,-0.058,0.344,0.019,0.058,0.257,4.493,5.286,0.984
naux,-0.076,-0.337,0.053,-0.028,-0.049,-0.373,1.163,3.022,0.39
récurrent,-0.222,0.118,-0.058,0.08,0.013,3.32,7.238,1.104,0.13
s,-0.027,-0.003,-0.03,0.038,0.013,-0.009,0.015,0.14,0.079
