In [13]:
# a simple self-attention mechanism calculation for a randomly taken encoded input sequence 
# values with a single head. The actual Transformer architecture used multi-headed attention

import torch
import torch.nn.functional as F

# a simple encoded sentence
sentence = torch.tensor([0, 2, 8, 4, 1, 8, 5, 2])
# embed sentence tensor with 10 vocab_size and 16 embed_dim
torch.manual_seed(42)
embed = torch.nn.Embedding(10, 16)
# each element in sentence tensor is converted to 16-dim embedded tensor. So shape is (8X16)
embedded_sentence = embed(sentence)

# calcualte query, key and value role projection tensors

# step1: define query, key, value weights to calculate above q, k, v
# shapes of these values are going to be dXd where shape of input sequence is nXd
# the input seq has shape 8X16. so the weights have shape 16X16 because these weights are 
# to be dotted with each input item in the sequence to generate corresponding q, k and v values of each input item
d = embedded_sentence.shape[1]
torch.manual_seed(42)
W_query = torch.rand(d, d)
W_key = torch.rand(d, d)
W_value = torch.rand(d, d)

# calculate query, key and value projection matrices
queries = torch.matmul(embedded_sentence, W_query)
keys = torch.matmul(embedded_sentence, W_key)
values = torch.matmul(embedded_sentence, W_value)

# calculate alignment scores by dot product(.matmul() internally is bunch of dot products) 
# now we have the scores for how each input is related to or similar to other input items in the enitre sequence 
omega = torch.matmul(queries, keys.T)

# normalise the values for prob distribution using softmax
attn_weights = F.softmax(omega / d**0.5, dim=1)
print(attn_weights.sum(dim=1))

#weigted sum over 'values' to calcualte the output of self-attention layer
z = torch.matmul(attn_weights, values)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)


In [36]:
# The actual Transformer architecture uses multi-headed self-attention mechanism
# which isi multiple sets of query, key and values. Each input item in the sequence gets
# its own query, key and value role projection matrices

import torch

sentence = torch.tensor([2, 1, 9, 3, 7, 0, 3, 6])
embed = torch.nn.Embedding(10, 16)
embedded_sentence = embed(sentence)

print(embedded_sentence.shape)

# define multiple sets of weights to calculate Q, K and V
h = 8 # number of sets of each weights
d = embedded_sentence.shape[1]
multihead_W_query = torch.rand(h, d, d)
multihead_W_key = torch.rand(h, d, d)
multihead_W_value = torch.rand(h, d, d)

# adjusting embedded input shape to be comaptible with weight tensors
stacked_inputs = embedded_sentence.T.repeat(h, 1, 1)

#calculate multihead q, k and v
mulithead_queries = torch.bmm(multihead_W_query, stacked_inputs)
mulithead_queries = mulithead_queries.permute(0, 2, 1)
print(mulithead_queries.shape)
multihead_keys = torch.bmm(multihead_W_key, stacked_inputs)
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = torch.bmm(multihead_W_value, stacked_inputs)
multihead_values = multihead_values.permute(0, 2, 1)

# comparison step or the dotting to calculate similarity scores
omega = torch.bmm(mulithead_queries, multihead_keys.permute(0, 2, 1))
print(omega.shape)

# calculate attention weights for all the inputs using softmax
attn_weights = torch.nn.functional.softmax(omega / d**0.5, dim=1)
print(attn_weights.sum(dim=1))

#calculate self-attention layer output(contexts) using weighted sum over 'multihead values'
print(attn_weights.shape)
print(multihead_values.shape)
z = torch.bmm(attn_weights, multihead_values)
print(z.shape)

torch.Size([8, 16])
torch.Size([8, 8, 16])
torch.Size([8, 8, 8])
tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
       grad_fn=<SumBackward1>)
torch.Size([8, 8, 8])
torch.Size([8, 8, 16])
torch.Size([8, 8, 16])
