In [1]:
from mask import (
    create_decoder_atn_mask,
    create_encoder_atn_mask,
)
import torch

In [2]:
batch_size = 2
src_len = 4
tgt_len = 5
num_heads = 2

min_val, max_val = 0, 10

In [3]:
src = torch.randint(min_val, max_val, (batch_size, src_len))
tgt = torch.randint(min_val, max_val, (batch_size, tgt_len))

src.shape, tgt.shape

(torch.Size([2, 4]), torch.Size([2, 5]))

In [4]:
src_mask = create_encoder_atn_mask(src).type(torch.int64)
tgt_mask = create_decoder_atn_mask(tgt).type(torch.int64)
src_tgt_mask = create_encoder_atn_mask(src).type(torch.int64)

In [5]:
src_mask

tensor([[[[1, 4, 7, 1]]],


        [[[6, 6, 8, 1]]]])

In [6]:
tgt_mask

tensor([[[[0, 0, 0, 0, 0],
          [0, 1, 0, 0, 0],
          [0, 1, 1, 0, 0],
          [0, 1, 1, 1, 0],
          [0, 1, 1, 1, 1]]],


        [[[0, 0, 0, 0, 0],
          [0, 1, 0, 0, 0],
          [0, 1, 0, 0, 0],
          [0, 1, 0, 1, 0],
          [0, 1, 0, 1, 0]]]])

In [7]:
attn_src_weight = torch.randint(min_val, max_val, (batch_size, num_heads, src_len, src_len)).type(torch.float32)
attn_tgt_weight = torch.randint(min_val, max_val, (batch_size, num_heads, tgt_len, tgt_len)).type(torch.float32)
attn_src_tgt_weight = torch.randint(min_val, max_val, (batch_size, num_heads, tgt_len, src_len)).type(torch.float32)

In [8]:
w1 = attn_src_weight.masked_fill(
    mask=src_mask == 0,
    value=float("-inf"),
)
w2 = attn_tgt_weight.masked_fill(
    mask=tgt_mask == 0,
    value=float("-inf"),
)
w3 = attn_src_tgt_weight.masked_fill(
    mask=src_tgt_mask == 0,
    value=float("-inf"),
)

In [9]:
attn_src_weight.shape, attn_tgt_weight.shape, attn_src_tgt_weight.shape

(torch.Size([2, 2, 4, 4]), torch.Size([2, 2, 5, 5]), torch.Size([2, 2, 5, 4]))

In [10]:
src_mask.shape, tgt_mask.shape, src_tgt_mask.shape

(torch.Size([2, 1, 1, 4]), torch.Size([2, 1, 5, 5]), torch.Size([2, 1, 1, 4]))

In [11]:
w1.shape, w2.shape, w3.shape

(torch.Size([2, 2, 4, 4]), torch.Size([2, 2, 5, 5]), torch.Size([2, 2, 5, 4]))

In [12]:
src_mask_expand = src_mask.expand(-1, -1, attn_src_weight.size(2), -1).expand(-1, attn_src_weight.size(1), -1, -1)
src_mask_expand.shape

torch.Size([2, 2, 4, 4])

In [13]:
src_mask_expand

tensor([[[[1, 4, 7, 1],
          [1, 4, 7, 1],
          [1, 4, 7, 1],
          [1, 4, 7, 1]],

         [[1, 4, 7, 1],
          [1, 4, 7, 1],
          [1, 4, 7, 1],
          [1, 4, 7, 1]]],


        [[[6, 6, 8, 1],
          [6, 6, 8, 1],
          [6, 6, 8, 1],
          [6, 6, 8, 1]],

         [[6, 6, 8, 1],
          [6, 6, 8, 1],
          [6, 6, 8, 1],
          [6, 6, 8, 1]]]])

In [14]:
w1_expand = attn_src_weight.masked_fill(
    mask=src_mask_expand == 0,
    value=float("-inf"),
)

In [15]:
w1_expand.shape

torch.Size([2, 2, 4, 4])

In [16]:
w1_expand == w1

tensor([[[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]]],


        [[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]]]])

In [17]:
tgt_mask_expand = tgt_mask.expand(-1, attn_tgt_weight.size(1), -1, -1)

In [18]:
w2_expand = attn_tgt_weight.masked_fill(
    mask=tgt_mask_expand == 0,
    value=float("-inf"),
)

In [19]:
w2_expand == w2

tensor([[[[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]],

         [[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]]],


        [[[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]],

         [[True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True],
          [True, True, True, True, True]]]])