In [42]:
import torch; torch.set_printoptions(linewidth=200)
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

This tutorial will go over the technique of modifying the attention mask when pre-training transformer models to ensure attention is only paid to tokens relevant to the current sequence text while at the same time leveraging the full context window of the model through packing sequences together.

In [43]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
config = AutoConfig.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_config(config)

Lets assume we would like to process the following three sentences in a single forward pass to make full use of the available GPU resources:

In [44]:
sentence1 = "The cat sat on the mat"
sentence2 = "The dog ate my homework"
sentence3 = "My aunt is a teacher"

We can simply concatenate the tokenized sentences and using either an <bos> or <eos> token, the model will know when a new sentence starts.

In [45]:
sentences = [sentence1, sentence2, sentence3]
tokenized_sentences = tokenizer(sentences, return_attention_mask=False, add_special_tokens=False)["input_ids"]
tokenized_sentences = [t for s in tokenized_sentences for t in s + [tokenizer.eos_token_id]]
tokenizer.decode(tokenized_sentences)

'The cat sat on the mat<|endoftext|>The dog ate my homework<|endoftext|>My aunt is a teacher<|endoftext|>'

The standard attention mask for causal language modeling for the packed sequences would look like this

In [46]:
tokenized_sentences = torch.tensor(tokenized_sentences)
attn_mask = torch.ones(tokenized_sentences.size(0), tokenized_sentences.size(0), dtype=torch.bool).tril()
attn_mask

tensor([[ True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False],

With this mask however, when processing the second sentence, the model can still attend to tokens in the first sentence which is not ideal as the two examples are independent. To fix this we can truncate the attention mask in a certain way.

When having only one sample in the batch it is relatively easy to do in pytorch.

In [47]:
def get_attention_mask_for_packed_sequence(x, token_id, eos: bool = True):
    # store sequence length in variable for easier readability
    T = tokenized_sentences.size(0)
    # get indices of all EOS tokens
    eos_indices = (tokenized_sentences == tokenizer.eos_token_id).nonzero().squeeze()
    # from indices, get length of each sequence
    reps = torch.cat([eos_indices[[0]]+1, eos_indices[1:] - eos_indices[:-1]])
    # repeat each eos index n times along dimension 1 (n is the number of tokens in the sequence)
    repeated_idx = torch.repeat_interleave(eos_indices, reps).view(1,-1).expand(T, -1)
    # create tensor with all indices from 0 to T-1 repeated T times along dimesion 1
    mask_indices = torch.arange(T).view(-1,1).expand(-1, T)
    # create causal mask and additionally mask out all tokens from preceeding sequences
    mask = torch.ones(T, T, dtype=torch.bool).tril().expand(-1, -1)
    mask.masked_fill_(mask_indices > repeated_idx, False)
    # get position ids for packed sequence
    pos_ids = torch.arange(T) - torch.repeat_interleave(torch.cat([torch.tensor([0]), eos_indices+1], dim=0)[:-1], reps)
    return mask, pos_ids


In [48]:
mask, pos_ids = get_attention_mask_for_packed_sequence(tokenized_sentences, tokenizer.eos_token_id)
mask, pos_ids

(tensor([[ True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, 

when having a batch of packed sequences it is a little bit more challenging due to the additional dimension. Lets create a second item of packed sqeuences to get a batch.

In [49]:
sentence4 = "Rome wasn't built in a day"
sentence5 = "My hovercraft is full of eels"

sentences = [sentence4, sentence5]
tokenized_sentences2 = tokenizer(sentences, return_attention_mask=False, add_special_tokens=False)["input_ids"]
tokenized_sentences2 = torch.tensor([t for s in tokenized_sentences2 for t in s + [tokenizer.eos_token_id]])

batch = torch.nn.utils.rnn.pad_sequence([tokenized_sentences, tokenized_sentences2], batch_first=True, padding_value=tokenizer.eos_token_id)

Lets go over the solution step by step. First lets assign the shape of the batch to two variables B and T. This makes the following code more readable.

In [50]:
B, T = batch.shape

Now we will construct a tensor like "repated_index" tensor in the example from above. For this we need the indices of the eos tokens.

In [51]:
eos_idx = (batch.view(-1) == tokenizer.eos_token_id).nonzero(as_tuple=True)[0] + 1
eos_idx

tensor([ 7, 13, 19, 28, 37, 38])

To this index vector we add the 0 index and the last token index for each batch item. This is needed to be able to separate the batch items again later on. We then remove duplicates (in case the first or last index for a batch item is already present) and sort.

In [52]:
eos_idx_expanded = torch.cat([eos_idx, torch.arange(0,B*T+1,T)]).unique().sort()[0]
eos_idx_expanded

tensor([ 0,  7, 13, 19, 28, 37, 38])

Next since our index vector contains the global indices of eos tokens within the batch (e.g. the forst index of the second batch item = T) we need to normalize the indices by the sequence length. For the normalized indices we replace zeros with T. This is needed in the following step.

In [53]:
normalized_idx = eos_idx_expanded - (eos_idx_expanded // T) * T
normalized_idx = torch.where(normalized_idx == 0, T, normalized_idx)
normalized_idx

tensor([19,  7, 13, 19,  9, 18, 19])

With the normalized indices we can check how often we need to repeat each EOS token index to get the correct sequence length. To achieve this we needed to have the last index for each sequence present. If we didnt replace 0s with T in the step beforfe the number of repetitions for the last eos index in each batch would be wrong.

In [54]:
reps = normalized_idx[1:] - normalized_idx[:-1]
reps = torch.where(reps < 1, normalized_idx[1:], reps)
reps

tensor([7, 6, 6, 9, 9, 1])

Now we can create the batched repeated index tensor

In [55]:
repeated_idx = torch.repeat_interleave(normalized_idx[1:], reps).view(B,1,T).expand(-1,T,-1)
repeated_idx[0]

tensor([[ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
        [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
        [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
        [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
        [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
        [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
        [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
        [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
        [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
        [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
        [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
        [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 1

The rest is similar to the example with batch size = 1. We construct a tensor with indices from 0 to T-1 repeated T times along dimension 1 and create a causal mask. We then mask out all tokens from preceeding sequences.

In [56]:
mask_indices = torch.arange(T).view(1,-1,1).expand(B, -1, T)
# create mask
mask = torch.ones(T, T, dtype=torch.bool).tril().expand(B, -1, -1)
mask = mask.masked_fill(mask_indices >= repeated_idx, False)
mask[1]

tensor([[ True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False],

Here is the full function. I added to possiblity to chose between checking eos tokens or bos tokens.

In [57]:
def get_attention_mask_for_packed_sequence(x, token_id, eos: bool = True):
    B, T = x.shape
    eos_idx = (x.view(-1) == token_id).nonzero(as_tuple=True)[0] + eos
    eos_idx_expanded = torch.cat([eos_idx, torch.arange(0,B*T+1,T)]).unique().sort()[0]
    normalized_idx = eos_idx_expanded - (eos_idx_expanded // T) * T
    normalized_idx = torch.where(normalized_idx == 0, T, normalized_idx)
    reps = normalized_idx[1:] - normalized_idx[:-1]
    reps = torch.where(reps < 1, normalized_idx[1:], reps)
    repeated_idx = torch.repeat_interleave(normalized_idx[1:], reps).view(B,1,T).expand(-1,T,-1)
    mask_indices = torch.arange(T).view(1,-1,1).expand(B, -1, T)
    mask = torch.ones(T, T, dtype=torch.bool).tril().expand(B, -1, -1)
    mask = mask.masked_fill(mask_indices >= repeated_idx, False)
    # get position ids for packed sequence
    pos_ids = (torch.arange(B*T) - torch.repeat_interleave(eos_idx_expanded[:-1], reps)).view(B,T)
    return mask, pos_ids

In [58]:
mask, pos_ids = get_attention_mask_for_packed_sequence(batch, tokenizer.eos_token_id)
mask, pos_ids

(tensor([[[ True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False],
          [ True,  True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, 