### Masking in Attention

#### 1. Casual Attention Mask : Decoder 

In [1]:
import torch
embed_dim = 4
mha = torch.nn.MultiheadAttention(embed_dim=embed_dim, num_heads=1, batch_first=True)

# assume we have a batch of 2 sentences. 1st has 3 tokens and 2nd has 2 tokens
embeddings = torch.normal(mean=0, std=1, size=(2, 3, embed_dim))
# create a padding mask with all zeros so that every token is valid by default
key_padding_mask = torch.zeros(size=(2, 3), dtype=torch.bool)
# 3rd token of second sentence is a pad token
key_padding_mask[1, 2] = 1

_, torch_attn_mask = mha(embeddings, embeddings, embeddings, key_padding_mask=key_padding_mask)
print(torch_attn_mask)

tensor([[[0.3844, 0.3415, 0.2741],
         [0.2364, 0.2711, 0.4925],
         [0.3552, 0.3550, 0.2898]],

        [[0.4541, 0.5459, 0.0000],
         [0.4614, 0.5386, 0.0000],
         [0.4634, 0.5366, 0.0000]]], grad_fn=<MeanBackward1>)


In [2]:
embeddings

tensor([[[ 1.8123e-01,  3.6679e-01, -1.3186e+00, -3.0594e-01],
         [ 2.7683e-01, -6.5081e-01,  2.2066e-02, -1.2154e+00],
         [-1.1054e+00,  8.5865e-01,  5.0024e-01,  5.5196e-01]],

        [[-1.0171e+00, -2.0578e-01, -9.2888e-01, -1.8757e+00],
         [-1.2949e+00, -8.5911e-01,  6.2310e-04, -1.0198e+00],
         [-6.9919e-01, -3.5184e-01,  1.5795e+00,  7.6255e-01]]])

In [3]:
key_padding_mask 

tensor([[False, False, False],
        [False, False,  True]])

In [4]:
# reshape mask to proper shape
key_padding_mask_expanded = key_padding_mask.unsqueeze(1) # (bs, 1, seq_len)
# expand 3 times in the 2nd dimension since we have 3 tokens
key_padding_mask_expanded = key_padding_mask_expanded.expand(-1, 3, -1)
print(key_padding_mask_expanded)

tensor([[[False, False, False],
         [False, False, False],
         [False, False, False]],

        [[False, False,  True],
         [False, False,  True],
         [False, False,  True]]])


In [5]:
# compute dot-product between Query and Key tokens
scores = embeddings @ embeddings.transpose(1, 2)
print(scores)
# where ever the mask value is True, fill the corresponding entry in scores to -inf
scores = scores.masked_fill(key_padding_mask_expanded, -torch.inf)
print(scores)
attn_weights = torch.softmax(scores, dim=-1)
print(attn_weights.round(decimals=2))

tensor([[[ 1.9997,  0.1542, -0.7139],
         [ 0.1542,  1.9779, -1.5246],
         [-0.7139, -1.5246,  2.5140]],

        [[ 5.4580,  3.4061, -2.1140],
         [ 3.4061,  3.4548,  0.4310],
         [-2.1140,  0.4310,  3.6889]]])
tensor([[[ 1.9997,  0.1542, -0.7139],
         [ 0.1542,  1.9779, -1.5246],
         [-0.7139, -1.5246,  2.5140]],

        [[ 5.4580,  3.4061,    -inf],
         [ 3.4061,  3.4548,    -inf],
         [-2.1140,  0.4310,    -inf]]])
tensor([[[0.8200, 0.1300, 0.0500],
         [0.1400, 0.8400, 0.0300],
         [0.0400, 0.0200, 0.9500]],

        [[0.8900, 0.1100, 0.0000],
         [0.4900, 0.5100, 0.0000],
         [0.0700, 0.9300, 0.0000]]])


In [6]:
scores = embeddings @ embeddings.transpose(1, 2)
# create a float_mask as I describe previously
float_mask = torch.zeros_like(key_padding_mask_expanded, dtype=torch.float32).masked_fill(key_padding_mask_expanded, -torch.inf)
# add the float mask to the scores and apply softmax function
print(torch.softmax(scores + float_mask, dim=-1).round(decimals=2))

tensor([[[0.8200, 0.1300, 0.0500],
         [0.1400, 0.8400, 0.0300],
         [0.0400, 0.0200, 0.9500]],

        [[0.8900, 0.1100, 0.0000],
         [0.4900, 0.5100, 0.0000],
         [0.0700, 0.9300, 0.0000]]])


In [7]:
# we have 2 sentences and 3 tokens
causal_mask = torch.ones((2, 3, 3), dtype=torch.bool)
causal_mask = torch.triu(causal_mask, diagonal=1)
print(causal_mask)
print(mha(embeddings, embeddings, embeddings, attn_mask=causal_mask)[1])

tensor([[[False,  True,  True],
         [False, False,  True],
         [False, False, False]],

        [[False,  True,  True],
         [False, False,  True],
         [False, False, False]]])
tensor([[[1.0000, 0.0000, 0.0000],
         [0.4658, 0.5342, 0.0000],
         [0.3552, 0.3550, 0.2898]],

        [[1.0000, 0.0000, 0.0000],
         [0.4614, 0.5386, 0.0000],
         [0.3019, 0.3496, 0.3485]]], grad_fn=<MeanBackward1>)


In [8]:
causal_mask = torch.nn.Transformer.generate_square_subsequent_mask(sz=3) # we have 3 tokens, so size=3
print(mha(embeddings, embeddings, embeddings, attn_mask=causal_mask)[1])

tensor([[[1.0000, 0.0000, 0.0000],
         [0.4658, 0.5342, 0.0000],
         [0.3552, 0.3550, 0.2898]],

        [[1.0000, 0.0000, 0.0000],
         [0.4614, 0.5386, 0.0000],
         [0.3019, 0.3496, 0.3485]]], grad_fn=<MeanBackward1>)


In [9]:
causal_mask

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])