<a href="https://colab.research.google.com/github/whistle-hikhi/attention-in-transformers/blob/main/masked_self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
class MaskedSelfAttention(nn.Module):
  def __init__(self, d_model=2, row_dim=0, col_dim=1):
    super().__init__()

    self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

    self.row_dim = row_dim
    self.col_dim = col_dim

  def forward(self, token_encodings, mask=None):
    q = self.W_q(token_encodings)
    k = self.W_k(token_encodings)
    v = self.W_v(token_encodings)

    sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

    scaled_sims = sims / torch.tensor(k.size(self.col_dim) ** 0.5)

    if mask is not None:
      scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

    attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
    attention_scores = torch.matmul(attention_percents, v)

    return attention_scores

In [8]:
## create a matrix of token encodings...
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

## set the seed for the random number generator
torch.manual_seed(42)

## create a masked self-attention object
maskedSelfAttention = MaskedSelfAttention(d_model=2,
                               row_dim=0,
                               col_dim=1)

## create the mask so that we don't use
## tokens that come after a token of interest
mask = torch.tril(torch.ones(3, 3))
mask = mask == 0
mask # print out the mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [9]:
## calculate masked self-attention
maskedSelfAttention(encodings_matrix, mask)

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)