# Visualize Attention in Transformer

Since there're some kind of interpretability by visualizing the attention learned in Seq2Seq with RNN + Attention, we can probably do it with Transformer too.

Recall that there are 3 types of Multihead Attention in Transformer:
- Self-Attention, in Encoder
- Masked Self-Attention, in Decoder
- Cross Attention from Decoder to Encoder

In this tutorial, you'll play around with bertviz and Huggingface's Transformer to visualize these 3 types of Attention.

But first, let's install these dependencies

## 0. Setup

In [1]:
%%capture
!pip install bertviz transformers

In [2]:
from transformers import MarianTokenizer, MarianMTModel
import torch
from bertviz import head_view, model_view

We will use a trained English to Vietnamese Translation model from [Helsinki-NLP](https://huggingface.co/Helsinki-NLP)

In [3]:
src = "en"  # source language
trg = "vi"  # target language

model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
model = MarianMTModel.from_pretrained(model_name, output_attentions=True)
tokenizer = MarianTokenizer.from_pretrained(model_name)



## 1. Translate with pretrained MarianMT model

In [4]:
sample_text = "When I was a young boy, my father, took me into the city to see a marching band"
tokenized_inputs = tokenizer(sample_text, return_tensors="pt")
print(tokenized_inputs)

{'input_ids': tensor([[  529,     9,    64,    15,  1144,  1256,     4,    85,   712,     4,
           887,    52,   344,     7,  1100,    10,   239,    15, 19579,  6313,
             0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [5]:
generated_ids = model.generate(**tokenized_inputs)
print(generated_ids)

tensor([[53684,   422,    20,   156,    14,    27,   158,   559,     4,   501,
            20,     4,   474,    20,   124,   166,  1059,    77,   271,    27,
          1980, 14042,  1777,     0]])


In [6]:
prediction = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(prediction)

['Khi tôi còn là một cậu bé, cha tôi, đưa tôi vào thành phố để xem một đoàn diễu binh']


## 2. Visualize Attention with Bertviz

In [7]:
input_text = "Time is a valuable thing."

tokenized_input = tokenizer(
    input_text,
    return_tensors="pt")

encoder_input_ids = tokenized_input.input_ids
generated_ids = model.generate(**tokenized_input)

In [8]:
output = model(input_ids=encoder_input_ids, decoder_input_ids=generated_ids) 
encoder_tokens = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_tokens = tokenizer.convert_ids_to_tokens(generated_ids[0])

In [9]:
head_view(attention=output.encoder_attentions, tokens=encoder_tokens)

<IPython.core.display.Javascript object>

In [10]:
head_view(attention=output.decoder_attentions, tokens=decoder_tokens)

<IPython.core.display.Javascript object>

In [11]:
model_view(
    cross_attention=output.cross_attentions,
    encoder_attention=output.encoder_attentions,
    decoder_attention=output.decoder_attentions,
    encoder_tokens=encoder_tokens, 
    decoder_tokens=decoder_tokens
)

<IPython.core.display.Javascript object>

## 3. Predefined pair of translation

What if our model produce a wrong target sentence? In this case, you will complete the following code to implement this.

Note that in MarianMT's tokenizer, there're two independent sentencepiece tokenizer for each source and target language. To switch our tokenizer to target language, you can use:

```python
with tokenizer.as_target_tokenizer():
    decoder_input_ids = tokenizer(
        "Đây là một câu trong bằng tiếng Việt", 
        return_tensors="pt").input_ids
```

In [13]:
source_sentence = "Students learn biology"
target_sentence = "Học sinh học sinh học"


decoder_input_ids = None

############# YOUR CODE HERE #################
encoder_input_ids = None
decoder_input_ids = None


##############################################

output = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids) 
encoder_tokens = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_tokens = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])

model_view(
    cross_attention=output.cross_attentions,
    encoder_attention=output.encoder_attentions,
    decoder_attention=output.decoder_attentions,
    encoder_tokens=encoder_tokens, 
    decoder_tokens=decoder_tokens
)

<IPython.core.display.Javascript object>