In [5]:
import torch

# Define the token IDs for the decoder
decoder_token_ids = torch.tensor([68, 72, 96])
sos_token_id = torch.tensor([2]) # SOS token ID
pad_token_id = torch.tensor([0]) # padding token ID
decoder_padding_num = 4 # number of padding tokens to add to the decoder input
seq_len = 8 # sequence length

# Create the decoder input by concatenating the SOS token, the decoder token IDs, and the padding tokens
decoder_input = torch.cat(
            [
                sos_token_id,
                decoder_token_ids,
                pad_token_id.repeat(decoder_padding_num),
            ]
        )

print("decoder_input:", decoder_input)

# Create a mask that indicates where the padding tokens are in the decoder input
padding_mask = (decoder_input != pad_token_id).unsqueeze(0).int() # (1, seq_len)
print("padding_mask: \n", padding_mask)
print("padding_mask shape:", padding_mask.shape)

# Create a causal mask that prevents the decoder from seeing future tokens in the sequence
causal_mask = torch.triu(torch.ones((1, seq_len, seq_len)), diagonal=1).type(torch.int) == 0
print('causal_mask: \n', causal_mask)
print('causal_mask shape:', causal_mask.shape)

# Combine the padding mask and the causal mask to create the final decoder mask
decoder_mask = padding_mask & causal_mask
print('decoder_mask: \n', decoder_mask)
print('decoder_mask shape:', decoder_mask.shape)

decoder_input: tensor([ 2, 68, 72, 96,  0,  0,  0,  0])
padding_mask: 
 tensor([[1, 1, 1, 1, 0, 0, 0, 0]], dtype=torch.int32)
padding_mask shape: torch.Size([1, 8])
causal_mask: 
 tensor([[[ True, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True]]])
causal_mask shape: torch.Size([1, 8, 8])
decoder_mask: 
 tensor([[[1, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0],
         [1, 1,