# Self Attention in Transformers

## Generate Data

In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8          ### L is length of total tokens , d_k is dimension of key and d_v is dimension of value
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[ 1.89674261  0.16876788 -0.1147014  -1.01039339  0.22240272  1.18669652
   1.53849536 -0.65083759]
 [ 0.75051365  0.2883125  -0.83135186 -0.42848096 -1.51048834 -0.43071518
   0.59017915  0.25628828]
 [ 0.38115841 -0.08466821  0.02170363  0.98229443 -1.37154773 -0.32327473
  -1.20469893  0.06098321]
 [-0.99297478  0.94789714 -1.65177143  0.88554675 -0.16072056  0.80961247
   0.06632881 -0.06898069]]
K
 [[ 0.4481545  -1.79396534  0.0908088  -0.74266193 -0.83537924  0.26304504
  -1.28730123 -0.81080393]
 [-1.52075254 -0.58949283 -0.59143947 -1.07363383  1.88036975  0.62861599
  -0.67031112 -0.12577496]
 [-0.1323446  -0.25197528 -0.41384947 -0.24159246  1.10554571 -2.77289787
   1.96111947 -0.08857994]
 [-0.86718853  0.74922133  0.82114015 -1.39302018  0.53917891 -0.7638527
  -0.28285176  1.4148641 ]]
V
 [[-0.07703553 -1.56496974 -0.95833645 -0.17826546  0.07570981  0.09826632
  -0.01383334 -1.28698563]
 [ 0.2261494  -0.64931334 -0.09503368  2.01603016  0.38349369  1.01334865
  -0.30

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$

In [3]:
np.matmul(q, k.T)

array([[-0.03920645, -1.61656787,  0.02813604, -2.34763583],
       [ 0.2428411 , -3.89844706,  0.93472338, -0.81034134],
       [ 2.15725957, -3.57957683, -3.26326825, -1.81004684],
       [-2.63538616,  1.14839658, -1.9242561 , -1.84008057]])

In [4]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(np.float64(0.7485991117903039),
 np.float64(1.059126177871358),
 np.float64(3.0306764642564294))

In [5]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(np.float64(0.7485991117903039),
 np.float64(1.059126177871358),
 np.float64(0.3788345580320537))

- Notice the reduction in variance of the product. This is the reason we use square root so to stabilize the learning by reducing variance of attention matrix

In [6]:
scaled

array([[-0.01386157, -0.57154305,  0.00994759, -0.83001461],
       [ 0.08585729, -1.37830918,  0.33047462, -0.28649893],
       [ 0.76270644, -1.26557152, -1.15373955, -0.6399482 ],
       [-0.93174971,  0.40601951, -0.68032727, -0.65056672]])

## Masking

- This is to ensure words don't get context from words generated in the future.
- Not required in the encoders, but required int he decoders

In [7]:
mask = np.tril(np.ones( (L, L) ))   ### Lower triangular matrix
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [9]:
mask[mask == 0] = -np.inf   ### We are using -inf instead of 0 due to use of softmax ahead
mask[mask == 1] = 0

In [None]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [10]:
scaled + mask

array([[-0.01386157,        -inf,        -inf,        -inf],
       [ 0.08585729, -1.37830918,        -inf,        -inf],
       [ 0.76270644, -1.26557152, -1.15373955,        -inf],
       [-0.93174971,  0.40601951, -0.68032727, -0.65056672]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [12]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [13]:
attention = softmax(scaled + mask)

In [14]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.8121691 , 0.1878309 , 0.        , 0.        ],
       [0.78204988, 0.10288795, 0.11506216, 0.        ],
       [0.13475123, 0.51347407, 0.17327029, 0.17850441]])

In [15]:
new_v = np.matmul(attention, v)
new_v

array([[-0.07703553, -1.56496974, -0.95833645, -0.17826546,  0.07570981,
         0.09826632, -0.01383334, -1.28698563],
       [-0.02008803, -1.39298117, -0.79618151,  0.23389106,  0.13352113,
         0.27014706, -0.06861845, -0.89739785],
       [-0.19326156, -1.35294493, -0.91257525, -0.16088769,  0.07437566,
         0.01851655, -0.01973131, -1.01979623],
       [-0.1337351 , -0.4252172 , -0.34859627,  0.73803012,  0.11741718,
        -0.17214688,  0.113683  ,  0.27128599]])

In [16]:
v

array([[-0.07703553, -1.56496974, -0.95833645, -0.17826546,  0.07570981,
         0.09826632, -0.01383334, -1.28698563],
       [ 0.2261494 , -0.64931334, -0.09503368,  2.01603016,  0.38349369,
         1.01334865, -0.30550586,  0.78715539],
       [-1.35825696, -0.54104669, -1.33258855, -1.98936307, -0.21110388,
        -1.41309689,  0.1957196 , -0.8195403 ],
       [-0.02314086,  1.1922242 ,  0.3374487 ,  0.40094203, -0.29758844,
        -2.58183501,  1.33612388,  1.02253491]])

# Function

In [17]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [18]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 1.89674261  0.16876788 -0.1147014  -1.01039339  0.22240272  1.18669652
   1.53849536 -0.65083759]
 [ 0.75051365  0.2883125  -0.83135186 -0.42848096 -1.51048834 -0.43071518
   0.59017915  0.25628828]
 [ 0.38115841 -0.08466821  0.02170363  0.98229443 -1.37154773 -0.32327473
  -1.20469893  0.06098321]
 [-0.99297478  0.94789714 -1.65177143  0.88554675 -0.16072056  0.80961247
   0.06632881 -0.06898069]]
K
 [[ 0.4481545  -1.79396534  0.0908088  -0.74266193 -0.83537924  0.26304504
  -1.28730123 -0.81080393]
 [-1.52075254 -0.58949283 -0.59143947 -1.07363383  1.88036975  0.62861599
  -0.67031112 -0.12577496]
 [-0.1323446  -0.25197528 -0.41384947 -0.24159246  1.10554571 -2.77289787
   1.96111947 -0.08857994]
 [-0.86718853  0.74922133  0.82114015 -1.39302018  0.53917891 -0.7638527
  -0.28285176  1.4148641 ]]
V
 [[-0.07703553 -1.56496974 -0.95833645 -0.17826546  0.07570981  0.09826632
  -0.01383334 -1.28698563]
 [ 0.2261494  -0.64931334 -0.09503368  2.01603016  0.38349369  1.01334865
  -0.30