# Self Attention in Transformers

## Generate Data

In [1]:
# https://www.youtube.com/watch?v=QCJQG4DuHT0&list=PLTl9hO2Oobd97qfWC40gOSU8C0iu0m2l4
import numpy as np
import math
#length of input sequence
#dimension of d_k and d_v 8
L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k) #(4,8)
k = np.random.randn(L, d_k) #(4,8)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[ 0.18366517  0.69549388  0.28884603  1.88644768  0.356903    0.39637446
   1.18332284 -1.58432176]
 [ 0.45665325  0.74760165 -1.64697687  0.50174414  0.96491998 -0.05649554
   1.27325481  0.99872008]
 [ 0.73219883 -1.1298786  -1.47287961 -0.75663372  0.67867084 -0.12142817
   0.67622107  1.43914014]
 [-0.08864795 -0.32775726  0.43245463 -1.37071227  0.14828554 -0.40827039
   0.18127552  0.28698922]]
K
 [[-0.12816748  0.21190403 -0.44249429 -0.13868744  1.42614063 -0.52737663
   0.43228756  0.18449538]
 [-1.73838804 -0.45762746 -0.84376157 -0.05844661  0.00398864  0.86225838
  -0.4556103   1.06896442]
 [ 0.27108615  1.41199393  0.46256588 -0.77860606 -0.21756045  1.15999105
   0.1461041   1.43479781]
 [-1.18038042 -1.83654413  1.1959073   0.93188211 -0.16055269 -0.03169963
   0.41339721  0.26344204]]
V
 [[-0.14600147  0.19098749 -0.52006742 -0.05113114  0.28067984  1.13968417
   0.29878616  0.67347476]
 [ 0.06907276 -1.2488427  -0.07867611  0.83387608 -2.40047255 -0.13323873
  -0.9

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [4]:
np.matmul(q, k.T)

array([[ 0.25358968, -2.88104898, -2.02151729,  0.6112187 ],
       [ 2.89966145,  0.66699015,  1.37043142, -2.77775803],
       [ 2.0131616 ,  1.65950051,  0.38608924, -0.70215742],
       [ 0.4987502 , -0.1079327 ,  0.7128667 ,  0.08609056]])

In [5]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(0.7725460876338028, 0.6963076537474191, 2.440736333874628)

In [6]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(0.7725460876338028, 0.6963076537474191, 0.30509204173432847)

Notice the reduction in variance of the product

In [7]:
scaled

array([[ 0.08965749, -1.01860463, -0.71471429,  0.21609844],
       [ 1.02518514,  0.23581663,  0.48452067, -0.98208577],
       [ 0.71176011,  0.58672203,  0.13650316, -0.24825014],
       [ 0.17633482, -0.03815997,  0.25203644,  0.03043761]])

## Masking

- This is to ensure words don't get context from words generated in the future. 
- Not required in the encoders, but required int he decoders

In [22]:
# https://numpy.org/doc/stable/reference/generated/numpy.tril.html
# import numpy as np
# np.tril([[1,2,3],[4,5,6],[7,8,9],[10,11,12]],-1)

In [23]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [24]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [25]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [26]:
scaled + mask

array([[ 0.08965749,        -inf,        -inf,        -inf],
       [ 1.02518514,  0.23581663,        -inf,        -inf],
       [ 0.71176011,  0.58672203,  0.13650316,        -inf],
       [ 0.17633482, -0.03815997,  0.25203644,  0.03043761]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [27]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [92]:
3/6

0.5

In [126]:
from pprint import pprint
#bradcasting
ar = np.array = [[1,2,3],[4,5,6],[7,8,9]]

k = np.sum(ar,axis =1)
print(np.sum(ar,axis=0)
print(np.sum(ar,axis=1))
print(np.sum(ar,axis=-1))
print(np.sum(ar,axis=-2))
# print(k)
# print(ar/k)

[[12]
 [15]
 [18]]
[ 6 15 24]
[ 6 15 24]
[12 15 18]


In [128]:
def softmax_new(x):
    exp = np.exp(x)
    return exp / np.sum(exp,axis=1).reshape(-1,1)

In [129]:
attention = softmax_new(scaled + mask)

In [130]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.68769572, 0.31230428, 0.        , 0.        ],
       [0.40899401, 0.3609222 , 0.23008379, 0.        ],
       [0.26667831, 0.2151958 , 0.28765008, 0.23047581]])

In [131]:
attention.shape,v.shape

((4, 4), (4, 8))

In [132]:
new_v = np.dot(attention, v)
new_v.shape

(4, 8)

In [133]:
v

array([[-0.14600147,  0.19098749, -0.52006742, -0.05113114,  0.28067984,
         1.13968417,  0.29878616,  0.67347476],
       [ 0.06907276, -1.2488427 , -0.07867611,  0.83387608, -2.40047255,
        -0.13323873, -0.90461205, -0.73292915],
       [-0.16592726, -1.524926  ,  0.61945438, -0.42440031,  0.26770069,
        -0.02320381,  0.1071973 , -1.96950857],
       [ 0.40209635,  0.64090491, -1.08064203, -1.47329574, -0.81535135,
         0.06387228,  1.01493538,  2.53451224]])

# Function

In [None]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [None]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 0.11672673 -2.54870451 -1.44065948  0.93661829  1.36278968  1.04252277
  -0.01310938 -1.3163937 ]
 [ 0.26721599 -0.90218255  0.07417847 -0.10430246  0.52684253 -0.07081531
  -0.60511725 -0.55225527]
 [-0.93297509  0.28724456  1.37184579  0.41589874  0.34981245 -0.24753755
  -1.24497125  0.05044148]
 [-0.11414585 -0.01545749 -0.58376828 -0.40193907  0.93931836 -1.94334363
  -0.34770465  1.50103406]]
K
 [[ 1.1226585  -0.85645535  0.54315044  1.36560451  0.52539476 -0.94502504
  -0.48444661  0.46268014]
 [-0.53713766 -1.16937329 -0.57988617  0.92713577 -0.85995607 -0.40352635
   0.26555146 -1.83159914]
 [-2.06994435 -0.09514715 -1.64928361 -0.17375184  0.13146819 -1.76335363
   1.56568846  0.69751826]
 [ 0.32910684 -0.1939204  -0.80444134  0.78816869  0.35599408  0.28309835
  -0.25970963  1.49744622]]
V
 [[-0.00368231  1.43739233 -0.59614565 -1.23171219  1.12030717 -0.98620738
  -0.15461465 -1.03106383]
 [ 0.85585446 -1.79878344  0.67321704  0.05607552 -0.15542661 -1.41264124
  -0.4