## Visualizing Self-Attention in Transformers

Transformers have revolutionized the way we approach tasks in NLP. At its core lies self-attention, a mechanism that allows models to weigh the importance of each sequence element (token embeddings). In this notebook, we will explore the intricacies of self-attention, providing both theoretical insights and practical visualizations.



In [3]:
%pip install --upgrade bertviz transformers --quiet

Note: you may need to restart the kernel to use updated packages.


In [59]:
from transformers import AutoTokenizer, AutoModel, utils
from bertviz import head_view, model_view
from bertviz.neuron_view import show
from bertviz.transformers_neuron_view import BertModel

utils.logging.set_verbosity_error()  # Suppress standard warnings

In [60]:
model_type = 'bert'
model_version = 'bert-base-uncased'

### Self-Attention

The intuition behind self-attention is that averaging token embeddings instead of using a fixed embedding for each token, enables the model to capture how words relate to each other in the input. In practice, said weighted relationships represent the syntactic and contextual structure of the sentence, leading to a more nuanced and rich understanding of the data.

In [61]:
sentence_a = "leaves fall in autumn"
sentence_b = "autumn is marked by colorful foliage"

In [62]:
model = AutoModel.from_pretrained(model_version, output_attentions=True)
tokenizer  = AutoTokenizer.from_pretrained(model_version)

In [63]:
inputs = tokenizer(sentence_a, sentence_b, return_tensors='pt')
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
tokens

['[CLS]',
 'leaves',
 'fall',
 'in',
 'autumn',
 '[SEP]',
 'autumn',
 'is',
 'marked',
 'by',
 'colorful',
 'foliage',
 '[SEP]']

In [53]:
attention = model(**inputs).attentions
len(attention)

12

In [64]:
attention[0].shape

torch.Size([1, 12, 13, 13])

----

### Multi-head Attention

BERT actually learns multiple attention mechanisms, called heads, which operate in parallel to one another. Because the attention heads do not share parameters, each head learns a unique attention pattern. The version of BERT that we consider here — BERT Base — has 12 layers and 12 heads, resulting in a total of 12 x 12 = 144 distinct attention mechanisms. We can visualize attention in all of the heads at once, using the model view.

In [65]:
sentence_b_start = inputs['token_type_ids'][0].tolist().index(1)
model_view(attention, tokens, sentence_b_start)

<IPython.core.display.Javascript object>

Some patters that can be observed:

- Attention to Delimiter Tokens.
- Attention to the Next Word.
- Attention to the Previous Next Word.
- Attention to Identical/Related words.
- Attention to identical/Related words in other sentence.



----

### Head View

In [79]:
head_view(attention, tokens, sentence_b_start)

<IPython.core.display.Javascript object>

----

### Neuron View

We can visualize how attention weights are computed from query and key vectors using the neuron view.

Query q: the query vector q encodes the word on the left that is paying attention, i.e. the one that is “querying” the other words.

Key k: the key vector k encodes the word on the right to which attention is being paid. The key vector and the query vector together determine a compatibility score between the two words.

q×k (elementwise): the elementwise product between the query vector of the selected word and each of the key vectors. This is a precursor to the dot product (the sum of the elementwise product) and is included for visualization purposes because it shows how individual elements in the query and key vectors contribute to the dot product.

q·k: the scaled dot product (see above) of the selected query vector and each of the key vectors. This is the unnormalized attention score.

Softmax: the softmax of the scaled dot product. This normalizes the attention scores to be positive and sum to one.

In [80]:
text = "astronomers discovered a new planet"

In [81]:
model = BertModel.from_pretrained(model_version) # supports only BERT, GPT-2, and RoBERTa
show(model, model_type, tokenizer, text, display_mode="light", layer=0, head=8)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>