In [1]:
from numpy import array
from numpy import random
from numpy import dot
from scipy.special import softmax

In [2]:
# encoder representations of four different words
''' 
In actual practice, these word embeddings would 
have been generated by an encoder; however, for this 
particular example, you will define them manually. 
'''
word_1 = array([1, 0, 0])
word_2 = array([0, 1, 0])
word_3 = array([1, 1, 0])
word_4 = array([0, 0, 1])

In [3]:
# stacking the word embeddings into a single array
words = array([word_1, word_2, word_3, word_4])

In [4]:
# generating the weight matrices
'''
The next step generates the weight matrices, which 
you will eventually multiply to the word embeddings 
to generate the queries, keys, and values.
'''
random.seed(40)
W_Q = random.randint(3, size=(3, 3))
W_K = random.randint(3, size=(3, 3))
W_V = random.randint(3, size=(3, 3))

In [5]:
# generating the queries, keys and values
'''
Subsequently, the query, key, and value vectors for each 
word are generated by multiplying each word embedding by 
each of the weight matrices. 
'''
Q = words @ W_Q
K = words @ W_K
V = words @ W_V

In [6]:
# scoring the query vectors against all key vectors
'''
The next step scores its query vector against all the key 
vectors using a dot product operation. 
'''
scores = Q @ K.transpose()

In [7]:
# computing the weights by a softmax operation
'''
The score values are subsequently passed through a softmax operation 
to generate the weights. Before doing so, it is common practice to divide 
the score values by the square root of the dimensionality of the key 
vectors (in this case, three) to keep the gradients stable. 
'''
weights = softmax(scores / K.shape[1] ** 0.5, axis=1)

In [8]:
# computing the attention by a weighted sum of the value vectors
'''
Finally, the attention output is calculated by a weighted sum of all 
four value vectors. 
'''
attention = weights @ V

In [9]:
print(attention)

[[1.68198596 1.68198596 1.57598128]
 [1.16552931 1.16552931 0.41477798]
 [1.50216109 1.50216109 1.05426852]
 [1.08547961 1.08547961 0.22483734]]
