# Bert playground

In [None]:
import torch
from transformers import BertTokenizer, BertModel, BertConfig
import matplotlib.pyplot as plt

model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
config = BertConfig.from_pretrained(model_name, output_hidden_states=True)
model = BertModel.from_pretrained(model_name, config=config)

### Tokenize and encode two different sentences containing the same word

In [None]:
text1 = "I love her much"
text2 = "Much love for her"
inputs1 = tokenizer(text1, return_tensors="pt", padding=False, truncation=True)
inputs2 = tokenizer(text2, return_tensors="pt", padding=False, truncation=True)

### Extract word embeddings

In [None]:
with torch.no_grad():
    outputs1 = model.embeddings.word_embeddings(inputs1["input_ids"])
    outputs2 = model.embeddings.word_embeddings(inputs2["input_ids"])

### Find index of the word 'love'

In [None]:
love_index = tokenizer.convert_tokens_to_ids('love')
love_index

### Assuming 'love' is present in both sentences, compare embeddings

In [None]:
embedding1 = outputs1[:, inputs1["input_ids"][0] == love_index, :]
embedding2 = outputs2[:, inputs2["input_ids"][0] == love_index, :]

print(torch.allclose(embedding1, embedding2, atol=1e-6))

In [None]:
embedding1.shape

# Positional Embeddings

In [None]:
love_index = 2

In [None]:
positional_embeddings = model.embeddings.position_embeddings.weight
pos_embedding = positional_embeddings[love_index, :]

In [None]:
pos_embedding.shape

# Attention Layer

In [None]:
text = "I love her much"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs["input_ids"]
word_embeddings = model.embeddings.word_embeddings(input_ids)

with torch.no_grad():
    outputs = model(input_ids)
    embeddings = model.embeddings(input_ids=input_ids)
    assert (embeddings == outputs[2][0]).all()

manual_sum = word_embeddings[0, love_index] + positional_embeddings[love_index]

In [None]:
print(torch.allclose(embeddings[0][love_index], manual_sum, atol=1e-6))

## Why?

In [None]:
model

### In Bert, we need to add token type and LayerNorm

In [None]:
manual_sum_with_type = manual_sum + model.embeddings.token_type_embeddings.weight[0]
manual_sum_with_layer_norm = model.embeddings.LayerNorm(manual_sum_with_type)

In [None]:
print(torch.allclose(embeddings[0][love_index], manual_sum_with_layer_norm, atol=1e-6))

### How does LayerNorm work?

In [None]:
gamma = model.embeddings.LayerNorm.weight
beta = model.embeddings.LayerNorm.bias
epsilon = 1e-12

mean = manual_sum_with_type.mean(dim=-1, keepdim=True)
variance = manual_sum_with_type.var(dim=-1, keepdim=True, unbiased=False)

normalized = (manual_sum_with_type - mean) / torch.sqrt(variance + epsilon)

layer_normed_embeddings = gamma * normalized + beta
print(torch.allclose(model.embeddings.LayerNorm(manual_sum_with_type), layer_normed_embeddings, atol=1e-3))

# Now, let's compute the attention scores

In [None]:
num_heads = 12
head = 0

head_size = model.config.hidden_size // num_heads
rows_start = head * head_size
rows_end = (head + 1) * head_size

q_weight = model.encoder.layer[0].attention.self.query.weight
k_weight = model.encoder.layer[0].attention.self.key.weight
v_weight = model.encoder.layer[0].attention.self.value.weight

q_bias = model.encoder.layer[0].attention.self.query.bias
k_bias = model.encoder.layer[0].attention.self.key.bias
v_bias = model.encoder.layer[0].attention.self.value.bias

q_weight_head = q_weight[rows_start:rows_end]
k_weight_head = k_weight[rows_start:rows_end]
v_weight_head = v_weight[rows_start:rows_end]

q_bias_head = q_bias[rows_start:rows_end]
k_bias_head = k_bias[rows_start:rows_end]
v_bias_head = v_bias[rows_start:rows_end]

In [None]:
queries_head = torch.matmul(embeddings, q_weight_head.T) + q_bias_head
keys_head = torch.matmul(embeddings, k_weight_head.T) + k_bias_head
values_head = torch.matmul(embeddings, v_weight_head.T) + v_bias_head

In [None]:
attention_scores_head = torch.matmul(queries_head, keys_head.transpose(-2, -1))
attention_scores_head /= torch.sqrt(torch.tensor(head_size, dtype=torch.float32))
attention_probs_head = torch.softmax(attention_scores_head, dim=-1)

In [None]:
outputs = model(inputs['input_ids'], output_attentions=True)
direct_attention_scores = outputs.attentions[0]  # For the first layer

In [None]:
direct_attention_scores_head = direct_attention_scores[0, head]  # First batch, head

print(torch.allclose(attention_probs_head, direct_attention_scores_head, atol=1e-6))

## Nice! :)

### Now, let's compute the full output after the multi-head attention

In [None]:
values = torch.matmul(embeddings, v_weight.T) + v_bias
values.shape

In [None]:
values_in_heads = values.view(values.shape[0], values.shape[1], num_heads, head_size).transpose(1, 2)

In [None]:
print(values_in_heads.shape)
print(direct_attention_scores.shape)

In [None]:
# matmul should handle the first two dimensions nicely

weighted_values_per_head = torch.matmul(direct_attention_scores, values_in_heads)
weighted_values_per_head.shape

In [None]:
torch.allclose(weighted_values_per_head[0][0], torch.matmul(direct_attention_scores_head, values_head[0]), atol=1e-6)   

In [None]:
post_attention = weighted_values_per_head.permute(0, 2, 1, 3).reshape(1, 6, -1)
result = torch.allclose(post_attention[0, 0, 0:64], weighted_values_per_head[0, 0, 0], atol=1e-6)

result, post_attention.shape

### Let's find the embeddings at layer 1 now!

In [None]:
model

## We need to apply a fully connected and another layer norm

In [None]:
fully_connected = model.encoder.layer[0].attention.output.dense
layer_norm = model.encoder.layer[0].attention.output.LayerNorm

In [None]:
fc1 = fully_connected(post_attention)
fc2 = fc1 + embeddings
fc3 = layer_norm(fc2)

## Something Bert has...

In [None]:
fully_connected_intermediate = model.encoder.layer[0].intermediate.dense
activation = model.encoder.layer[0].intermediate.intermediate_act_fn

intermediate1 = fully_connected_intermediate(fc3)
intermediate2 = activation(intermediate1)

In [None]:
fully_connected_output = model.encoder.layer[0].output.dense
layer_norm_output = model.encoder.layer[0].output.LayerNorm

output1 = fully_connected_output(intermediate2)
output2 = output1 + fc3
output3 = layer_norm_output(output2)

In [None]:
output3.shape

### Remember that we had the following

In [None]:
if False:
    outputs = model(input_ids)
    embeddings = model.embeddings(input_ids=input_ids)
    assert (embeddings == outputs[2][0]).all()

In [None]:
outputs.keys()

In [None]:
embeddings_layer1 = outputs[2][1]
embeddings_layer1.shape

In [None]:
did_we_make_it = torch.allclose(output3, embeddings_layer1, atol=1e-5)
did_we_make_it ## ?