## Single Head Attention

In [11]:
import torch
import torch.nn as nn
from typing import List

In [6]:
def get_input_embeddings(words: List[str], embeddings_dim: int):
    # we are creating random vector of embeddings_dim size for each words
    # normally we train a tokenizer to get the embeddings.
    # check the blog on tokenizer to learn about this part
    embeddings = [torch.randn(embeddings_dim) for word in words]
    return embeddings

In [7]:
text = "I should sleep now"
words = text.split(" ")
len(words) # 4

4

In [12]:
embeddings_dim = 512
embeddings = get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape # torch.Size([512])

torch.Size([512])

In [20]:
# initialize the query, key and value metrices
query_matrix = nn.Linear(embeddings_dim, embeddings_dim)
key_matrix = nn.Linear(embeddings_dim, embeddings_dim)
value_matrix = nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape # torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])

(torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512]))

In [29]:
# query, key and value vectors computation for each words embeddings
query_vectors = torch.stack([query_matrix(embedding) for embedding in embeddings])
key_vectors = torch.stack([key_matrix(embedding) for embedding in embeddings])
value_vectors = torch.stack([value_matrix(embedding) for embedding in embeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape # torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])

(torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512]))

In [58]:
key_vectors.permute(1, 0).shape

(torch.Size([512, 4]), torch.Size([512, 4]))

In [34]:
# compute the score
scores = torch.matmul(query_vectors, key_vectors.permute(1, 0)) / torch.sqrt(torch.tensor(embeddings_dim, dtype=torch.float32))
scores.shape # torch.Size([4, 4])

torch.Size([4, 4])

In [40]:
# compute the attention weights for each of the words with the other words
softmax = nn.Softmax(dim=-1)
attention_weights = softmax(scores)
attention_weights.shape # torch.Size([4, 4])

torch.Size([4, 4])

In [42]:
# attention output
output = torch.matmul(attention_weights, value_vectors)
output.shape # torch.Size([4, 512])

torch.Size([4, 512])

## Multi-Head Attention

computations are same till query, key and value vector computation.

In [54]:
num_heads = 8
# batch dim is 1 since we are processing one text.
batch_size = 1

In [44]:
text = "I should sleep now"
words = text.split(" ")
len(words) # 4

4

In [45]:
embeddings_dim = 512
embeddings = get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape # torch.Size([512])

torch.Size([512])

In [46]:
# initialize the query, key and value metrices
query_matrix = nn.Linear(embeddings_dim, embeddings_dim)
key_matrix = nn.Linear(embeddings_dim, embeddings_dim)
value_matrix = nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape # torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])

(torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512]))

In [47]:
# query, key and value vectors computation for each words embeddings
query_vectors = torch.stack([query_matrix(embedding) for embedding in embeddings])
key_vectors = torch.stack([key_matrix(embedding) for embedding in embeddings])
value_vectors = torch.stack([value_matrix(embedding) for embedding in embeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape # torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])

(torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512]))

In [55]:
# (batch_size, num_heads, seq_len, embeddings_dim)
query_vectors_view = query_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
key_vectors_view = key_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
value_vectors_view = value_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
query_vectors_view.shape, key_vectors_view.shape, value_vectors_view.shape
# torch.Size([1, 8, 4, 64]),
#  torch.Size([1, 8, 4, 64]),
#  torch.Size([1, 8, 4, 64])

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

We are splitting the each vectors into 8 heads. Assuming we have one text (batch size of 1), So we split the embedding vectors also into 8 parts. Each head will take these parts. If we do this one head at a time:

In [59]:
head1_query_vector = query_vectors_view[0, 0, ...]
head1_key_vector = key_vectors_view[0, 0, ...]
head1_value_vector = value_vectors_view[0, 0, ...]
head1_query_vector.shape, head1_key_vector.shape, head1_value_vector.shape

(torch.Size([4, 64]), torch.Size([4, 64]), torch.Size([4, 64]))

In [60]:
# The above vectors are of same size as before only the feature dim is changed from 512 to 64
# compute the score
scores_head1 = torch.matmul(head1_query_vector, head1_key_vector.permute(1, 0)) / torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))
scores_head1.shape # torch.Size([4, 4])

torch.Size([4, 4])

In [61]:
# compute the attention weights for each of the words with the other words
softmax = nn.Softmax(dim=-1)
attention_weights_head1 = softmax(scores_head1)
attention_weights_head1.shape # torch.Size([4, 4])

torch.Size([4, 4])

In [62]:
output_head1 = torch.matmul(attention_weights_head1, head1_value_vector)
output_head1.shape # torch.Size([4, 512])

torch.Size([4, 64])

In [63]:
# we can compute the output for all the heads
outputs = []
for head_idx in range(num_heads):
    head_idx_query_vector = query_vectors_view[0, head_idx, ...]
    head_idx_key_vector = key_vectors_view[0, head_idx, ...]
    head_idx_value_vector = value_vectors_view[0, head_idx, ...]
    scores_head_idx = torch.matmul(head_idx_query_vector, head_idx_key_vector.permute(1, 0)) / torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))

    softmax = nn.Softmax(dim=-1)
    attention_weights_idx = softmax(scores_head_idx)
    output = torch.matmul(attention_weights_idx, head_idx_value_vector)
    outputs.append(output)

In [64]:
[out.shape for out in outputs]

[torch.Size([4, 64]),
 torch.Size([4, 64]),
 torch.Size([4, 64]),
 torch.Size([4, 64]),
 torch.Size([4, 64]),
 torch.Size([4, 64]),
 torch.Size([4, 64]),
 torch.Size([4, 64])]

In [75]:
# stack the result from each heads for the corresponding words
word0_outputs = torch.cat([out[0] for out in outputs])
word0_outputs.shape

torch.Size([512])

In [76]:
# lets do it for all the words
attn_outputs = []
for i in range(len(words)):
    attn_output = torch.cat([out[i] for out in outputs])
    attn_outputs.append(attn_output)
[attn_output.shape for attn_output in attn_outputs]

[torch.Size([512]), torch.Size([512]), torch.Size([512]), torch.Size([512])]

In [77]:
# Now lets do it in vectorize way.
# We can not permute the last two dimension of the key vector.
key_vectors_view.permute(0, 1, 3, 2).shape

torch.Size([1, 8, 64, 4])

In [78]:
# Transpose the key vector on the last dim
score = torch.matmul(query_vectors_view, key_vectors_view.permute(0, 1, 3, 2)) # Q*k
score = torch.softmax(score, dim=-1)

In [82]:
# reshape the results
attention_results = torch.matmul(score, value_vectors_view)
attention_results.shape # [1, 8, 4, 64]

torch.Size([1, 8, 4, 64])

In [88]:
# merge the results
attention_results = attention_results.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, embeddings_dim)
attention_results.shape

torch.Size([1, 4, 512])