In [28]:
import torch
from rich import print
import einops
import math
import torch.nn.functional as F

In [2]:
x = einops.repeat(torch.arange(20), "length -> batch length", batch=3)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DENOISING_TOKEN = 100
MASK_TOKEN = -1

In [29]:
@torch.no_grad()
def mask_spans(
    sequence: torch.Tensor,
    mask_width: int,
    num_masks: int,
    max_percentage_masked: float,
    causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    if causal:
        targets = torch.zeros_like(sequence).copy_(sequence)
    else:
        targets = sequence.roll(1)
        targets[:, 0] = DENOISING_TOKEN

    num_masks = min(num_masks, math.floor(max_percentage_masked * sequence.shape[-1] / mask_width))

    batch_size, seq_length = sequence.shape
    mask = torch.randint(0, seq_length - mask_width + 1, (batch_size, num_masks))
    
    # Create a tensor to hold the mask
    mask_tensor = torch.zeros_like(sequence)
    
    # Create a range for the mask width
    width_range = torch.arange(mask_width).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, mask_width)
    
    # Expand the mask to the desired width
    mask_positions = mask.unsqueeze(-1) + width_range  # Shape: (batch_size, num_masks, mask_width)
    
    # Clip the positions to stay within sequence bounds
    mask_positions = mask_positions.clamp(0, seq_length - 1)
    
    # Scatter the mask positions into the mask tensor
    mask_tensor.scatter_(1, mask_positions.view(batch_size, -1), 1)

    # Apply the mask to the sequence
    inputs = sequence.masked_fill(mask_tensor.bool(), MASK_TOKEN)

    # Add task token
    inputs = inputs.roll(1, dims=-1)
    inputs[:, 0] = DENOISING_TOKEN

    return inputs, targets

inputs, targets = mask_spans(x, 4, 2, 0.8, True)
print(f"{inputs=}\n{targets=}\n{x=}")

In [5]:
@torch.no_grad()
def get_s_denoised_data(
        sequence: torch.Tensor,
        mask_width: int | None,
        masking_rate: float = 0.5,
        causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    if causal:
        targets = torch.zeros_like(sequence).copy_(sequence)  # copy sequence to not have negative downstream effects
    else:
        targets = sequence.roll(1)
        targets[:, 0] = DENOISING_TOKEN

    mask_width = mask_width or math.floor(masking_rate * sequence.shape[-1])
    mask_width = min(mask_width, math.floor(masking_rate * sequence.shape[-1]))
    inputs = sequence.roll(1, dims=-1)
    inputs [:, -mask_width:] = MASK_TOKEN
    inputs[:, 0] = DENOISING_TOKEN

    return inputs, targets

if False:
    inputs, targets = get_s_denoised_data(x, None, 0.25, True)
    print(f"{inputs=}\n{targets=}\n{x=}")

In [6]:
@torch.no_grad()
def get_causal_data(sequence: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    targets  = sequence

    # Inputs: add special token to beginning
    # Just roll the tensor and replace the first (previously final) token to get the causality going
    inputs = sequence.roll(1, dims=-1)
    inputs[:, 0] = DENOISING_TOKEN

    return inputs, targets

In [7]:
def test_get_causal_data():
    sequence = einops.rearrange(torch.arange(10), "s -> () s")
    print(sequence[:, :-1])
    print(sequence[:, 1:])
    inputs, targets = get_causal_data(sequence)
    print(inputs)
    print(targets)


test_get_causal_data()

## Try masking

In [8]:
max_seq_len = 1024

In [9]:
with torch.no_grad():
    # Create the base arrays for the learnable linear positional bias. This helps save some memory consumption & processing time
    bias_range                    = torch.arange(-max_seq_len+1, 1).to(torch.bfloat16)
    position_bias_base            = bias_range.unsqueeze(0) - bias_range.unsqueeze(1)
    negative_infinity_matrix_base = torch.empty_like(position_bias_base).fill_(-float("inf"))
    causal_mask = torch.tril(torch.ones((max_seq_len, max_seq_len), dtype=torch.bool))

In [31]:
def make_mask(x: torch.Tensor) -> torch.Tensor:
    masked_spans: torch.Tensor = (x == MASK_TOKEN).bool()
    masked_spans = einops.repeat(masked_spans, 'b l -> b l h', h=x.shape[1])
    masked_spans = masked_spans & masked_spans.swapaxes(1, 2)
    masked_spans = masked_spans | causal_mask[:x.shape[1], :x.shape[1]].unsqueeze(0)
    attn_mask = torch.where(masked_spans, position_bias_base[:x.shape[1], :x.shape[1]], negative_infinity_matrix_base[:x.shape[1], :x.shape[1]])
    return attn_mask


x = torch.rand(2, 6).to(dtype=torch.bfloat16)
inputs, targets = mask_spans(x, 2, num_masks=100000000, max_percentage_masked=0.8, causal=True)
attn_mask = make_mask(inputs)
print(inputs)
print(attn_mask)  # looks good
# Can attention work with it?
# q, k, v = torch.randn(2, 6, 8).to(dtype=torch.bfloat16), torch.randn(2, 6, 8).to(dtype=torch.bfloat16), torch.randn(2, 6, 8).to(dtype=torch.bfloat16)
# print(attn_mask.shape, q.shape, k.shape, v.shape)
# y = F.scaled_dot_product_attention(q, k, v, attn_mask)
# print(y)