# Self Attention in Transformers

## Generate Data

In [1]:
import numpy as np
import math
#Length of the sequence, dimension of the key, dimension of the value
L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)
print(f"Query: {q.shape}, Key: {k.shape}, Value: {v.shape}")

Query: (4, 8), Key: (4, 8), Value: (4, 8)


In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-0.74947248 -0.61467048  0.27795685 -1.52031796  0.19959513 -0.82411098
   0.87743222 -0.10540908]
 [ 1.17429321 -1.84903911  0.04399508 -0.36574474  0.33059822 -0.30801206
  -1.58653368  2.04158393]
 [ 1.03565886 -0.28076427 -0.02179129 -1.15421532  0.68699884  0.82653069
  -0.34137593  1.10470612]
 [ 1.44176285  1.70663505 -0.38310084 -0.67941888 -0.20895877 -0.64475789
   0.90452786 -1.93127615]]
K
 [[ 0.34783044  0.33978026  1.84419648  0.1140325   0.43904318 -0.27640791
   0.38455458 -1.42061539]
 [ 0.53235409  2.15349007 -0.6754666  -0.16133213  1.59204081  0.91056436
   0.18703004 -0.68415461]
 [-0.5601536   1.25902356  0.06203232  1.20840015  0.37877804 -0.7393191
  -0.72186826  0.23106176]
 [ 1.65577883  0.79578684 -1.07266697  0.43970972  0.4445189  -1.99685624
   1.24360732 -0.30873901]]
V
 [[-0.7368274  -0.61469973  0.32486237 -0.0570102   1.52021924  0.54076089
   0.07458097 -0.51615085]
 [-0.04248102 -1.18526217 -0.8761295   1.10760605  2.20175399  1.31060859
  -2.76

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$
$$
M = Mask
$$
$$
\text{new V} = \text{self attention}.V
$$ 

In [3]:
np.matmul(q, k.T) # (L, d_k) * (d_k, L) = (L, L) :(input_sequence_length, input_sequence_length)

array([[ 0.67228716, -1.86156614, -2.1468383 ,  0.16131892],
       [-3.46051378, -4.77508585, -1.45506242, -1.57641328],
       [-1.53444764,  1.17434771, -2.17888867, -1.10343389],
       [ 3.47529607,  5.38183481, -0.20535062,  6.77329111]])

In [4]:
# Why we need sqrt(d_k) in denominator of softmax?
# ANS: TO make the softmax more stable and to avoid saturation of softmax function
# ANS: To normalize the variance of qk.T matrix 
q.var(), k.var(), np.matmul(q, k.T).var()

(1.0064999976924682, 0.8923701905630524, 9.009341877886312)

In [5]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.0064999976924682, 0.8923701905630524, 1.1261677347357888)

Notice the reduction in variance of the product

In [6]:
scaled

array([[ 0.23768941, -0.65816302, -0.75902196,  0.05703485],
       [-1.22347638, -1.68824779, -0.51444225, -0.55734626],
       [-0.54250916,  0.41519461, -0.77035348, -0.39012279],
       [ 1.22870271,  1.90276595, -0.07260241,  2.39472004]])

## Masking

- This is to ensure words don't get context from words generated in the future. 
- Not required in the encoders, but required in the decoders
- Masking is done by setting the values to $-\infty$ before the softmax
- In Encoder, masking is done for padding
- In Decoder, masking is done for padding (padding mask) and future words (look ahead mask)

In [7]:
# look ahead mask
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [8]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [9]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [None]:
scaled + mask

array([[ 0.68537216,        -inf,        -inf,        -inf],
       [ 0.47796088,  0.42358302,        -inf,        -inf],
       [ 0.37611945, -0.30709922, -0.65849946,        -inf],
       [ 0.78209275, -0.99700418,  1.88206279,  0.79213542]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [10]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [11]:
attention = softmax(scaled + mask)

In [None]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.51359112, 0.48640888, 0.        , 0.        ],
       [0.53753304, 0.27144826, 0.1910187 , 0.        ],
       [0.19293995, 0.03256643, 0.57960627, 0.19488734]])

In [12]:
new_v = np.matmul(attention, v)
new_v

array([[-0.7368274 , -0.61469973,  0.32486237, -0.0570102 ,  1.52021924,
         0.54076089,  0.07458097, -0.51615085],
       [-0.46891071, -0.83485383, -0.13854577,  0.39236225,  1.78319251,
         0.83781011, -1.02093191, -0.97611492],
       [-0.30907621, -0.65267498, -0.53722892,  0.71374701,  1.8389214 ,
         0.75122071, -1.52608719, -1.09897583],
       [ 0.39881098, -0.17285712, -0.41416886,  0.58581643,  0.57065709,
         0.49132587, -0.57380911, -0.16152413]])

In [13]:
v

array([[-0.7368274 , -0.61469973,  0.32486237, -0.0570102 ,  1.52021924,
         0.54076089,  0.07458097, -0.51615085],
       [-0.04248102, -1.18526217, -0.8761295 ,  1.10760605,  2.20175399,
         1.31060859, -2.76460533, -1.70821689],
       [-0.64429528,  1.04251321, -0.51087522,  0.3928373 ,  1.05181382,
        -0.81505056,  0.51667582,  0.16278935],
       [ 1.11096906,  0.48076118, -0.35379643,  0.48345357, -0.76334168,
         0.08578319,  0.47118265,  0.86716897]])

In [16]:
v.shape

(4, 8)

# Function

In [17]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [None]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 0.11672673 -2.54870451 -1.44065948  0.93661829  1.36278968  1.04252277
  -0.01310938 -1.3163937 ]
 [ 0.26721599 -0.90218255  0.07417847 -0.10430246  0.52684253 -0.07081531
  -0.60511725 -0.55225527]
 [-0.93297509  0.28724456  1.37184579  0.41589874  0.34981245 -0.24753755
  -1.24497125  0.05044148]
 [-0.11414585 -0.01545749 -0.58376828 -0.40193907  0.93931836 -1.94334363
  -0.34770465  1.50103406]]
K
 [[ 1.1226585  -0.85645535  0.54315044  1.36560451  0.52539476 -0.94502504
  -0.48444661  0.46268014]
 [-0.53713766 -1.16937329 -0.57988617  0.92713577 -0.85995607 -0.40352635
   0.26555146 -1.83159914]
 [-2.06994435 -0.09514715 -1.64928361 -0.17375184  0.13146819 -1.76335363
   1.56568846  0.69751826]
 [ 0.32910684 -0.1939204  -0.80444134  0.78816869  0.35599408  0.28309835
  -0.25970963  1.49744622]]
V
 [[-0.00368231  1.43739233 -0.59614565 -1.23171219  1.12030717 -0.98620738
  -0.15461465 -1.03106383]
 [ 0.85585446 -1.79878344  0.67321704  0.05607552 -0.15542661 -1.41264124
  -0.4