In [2]:
import numpy as np
import math

In [3]:
L, dk, dv = 4, 8, 8
q = np.random.randn(L, dk)
k = np.random.randn(L, dk)
v = np.random.randn(L, dv)

In [6]:
print("Q: ", q, '\n\n')
print("K: ", k, '\n\n')
print("V: ", v, '\n\n')

Q:  [[ 1.16797418 -0.73724956  0.75041732 -0.1220419   1.70315355  0.31347885
   0.34485802  0.4460984 ]
 [ 3.30107262  0.52084263 -0.30178335  1.6757843  -0.5155369  -2.57282086
   0.7492196  -0.1116739 ]
 [ 1.79346332  1.43356091  1.04335097 -0.17879843  0.64886601 -1.36614847
   0.23380491  0.87536806]
 [ 1.47103375 -1.58489204 -0.45911092  0.28391595  0.21250504 -1.61590866
  -1.29484902  0.00475083]] 


K:  [[-0.27718816 -1.65065268  0.65895024 -0.30877976  0.08718337 -1.25869869
  -0.189555    0.19835477]
 [ 0.34020432  0.52734301  0.70054018  0.02666352 -0.47575991  0.27451574
   0.31178952 -0.80116187]
 [-0.32764932 -0.14549352 -0.40258405  0.75537362  1.05953459 -1.34530308
   1.99241804  0.17674986]
 [-1.40543092  0.43389184 -0.07028518  0.10003431  1.26343057 -2.0707861
   0.51659185 -1.03069964]] 


V:  [[-1.20455614 -0.78141984 -0.60550966 -0.49607676  0.05358656  0.31648808
   1.97222775  0.97455048]
 [-0.22988732 -2.3585375   0.49206019 -0.74539263 -0.530485   -0.2361971

# Self Attention

In [7]:
np.matmul(q, k.T)

array([[ 1.20239352, -0.44310134,  1.479061  , -0.80531921],
       [ 0.53823363,  1.09303249,  4.61797795,  0.95396625],
       [-0.21525807,  0.78011689,  1.79464107,  0.87754205],
       [ 4.11701088, -1.60160544, -0.03209583,  0.24644341]])

In [8]:
q.var(), k.var(), v.var()

(np.float64(1.3772888688135403),
 np.float64(0.7481376289021062),
 np.float64(1.2000978210749365))

In [9]:
# here we are calculating the variance of the dot product of q and k which is ....
# larger then the variance of q and k. To reduce the variance of the dot product we can scale the q and k by 1/sqrt(dk)

np.matmul(q, k.T).var()


np.float64(2.443300481459231)

In [11]:
scaled = np.matmul(q, k.T) / math.sqrt(dk)

q.var(), k.var(), scaled.var()

(np.float64(1.3772888688135403),
 np.float64(0.7481376289021062),
 np.float64(0.30541256018240387))

In [12]:
scaled

array([[ 0.42511031, -0.15665998,  0.52292703, -0.28472334],
       [ 0.19029433,  0.38644534,  1.63270176,  0.337278  ],
       [-0.07610522,  0.27581297,  0.63450144,  0.31025797],
       [ 1.45558315, -0.56625303, -0.01134759,  0.0871309 ]])

# Masking

- A process to hide the future information from model durning the training process
- Upper diagonal would be `zero`, with will become `inf` drurning the `softmax` calculation

In [13]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [21]:
mask[mask == 0] = -np.inf
# this will help the model to see the future tokens in the sequence
(mask)

array([[  1., -inf, -inf, -inf],
       [  1.,   1., -inf, -inf],
       [  1.,   1.,   1., -inf],
       [  1.,   1.,   1.,   1.]])

In [22]:
scaled + mask

array([[1.42511031,       -inf,       -inf,       -inf],
       [1.19029433, 1.38644534,       -inf,       -inf],
       [0.92389478, 1.27581297, 1.63450144,       -inf],
       [2.45558315, 0.43374697, 0.98865241, 1.0871309 ]])

# Softmax

In [24]:
def softmax_fn(x) -> np.ndarray:
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [25]:
attention = softmax_fn(scaled + mask)  # (L, L)
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.45111887, 0.54888113, 0.        , 0.        ],
       [0.22436527, 0.31900079, 0.45663394, 0.        ],
       [0.61822085, 0.08185993, 0.14258168, 0.15733755]])

In [26]:
new_v = np.matmul(attention, v)
new_v

array([[-1.20455614, -0.78141984, -0.60550966, -0.49607676,  0.05358656,
         0.31648808,  1.97222775,  0.97455048],
       [-0.66957881, -1.64706996, -0.00307428, -0.63292153, -0.2669993 ,
         0.0131296 ,  1.32977774,  0.97090339],
       [-0.89506767, -1.53613058, -0.70591991,  0.84266061, -0.00318276,
         0.20823676,  1.38968769,  0.86036549],
       [-0.84228483, -0.96535234, -0.71058386,  0.14229459, -0.20809817,
         0.29146427,  1.44269616,  0.98682583]])

In [27]:
v

array([[-1.20455614, -0.78141984, -0.60550966, -0.49607676,  0.05358656,
         0.31648808,  1.97222775,  0.97455048],
       [-0.22988732, -2.3585375 ,  0.49206019, -0.74539263, -0.530485  ,
        -0.23619713,  0.80175571,  0.96790588],
       [-1.20769139, -1.33242785, -1.59215534,  2.60984511,  0.33729297,
         0.46552584,  1.51418349,  0.72913429],
       [ 0.5936924 , -0.6305729 , -0.95026899,  0.87634056, -1.56283632,
         0.3099347 , -0.36928654,  1.2784266 ]])