In [21]:
import numpy as np
from scipy.special import softmax

# scaled dot-product attention simulation

In [22]:
# random word embeddings
w1 = np.random.rand(5,1)
w2 = np.random.rand(5,1)
w3 = np.random.rand(5,1)

print(f'w1={w1}\n')
print(f'w2={w2}\n')
print(f'w3={w3}\n')

w1=[[0.69453339]
 [0.56874769]
 [0.67876956]
 [0.46420813]
 [0.72690923]]

w2=[[0.45540857]
 [0.74785557]
 [0.81749043]
 [0.28979171]
 [0.03700459]]

w3=[[0.69380642]
 [0.43719921]
 [0.67627902]
 [0.32564306]
 [0.26195105]]



In [23]:
# random transformation matrices
WQ = np.random.rand(5,3)
WK = np.random.rand(5,3)
WV = np.random.rand(5,3)
print(f'\nWQ={WQ}\n')
print(f'\nWK={WK}\n')
print(f'\nWV={WV}\n')


WQ=[[0.85768293 0.8097524  0.1131521 ]
 [0.03434299 0.78957685 0.98330118]
 [0.50379376 0.63112608 0.67818403]
 [0.930376   0.83203011 0.04887063]
 [0.61405312 0.98550047 0.40614657]]


WK=[[0.36919691 0.1361247  0.72053443]
 [0.71822876 0.39686467 0.04528545]
 [0.19157466 0.36145483 0.2659768 ]
 [0.10982799 0.77172553 0.55713648]
 [0.87763333 0.45291319 0.10164459]]


WV=[[0.817592   0.94143668 0.22078388]
 [0.3274662  0.41683922 0.81230439]
 [0.73981798 0.76074838 0.09931974]
 [0.0048496  0.74979256 0.64001566]
 [0.35269625 0.31325836 0.31122616]]



In [25]:
# calculate the queries
q1 = w1.T @ WQ
q2 = w2.T @ WQ
q3 = w3.T @ WQ
print(f'q1={q1}\n')
print(f'q2={q2}\n')
print(f'q3={q3}\n')

q1=[[1.83543079 2.54246378 1.41608669]]

q2=[[1.12046437 1.75278063 1.37049824]]

q3=[[1.41460816 1.86292883 1.08935068]]



In [26]:
# calculate the keys
k1 = w1.T @ WK
k2 = w2.T @ WK
k3 = w3.T @ WK
print(f'k1={k1}\n')
print(f'k2={k2}\n')
print(f'k3={k3}\n')

k1=[[1.48388839 1.25307159 1.03924184]]

k2=[[0.92618096 0.89467521 0.74465286]]

k3=[[0.96537985 0.88234558 0.90763824]]



In [27]:
# calculate values
v1 = w1.T @ WV
v2 = w2.T @ WV
v3 = w3.T @ WV
print(f'v1={v1}\n')
print(f'v2={v2}\n')
print(f'v3={v3}\n')

v1=[[1.51488588 1.98307858 1.20608688]]

v2=[[1.23648669 1.59125405 0.9862142 ]]

v3=[[1.30471031 1.67611786 0.86543065]]



In [28]:
# matrix of word embeddings
W = np.concatenate((w1,w2,w3), axis=1).T
print(f'\nW={W}\n')


W=[[0.69453339 0.56874769 0.67876956 0.46420813 0.72690923]
 [0.45540857 0.74785557 0.81749043 0.28979171 0.03700459]
 [0.69380642 0.43719921 0.67627902 0.32564306 0.26195105]]



In [29]:
# queries, keys and values all at once using matrix multiplication
Q = W @ WQ
K = W @ WK
V = W @ WV
print(f'\nQ={Q}\n')
print(f'\nK={K}\n')
print(f'\nV={V}\n')


Q=[[1.83543079 2.54246378 1.41608669]
 [1.12046437 1.75278063 1.37049824]
 [1.41460816 1.86292883 1.08935068]]


K=[[1.48388839 1.25307159 1.03924184]
 [0.92618096 0.89467521 0.74465286]
 [0.96537985 0.88234558 0.90763824]]


V=[[1.51488588 1.98307858 1.20608688]
 [1.23648669 1.59125405 0.9862142 ]
 [1.30471031 1.67611786 0.86543065]]



In [30]:
# output of the attention layer
attention = softmax( (Q @ K.T) / np.sqrt(K.shape[1]), axis=1 ) @ V
print(f'\nattention={attention}\n')


attention=[[1.42834786 1.85912936 1.10401702]
 [1.40540313 1.82628689 1.07729048]
 [1.41006411 1.8329904  1.08321943]]



# multihead attention simulation

In [31]:
# randomly initialize the projection matrices for each head
WQ1 = np.random.rand(5,3)
WK1 = np.random.rand(5,3)
WV1 = np.random.rand(5,3)
WQ2 = np.random.rand(5,3)
WK2 = np.random.rand(5,3)
WV2 = np.random.rand(5,3)
WQ3 = np.random.rand(5,3)
WK3 = np.random.rand(5,3)
WV3 = np.random.rand(5,3)
WQ4 = np.random.rand(5,3)
WK4 = np.random.rand(5,3)
WV4 = np.random.rand(5,3)

In [32]:
# let's calculate queries, keys and values for each head
Q1 = W @ WQ1
Q2 = W @ WQ2
Q3 = W @ WQ3
Q4 = W @ WQ4

K1 = W @ WK1
K2 = W @ WK2
K3 = W @ WK3
K4 = W @ WK4

V1 = W @ WV1
V2 = W @ WV2
V3 = W @ WV3
V4 = W @ WV4

In [33]:
# attention from each head
attention1 = softmax( (Q1 @ K1.T) / np.sqrt(K1.shape[1]), axis=1 ) @ V1
attention2 = softmax( (Q2 @ K2.T) / np.sqrt(K2.shape[1]), axis=1 ) @ V2
attention3 = softmax( (Q3 @ K3.T) / np.sqrt(K3.shape[1]), axis=1 ) @ V3
attention4 = softmax( (Q4 @ K4.T) / np.sqrt(K4.shape[1]), axis=1 ) @ V4

In [34]:
# concatenate all heads
Z = np.concatenate((attention1,attention2,attention3,attention4), axis=1)
Z.shape

(3, 12)

In [35]:
# multihead projection matrix
WO = np.random.rand(Z.shape[1],3)

In [37]:
# output of multihead attention
multihead = Z @ WO
print(f'\nmultihead={multihead}\n')


multihead=[[8.8701854  9.01219884 8.6135078 ]
 [8.74879861 8.87606941 8.46771228]
 [8.74000155 8.87047045 8.47205605]]

