# How to do Masking

2 types that are used in transformer:
- Padding Mask: Ensures the model does not pay attention to the padding tokens added to sequences within a batch to make them the same length.
- Causal Mask (or Look-Ahead Mask): Used specifically in decoder self-attention to prevent a position from attending to subsequent (future) positions. This maintains the autoregressive property – predictions for position `i` can only depend on outputs at positions `<i`.

In [1]:
import torch
import torch.nn as nn

In [2]:
torch.manual_seed(100)
a = torch.randn(3)
b = torch.randn(2)
c = torch.randn(5)
a, b, c

(tensor([ 0.3607, -0.2859, -0.3938]),
 tensor([ 0.2429, -1.3833]),
 tensor([-2.3134, -0.3172, -0.8660,  1.7482, -0.2759]))

First, I need to determine the maximum seq_len in these batches (3) and then increase the seq_len of every other batch to the max seq_len using a padding token ID of 0.

In [3]:
def pad_sequence(seq, max_len, pad_value=0.0):
    # Convert sequence to tensor if it's not already
    seq_tensor = torch.tensor(seq) if not isinstance(seq, torch.Tensor) else seq
    current_len = len(seq_tensor)
    
    if current_len < max_len:
        # Create padding
        padding = torch.full((max_len - current_len,), pad_value)
        # Concatenate the original sequence with padding
        return torch.cat([seq_tensor, padding])
    return seq_tensor[:max_len]  # Truncate if longer than max_len

In [4]:
# Find max length among sequences
max_len = max(len(a), len(b), len(c))

# Pad all sequences to max_len
padded_a = pad_sequence(a, max_len)
padded_b = pad_sequence(b, max_len)
padded_c = pad_sequence(c, max_len)

print('Original lengths:', len(a), len(b), len(c))
print('After padding:', len(padded_a), len(padded_b), len(padded_c))
print('\nPadded sequences:')
print('a:', padded_a)
print('b:', padded_b)
print('c:', padded_c)

Original lengths: 3 2 5
After padding: 5 5 5

Padded sequences:
a: tensor([ 0.3607, -0.2859, -0.3938,  0.0000,  0.0000])
b: tensor([ 0.2429, -1.3833,  0.0000,  0.0000,  0.0000])
c: tensor([-2.3134, -0.3172, -0.8660,  1.7482, -0.2759])


In [5]:
def create_padding_mask(sequence):
    # Create mask (1 for real values, 0 for padding)
    return (sequence != 0).float().unsqueeze(0).unsqueeze(0)

# Example usage
mask_a = create_padding_mask(padded_a)
mask_b = create_padding_mask(padded_b)
mask_c = create_padding_mask(padded_c)

print('Padding masks:')
print('a:', mask_a)
print('b:', mask_b)
print('c:', mask_c)

Padding masks:
a: tensor([[[1., 1., 1., 0., 0.]]])
b: tensor([[[1., 1., 0., 0., 0.]]])
c: tensor([[[1., 1., 1., 1., 1.]]])


In [6]:
list_of_tensors = [a, b, c]

In [7]:
list_of_tensors

[tensor([ 0.3607, -0.2859, -0.3938]),
 tensor([ 0.2429, -1.3833]),
 tensor([-2.3134, -0.3172, -0.8660,  1.7482, -0.2759])]

In [8]:
from torch.nn.utils.rnn import pad_sequence

padded_batch = pad_sequence(sequences=list_of_tensors, batch_first=True, padding_value=0.0)
padded_batch

tensor([[ 0.3607, -0.2859, -0.3938,  0.0000,  0.0000],
        [ 0.2429, -1.3833,  0.0000,  0.0000,  0.0000],
        [-2.3134, -0.3172, -0.8660,  1.7482, -0.2759]])

In [9]:
padding_mask_bool = (padded_batch != 0)
padding_mask = padding_mask_bool.float()
padding_mask

tensor([[1., 1., 1., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1.]])

In [10]:
# Reshape the mask for broadcasting: [batch_size, 1, 1, key_seq_len]
# This shape is needed so it aligns with attn_scores [batch, num_heads, query_seq_len, key_seq_len]
# The mask will be applied based on the KEY sequence.
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
padding_mask.shape

torch.Size([3, 1, 1, 5])

## As a proper function in a Transformer

In [11]:
def create_padding_mask(input_ids: torch.Tensor, padding_idx: int = 0) -> torch.Tensor:
    """
    Creates a padding mask for multi-head attention based on input token IDs.

    The mask identifies positions in the sequence that correspond to padding tokens.
    It's shaped for broadcasting compatibility with attention scores inside MHA.

    Convention:
        - 1.0 : Represents a token that should be attended to (kept).
        - 0.0 : Represents a padding token that should be masked out (ignored).

    Args:
        input_ids (torch.Tensor): A tensor of token IDs with shape
            [batch_size, sequence_length].
        padding_idx (int, optional): The index representing the padding token
            in the vocabulary. Defaults to 0.

    Returns:
        torch.Tensor: A float tensor representing the padding mask with shape
            [batch_size, 1, 1, sequence_length]. Ready to be used in
            MultiHeadAttention.
    """
    # 1. Create boolean mask: True where input is NOT padding, False where it IS padding.
    # Shape: [batch_size, sequence_length]
    mask = (input_ids != padding_idx)

    # 2. Convert boolean mask to float (True -> 1.0, False -> 0.0).
    # This matches the convention needed for `masked_fill(mask == 0, -inf)`.
    # Shape: [batch_size, sequence_length]
    mask = mask.float()

    # 3. Reshape for broadcasting within MHA.
    # Add dimensions for `num_heads` (dim 1) and `query_sequence_length` (dim 2).
    # The mask applies based on the key sequence length (the last dimension).
    # Shape: [batch_size, 1, 1, sequence_length]
    # The mask tensor will automatically be on the same device as input_ids.
    return mask.unsqueeze(1).unsqueeze(2)

# --- Example Usage ---

# Assume we have a batch of token IDs after tokenization and padding
# (batch_size=3, seq_len=7, padding_idx=0)
input_ids_batch = torch.tensor([
    [101, 567, 890, 102,   0,   0,   0], # Seq length 4
    [101, 432, 102,   0,   0,   0,   0], # Seq length 3
    [101, 666, 777, 888, 999, 555, 102]  # Seq length 7 (no padding)
])

# Create the padding mask using the function
padding_mask = create_padding_mask(input_ids_batch, padding_idx=0)

print("Input IDs Batch Shape:", input_ids_batch.shape)
print("Padding Mask Shape:", padding_mask.shape)
print("\nPadding Mask (Batch 0):")
# Print the mask for the first item in the batch (squeeze extra dims for readability)
print(padding_mask[0].squeeze())
print("\nPadding Mask (Batch 1):")
print(padding_mask[1].squeeze())
print("\nPadding Mask (Batch 2):")
print(padding_mask[2].squeeze())

# --- How to integrate with our MHA ---

# Assume 'embeddings' is the output of our nn.Embedding layer applied to input_ids_batch
# embeddings = embedding_layer(input_ids_batch) -> Shape [3, 7, 512]

# Instantiate our MHA module
# multihead_attn = MultiHeadAttention(num_heads=8, d_model=512, dropout=0.1)

# Pass the embeddings and the created mask to the forward method
# output, attn_weights = multihead_attn(
#     query_input=embeddings,
#     key_input=embeddings,   # For self-attention
#     value_input=embeddings, # For self-attention
#     mask=padding_mask       # Pass the generated mask here
# )

# print("\nOutput shape from MHA:", output.shape)
# print("Attention weights shape:", attn_weights.shape)

Input IDs Batch Shape: torch.Size([3, 7])
Padding Mask Shape: torch.Size([3, 1, 1, 7])

Padding Mask (Batch 0):
tensor([1., 1., 1., 1., 0., 0., 0.])

Padding Mask (Batch 1):
tensor([1., 1., 1., 0., 0., 0., 0.])

Padding Mask (Batch 2):
tensor([1., 1., 1., 1., 1., 1., 1.])
