In [13]:
# a simple self-attention mechanism calculation for a randomly taken encoded input sequence 
# values with a single head. The actual Transformer architecture used multi-headed attention

import torch
import torch.nn.functional as F

# a simple encoded sentence
sentence = torch.tensor([0, 2, 8, 4, 1, 8, 5, 2])
# embed sentence tensor with 10 vocab_size and 16 embed_dim
torch.manual_seed(42)
embed = torch.nn.Embedding(10, 16)
# each element in sentence tensor is converted to 16-dim embedded tensor. So shape is (8X16)
embedded_sentence = embed(sentence)

# calcualte query, key and value role projection tensors for

# step1: define query, key, value weights to calculate above q, k, v
# shapes of these values are going to be dXd where shape of input sequence is nXd
# the input seq has shape 8X16. so the weights have shape 16X16 because these weights are 
# to be dotted with each input item in the sequence to generate corresponding q, k and v values of each input item
d = embedded_sentence.shape[1]
torch.manual_seed(42)
W_query = torch.rand(d, d)
W_key = torch.rand(d, d)
W_value = torch.rand(d, d)

# calculate query, key and value projection matrices
queries = torch.matmul(embedded_sentence, W_query)
keys = torch.matmul(embedded_sentence, W_key)
values = torch.matmul(embedded_sentence, W_value)

# calculate alignment scores by dot product(.matmul() internally is bunch of dot products) 
# now we have the scores for how each input is related to or similar to other input items in the enitre sequence 
omega = torch.matmul(queries, keys.T)

# normalise the values for prob distribution using softmax
attn_weights = F.softmax(omega / d**0.5, dim=1)
print(attn_weights.sum(dim=1))

#weigted sum over 'values' to calcualte the output of self-attention layer
# finally we represented the input sequence in a different way by calculating the self-attention
z = torch.matmul(attn_weights, values)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
