To talk about masking one must have training/inference in mind, the GPT models are autoregressive, which is pretty much just a fancy way of saying, "give it a few words/tokens, let it predict the next one". The huge scaling benefit of this lies in the unsupervised nature (no labels needed, to wit, "training data is also test data") allows autoregressive models to train on any amount of text data available, contents from the whole internet is the limit (and has indeed been attempted: https://commoncrawl.org/).  

Masking is just a technicality to ensure that the autoregressive property holds by not allowing the model to look ahead. While it is not a deep concept, its implementation can be a source of confusion. So let's get the questions-party started.

There can be other masks for different needs (e.g. BERT). We focus on the autoregressive one here. 

>   Q1. Why would the model look ahead? 

This is due to the training process and the very heart of attention (processing many tokens at once). Now picture or refer to a diagram for the attention mechanism. It receives a input sequence of n vectors, each is called a token, representing a word. It outputs n vectors, which eventually will be transformed into n predictions for the next-token, for which the "ground truth" should be each word's next neighhour in the input sequence, i.e. the solution is right in the input. It is obvious that we will reveal the ground truth for attention to overfit on if we let it see the whole sequence, so we must have a way to hide the solution so the model cannot cheat -- since during inference we cannot supply the next tokens.

Another way to look at this is that the feature of attention layer that enables it to take in multiple tokens at once (a key advantage over RNNs) means we need something to preserve causality which RNN naturally does, and the trick we use is masking.

>   Q2. So what do we do?

We add an extra step inside attention so that the i th output token can only receive information from 0 to i th (inclusive) input tokens.

>   Q3. What extra step? And why is it positioned right before softmax? And why is the mask a triangular matrix? $-\infty$?

Actually, the core logic is almost trivial if we recast this into a translation issue from vector to matrix notion. 

Masking -- the artifect of matrix notation to capture the vector sums.

The logic with summuing vectors to preserve causality is almost trivial, it's only when we want to translate this into matrix notation that it begins to look like an advanced trick, but if we view it this way, the core logic is as trivial as saying "drop any info from future tokens", or "if (model about to cheat): don't. "

Let's use a seq. of len. 2 to illustrate the process. Attention is really just about seqeunce to sequence which looks like:

[x1, x2] -> attention layer -> [y1, y2]

Let's say y1 is the first to come out, it should only see x1, and y2 should be able to see x1, x2 as discussed above. 

Then we need to mask out info, let's look closer

If we write attention as a sum

$$y_1 = w_1 \cdot v_1 \\ 
\sim \exp(k_1 \cdot q_1)  v_1\\
= \exp(k_1 \cdot q_1)  v_1 + 0 \cdot v_2 \\
= exp(k_1 \cdot q_1)  v_1 + \exp(-\infty) \cdot v_2 $$

$$y_2 = w_3 \cdot v_1+ w_4 \cdot v_2 \\
\sim \exp(k_1 \cdot q_2)  v_1 + \exp(k_2 \cdot q_2)  v2 $$

where ~ means proportional to since we don't need to other constants to illustrate the point, and $w_i$ are weights for matrices below.

At the point, we ask, how do we turn this into matrix operation? 

In matrix form, it looks like:

Y = W*V

where,

$$ W = \begin{bmatrix}
w_1 & 0\\
w_3 & w_4
\end{bmatrix}
$$
 
$$ Y = \begin{bmatrix} y_1 \\ y_2  \end{bmatrix}$$

$$ V = \begin{bmatrix} v_1 \\ v_2  \end{bmatrix}$$

Now, we know M is row-wise softmaxed, so it means before softmax it must look like,

$$ W_{raw} = \begin{bmatrix}
w_{raw1} & -\infty\\
w_{raw3} & w_{raw4}
\end{bmatrix}
$$
 
This is why we want to apply a "mask" to make the elements above diagonal in $W_{raw}$ to be $-\infty$, which will translate into dropping any vectors derived from future tokens in matrix notation.

Now we see clearly where the mask comes in -- right before the softmax step, 
everything else is business as usual.



In [6]:
#Let's demonstrate the masking in action
#first we start as usual until the softmax step

import torch
import torch.nn as nn

emb_dim = 4
seq_len = 5

#input tokens:
x = torch.tensor([0,1,2,3,4]) 
embedding = nn.Embedding(20, emb_dim) #vocab size 20, emb dim 4
x = embedding(x)

#attention matrices
M_K, M_Q, M_V = [nn.Linear(emb_dim, emb_dim, bias=False) for _ in range(3)]
K, Q, V = [M(x) for M in [M_K, M_Q, M_V ]]
W_raw = Q@(K.transpose(-1,-2))

# stop before next steps
# W = F.softmax(W_raw, dim=-2)
# Y = W@V

In [35]:
W_raw

tensor([[ 0.7918, -0.5124,  0.7467, -0.1884, -0.7821],
        [-0.0497, -0.0079, -0.2227,  0.0234,  0.1894],
        [-0.9871,  0.3926,  0.0859,  0.3219,  0.9626],
        [-0.3943,  0.1961, -0.0807,  0.1332,  0.4240],
        [-0.9710,  0.5766, -0.3295,  0.3093,  0.9815]], grad_fn=<MmBackward>)

In [46]:
seq_len = 5
ones = torch.ones((seq_len, seq_len), dtype=torch.uint8)
mask = torch.triu(ones, diagonal=1)
mask

tensor([[0, 1, 1, 1, 1],
        [0, 0, 1, 1, 1],
        [0, 0, 0, 1, 1],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0]], dtype=torch.uint8)

In [38]:
# in Pytorch we can apply mask with another tensor as boolean 
# i.e. if mask[i][j] == 1: W_raw[i][j] = float('-inf')
W_raw[mask] = float('-inf')

  W_raw[mask] = float('-inf')


In [39]:
W_raw

tensor([[ 0.7918,    -inf,    -inf,    -inf,    -inf],
        [-0.0497, -0.0079,    -inf,    -inf,    -inf],
        [-0.9871,  0.3926,  0.0859,    -inf,    -inf],
        [-0.3943,  0.1961, -0.0807,  0.1332,    -inf],
        [-0.9710,  0.5766, -0.3295,  0.3093,  0.9815]],
       grad_fn=<IndexPutBackward>)

In [42]:
import torch.nn.functional as F

# proceed with the next steps to finish attention as usual
W = F.softmax(W_raw, dim=-1)
W

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4896, 0.5104, 0.0000, 0.0000, 0.0000],
        [0.1266, 0.5032, 0.3702, 0.0000, 0.0000],
        [0.1704, 0.3076, 0.2332, 0.2888, 0.0000],
        [0.0548, 0.2576, 0.1041, 0.1972, 0.3862]], grad_fn=<SoftmaxBackward>)

In [43]:
V

tensor([[-0.7833,  0.4562,  0.3547,  0.4063],
        [ 0.2458, -0.4403,  0.2678, -0.7130],
        [-0.0089, -0.5916,  0.7337, -0.9883],
        [ 0.1896, -0.2253,  0.1491, -0.5264],
        [ 0.6331, -0.2735, -0.1155, -0.8580]], grad_fn=<MmBackward>)

In [45]:
Y = W@V
Y

tensor([[-0.7833,  0.4562,  0.3547,  0.4063],
        [-0.2580, -0.0014,  0.3103, -0.1651],
        [ 0.0212, -0.3828,  0.4513, -0.6732],
        [-0.0052, -0.2607,  0.3570, -0.5326],
        [ 0.3014, -0.3001,  0.1496, -0.6995]], grad_fn=<MmBackward>)

In [47]:
# To summarize, we can write the masked version of attention as:

def self_attention_with_mask(x, emb_dim, seq_len):
    M_K, M_Q, M_V = [nn.Linear(emb_dim, emb_dim, bias=False) for _ in range(3)]
    K, Q, V = [M(x) for M in [M_K, M_Q, M_V ]]
    W_raw = Q@(K.transpose(-1,-2))
    # == masking begins ==
    ones = torch.ones((seq_len, seq_len), dtype=torch.uint8)
    mask = torch.triu(ones, diagonal=1)
    W_raw[mask] = float('-inf')
    # == masking ends ==
    W = F.softmax(W_raw, dim=-1)
    Y = W@V
    return Y

In [48]:
self_attention_with_mask(x, emb_dim, seq_len) 
#result will look different each time due to re initialization of nn.Linear

  W_raw[mask] = float('-inf')


tensor([[-0.7491,  0.7891,  0.0687, -0.4984],
        [-0.2918,  0.6358, -0.1248, -0.1398],
        [-0.4738,  0.8256, -0.1297, -0.2850],
        [ 0.0148,  0.7448, -0.4092,  0.0290],
        [ 0.1783,  0.6184, -0.4607,  0.0386]], grad_fn=<MmBackward>)

Next, we will move on to constructing a transformer and the training process.