In [109]:
import torch  # Tensor library and helper functions
import torch.nn as nn  # nn.module() and nn.Linear()
import torch.nn.functional as F  # softmax()

In [110]:
class MaskedSelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__()

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, token_encodings, mask=None):
        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            ## Here we are masking out things we don't want to pay attention to
            ##
            ## We replace values we wanted masked out
            ## with a very small negative number so that the SoftMax() function
            ## will give all masked elements an output value (or "probability") of 0.
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

        # TODO: Why do we aggregate on column dimension but not on row dimension?
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

In [111]:
## create a matrix of token encodings...
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

## set the seed for the random number generator
torch.manual_seed(42)

masked_self_attention = MaskedSelfAttention(d_model=2, row_dim=0, col_dim=1)

mask = torch.tril(torch.ones(3, 3))
mask = mask == 0
mask # print out the mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [112]:
masked_self_attention(encodings_matrix, mask).detach().numpy()

array([[ 0.60376686,  0.74339145],
       [-0.0061961 ,  0.6071508 ],
       [ 3.498918  ,  2.242719  ]], dtype=float32)

In [113]:
## print out the weight matrix that creates the queries
masked_self_attention.W_q.weight.transpose(0, 1).detach().numpy()

array([[ 0.5406104 , -0.16565567],
       [ 0.5869042 ,  0.6495562 ]], dtype=float32)

In [114]:
## print out the weight matrix that creates the keys
masked_self_attention.W_k.weight.transpose(0, 1).detach().numpy()

array([[-0.15492964, -0.3442585 ],
       [ 0.14268756,  0.41527158]], dtype=float32)

In [115]:
## print out the weight matrix that creates the values
masked_self_attention.W_v.weight.transpose(0, 1).detach().numpy()

array([[ 0.62334496,  0.61461455],
       [-0.5187534 ,  0.13234162]], dtype=float32)

In [116]:
## calculate the queries
masked_self_attention.W_q(encodings_matrix).detach().numpy()

array([[ 0.76209605, -0.04276264],
       [ 1.1063377 ,  0.78897274],
       [ 1.1163784 , -2.133583  ]], dtype=float32)

In [117]:
## calculate the keys
masked_self_attention.W_k(encodings_matrix).detach().numpy()

array([[-0.14690024, -0.30382735],
       [ 0.10574519,  0.368542  ],
       [-0.9914448 , -2.4151666 ]], dtype=float32)

In [118]:
## calculate the values
masked_self_attention.W_v(encodings_matrix).detach().numpy()

array([[ 0.60376686,  0.74339145],
       [-0.35019803,  0.5303149 ],
       [ 3.8694587 ,  2.4245923 ]], dtype=float32)

In [119]:
q = masked_self_attention.W_q(encodings_matrix)
q.detach().numpy()

array([[ 0.76209605, -0.04276264],
       [ 1.1063377 ,  0.78897274],
       [ 1.1163784 , -2.133583  ]], dtype=float32)

In [120]:
k = masked_self_attention.W_k(encodings_matrix)
k.detach().numpy()

array([[-0.14690024, -0.30382735],
       [ 0.10574519,  0.368542  ],
       [-0.9914448 , -2.4151666 ]], dtype=float32)

In [121]:
sims = torch.matmul(q, k.transpose(dim0=0, dim1=1))
sims.detach().numpy()

array([[-0.09895963,  0.06482816, -0.6522973 ],
       [-0.40223277,  0.40775946, -3.0023735 ],
       [ 0.4842446 , -0.6682633 ,  4.0461307 ]], dtype=float32)

In [122]:
scaled_sims = sims / (torch.tensor(2)**0.5)

In [123]:
scaled_sims.detach().numpy()

array([[-0.06997503,  0.04584043, -0.46124387],
       [-0.28442153,  0.28832948, -2.1229987 ],
       [ 0.34241265, -0.47253352,  2.8610466 ]], dtype=float32)

In [124]:
masked_scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
masked_scaled_sims.detach().numpy()

array([[-6.9975026e-02, -1.0000000e+09, -1.0000000e+09],
       [-2.8442153e-01,  2.8832948e-01, -1.0000000e+09],
       [ 3.4241265e-01, -4.7253352e-01,  2.8610466e+00]], dtype=float32)

In [125]:
attention_percents = F.softmax(masked_scaled_sims, dim=1)
attention_percents.detach().numpy()

array([[1.        , 0.        , 0.        ],
       [0.36060232, 0.63939774, 0.        ],
       [0.0721798 , 0.03195134, 0.8958689 ]], dtype=float32)

In [126]:
torch.matmul(attention_percents, masked_self_attention.W_v(encodings_matrix)).detach().numpy()

array([[ 0.60376686,  0.74339145],
       [-0.0061961 ,  0.6071508 ],
       [ 3.498918  ,  2.242719  ]], dtype=float32)