In [1]:

import torch

from torch.nn import functional as F



## Attention

* attention converts "x" into K, Q, V and performs the attention mechanism by the matrix multiplication between K and Q


In [2]:

N = 32 

x = torch.randn(N, 40, 512)
x.shape


torch.Size([32, 40, 512])


## Q


In [3]:

wq = torch.randn(N, 512, 64)
wq.shape


torch.Size([32, 512, 64])

In [4]:

bq = torch.randn(  N, 40, 64  )
bq.shape


torch.Size([32, 40, 64])

In [5]:

Q = torch.matmul(  x, wq  ) 
Q.shape


torch.Size([32, 40, 64])

In [6]:

Q = Q + bq
Q.shape


torch.Size([32, 40, 64])


## K 


In [7]:

wk = torch.randn(N, 512, 64)
wk.shape


torch.Size([32, 512, 64])

In [8]:

bk = torch.randn(  N, 40, 64  )
bk.shape


torch.Size([32, 40, 64])

In [9]:

K = torch.matmul(  x, wk  ) 
K.shape


torch.Size([32, 40, 64])

In [10]:

K = K + bk
K.shape


torch.Size([32, 40, 64])


## Attention Q*K = [N, 40, 40]


In [11]:

attention_scores = torch.matmul(   Q, K.transpose( -2, -1 )   )
attention_scores.shape


torch.Size([32, 40, 40])


## V


In [12]:

wv = torch.randn(N, 512, 64)
wv.shape


torch.Size([32, 512, 64])

In [13]:

bv = torch.randn(  N, 40, 64  )
bv.shape


torch.Size([32, 40, 64])

In [14]:

V = torch.matmul(  x, wv  ) 
V.shape


torch.Size([32, 40, 64])

In [15]:

V = V + bv
V.shape


torch.Size([32, 40, 64])

In [16]:

out = torch.matmul( attention_scores , V )
out.shape


torch.Size([32, 40, 64])


## Concatenate All 8 heads


In [17]:

list_head = [ out for i in range(8) ]


In [18]:

for j in range(len(list_head)):
    print(   list_head[j].shape   )


torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])


In [19]:

out_cat = torch.cat(  list_head, dim = -1  )
out_cat.shape


torch.Size([32, 40, 512])


## Another projection for the concatenated 8 heads


In [20]:

8*64


512

In [21]:

w0 = torch.randn(   N, 8*64, 512   )
w0.shape


torch.Size([32, 512, 512])

In [22]:

b0 = torch.randn(  N,  40,  512  )
b0.shape


torch.Size([32, 40, 512])

In [23]:

z = torch.matmul( out_cat, w0  )
z.shape


torch.Size([32, 40, 512])

In [24]:

z = z + b0
z.shape


torch.Size([32, 40, 512])


## The Mask


In [25]:

tril_def = torch.tril(
              torch.ones(10, 10)      ## should be 40 but using 10 for viz
)
tril_def.shape


torch.Size([10, 10])

In [26]:

tril_def 


tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [27]:

## this is just to record tril_deg as a buffer that is not updated during training

'''


import torch.nn as nn

my_tril_reg = nn.Module.register_buffer('tril', tril_def)
my_tril_reg

'''


"\n\n\nimport torch.nn as nn\n\nmy_tril_reg = nn.Module.register_buffer('tril', tril_def)\nmy_tril_reg\n\n"


## Batch of 32 sentences in the attention matrix 40x40


In [28]:

attention_scores.shape


torch.Size([32, 40, 40])

In [29]:

size10_attention = torch.randn(   N, 10, 10  )
size10_attention.shape


torch.Size([32, 10, 10])


## Use the tril for masking


In [30]:

tril_def[:10, :10].shape


torch.Size([10, 10])

In [31]:

tril_def[:10, :10] == 0


tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])

In [32]:

size10_attention = size10_attention.masked_fill(
                            tril_def[:10, :10] == 0,
                            float('-inf')
    
)
size10_attention.shape


torch.Size([32, 10, 10])


## Negative infinities

* softmax makes negative infinities close to zero


In [33]:

size10_attention[0]


tensor([[-0.1414,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [-0.5979,  0.9721,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 2.4868, -0.5984,  0.5984,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [-0.6860,  0.2422,  0.3465, -1.3606,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 0.7698, -0.9402, -1.5300,  0.1798,  0.3901,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [-1.4165, -1.1399,  0.2813, -0.3932, -0.5217,  1.7322,    -inf,    -inf,
            -inf,    -inf],
        [ 0.6403,  0.9066,  0.0405,  0.5469,  1.6245,  0.0386, -0.6646,    -inf,
            -inf,    -inf],
        [-0.3735,  0.9200,  0.2612, -0.1700, -0.5403, -0.1230,  0.0656,  1.3097,
            -inf,    -inf],
        [ 0.2676,  0.8405, -2.3592,  1.5083,  0.1222,  0.8123, -0.3797, -0.3743,
          0.0901,    -inf],
        [-0.1547,  

In [34]:

size10_attention_softmax = F.softmax( size10_attention, dim=-1)
size10_attention_softmax.shape


torch.Size([32, 10, 10])


## Assume batch of only one sentence


In [35]:

size_1_attention = torch.randn(  1, 10, 10  )
size_1_attention.shape


torch.Size([1, 10, 10])

In [36]:

size_1_attention = size_1_attention.masked_fill(
                            tril_def[:10, :10] == 0,
                            float('-inf')
    
)
size_1_attention.shape


torch.Size([1, 10, 10])

In [37]:

size_1_attention.shape


torch.Size([1, 10, 10])

In [38]:

size_1_attention


tensor([[[ 0.3694,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf,    -inf,    -inf],
         [-0.4083, -1.0354,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf,    -inf,    -inf],
         [-0.2602,  0.3039,  0.2962,    -inf,    -inf,    -inf,    -inf,
             -inf,    -inf,    -inf],
         [-2.0628, -0.6610,  1.0288, -0.9754,    -inf,    -inf,    -inf,
             -inf,    -inf,    -inf],
         [-0.4197, -0.4583,  0.3219,  0.3110, -2.2110,    -inf,    -inf,
             -inf,    -inf,    -inf],
         [ 0.7639,  2.3473,  1.0977,  1.4825, -0.2518, -1.1985,    -inf,
             -inf,    -inf,    -inf],
         [-1.9602,  0.9058,  0.1762, -1.3424, -1.8256,  0.7835, -1.0198,
             -inf,    -inf,    -inf],
         [-2.6099,  0.2272,  1.0389,  0.6421, -1.0396, -1.2019,  0.4806,
          -1.5056,    -inf,    -inf],
         [-0.8609,  2.1190, -1.5526,  0.0927, -0.0035, -0.7038,  1.1124,
           0.5445, -0.3165,    -inf],
 

In [39]:

size_1_attention_softmax = F.softmax( size_1_attention, dim=-1)
size_1_attention_softmax.shape


torch.Size([1, 10, 10])

In [40]:

size_1_attention_softmax


tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.6518, 0.3482, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.2221, 0.3905, 0.3874, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.0333, 0.1352, 0.7327, 0.0987, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.1586, 0.1526, 0.3330, 0.3294, 0.0264, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.1018, 0.4960, 0.1422, 0.2089, 0.0369, 0.0143, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.0208, 0.3649, 0.1759, 0.0385, 0.0238, 0.3229, 0.0532, 0.0000,
          0.0000, 0.0000],
         [0.0086, 0.1468, 0.3306, 0.2223, 0.0414, 0.0352, 0.1892, 0.0260,
          0.0000, 0.0000],
         [0.0248, 0.4884, 0.0124, 0.0644, 0.0585, 0.0290, 0.1785, 0.1012,
          0.0428, 0.0000],
         [0.0606, 0.3990, 0.2010, 0.0539, 0.0107, 0.0576, 0.0815, 0.0237,
          0.0813,