In [43]:
import torch
import torch.nn.functional as F
import math


### Single head attention

In [44]:
x = torch.randn(1, 4, 6)   # (batch=1, seq_len=4, dim=6)

W_q = torch.randn(6, 6)
W_k = torch.randn(6, 6)
W_v = torch.randn(6, 6)

Q = x @ W_q     # (1,4,6)
K = x @ W_k     # (1,4,6)
V = x @ W_v     # (1,4,6)

scores = Q @ K.transpose(-2, -1) / math.sqrt(6)   # (1,4,4)
scores.shape

torch.Size([1, 4, 4])

### Padding Mask
- let we assume a padded sequence
tokens = [x1, x2, x3, pad]

In [45]:
# we create a mask
# help to zero out the non existant part
pad_mask = torch.tensor([[1,1,1,0]]) # (1,4)

pad_mask_expanded = pad_mask.unsqueeze(1) # convert (1,4) into (1,1,4) for the computation with scores tensor
pad_mask_expanded.shape

torch.Size([1, 1, 4])

In [46]:
# converting 1 -> 0 (allowed) and 0 -> -inf (block)

scores_masked = scores.masked_fill(pad_mask_expanded == 0, float('-inf'))
# masked_fill(mask,value) # if mask true -> keep original else assign value to it

In [47]:
attn = F.softmax(scores_masked, dim=-1)
attn

tensor([[[1.9768e-01, 6.3114e-01, 1.7118e-01, 0.0000e+00],
         [6.2592e-01, 1.4123e-03, 3.7266e-01, 0.0000e+00],
         [2.2770e-02, 9.5557e-01, 2.1659e-02, 0.0000e+00],
         [2.6485e-04, 9.9972e-01, 1.0329e-05, 0.0000e+00]]])

- last past is zero, so it would not considered in the computation

In [48]:
output = attn @ V # (1,4,4) X (1,4,6) -> (1,4,6)

In [49]:
print('input: \n',x)
print('output: \n',output)

input: 
 tensor([[[-1.0791, -0.3096,  1.1264,  0.9998,  1.1162,  0.8280],
         [-0.3296, -1.8103, -0.5017,  1.2415, -1.2396, -0.3374],
         [-1.3646, -0.3819,  0.8389,  0.8109,  1.9820,  0.4137],
         [-0.3675, -1.4777,  1.9310, -0.9693, -1.3933,  0.3477]]])
output: 
 tensor([[[ 2.5162, -0.3909, -3.9355,  3.2794,  0.6374,  3.9869],
         [ 3.2079,  3.7480, -0.3780,  2.7375,  0.1082,  0.2387],
         [ 2.1723, -2.5449, -5.7781,  3.5293,  0.8449,  5.9320],
         [ 2.1258, -2.8386, -6.0291,  3.5624,  0.8713,  6.1971]]])


- padding mask : zero out the part which do not exist

### Casual mask - no looking ahead
- during training transformer act as non-autoregressively and during inference act as autoregressively
- so we need to stop transformer to from seeing(train) on future sequence(data)

- casual mask mimick the behaviour of autoregressive during non-autoregressive training

In [50]:
seq_len = x.size(1)
seq_len

4

In [51]:
casual_mask = torch.triu(torch.ones(seq_len,seq_len), diagonal=1)
# torch.triu(matrix,dim=1) -> convert matrix into upper trinagle
# for dim = 1 upper trinangle , for dim = 0 lower triangle

casual_mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [52]:
scores_masked = scores.masked_fill(casual_mask.bool(), float('-inf'))

scores_masked

tensor([[[ -2.6225,     -inf,     -inf,     -inf],
         [ -0.2130,  -6.3070,     -inf,     -inf],
         [ -2.9483,   0.7886,  -2.9983,     -inf],
         [  0.6890,   8.9251,  -2.5553, -14.8012]]])

- now 1st token only sees first 1st one, 2nd only upto 2nd and so on
- i the token would not going to see ith+ token

In [53]:
attn = F.softmax(scores_masked, dim=-1)

attn

tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [9.9775e-01, 2.2513e-03, 0.0000e+00, 0.0000e+00],
         [2.2770e-02, 9.5557e-01, 2.1659e-02, 0.0000e+00],
         [2.6485e-04, 9.9972e-01, 1.0329e-05, 4.9624e-11]]])

In [54]:
output = attn @ V

In [55]:
print('input: \n',x)
print('output: \n',output)

input: 
 tensor([[[-1.0791, -0.3096,  1.1264,  0.9998,  1.1162,  0.8280],
         [-0.3296, -1.8103, -0.5017,  1.2415, -1.2396, -0.3374],
         [-1.3646, -0.3819,  0.8389,  0.8109,  1.9820,  0.4137],
         [-0.3675, -1.4777,  1.9310, -0.9693, -1.3933,  0.3477]]])
output: 
 tensor([[[ 3.3110,  3.5816, -0.4492,  2.4966, -0.4249,  0.3445],
         [ 3.3084,  3.5671, -0.4618,  2.4990, -0.4220,  0.3577],
         [ 2.1723, -2.5449, -5.7781,  3.5293,  0.8449,  5.9320],
         [ 2.1258, -2.8386, -6.0291,  3.5624,  0.8713,  6.1971]]])


### Combined mask - padding mask + causal mask

- let assume that sequence is [x1,x2,x3,pad]


In [65]:
combined_mask = casual_mask.unsqueeze(0) # (4 x 4) -> (1 x 4 x 4) to compute with input shape

In [57]:
combined_mask.shape, combined_mask

(torch.Size([1, 4, 4]),
 tensor([[[0., 1., 1., 1.],
          [0., 0., 1., 1.],
          [0., 0., 0., 1.],
          [0., 0., 0., 0.]]]))

In [58]:
combined_mask += (pad_mask == 0).unsqueeze(1) # (1,4,4) + (1,1,4)

In [59]:
combined_mask

tensor([[[0., 1., 1., 2.],
         [0., 0., 1., 2.],
         [0., 0., 0., 2.],
         [0., 0., 0., 1.]]])

In [61]:
scores_masked = scores.masked_fill(combined_mask.bool(), float('-inf'))
# 0 -> no change, 1 or 2 -> inf (would be zero during softmax)

scores_masked

tensor([[[-2.6225,    -inf,    -inf,    -inf],
         [-0.2130, -6.3070,    -inf,    -inf],
         [-2.9483,  0.7886, -2.9983,    -inf],
         [ 0.6890,  8.9251, -2.5553,    -inf]]])

In [62]:
attn = F.softmax(scores_masked, dim=-1)

attn

tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [9.9775e-01, 2.2513e-03, 0.0000e+00, 0.0000e+00],
         [2.2770e-02, 9.5557e-01, 2.1659e-02, 0.0000e+00],
         [2.6485e-04, 9.9972e-01, 1.0329e-05, 0.0000e+00]]])

In [63]:
output = attn @ V

output

tensor([[[ 3.3110,  3.5816, -0.4492,  2.4966, -0.4249,  0.3445],
         [ 3.3084,  3.5671, -0.4618,  2.4990, -0.4220,  0.3577],
         [ 2.1723, -2.5449, -5.7781,  3.5293,  0.8449,  5.9320],
         [ 2.1258, -2.8386, -6.0291,  3.5624,  0.8713,  6.1971]]])