# BERT attention heads

Going deeper on the BERT representations.

See:  
https://huggingface.co/transformers/bertology.html

From Clark et al's analysis of BERT's attention heads:  
https://www-nlp.stanford.edu/pubs/clark2019what.pdf

![title](../data/coref_head.png)

Let's see if we can grab head 5-4 and confirm this pattern of attention.

In [4]:
import torch
from transformers import *

# All the classes for an architecture can be initiated from pretrained weights for this architecture
# Note that additional weights added for fine-tuning are only initialized
# and need to be trained on the down-stream task
pretrained_weights = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_weights)

# Models can return full list of hidden-states & attentions weights at each layer
model = BertModel.from_pretrained(pretrained_weights,
                                  output_hidden_states=True,
                                  output_attentions=True)

In [5]:
# Use the sentence from the paper:
input_ids = torch.tensor([tokenizer.encode("joining peace talks between Israel and the Palestinians. The negotiations are")])
all_hidden_states, all_attentions = model(input_ids)[-2:]

### Extracting Attention Weights

From: https://huggingface.co/transformers/model_doc/bert.html#bertmodel

**attentions**: Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

In [6]:
coref_head = all_attentions[4][0][3]

print(coref_head.shape)

# Look at the references for the two co-referent words:
print("\"Negotiations\" attention weights: \n\t{}".format(coref_head[-2]))

print("\n\"Talks\" attention weights: \n\t{}".format(coref_head[2]))

torch.Size([12, 12])
"Negotiations" attention weights: 
	tensor([0.0180, 0.0123, 0.3793, 0.0214, 0.0043, 0.0070, 0.0026, 0.0116, 0.0046,
        0.0022, 0.5357, 0.0011], grad_fn=<SelectBackward>)

"Talks" attention weights: 
	tensor([1.5396e-02, 2.9177e-03, 5.4983e-01, 2.8408e-03, 3.1879e-04, 5.3173e-04,
        3.1515e-04, 8.3999e-04, 4.0847e-03, 3.0516e-03, 4.1416e-01, 5.7109e-03],
       grad_fn=<SelectBackward>)


### Use BertViz
https://github.com/jessevig/bertviz

**important**: need to install this package. Run this one directory up from the notebook:

`git clone https://github.com/jessevig/bertviz bertviz_repo`


In [7]:
# Set notebook up to run files from the BertViz repo
import os
import sys
module_path = os.path.abspath(os.path.join('../bertviz_repo'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [8]:
from bertviz import head_view

In [9]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [11]:
sentence = "joining peace talks between Israel and the Palestinians. The negotiations are"

inputs = tokenizer.encode_plus(sentence, return_tensors='pt')
input_ids = inputs['input_ids']
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)    

all_hidden_states, all_attentions = model(input_ids)[-2:]

In [12]:
data = open("data/greeneyes.txt").read()

In [13]:
data[:400]

'“All right, try to take it a little easy now, Arthur,” the gray-haired man said. “In the first place, if I know the Ellenbogens, they probably all hopped in a cab and went down to the Village for a couple of hours. All three of ’em’ll probably barge -”\n\n“I have a feeling she went to work on some bastard in the kitchen. I just have a feeling. She always starts necking some bastard in the kitchen wh'

In [14]:
inputs = tokenizer.encode_plus(data, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)    

Token indices sequence length is longer than the specified maximum sequence length for this model (1352 > 512). Running this sequence through the model will result in indexing errors


In [16]:
attention_and_tokens_list = []

for i in range(10, 200, 20):
    
    # define lookback window of 30 tokens
    window_start = max(0, i - 30)
    
    # Reach into the IDs tensor and grab the slice we want
    token_ids_to_process = input_ids[0][window_start:i]
    
    # Go back and get the actual tokens that correspond to those IDs
    tokens = tokenizer.convert_ids_to_tokens(token_ids_to_process.tolist())
    
    # Run the IDs through the model (need to reshape back to 1xtokens tensor)
    attentions = model(token_ids_to_process.reshape(1, -1))[-1]
    
    # Append the attention weights and the tokens to our list
    attention_and_tokens_list.append((attentions, tokens))
    

AttributeError: 'list' object has no attribute 'reshape'

In [91]:
attentions, tokens = attention_and_tokens_list[7]

head_view(attentions, tokens)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [89]:
attentions[0].shape

torch.Size([1, 12, 30, 30])