In [3]:
import torch
from rich import print
import einops
import math

In [17]:
x = einops.repeat(torch.arange(20), "length -> batch length", batch=3)

In [5]:
X_DENOISING_TOKEN = 100
MASK_TOKEN = -1

In [31]:
@torch.no_grad()
def mask_spans(
    sequence: torch.Tensor,
    mask_width: int,
    num_masks: int,
    max_percentage_masked: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    targets = sequence.roll(1)
    targets[:, 0] = X_DENOISING_TOKEN

    num_masks = min(num_masks, math.floor(max_percentage_masked * sequence.shape[-1] / mask_width))

    batch_size, seq_length = sequence.shape
    mask = torch.randint(0, seq_length - mask_width + 1, (batch_size, num_masks))
    
    # Create a tensor to hold the mask
    mask_tensor = torch.zeros_like(sequence)
    
    # Create a range for the mask width
    width_range = torch.arange(mask_width).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, mask_width)
    
    # Expand the mask to the desired width
    mask_positions = mask.unsqueeze(-1) + width_range  # Shape: (batch_size, num_masks, mask_width)
    
    # Clip the positions to stay within sequence bounds
    mask_positions = mask_positions.clamp(0, seq_length - 1)
    
    # Scatter the mask positions into the mask tensor
    mask_tensor.scatter_(1, mask_positions.view(batch_size, -1), 1)

    # Apply the mask to the sequence
    inputs = sequence.masked_fill(mask_tensor.bool(), MASK_TOKEN)

    # Add task token
    inputs = inputs.roll(1, dims=-1)
    inputs[:, 0] = X_DENOISING_TOKEN

    return inputs, targets

mask_spans(x, 4, 2, 0.8)


(tensor([[100,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
           13,  -1,  -1,  -1,  -1,  -1],
         [100,   0,  -1,  -1,  -1,  -1,   5,   6,   7,   8,   9,  -1,  -1,  -1,
           -1,  14,  15,  16,  17,  18],
         [100,   0,  -1,  -1,  -1,  -1,   5,   6,   7,   8,   9,  10,  11,  12,
           13,  -1,  -1,  -1,  -1,  18]]),
 tensor([[100,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
           13,  14,  15,  16,  17,  18],
         [100,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
           13,  14,  15,  16,  17,  18],
         [100,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
           13,  14,  15,  16,  17,  18]]))