In [14]:
import numpy as np
import math
L , d_k , d_v = 4,8,8

q = np.random.randn(L , d_k)
k = np.random.randn(L , d_k)
v = np.random.randn(L , d_k)

In [15]:
print("Q\n" , q)
print("K\n" , k)
print("V\n" , v)

Q
 [[ 1.76572456 -0.09270103  1.85045424  1.93527855 -1.70270076  0.88448396
   0.63425415 -0.01749719]
 [ 1.07486017  0.06255792 -0.88694926  0.40140468 -0.67338265 -0.59566549
  -0.57667649  0.08293533]
 [ 0.9830268  -1.29905319 -1.27460438 -0.07483937 -0.00826813 -0.78283928
   0.16372286  0.61971329]
 [-0.88467336 -0.5604085   0.94750255 -1.62358875 -1.0599029  -0.6144543
  -0.0409876   0.16435924]]
K
 [[ 1.89160877  1.174909    1.87040293  0.48762916 -1.09535518  0.55980148
  -1.47887525  2.27295813]
 [ 1.30067442  1.11379992 -1.24644673 -2.18297141  0.35262162 -1.05881751
   1.20386723 -0.76781893]
 [ 0.34492372 -0.75682887 -0.5320155  -0.5253843  -1.47270244 -0.4158723
  -1.17572458  0.85470216]
 [ 1.53277875  0.84089595 -0.47458492 -0.15825379 -0.91495678  0.78670665
  -0.77250063 -1.30371231]]
V
 [[ 0.76749538 -0.54523642  1.50417453  1.44480733 -0.59631405  0.48483205
   0.02024295 -0.31996345]
 [ 0.92058135 -0.64767917  1.09034415 -0.37986402 -0.09516703 -0.04966115
   0.626

## Attention

In [16]:
np.matmul(q , k.T)

array([[ 9.01838249, -5.09769172,  0.0570399 ,  3.23062692],
       [ 2.0889788 ,  1.33233073,  2.57268894,  2.54239762],
       [-1.35000616,  2.13105255,  2.71457389, -0.5115581 ],
       [-0.10018764,  0.68968995,  2.47303495, -1.71622925]])

In [17]:
q.var()  , k.var() , np.matmul(q , k.T).var()

(0.9080552617780995, 1.3393643336317367, 8.458295994928855)

In [18]:
scaled = np.matmul(q , k.T) / math.sqrt(d_k)

In [19]:
q.var()  , k.var() ,scaled.var()

(0.9080552617780995, 1.3393643336317367, 1.0572869993661067)

## Masking
#### . This is to ensure words don't get context from words generated in the future.
#### . Not required in the encoders, but required in the decoders

In [20]:
mask = np.tril(np.ones((L , L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [21]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [22]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [23]:
scaled + mask

array([[ 3.18847971,        -inf,        -inf,        -inf],
       [ 0.73856554,  0.47105005,        -inf,        -inf],
       [-0.47729925,  0.75344085,  0.9597468 ,        -inf],
       [-0.03542168,  0.24384222,  0.87434989, -0.60677867]])

## Softmax

In [24]:
def softmax(x):
    return (np.exp(x).T/np.sum(np.exp(x) , axis = -1)).T

In [25]:
attention = softmax(scaled + mask)
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.56648286, 0.43351714, 0.        , 0.        ],
       [0.11584789, 0.39663565, 0.48751646, 0.        ],
       [0.18619651, 0.24618088, 0.46246654, 0.10515607]])

In [26]:
softmax(scaled)

array([[0.84883198, 0.00577233, 0.03571419, 0.10968149],
       [0.24238535, 0.18549229, 0.28759297, 0.2845294 ],
       [0.10022987, 0.34316327, 0.421792  , 0.13481486],
       [0.18619651, 0.24618088, 0.46246654, 0.10515607]])

In [27]:
new_v = np.matmul(attention , v)

In [28]:
new_v

array([[ 0.76749538, -0.54523642,  1.50417453,  1.44480733, -0.59631405,
         0.48483205,  0.02024295, -0.31996345],
       [ 0.83386077, -0.58964711,  1.32477197,  0.65378102, -0.37905823,
         0.25312009,  0.28309717, -0.47802678],
       [ 0.71271418, -0.39142579,  0.4432656 ,  0.20315135,  0.51673106,
        -0.39047738, -0.05402122, -0.43594749],
       [ 0.73016552, -0.3826378 ,  0.49985682,  0.4817085 ,  0.41970286,
        -0.33424314,  0.02736255, -0.16400372]])

In [29]:
v

array([[ 0.76749538, -0.54523642,  1.50417453,  1.44480733, -0.59631405,
         0.48483205,  0.02024295, -0.31996345],
       [ 0.92058135, -0.64767917,  1.09034415, -0.37986402, -0.09516703,
        -0.04966115,  0.62657242, -0.68457032],
       [ 0.53057917, -0.14639229, -0.33528961,  0.38243035,  1.27905306,
        -0.87575866, -0.6253887 , -0.26123303],
       [ 1.09604623, -0.51322684,  1.01204913,  1.23002346, -0.35524795,
        -0.069254  ,  1.50789757,  1.75845304]])

## Function 

In [30]:
def softmax(x):
    return (np.exp(x).T/np.sum(np.exp(x) , axis = -1)).T

def scaled_dot_product_attention(q , k , v , mask = None):
    d_k = q.shape[-1]
    scaled = np.matmul(q , k.T) / math.sqrt(d_k)
    if mask is not None :
        scaled = mask + scaled
    attention  = softmax(scaled)
    out = np.matmul(attention , v)
    return out , attention

In [31]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 1.76572456 -0.09270103  1.85045424  1.93527855 -1.70270076  0.88448396
   0.63425415 -0.01749719]
 [ 1.07486017  0.06255792 -0.88694926  0.40140468 -0.67338265 -0.59566549
  -0.57667649  0.08293533]
 [ 0.9830268  -1.29905319 -1.27460438 -0.07483937 -0.00826813 -0.78283928
   0.16372286  0.61971329]
 [-0.88467336 -0.5604085   0.94750255 -1.62358875 -1.0599029  -0.6144543
  -0.0409876   0.16435924]]
K
 [[ 1.89160877  1.174909    1.87040293  0.48762916 -1.09535518  0.55980148
  -1.47887525  2.27295813]
 [ 1.30067442  1.11379992 -1.24644673 -2.18297141  0.35262162 -1.05881751
   1.20386723 -0.76781893]
 [ 0.34492372 -0.75682887 -0.5320155  -0.5253843  -1.47270244 -0.4158723
  -1.17572458  0.85470216]
 [ 1.53277875  0.84089595 -0.47458492 -0.15825379 -0.91495678  0.78670665
  -0.77250063 -1.30371231]]
V
 [[ 0.76749538 -0.54523642  1.50417453  1.44480733 -0.59631405  0.48483205
   0.02024295 -0.31996345]
 [ 0.92058135 -0.64767917  1.09034415 -0.37986402 -0.09516703 -0.04966115
   0.626