## 2.2 Scaled dot product attention

#### Installation
- `pip install transformers`
- `pip install bertviz`

In [1]:
from transformers import BertModel  # Let's use a BERT model

ModuleNotFoundError: No module named 'transformers'

In [3]:
model = BertModel.from_pretrained('bert-base-uncased')

Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████| 570/570 [00:00<00:00, 211kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████| 440M/440M [00:19<00:00, 22.2MB/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expec

In [4]:
len(model.encoder.layer)  # Base BERT has 12 encoders in the encoder stack

12

In [50]:
model.encoder.layer[0]  # the first encoder

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [54]:
model.encoder.layer[0].attention  # The attention in the first encoder

BertAttention(
  (self): BertSelfAttention(
    (query): Linear(in_features=768, out_features=768, bias=True)
    (key): Linear(in_features=768, out_features=768, bias=True)
    (value): Linear(in_features=768, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (output): BertSelfOutput(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

## 2.3 Multi-headed attention

In [2]:
from transformers import BertModel, BertTokenizer
from bertviz import head_view
import torch
import pandas as pd

In [3]:
# Let's load a vanilla BERT-base model. 
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

Downloading: 100%|██████████████████████████████████████████████████████████████████████████████████████| 232k/232k [00:00<00:00, 751kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████| 28.0/28.0 [00:00<00:00, 7.69kB/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expec

In [63]:
# 包括 . 和 ! 总共 18 个 token, 但是下面有个 CLS 和 SEP 不知道是啥，会变成 20个 token??
text = "My friend told me about this class and I love it so far! She was right."

tokens = tokenizer.encode(text)
print(tokens)
inputs = torch.tensor(tokens).unsqueeze(0) # unsqueeze changes the shape from (20,) -> (1, 20)
print(inputs)
print(inputs.size())

[101, 2026, 2767, 2409, 2033, 2055, 2023, 2465, 1998, 1045, 2293, 2009, 2061, 2521, 999, 2016, 2001, 2157, 1012, 102]
tensor([[ 101, 2026, 2767, 2409, 2033, 2055, 2023, 2465, 1998, 1045, 2293, 2009,
         2061, 2521,  999, 2016, 2001, 2157, 1012,  102]])
torch.Size([1, 20])


In [64]:
tokenizer.decode(tokens)

'[CLS] my friend told me about this class and i love it so far! she was right. [SEP]'

In [45]:
print(model(inputs, output_attentions=True).attentions)

(tensor([[[[3.6568e-02, 3.7711e-02, 2.2552e-02,  ..., 7.3119e-02,
           6.9431e-02, 1.7964e-01],
          [4.5540e-02, 8.1695e-02, 3.9291e-02,  ..., 3.4719e-02,
           4.2371e-02, 6.0818e-02],
          [3.9046e-02, 6.9282e-02, 3.4435e-02,  ..., 3.2120e-02,
           6.7234e-02, 4.6184e-02],
          ...,
          [3.4590e-02, 6.1221e-02, 3.8246e-02,  ..., 1.9331e-02,
           7.9744e-02, 4.9688e-02],
          [3.1316e-02, 3.5717e-02, 4.7802e-02,  ..., 3.5958e-02,
           7.8209e-02, 3.8041e-02],
          [3.0512e-02, 5.7132e-02, 3.0015e-02,  ..., 3.7977e-02,
           8.4284e-02, 5.3630e-02]],

         [[6.4301e-01, 7.0693e-03, 5.0891e-03,  ..., 8.4802e-03,
           3.2145e-02, 6.7034e-03],
          [8.5497e-03, 1.0524e-02, 3.4939e-01,  ..., 7.3552e-02,
           5.3895e-03, 1.7580e-02],
          [6.2130e-03, 2.2673e-02, 1.0013e-01,  ..., 8.1952e-02,
           5.0246e-03, 4.8268e-02],
          ...,
          [3.0313e-03, 1.1815e-02, 2.3804e-01,  ..., 3.074

In [49]:
# Grab the attention scores from BERT
# 咋知道第二个 index 是attention score?? 这里？ 
# https://huggingface.co/docs/transformers/main_classes/output
# 这里 model 是上面的 pretrain model 也即是这里其实是 inference 哦
# 这里其实是相当于 model(inputs, output_attentions=True).attentions, 
# 这样写真的不好，上其实是能 print 出来。 感觉其实是因为上面hidden_states=None, past_key_values=None 所以 .attention
# 就变成的第三个 index..
attention = model(inputs, output_attentions=True)[2]  

# 啊，他这里的 attention 这个其实是 tuple of tensors， 然后这边 len=12 对应上面有 12 layers of encoders
# [1, 12, 20, 20]? 这个就不知道啥回事？ 作者是说 each encoder has 12 attention heads, 不知道咋看的？？？
print(len(attention))
print(attention[-1].size())
print(attention)

12
torch.Size([1, 12, 20, 20])
(tensor([[[[3.6568e-02, 3.7711e-02, 2.2552e-02,  ..., 7.3119e-02,
           6.9431e-02, 1.7964e-01],
          [4.5540e-02, 8.1695e-02, 3.9291e-02,  ..., 3.4719e-02,
           4.2371e-02, 6.0818e-02],
          [3.9046e-02, 6.9282e-02, 3.4435e-02,  ..., 3.2120e-02,
           6.7234e-02, 4.6184e-02],
          ...,
          [3.4590e-02, 6.1221e-02, 3.8246e-02,  ..., 1.9331e-02,
           7.9744e-02, 4.9688e-02],
          [3.1316e-02, 3.5717e-02, 4.7802e-02,  ..., 3.5958e-02,
           7.8209e-02, 3.8041e-02],
          [3.0512e-02, 5.7132e-02, 3.0015e-02,  ..., 3.7977e-02,
           8.4284e-02, 5.3630e-02]],

         [[6.4301e-01, 7.0693e-03, 5.0891e-03,  ..., 8.4802e-03,
           3.2145e-02, 6.7034e-03],
          [8.5497e-03, 1.0524e-02, 3.4939e-01,  ..., 7.3552e-02,
           5.3895e-03, 1.7580e-02],
          [6.2130e-03, 2.2673e-02, 1.0013e-01,  ..., 8.1952e-02,
           5.0246e-03, 4.8268e-02],
          ...,
          [3.0313e-03, 1.18

In [23]:
# average attention in the last encoder
final_attention = attention[-1].mean(1)[0]

In [21]:
# 这里代码好讨厌，讲的好含糊..

# 新建一个 dataframe 但是这里 key 是 0 -> 19
attention_df = pd.DataFrame(final_attention.detach()).applymap(float).round(3)

# 这里就是吧 column 和 index 都改成 token 的名字
attention_df.columns = tokenizer.convert_ids_to_tokens(tokens)
attention_df.index = tokenizer.convert_ids_to_tokens(tokens)

attention_df  # sums across rows add up to 1. sums across columns do not


Unnamed: 0,[CLS],my,friend,told,me,about,this,class,and,i,love,it,so,far,!,she,was,right,.,[SEP]
[CLS],0.092,0.028,0.019,0.011,0.012,0.022,0.05,0.087,0.031,0.023,0.023,0.031,0.007,0.028,0.067,0.057,0.065,0.124,0.104,0.12
my,0.021,0.023,0.014,0.01,0.013,0.021,0.028,0.015,0.014,0.012,0.01,0.023,0.011,0.009,0.016,0.022,0.021,0.019,0.312,0.388
friend,0.018,0.009,0.129,0.009,0.005,0.008,0.008,0.012,0.009,0.005,0.009,0.006,0.004,0.005,0.009,0.023,0.01,0.006,0.314,0.401
told,0.01,0.004,0.013,0.084,0.004,0.011,0.005,0.005,0.005,0.002,0.008,0.005,0.005,0.003,0.006,0.008,0.004,0.003,0.351,0.464
me,0.024,0.013,0.01,0.011,0.017,0.016,0.018,0.011,0.014,0.01,0.01,0.014,0.007,0.008,0.014,0.009,0.006,0.005,0.347,0.436
about,0.019,0.01,0.007,0.018,0.01,0.079,0.021,0.012,0.012,0.006,0.014,0.019,0.008,0.008,0.012,0.005,0.003,0.005,0.32,0.412
this,0.026,0.014,0.003,0.004,0.01,0.015,0.069,0.02,0.011,0.01,0.011,0.018,0.006,0.008,0.012,0.005,0.003,0.004,0.331,0.421
class,0.028,0.01,0.007,0.006,0.006,0.015,0.029,0.096,0.01,0.009,0.013,0.019,0.006,0.009,0.015,0.01,0.005,0.005,0.312,0.39
and,0.031,0.016,0.006,0.007,0.012,0.009,0.013,0.009,0.08,0.013,0.01,0.01,0.008,0.009,0.024,0.014,0.012,0.011,0.316,0.386
i,0.023,0.014,0.008,0.005,0.011,0.011,0.019,0.012,0.021,0.029,0.014,0.013,0.008,0.014,0.019,0.012,0.009,0.008,0.334,0.414


In [55]:
# https://nlp.stanford.edu/pubs/clark2019what.pdf
# Layer index 2 seems to be attending to the previous token
# Layer index 6 seems to be for pronouns

In [60]:
# 卧槽，这个就非常 nice 了
# 这个 inputs 是上面的 inputs = torch.tensor(tokens).unsqueeze(0)
# 这里 inputs[0] 是因为上面的 shape 是 1x20
# 这里怎么把 token 专程 word 的就不知道了，挺厉害的 .. 上面 inputs 不是已经 encoded 了?
    # 哦，因为 tokenizer.decode(tokens) 就可以转回来，然后还会有 CLS 和 SEP
    # CLS” is the reserved token to represent the start of sequence while “SEP” separate segment (or sentence)
tokens_as_list = tokenizer.convert_ids_to_tokens(inputs[0])

# attention 就是上面 pretrain model forward 出来的 attention
# 这里每个颜色，代表 each self-attention QKV head at a given layer!
head_view(attention=attention, 
          tokens=tokens_as_list)

<IPython.core.display.Javascript object>

In [61]:
# Head 3-1 attends to previous token
head_view(attention, 
          tokenizer.convert_ids_to_tokens(inputs[0]), 
          layer=2, 
          heads=[0])

<IPython.core.display.Javascript object>

In [18]:
# Head 8-10 relating direct objects to their verbs eg told -> me
head_view(attention, tokenizer.convert_ids_to_tokens(inputs[0]), layer=7, heads=[9])

<IPython.core.display.Javascript object>

In [19]:
# attention in the 8th encoder's 10th head to see direct object attention
eight_ten = attention[7][0][9]

In [23]:
# Get the attention matrix
attention_df = pd.DataFrame(eight_ten.detach()).applymap(float).round(3)

attention_df.columns = tokenizer.convert_ids_to_tokens(tokens)
attention_df.index = tokenizer.convert_ids_to_tokens(tokens)

attention_df  # sums across rows add up to 1. sums across columns do not


Unnamed: 0,[CLS],my,friend,told,me,about,this,class,and,i,love,it,so,far,!,she,was,right,.,[SEP]
[CLS],0.007,0.004,0.005,0.002,0.001,0.001,0.003,0.004,0.002,0.001,0.002,0.003,0.001,0.005,0.006,0.005,0.01,0.039,0.033,0.867
my,0.031,0.03,0.027,0.009,0.004,0.002,0.001,0.004,0.006,0.001,0.001,0.002,0.001,0.002,0.033,0.002,0.002,0.004,0.05,0.788
friend,0.022,0.128,0.024,0.002,0.002,0.0,0.001,0.001,0.004,0.001,0.001,0.001,0.0,0.002,0.016,0.001,0.002,0.001,0.025,0.765
told,0.035,0.072,0.014,0.013,0.005,0.002,0.002,0.002,0.002,0.0,0.0,0.0,0.0,0.001,0.009,0.001,0.0,0.0,0.034,0.808
me,0.01,0.01,0.005,0.683,0.015,0.007,0.001,0.001,0.001,0.0,0.0,0.0,0.0,0.0,0.003,0.0,0.001,0.003,0.004,0.255
about,0.017,0.015,0.025,0.222,0.024,0.015,0.011,0.02,0.005,0.001,0.001,0.0,0.0,0.001,0.004,0.001,0.001,0.006,0.012,0.618
this,0.005,0.002,0.008,0.223,0.03,0.452,0.073,0.046,0.003,0.001,0.002,0.001,0.001,0.002,0.001,0.0,0.0,0.007,0.001,0.143
class,0.012,0.002,0.004,0.074,0.02,0.204,0.339,0.138,0.004,0.001,0.008,0.007,0.002,0.004,0.001,0.002,0.001,0.018,0.006,0.154
and,0.03,0.008,0.001,0.077,0.019,0.091,0.017,0.013,0.084,0.008,0.002,0.001,0.001,0.001,0.009,0.001,0.002,0.003,0.022,0.61
i,0.022,0.016,0.008,0.181,0.018,0.021,0.003,0.007,0.363,0.033,0.004,0.001,0.002,0.002,0.01,0.0,0.001,0.001,0.004,0.302
