In [17]:
import torch
import torch.nn.functional as F
import numpy as np
import pickle
from multiprocessing import Pool
from transformers import BertTokenizer, BertModel, BertConfig
import math

In [2]:
model = BertModel.from_pretrained('bert-base-uncased',output_hidden_states=True,output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [3]:
sample_sent = 'This is a sample sentence.'
input_sent = torch.tensor([tokenizer.encode(sample_sent)])
print(input_sent)

tensor([[ 101, 2023, 2003, 1037, 7099, 6251, 1012,  102]])


In [4]:
# The target outputs we want to recreate
outputs = model(input_sent)

### Recreate the z representations within each layer

In [5]:
# Recreate the sentence embedding
embedded = model.embeddings(input_sent)
assert torch.all(embedded==outputs[2][0])

# Recreate layer 0
layer_0 = model.encoder.layer[0](embedded)[0]
assert torch.all(layer_0==outputs[2][1])

# Recreate layer 1
layer_1 = model.encoder.layer[1](layer_0)[0]
assert torch.all(layer_1==outputs[2][2])

In [6]:
# OK. Let's recreate all the layers!
# Recreate all the layers from attention.self, attention.output, intermediate and output layers
# Note: The output of attention.self should only be the first element (i.e. we need [0]),
#       but output of the intermediate should be kept the way it is (i.e. no [0] needed).
#       When this is messed up, we oberve a drifting error on the order of 10^-4 in total variation.
hidden = embedded.clone()
for layer_num in range(12):
    hidden_post_attn_self = model.encoder.layer[layer_num].attention.self(hidden)[0] #This is the "z" representation
    hidden_post_attn = model.encoder.layer[layer_num].attention.output(hidden_post_attn_self,hidden)
    hidden_post_interm = model.encoder.layer[layer_num].intermediate(hidden_post_attn)
    hidden = model.encoder.layer[layer_num].output(hidden_post_interm,hidden_post_attn)
    assert torch.all(hidden==outputs[2][layer_num+1])
hidden = model.pooler(hidden)

### Recreate the attention matrix

In [7]:
# Recreate the self attention layer using value and attention matrix
# We take the first layer as an example
# First, calculate query, key and value
value = model.encoder.layer[0].attention.self.value(embedded)
query = model.encoder.layer[0].attention.self.query(embedded)
key = model.encoder.layer[0].attention.self.key(embedded)

In [9]:
# Go through the entire first layer
# layer_0_post_attn_0 is the representation we want to recreate
layer_0_post_attn_0 = model.encoder.layer[0].attention.self(embedded)[0]
layer_0_post_attn_1 = model.encoder.layer[0].attention.output(layer_0_post_attn_0,embedded)
layer_0_post_attn = model.encoder.layer[0].attention(embedded)[0]
assert torch.all(layer_0_post_attn==layer_0_post_attn_1)
layer_0_post_interm = model.encoder.layer[0].intermediate(layer_0_post_attn)
layer_0_post_all = model.encoder.layer[0].output(layer_0_post_interm,layer_0_post_attn)
assert torch.all(layer_0_post_all==outputs[2][1])

In [28]:
# Finally, create attn_matrix from query and key, make sure it is the same as outputs[3], and use it to recreate attention.self output (layer_0_post_attn_0)
layer_post_attn_0_new = torch.zeros(value.shape)
assert outputs[3][0].shape[2]==outputs[3][0].shape[3]
num_heads = outputs[3][0].shape[1]
head_dim = outputs[2][0].shape[-1]//num_heads
for head_id in range(num_heads):
    attn_matrix = F.softmax(query[0,:,head_dim*head_id:head_dim*(head_id+1)]@key[0,:,head_dim*head_id:head_dim*(head_id+1)].T/math.sqrt(head_dim),dim=1)
    assert torch.all(attn_matrix==outputs[3][0][0,head_id])
    layer_post_attn_0_new[0,:,head_dim*head_id:head_dim*(head_id+1)] = attn_matrix@value[0,:,head_dim*head_id:head_dim*(head_id+1)]
assert torch.all(layer_post_attn_0_new==layer_0_post_attn_0)