In [1]:
import os
import torch

from collections import defaultdict
from transformers import AutoTokenizer, AutoModel, utils
from bertviz.transformers_neuron_view import BertModel, BertTokenizer
from bertviz import model_view
# from bertviz.neuron_view import get_attention

In [2]:
utils.logging.set_verbosity_error()  # Suppress standard warnings

In [3]:
path_to_model = os.path.join(
    "../experiments/base/ASAP/models",
    "Original/base/sets/set1/lr2e-5-b8a1-ada-fp16",
    "bert-base-cased-512-conventions-mod"
)

In [4]:
# text = "Hi my name is @CAPS1 @CAPS2 @PERSON1 and I'm here to talk about why I thing we should have these certain materials in the libraries.                       Why I think these materials should stay well because,  For @CAPS3 lets say its a certain movie or song or book or whatever just that one thing you were looking for, for ages and you couldn't find it in any other store and you just happening to go to the library just to look. Even though  you know its not there, but you just looking and you end up finding whatever it is you where looking for.             And what if you wanted to look for a myster book and only the library had that book you where looking for. I can think of alot of reasons why we need certain materials. well i guess that about sums that up for me.                 And in concultion put yourself in the same shoes or even my shoes. What would you do? would you let them remove those materials or would you fight for them to stay."
"""
Hi my name is @CAPS1 @CAPS2 @PERSON1 and I'm here to talk about why I thing we should have these certain materials in the libraries.
"""
text = "Hi my name is @CAPS1 @CAPS2 @PERSON1 and I'm here to talk about why I thing we should have these certain materials in the libraries."

In [5]:
model_type = 'bert'
model = BertModel.from_pretrained(path_to_model, output_attentions=True)  # Configure model to return attention values
tokenizer = BertTokenizer.from_pretrained(path_to_model)

In [6]:
# Prepare inputs to model
tokens_a = [tokenizer.cls_token] + tokenizer.tokenize(text) + [tokenizer.sep_token]
token_ids = tokenizer.convert_tokens_to_ids(tokens_a)
tokens_tensor = torch.tensor(token_ids).unsqueeze(0)
tokens_tensor

tensor([[  101, 20844,  1139,  1271,  1110,   137, 10184,  1475,   137, 10184,
          1477,   137,  1825,  1475,  1105,   178,   112,   182,  1303,  1106,
          2037,  1164,  1725,   178,  1645,  1195,  1431,  1138,  1292,  2218,
          3881,  1107,  1103,  9818,   119,   102]])

In [7]:
# Call model to get attention data
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Lin

In [8]:
output = model(tokens_tensor)
attn_data_list = output[-1]
attn_data_list

({'attn': tensor([[[[4.8071e-01, 8.0965e-03, 4.5993e-03,  ..., 6.8932e-03,
             2.1248e-02, 2.5123e-01],
            [3.7919e-02, 1.7543e-02, 2.0456e-02,  ..., 4.0543e-02,
             2.5979e-02, 1.5348e-02],
            [1.2372e-02, 1.7819e-02, 8.2233e-03,  ..., 4.0908e-02,
             4.0114e-02, 3.6282e-02],
            ...,
            [2.1380e-02, 1.6332e-02, 1.3302e-02,  ..., 7.0767e-02,
             8.2557e-02, 4.2069e-02],
            [2.3592e-02, 1.0758e-02, 5.7782e-03,  ..., 2.4473e-02,
             1.8987e-02, 1.0775e-01],
            [1.5904e-01, 4.7973e-03, 1.7883e-03,  ..., 2.6448e-04,
             2.6943e-04, 7.1344e-01]],
  
           [[3.9560e-01, 2.8550e-02, 1.7632e-02,  ..., 9.0587e-03,
             1.2278e-02, 1.0900e-02],
            [1.0781e-01, 1.4477e-01, 1.8125e-01,  ..., 4.9193e-04,
             5.0953e-04, 6.4658e-04],
            [3.0502e-02, 1.5718e-01, 8.7146e-02,  ..., 5.0765e-04,
             4.6791e-04, 3.9121e-04],
            ...,
         

In [9]:
# Populate map with attn data and, optionally, query, key data
attn_dict = defaultdict(list)

In [10]:
print(f"Num of Tokens: {len(tokens_a)}")
print(f"Num of Layers: {len(attn_data_list)}")

for layer, attn_data in enumerate(attn_data_list):
    # Process attention
    attn = attn_data['attn'][0]  # assume batch_size=1; shape = [num_heads, source_seq_len, target_seq_len]
    attn_dict['all'].append(attn.tolist())
    print(len(attn.tolist()))
    print(len(attn.tolist()[0]))
    print(len(attn.tolist()[0][0]))
    break

Num of Tokens: 36
Num of Layers: 12
12
36
36


In [11]:
for layer, attn_data in enumerate(attn_data_list):
    # Process attention
    attn = attn_data['attn'][0]  # assume batch_size=1; shape = [num_heads, source_seq_len, target_seq_len]
    attn_dict['all'].append(attn.tolist())
    # print(len(attn.tolist()[0]))
    # print(len(attn.tolist()[0][0]))
    print(f"layer: {layer}")
    print("-" * 100)
    for head in range(len(attn)):
        for i in range(len(tokens_a)):

            if tokens_a[i] in ["[CLS]", "[SEP]"]:
                continue

            max_weight_local = 0
            max_tokens_local = defaultdict(list)
            for j in range(len(tokens_a)):
                if attn.tolist()[head][i][j] > max_weight_local:
                    if tokens_a[j] not in ["[CLS]", "[SEP]"]:
                        max_weight_local = attn.tolist()[head][i][j]
                        max_tokens_local[max_weight_local].append(tokens_a[j])

                # break
            if max_weight_local >= 0.98:
                print(
                    f"layer: {layer}, ",
                    f"head: {head}, "
                    f"token_left: {tokens_a[i]}, ",
                    f"token_right: {max_tokens_local[max_weight_local]}, ",
                    f"attention_wight: {round(max_weight_local, 4)}")
            # break
        # break

    print("-" * 100)
    # break

layer: 0
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
layer: 1
----------------------------------------------------------------------------------------------------
layer: 1,  head: 3, token_left: talk,  token_right: ['about'],  attention_wight: 0.9956
layer: 1,  head: 4, token_left: is,  token_right: ['name'],  attention_wight: 0.9911
layer: 1,  head: 4, token_left: and,  token_right: ['##1'],  attention_wight: 0.9853
layer: 1,  head: 4, token_left: ',  token_right: ['i'],  attention_wight: 0.9977
layer: 1,  head: 4, token_left: here,  token_right: ['m'],  attention_wight: 0.9823
layer: 1,  head: 4, token_left: to,  token_right: ['here'],  attention_wight: 0.9895
layer: 1,  head: 4, token_left: talk,  token_right: ['to'],  attention_wight: 0.9951
layer: 1,  head: 4, token_left: about,  token_right: ['talk'],  attention_wight: 0.997

In [12]:
for layer, attn_data in enumerate(attn_data_list):
    # Process attention
    attn = attn_data['attn'][0]  # assume batch_size=1; shape = [num_heads, source_seq_len, target_seq_len]
    attn_dict['all'].append(attn.tolist())

In [13]:
def format_special_chars(tokens):
    return [t.replace('Ġ', ' ').replace('▁', ' ') for t in tokens]

In [14]:
tokens_a = format_special_chars(tokens_a)
tokens_a

['[CLS]',
 'hi',
 'my',
 'name',
 'is',
 '@',
 'caps',
 '##1',
 '@',
 'caps',
 '##2',
 '@',
 'person',
 '##1',
 'and',
 'i',
 "'",
 'm',
 'here',
 'to',
 'talk',
 'about',
 'why',
 'i',
 'thing',
 'we',
 'should',
 'have',
 'these',
 'certain',
 'materials',
 'in',
 'the',
 'libraries',
 '.',
 '[SEP]']

In [15]:
results = {
    'all': {
        'attn': attn_dict['all'],
        'left_text': tokens_a,
        'right_text': tokens_a
    }
}