In [None]:
!pip install transformers
!pip install sentencepiece

# Extracting BERT Attention Layers for Heatmap Visualizations

### Just for fun ;)

In [2]:
from transformers import BertTokenizer, BertForTokenClassification

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('bert-base-uncased')


In [4]:
input_text = "The quick brown fox jumps over the lazy dog."



In [5]:
inputs = tokenizer(input_text, return_tensors='pt')
outputs = model(**inputs, output_hidden_states=True, output_attentions=True )
last_hidden_states = outputs.hidden_states[-1]

In [6]:
tokens = [tokenizer.convert_ids_to_tokens(x, skip_special_tokens=False) for x in inputs.input_ids.detach().numpy()][0]

print(tokens)

['[CLS]', 'the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog', '.', '[SEP]']


In [7]:
tokenizer.convert_ids_to_tokens(7592, skip_special_tokens=False)

'hello'

In [8]:
last_attentions_layer = outputs.attentions[11]

In [9]:
last_attentions_layer[0][10].detach().numpy()

array([[0.03147584, 0.19651039, 0.02151786, 0.05381653, 0.0314916 ,
        0.20405826, 0.03552675, 0.12405386, 0.02467739, 0.04839181,
        0.10014535, 0.12833437],
       [0.03165355, 0.02928682, 0.00815965, 0.01365662, 0.00899835,
        0.02280228, 0.00732772, 0.02578287, 0.00665841, 0.01132655,
        0.22706069, 0.6072865 ],
       [0.01180759, 0.01261185, 0.02239563, 0.01487199, 0.0087803 ,
        0.01523246, 0.01317953, 0.01807125, 0.01887546, 0.00947179,
        0.22242929, 0.6322729 ],
       [0.01740275, 0.01063784, 0.01179645, 0.01060013, 0.00633165,
        0.01391911, 0.00464607, 0.01675533, 0.01273173, 0.00879767,
        0.22735417, 0.6590271 ],
       [0.03047928, 0.01599603, 0.00464044, 0.00498035, 0.00568827,
        0.01444254, 0.00584413, 0.01518099, 0.00398714, 0.00689612,
        0.2368943 , 0.6549705 ],
       [0.02006025, 0.02081166, 0.00878422, 0.01931005, 0.00733317,
        0.02052327, 0.02668031, 0.03522412, 0.01858792, 0.01147815,
        0.21417305,

In [10]:
len(last_attentions_layer[0])

12

In [11]:
import pandas as pd
import numpy as np

square = last_attentions_layer[0][11].detach().numpy()

In [12]:
square[0][1]

0.1090239

In [13]:
df = pd.DataFrame(square)
df.style.background_gradient(cmap ='viridis').set_properties(**{'font-size': '20px'})

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,0.014106,0.109024,0.041868,0.058096,0.030475,0.088796,0.064844,0.117027,0.045787,0.045433,0.144627,0.239918
1,0.020414,0.022445,0.015969,0.015657,0.014576,0.024353,0.016392,0.01498,0.010507,0.015763,0.214604,0.61434
2,0.010732,0.012563,0.023212,0.013554,0.012637,0.016352,0.015529,0.007773,0.019237,0.01002,0.227392,0.630999
3,0.016309,0.014172,0.018864,0.013878,0.015255,0.015497,0.021642,0.014787,0.031765,0.018139,0.21477,0.604923
4,0.01775,0.01718,0.010462,0.009634,0.02543,0.037314,0.031265,0.009961,0.0077,0.010239,0.231369,0.591695
5,0.009117,0.013612,0.015884,0.005553,0.018457,0.028312,0.034185,0.008713,0.011445,0.013144,0.219209,0.62237
6,0.012896,0.005992,0.006469,0.004776,0.012493,0.024708,0.031301,0.009206,0.021861,0.013433,0.189178,0.667688
7,0.016821,0.014285,0.011316,0.008726,0.010586,0.017164,0.027378,0.022018,0.024503,0.025152,0.198837,0.623213
8,0.017943,0.011,0.018357,0.009987,0.011777,0.011784,0.039077,0.018484,0.042742,0.02809,0.207798,0.582961
9,0.023179,0.017851,0.012909,0.012829,0.028873,0.021419,0.033318,0.025655,0.026706,0.032064,0.205401,0.559795


In [16]:
def generate_heatmap_data(tokens):
  heatmap_data = []
  for layer in range(len(last_attentions_layer[0])):
    square_2 = last_attentions_layer[0][layer].detach().numpy()
    layer_list = []
    for x in range(len(tokens)):
      for y in range(len(tokens)):
        layer_list.append({'x':tokens[x], 'y':tokens[y], 'color':str(square_2[x][y])})
    heatmap_data.append(layer_list)
  return heatmap_data


In [17]:
data = generate_heatmap_data(tokens)

In [None]:
data

[[{'x': '[CLS]', 'y': '[CLS]', 'color': '0.31455678'},
  {'x': '[CLS]', 'y': 'je', 'color': '0.06671245'},
  {'x': '[CLS]', 'y': 'ne', 'color': '0.050966017'},
  {'x': '[CLS]', 'y': 'lu', 'color': '0.0076460005'},
  {'x': '[CLS]', 'y': '##i', 'color': '0.023676913'},
  {'x': '[CLS]', 'y': 'ai', 'color': '0.01923178'},
  {'x': '[CLS]', 'y': 'pas', 'color': '0.040870495'},
  {'x': '[CLS]', 'y': 'par', 'color': '0.03169796'},
  {'x': '[CLS]', 'y': '##le', 'color': '0.033290785'},
  {'x': '[CLS]', 'y': '[SEP]', 'color': '0.10227836'},
  {'x': '[CLS]', 'y': 'i', 'color': '0.029609028'},
  {'x': '[CLS]', 'y': 'haven', 'color': '0.021251015'},
  {'x': '[CLS]', 'y': "'", 'color': '0.10492619'},
  {'x': '[CLS]', 'y': 't', 'color': '0.008520775'},
  {'x': '[CLS]', 'y': 'spoken', 'color': '0.009183092'},
  {'x': '[CLS]', 'y': 'to', 'color': '0.008547873'},
  {'x': '[CLS]', 'y': 'him', 'color': '0.022086926'},
  {'x': '[CLS]', 'y': '[SEP]', 'color': '0.10494747'},
  {'x': 'je', 'y': '[CLS]', 'colo

In [18]:
import json

dumps = json.dumps(data, indent=4)  

with open("bert_heatmap.json", "w") as outfile:
    outfile.write(dumps)