Self attention


In [2]:
# imports
import numpy as np
import math

In [3]:
# for our example
# we are considering 'My name is Satya' sentence
# the length of this sentence (no of words) = 4
seq_len = 4

# let's take the dimension of word embeding = 8
d_dim = 8

# word_embed_dim = q_dim = k_dim = v_dim

# generating q,k,v matrices
# shape of each matrix will be (seq_len, d_dim)
q_mat = np.random.randn(seq_len,d_dim)
print(f'The shape of Q matrix is {q_mat.shape}\n {q_mat}')
k_mat = np.random.randn(seq_len,d_dim)
print(f'The shape of K matrix is {k_mat.shape}\n {k_mat}')
v_mat =np.random.randn(seq_len,d_dim)
print(f'The shape of V matrix is {v_mat.shape}\n {v_mat}')

The shape of Q matrix is (4, 8)
 [[-0.00321774 -0.3779472  -0.91467156  0.21036567 -1.0191051  -0.53591204
   1.28280734 -0.70308224]
 [ 0.66052383  2.22399239  1.23155579 -0.0514911   1.32791039 -1.04957902
   0.1978133  -0.01616973]
 [-0.77985494 -0.14101652  1.59431799  0.66724715 -0.06256293 -1.26988843
  -0.15370181 -1.32643477]
 [ 0.13122343 -1.88428763 -1.40611083 -0.84072834  0.02527446  0.58994408
  -0.78625852 -0.8132695 ]]
The shape of K matrix is (4, 8)
 [[ 1.82249133  0.82941662 -0.20053611 -1.01045564 -0.34123768  0.11549175
   0.34662191 -0.56428335]
 [-0.62026462 -0.10328412 -1.73701664  3.2175722   1.84092872  0.30026842
  -0.97256273  0.82610147]
 [ 0.96093656 -0.66796415 -0.93710505  0.22011191  1.65140729  1.76855159
  -0.22928647  0.09913789]
 [ 0.08842656 -0.23808166 -0.70630787  1.02076699 -0.41354525 -0.45657494
   0.43511297 -1.15513179]]
The shape of V matrix is (4, 8)
 [[ 0.55163977  0.33126945 -0.51472954 -0.20295922  1.21707387  1.03348466
   0.14556205  0.

In [4]:
# q.kT
q_kT = np.matmul(q_mat, k_mat.T)
print(f'q_kT shape ; {q_kT.shape}\n{q_kT}')

q_kT shape ; (4, 4)
[[ 0.77876988 -1.55874695 -1.84176767  2.98692065]
 [ 2.35681322 -1.02062263 -1.72652934 -1.35869215]
 [-1.96228646 -1.56692616 -4.44780981  1.59063924]
 [ 0.05367787  0.16707017  3.70209083  0.9126937 ]]


In [5]:
# scale q_kt/sqrt(d_dim)
scaled = q_kT/math.sqrt(d_dim)
print(f'scaled shape : {scaled.shape}\n{scaled}')

scaled shape : (4, 4)
[[ 0.27533673 -0.55110027 -0.6511632   1.05603592]
 [ 0.8332593  -0.36084459 -0.6104203  -0.48037022]
 [-0.69377303 -0.55399206 -1.57253824  0.56237589]
 [ 0.01897799  0.05906822  1.30888676  0.32268595]]


In [6]:
# mask
# the shape of the mask would be same as q_kT, i.e., [seq_len, seq_len]
mask = np.tril(np.ones((seq_len,seq_len)))
print(f'mask shape : {mask.shape}\n{mask}')

mask shape : (4, 4)
[[1. 0. 0. 0.]
 [1. 1. 0. 0.]
 [1. 1. 1. 0.]
 [1. 1. 1. 1.]]


In [9]:
# for softmax, the zeros in the upper diagonal needs to be set to -infinity
# we will set 1 to 0 and 0 to -infity

mask[mask==0]= -np.inf
mask[mask==1] = 0

print(f'mask shape : {mask.shape}\n{mask}')

mask shape : (4, 4)
[[  0. -inf -inf -inf]
 [  0.   0. -inf -inf]
 [  0.   0.   0. -inf]
 [  0.   0.   0.   0.]]


In [10]:
# adding mask to the scaled
# scaled + mask
# then applying softmax
scaled + mask

array([[ 0.27533673,        -inf,        -inf,        -inf],
       [ 0.8332593 , -0.36084459,        -inf,        -inf],
       [-0.69377303, -0.55399206, -1.57253824,        -inf],
       [ 0.01897799,  0.05906822,  1.30888676,  0.32268595]])

In [11]:
# softmax
import numpy as np

x = np.array([[1, 2, 3], [1, 2, 3]])
#print('main array x shape ; ',x.shape)

# Compute the softmax
# softmax = (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

attention = softmax(scaled + mask)

In [12]:
new_v = np.matmul(attention, v_mat)
print(new_v, new_v.shape)

[[ 0.55163977  0.33126945 -0.51472954 -0.20295922  1.21707387  1.03348466
   0.14556205  0.14549306]
 [ 0.52459434  0.26499682 -0.2459435  -0.20360945  0.74367678  0.48018717
  -0.14423018 -0.21064683]
 [ 0.35263425  0.11837709  0.17507933 -0.05020481  0.17125333 -0.25159593
  -0.623385   -0.71988646]
 [ 0.00492915  0.33837452  0.10923947  0.81378563  0.19496172  0.0014559
  -0.70012792 -0.75036496]] (4, 8)


Positional Embedding

In [13]:
# import
import torch as t
import torch.nn as nn

In [14]:
# the sequence of lenght, i.e., the number of words that will be fed at once to the encoder = 10
# d_dim = 6
seq_len = 10
d_dim = 6

In [15]:
even_i = 2 * t.arange(0,d_dim,1).float()    # 2i
print(even_i)

tensor([ 0.,  2.,  4.,  6.,  8., 10.])


In [16]:
denom = t.pow(1000,(even_i/d_dim))  #10000 ** (2i/d_model)
print(denom)

tensor([1.0000e+00, 1.0000e+01, 1.0000e+02, 1.0000e+03, 1.0000e+04, 1.0000e+05])


In [17]:
# even positions in the sequence
even_pos = t.arange(0,seq_len,2).float()    # 2i
print(f'even positions in the sequence :\n{even_pos}')

# odd positions in the sequence
odd_pos = t.arange(1,seq_len,2).float()   # 2i+1
print(f'even positions in the sequence :\n{odd_pos}')

even positions in the sequence :
tensor([0., 2., 4., 6., 8.])
even positions in the sequence :
tensor([1., 3., 5., 7., 9.])


In [18]:
# reshaping to (10,1)
even_pos = even_pos.reshape(even_pos.shape[0],1)
print(even_pos.shape)

odd_pos = odd_pos.reshape(odd_pos.shape[0],1)
print(odd_pos.shape)

torch.Size([5, 1])
torch.Size([5, 1])


In [19]:
even_pe = t.sin(even_pos/denom) # sin for even positions
print(f'even position tensor :\n{even_pe.shape}\n{even_pe}')

odd_pe = t.cos(odd_pos/denom) # cos for odd positions
print(f'odd position tensor :\n{odd_pe.shape}\n{odd_pe}')

even position tensor :
torch.Size([5, 6])
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 9.0930e-01,  1.9867e-01,  1.9999e-02,  2.0000e-03,  2.0000e-04,
          2.0000e-05],
        [-7.5680e-01,  3.8942e-01,  3.9989e-02,  4.0000e-03,  4.0000e-04,
          4.0000e-05],
        [-2.7942e-01,  5.6464e-01,  5.9964e-02,  6.0000e-03,  6.0000e-04,
          6.0000e-05],
        [ 9.8936e-01,  7.1736e-01,  7.9915e-02,  7.9999e-03,  8.0000e-04,
          8.0000e-05]])
odd position tensor :
torch.Size([5, 6])
tensor([[ 0.5403,  0.9950,  0.9999,  1.0000,  1.0000,  1.0000],
        [-0.9900,  0.9553,  0.9996,  1.0000,  1.0000,  1.0000],
        [ 0.2837,  0.8776,  0.9988,  1.0000,  1.0000,  1.0000],
        [ 0.7539,  0.7648,  0.9976,  1.0000,  1.0000,  1.0000],
        [-0.9111,  0.6216,  0.9960,  1.0000,  1.0000,  1.0000]])


In [20]:
# stacking even_pos tensor and odd_pos tensor along column
pe = t.stack([even_pe,odd_pe], dim = 1)
print(f'pe tensor shape : {pe.shape}\n{pe}')

pe tensor shape : torch.Size([5, 2, 6])
tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00],
         [ 5.4030e-01,  9.9500e-01,  9.9995e-01,  1.0000e+00,  1.0000e+00,
           1.0000e+00]],

        [[ 9.0930e-01,  1.9867e-01,  1.9999e-02,  2.0000e-03,  2.0000e-04,
           2.0000e-05],
         [-9.8999e-01,  9.5534e-01,  9.9955e-01,  1.0000e+00,  1.0000e+00,
           1.0000e+00]],

        [[-7.5680e-01,  3.8942e-01,  3.9989e-02,  4.0000e-03,  4.0000e-04,
           4.0000e-05],
         [ 2.8366e-01,  8.7758e-01,  9.9875e-01,  9.9999e-01,  1.0000e+00,
           1.0000e+00]],

        [[-2.7942e-01,  5.6464e-01,  5.9964e-02,  6.0000e-03,  6.0000e-04,
           6.0000e-05],
         [ 7.5390e-01,  7.6484e-01,  9.9755e-01,  9.9998e-01,  1.0000e+00,
           1.0000e+00]],

        [[ 9.8936e-01,  7.1736e-01,  7.9915e-02,  7.9999e-03,  8.0000e-04,
           8.0000e-05],
         [-9.1113e-01,  6.2161e-01,  9.9595e-01,  9.9996e-01,

In [21]:
# Reshape the stacked tensor to interleave the rows
pe = pe.view(-1, 6)
print(pe.shape)  # Output: torch.Size([10, 6])
print(pe)

torch.Size([10, 6])
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 5.4030e-01,  9.9500e-01,  9.9995e-01,  1.0000e+00,  1.0000e+00,
          1.0000e+00],
        [ 9.0930e-01,  1.9867e-01,  1.9999e-02,  2.0000e-03,  2.0000e-04,
          2.0000e-05],
        [-9.8999e-01,  9.5534e-01,  9.9955e-01,  1.0000e+00,  1.0000e+00,
          1.0000e+00],
        [-7.5680e-01,  3.8942e-01,  3.9989e-02,  4.0000e-03,  4.0000e-04,
          4.0000e-05],
        [ 2.8366e-01,  8.7758e-01,  9.9875e-01,  9.9999e-01,  1.0000e+00,
          1.0000e+00],
        [-2.7942e-01,  5.6464e-01,  5.9964e-02,  6.0000e-03,  6.0000e-04,
          6.0000e-05],
        [ 7.5390e-01,  7.6484e-01,  9.9755e-01,  9.9998e-01,  1.0000e+00,
          1.0000e+00],
        [ 9.8936e-01,  7.1736e-01,  7.9915e-02,  7.9999e-03,  8.0000e-04,
          8.0000e-05],
        [-9.1113e-01,  6.2161e-01,  9.9595e-01,  9.9996e-01,  1.0000e+00,
          1.0000e+00]])
