# BERT Attention Visualization

This notebook demonstrates how to visualize attention mechanisms in BERT using the `transformers` and `bertviz` libraries.


## Understanding the Visualization

The interactive visualization above shows:
- **Layers**: 12 different layers (each building on the previous)
- **Heads**: 12 attention heads per layer (each focusing on different relationships)
- **Attention patterns**: Lines connecting tokens, with thickness indicating attention strength
- **Interactive exploration**: Click on different layers/heads to see different attention patterns

### Key Concepts:
- **Attention**: How much each word "pays attention" to other words
- **Multi-head attention**: Different heads focus on different types of relationships (syntax, semantics, etc.)
- **Layer depth**: Deeper layers capture more complex, abstract relationships

### Requirements:
To run this notebook, you'll need to install the required packages:
```bash
pip install transformers torch bertviz
```


In [17]:
from transformers import BertModel, BertTokenizer
import torch
from bertviz import head_view


In [18]:
# Load pre-trained BERT model and tokenizer
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


In [19]:
# Define and tokenize a sample sentence
sentence = "Robots are incredible!"
inputs = tokenizer(sentence, return_tensors="pt", return_token_type_ids=True)
input_ids = inputs['input_ids']
token_type_ids = inputs['token_type_ids']
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

print(f"Original sentence: {sentence}")
print(f"Tokens: {tokens}")


Original sentence: Robots are incredible!
Tokens: ['[CLS]', 'robots', 'are', 'incredible', '!', '[SEP]']


In [20]:
# Run the model to get outputs and attention weights
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

attentions = outputs.attentions
last_hidden_state = outputs.last_hidden_state
hidden_states = outputs.hidden_states


In [21]:
# Print shapes of various outputs
print(f"Input IDs shape: {inputs['input_ids'].shape}")
print(f"Last hidden state shape: {last_hidden_state.shape}")
print(f"Number of hidden layers: {len(hidden_states)}")
print(f"Hidden state[0] shape: {hidden_states[0].shape}")
print(f"Attention[0] shape: {attentions[0].shape}")


Input IDs shape: torch.Size([1, 6])
Last hidden state shape: torch.Size([1, 6, 768])
Number of hidden layers: 13
Hidden state[0] shape: torch.Size([1, 6, 768])
Attention[0] shape: torch.Size([1, 12, 6, 6])


In [22]:
# Visualize attention patterns
head_view(attentions, tokens)


<IPython.core.display.Javascript object>