### Scaled dot-product attention

Purpose of self-attention is enrich the input features. In sequential data, in addition to individual data elements, the relationship amongst them carry a lot of information. This relationship is commonly called as the context. In a sentence "premature optimization is the root cause of all evil", the famous Donald Knuth quote, the word evil is better understood from the context of "premature optimization".

The context information is calculated using the dot product of input.


In [39]:
import numpy as np

SEQ_LEN   = 5
EMBD_LEN  = 10

# input matrix
x = np.random.normal(size=(SEQ_LEN,EMBD_LEN))

# dimensions of q,k and v 
q_d = k_d = 20 # dimension of query and key weights
v_d = 25       # dimension of value matrix weight

# weight matrices
wq = np.random.normal(size=(q_d, EMBD_LEN))
wk = np.random.normal(size=(k_d, EMBD_LEN))
wv = np.random.normal(size=(v_d, EMBD_LEN))

print(f"x, input shape {x.shape}")

print(f"wq, q weight matrix shape {wq.shape}")
print(f"wk, k weight matrix shape {wk.shape}")
print(f"wv, v weight matrix shape {wv.shape}")



x, input shape (5, 10)
wq, q weight matrix shape (20, 10)
wk, k weight matrix shape (20, 10)
wv, v weight matrix shape (25, 10)


In [43]:
# projection operation
wqp = np.matmul(wq,x.T).T
wkp = np.matmul(wk,x.T).T
wvp = np.matmul(wv,x.T).T

print(f"wqp, q weight matrix shape {wqp.shape}")
print(f"wkp, k weight matrix shape {wkp.shape}")
print(f"wvp, v weight matrix shape {wvp.shape}")


wqp, q weight matrix shape (5, 20)
wkp, k weight matrix shape (5, 20)
wvp, v weight matrix shape (5, 25)


In [78]:
# score calculation
def softmax(x):
    e_x = np.exp(x)
    return e_x / e_x.sum(axis=1,keepdims=True)

score = np.matmul(wqp, wkp.T)

print(f"score shape {score.shape}")
scaled_score = score / np.sqrt(wkd)
scaled_softmax_score = softmax(scaled_score)


score shape (5, 5)


In [85]:
context_vector = np.sum(np.matmul(scaled_softmax_score, wvp),axis=0)
context_vector

array([  2.83474113,   5.01490124,  -0.08648265,  -2.21421563,
         1.90748451,   5.32280097,  -7.62641771,  -2.64733799,
       -10.30350218,   7.95395517,   3.18308381,   4.1216483 ,
        -7.71829382, -10.80888751,  -3.56157144,  -9.99145609,
         1.45420529,  10.05256878,  -9.57323371,   5.71147096,
        -4.5186982 ,   1.80217433,   1.17459313,  -2.42544073,
        -0.58961725])