In [None]:
import numpy as np
import math

Query
- What am I looking for
- [ sequence length, d_k]

Key
- What I can offer
- [ sequece length, d_k ]

Value
- What I actually offer
- [ sequence length, d_v ]

In [None]:
L, d_k, d_v = 4, 8, 8
# L = length of sequence
# d_k =
# d_v =
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [None]:
q

array([[ 1.36356156, -0.18691182,  1.28193485, -0.1105039 ,  1.85959866,
        -1.16846223,  0.03678578, -0.64119846],
       [-0.53128764,  0.23745705, -0.49561824, -1.84935788,  0.22371562,
        -0.50987601, -0.02358382,  1.86298275],
       [ 0.0385568 ,  1.86055424, -1.46178503,  1.01461714, -1.23424907,
         0.30591167, -0.77874867, -0.29045989],
       [-0.70135465,  1.72472528, -0.14120923, -0.62749337, -0.39903098,
        -0.00354677, -1.51848765, -1.38518668]])

In [None]:
k

array([[-1.37286052, -0.23153669,  0.2715611 , -0.65591073, -0.13488642,
         0.72374622, -0.17189557, -0.2228312 ],
       [ 0.24281772, -1.1105715 , -0.32557579, -0.3432607 ,  0.41166963,
        -0.34568284,  1.47291009, -2.06148413],
       [ 1.76737621, -0.29152771,  0.56727085,  0.60468754, -0.31842822,
         1.29399788, -0.64047442, -1.65696562],
       [ 2.7337647 , -0.44233393, -0.31082483, -1.38237913,  0.82808137,
        -0.81943862, -0.03059131, -1.51333032]])

In [None]:
v

array([[ 0.34131403,  1.24786025, -0.12903802, -0.59826645, -0.24232617,
         0.54173432, -0.86491139, -0.76773299],
       [-0.44219429, -0.23895719,  2.27152296, -0.08224842,  0.45455451,
         0.31486997,  0.87850504, -1.27064538],
       [ 1.65933199, -0.2012565 ,  1.40002879, -0.03821449, -1.88177953,
        -0.73513935,  1.01886936,  0.69763793],
       [ 0.38614034, -1.22304162,  0.94780548, -1.88126599, -0.41911076,
        -1.18638461,  1.90089709, -0.88107686]])

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$

In [None]:
np.matmul(q, k.T)

array([[-2.36804757,  2.70470079,  2.0595473 ,  7.03123685],
       [ 0.94255307, -3.2034398 , -6.21045325, -1.0624065 ],
       [-0.95970908, -3.09137222,  1.07895899, -2.47515733],
       [ 1.55769863, -1.36846225,  1.18832837,  0.04625574]])

In [None]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(1.0504598953361777, 1.033290378314194, 8.741723348838764)

In [None]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.0504598953361777, 1.033290378314194, 1.0927154186048453)

In [None]:
scaled

array([[-0.83723125,  0.95625613,  0.72815993,  2.48591763],
       [ 0.33324283, -1.132587  , -2.1957268 , -0.37561742],
       [-0.3393084 , -1.09296513,  0.38146961, -0.87510027],
       [ 0.55072963, -0.48382447,  0.42013753,  0.01635387]])

## Masking

- This is to ensure words don't get context from words generated in the future.
- Not required in the encoders, but required int he decoders

In [None]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [None]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [None]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [None]:
scaled + mask

array([[ 0.68537216,        -inf,        -inf,        -inf],
       [ 0.47796088,  0.42358302,        -inf,        -inf],
       [ 0.37611945, -0.30709922, -0.65849946,        -inf],
       [ 0.78209275, -0.99700418,  1.88206279,  0.79213542]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [None]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [None]:
attention = softmax(scaled + mask)

In [None]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.51359112, 0.48640888, 0.        , 0.        ],
       [0.53753304, 0.27144826, 0.1910187 , 0.        ],
       [0.19293995, 0.03256643, 0.57960627, 0.19488734]])

In [None]:
new_v = np.matmul(attention, v)
new_v

array([[-0.00368231,  1.43739233, -0.59614565, -1.23171219,  1.12030717,
        -0.98620738, -0.15461465, -1.03106383],
       [ 0.41440401, -0.13671232,  0.02128364, -0.60532081,  0.49977893,
        -1.1936286 , -0.27463831, -1.10169151],
       [ 0.32673907,  0.72121642, -0.00947672, -0.59897862,  0.90155754,
        -0.88535361, -0.21384855, -0.7053796 ],
       [ 0.18700384,  1.67754576,  0.33105314, -0.41795742,  1.4258469 ,
        -0.18788199, -0.10285145,  0.54683565]])

In [None]:
v

array([[-0.00368231,  1.43739233, -0.59614565, -1.23171219,  1.12030717,
        -0.98620738, -0.15461465, -1.03106383],
       [ 0.85585446, -1.79878344,  0.67321704,  0.05607552, -0.15542661,
        -1.41264124, -0.40136933, -1.17626611],
       [ 0.50465335,  2.28693419,  0.67128338,  0.2506863 ,  1.78802234,
         0.14775751, -0.11405725,  0.88026286],
       [-0.68069105,  0.68385101,  0.17994557, -1.68013201,  0.91543969,
        -0.19108312,  0.03160471,  1.40527326]])

# Function

In [None]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [None]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 0.11672673 -2.54870451 -1.44065948  0.93661829  1.36278968  1.04252277
  -0.01310938 -1.3163937 ]
 [ 0.26721599 -0.90218255  0.07417847 -0.10430246  0.52684253 -0.07081531
  -0.60511725 -0.55225527]
 [-0.93297509  0.28724456  1.37184579  0.41589874  0.34981245 -0.24753755
  -1.24497125  0.05044148]
 [-0.11414585 -0.01545749 -0.58376828 -0.40193907  0.93931836 -1.94334363
  -0.34770465  1.50103406]]
K
 [[ 1.1226585  -0.85645535  0.54315044  1.36560451  0.52539476 -0.94502504
  -0.48444661  0.46268014]
 [-0.53713766 -1.16937329 -0.57988617  0.92713577 -0.85995607 -0.40352635
   0.26555146 -1.83159914]
 [-2.06994435 -0.09514715 -1.64928361 -0.17375184  0.13146819 -1.76335363
   1.56568846  0.69751826]
 [ 0.32910684 -0.1939204  -0.80444134  0.78816869  0.35599408  0.28309835
  -0.25970963  1.49744622]]
V
 [[-0.00368231  1.43739233 -0.59614565 -1.23171219  1.12030717 -0.98620738
  -0.15461465 -1.03106383]
 [ 0.85585446 -1.79878344  0.67321704  0.05607552 -0.15542661 -1.41264124
  -0.4