In [13]:
import torch
import torch.nn as nn

## First, some random practice

In [14]:
a = torch.randn(4, 4)
a

tensor([[-0.7444,  1.0120,  0.3675,  2.6592],
        [-0.9483, -1.0301, -0.1990, -0.2448],
        [-0.7577,  1.1288, -0.9627,  0.4792],
        [-1.9023, -0.7524, -0.5739, -0.4064]])

In [48]:
mask = torch.tril(torch.ones_like(a, dtype=torch.bool))
print(f"Mask with only triangular booleans:\n{mask}")
mask = mask == 0  # logic: True == 0 is False, False == 0 is True
print(f"\nMask after doing == 0:\n{mask}")
b = a.masked_fill(mask=mask, value=torch.tensor(float("-inf")))
print(f"\nMasked matrix:\n{b}")

Mask with only triangular booleans:
tensor([[ True, False, False, False],
        [ True,  True, False, False],
        [ True,  True,  True, False],
        [ True,  True,  True,  True]])

Mask after doing == 0:
tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

Masked matrix:
tensor([[-0.7444,    -inf,    -inf,    -inf],
        [-0.9483, -1.0301,    -inf,    -inf],
        [-0.7577,  1.1288, -0.9627,    -inf],
        [-1.9023, -0.7524, -0.5739, -0.4064]])


In [32]:
c = torch.softmax(b, dim=-1)
c

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5204, 0.4796, 0.0000, 0.0000],
        [0.1189, 0.7843, 0.0969, 0.0000],
        [0.0807, 0.2547, 0.3045, 0.3600]])

## Time to build the masked self-attention

In [110]:
torch.manual_seed(42)

token_embed_w_pos = torch.randn(5, 5)

mask = torch.tril(torch.ones_like(token_embed_w_pos, dtype=torch.bool))
mask = mask == 0


class MaskedSelfAttention(nn.Module):
    def __init__(self, embed_dim=5, row_index=0, col_index=1):
        super().__init__()
        self.W_q = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False)
        self.W_k = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False)
        self.W_v = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False)
        self.row_index = row_index
        self.col_index = col_index

    def forward(self, token_embeddings, mask=None):
        q = self.W_q(token_embeddings)
        k = self.W_k(token_embeddings)
        v = self.W_v(token_embeddings)

        sims = torch.matmul(
            q, torch.transpose(k, dim0=self.row_index, dim1=self.col_index)
        )

        scaled_sims = sims / torch.tensor(k.shape[self.col_index] ** 0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(
                mask=mask, value=torch.tensor(float("-inf"))
            )

        attn_percents = torch.softmax(input=scaled_sims, dim=self.col_index)
        self_attn_scores = attn_percents @ v
        return self_attn_scores

In [111]:
token_embed_w_pos, mask

(tensor([[ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784],
         [-1.2345, -0.0431, -1.6047, -0.7521, -0.6866],
         [-0.4934,  0.2415, -1.1109,  0.0915, -2.3169],
         [-0.2168, -1.3847, -0.3957,  0.8034, -0.6216],
         [-0.5920, -0.0631, -0.8286,  0.3309, -1.5576]]),
 tensor([[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False, False,  True],
         [False, False, False, False, False]]))

In [112]:
masked_selfattn = MaskedSelfAttention()

In [113]:
masked_selfattn(token_embed_w_pos, mask=mask)

tensor([[ 9.7145e-01,  1.1592e+00, -3.9426e-01, -5.4707e-01,  5.0947e-01],
        [ 5.7437e-01,  7.7056e-01,  2.8447e-02,  1.2527e-01,  6.6998e-01],
        [ 6.2325e-01,  8.1523e-01,  1.0536e-01, -1.6011e-01,  6.4922e-01],
        [ 1.4171e-01,  4.4811e-01,  3.9255e-01,  1.4965e-05,  4.4730e-01],
        [ 2.7494e-01,  5.2278e-01,  3.0604e-01, -9.1794e-02,  4.7333e-01]],
       grad_fn=<MmBackward0>)

---

### SO, basically, what I did was...

In [153]:
# 1. Create a mask tensor

randinput = torch.randn(5, 5)
mask2 = torch.tril(input=torch.ones_like(randinput, dtype=bool))
mask2 = mask2 == 0
mask2

tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])

In [154]:
randinput

tensor([[ 0.0853,  0.7481, -0.1636, -0.9086,  0.3130],
        [ 0.8050, -1.1134,  0.5258, -1.2000, -0.8326],
        [-0.8129,  0.9700, -0.6758,  0.2043, -0.0265],
        [-0.4138,  0.5184,  0.3418, -2.7016,  0.0666],
        [-0.9120,  0.3682,  0.7050, -1.0838, -0.3889]])

In [155]:
# 2. Apply the mask to the input

masked_randinput = randinput.masked_fill(mask=mask2, value=torch.tensor(float("-inf")))
masked_randinput_softmax = torch.softmax(masked_randinput, dim=-1)
masked_randinput_softmax

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8720, 0.1280, 0.0000, 0.0000, 0.0000],
        [0.1235, 0.7348, 0.1417, 0.0000, 0.0000],
        [0.1733, 0.4402, 0.3689, 0.0176, 0.0000],
        [0.0822, 0.2957, 0.4141, 0.0692, 0.1387]])