## Sample masking for Flash attention 2.

Llama3 architecture exploits self-attention masking for different samples, as shown in https://arxiv.org/pdf/2407.21783 3.2 Model Architecture ```We use an attention mask that prevents self-attention between different documents within the same sequence. ```.

It is easy to implement it for common self-attention (e.g., `torch.matmul`), while it needs several tricks to make the sample-level masking compatible with [flash-attention-v2](https://github.com/Dao-AILab/flash-attention).

In [None]:
!pip install transformers==4.40.3

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

## Suppose there are three samples, they are tokenized into token ids.
'''
tokenizer('Hello world!').input_ids => [9906, 1917, 0]
tokenizer('Nice to meet you!').input_ids => [46078, 311, 3449, 499, 0]
tokenizer('What\' you name?').input_ids => [3923, 6, 499, 836, 30]
'''

Suppose there are two training sequences:
1. 'Hello world!<|end_of_text|>Nice to meet you!<|end_of_text|>What\' you name?<|end_of_text|>' => \[9906, 1917, 0, 128001, 46078, 311, 3449, 499, 0, 128001, 3923, 6, 499, 836, 30, 128001\]
2. 'What\' you name?<|end_of_text|>Nice to meet you!<|end_of_text|>Hello world!<|end_of_text|>' => \[3923, 6, 499, 836, 30, 128001, 46078, 311, 3449, 499, 0, 128001, 9906, 1917, 0, 128001\]

In [None]:
import torch

## <|end_of_text|>
eos_token_id = 128001

tokens = torch.tensor(
    [
        [9906, 1917, 0, 128001, 46078, 311, 3449, 499, 0, 128001, 3923, 6, 499, 836, 30, 128001],
        [3923, 6, 499, 836, 30, 128001, 46078, 311, 3449, 499, 0, 128001, 9906, 1917, 0, 128001],
    ],
    dtype=torch.long)

def to_device(batch, device):
    output = {}
    for k, v in batch.items():
        try:
            output[k] = v.to(device)
        except:
            output[k] = v
    return output

device = torch.device("cuda:0")

batch = {"input_ids": tokens, 'labels': tokens}
batch = to_device(batch, device)

### Option 1: No masking between different samples.

In [None]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16, use_flash_attention_2=True)

outputs = model(**batch, use_cache=False)
loss = outputs.loss

### Option 2: Do masking between different samples.

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import transformers
from einops import rearrange
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis

def _get_unpad_data_for_concatenated_sequences(attention_mask_in_length):
    """
    Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
    The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).

    For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
        ```
        [
          [2, 3, 0, 0, 0, 0],
          [3, 2, 0, 0, 0, 0],
          [6, 0, 0, 0, 0, 0]
        ]
        ```
    , which refers to the 3D-attention mask:
        ```
        [
          [
            [1, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0],
            [0, 0, 1, 1, 0, 0],
            [0, 0, 1, 1, 1, 0],
            [0, 0, 0, 0, 0, 1]
          ],
          [
            [1, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0],
            [1, 1, 1, 0, 0, 0],
            [0, 0, 0, 1, 0, 0],
            [0, 0, 0, 1, 1, 0],
            [0, 0, 0, 0, 0, 1]
          ],
          [
            [1, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0],
            [1, 1, 1, 0, 0, 0],
            [1, 1, 1, 1, 0, 0],
            [1, 1, 1, 1, 1, 0],
            [1, 1, 1, 1, 1, 1]
          ]
        ]
        ```.

    Arguments:
        attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
    Return:
        cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
        max_seqlen_in_batch: int
    """
    length = attention_mask_in_length.sum(dim=-1)
    seqlen = attention_mask_in_length.size(-1)
    attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
    real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
    seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
    indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )

def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
    #indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
    indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data_for_concatenated_sequences(attention_mask)
    
    batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

    key_layer = index_first_axis(
        key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
    )
    value_layer = index_first_axis(
        value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
    )
    if query_length == kv_seq_len:
        query_layer = index_first_axis(
            query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
        )
        cu_seqlens_q = cu_seqlens_k
        max_seqlen_in_batch_q = max_seqlen_in_batch_k
        indices_q = indices_k
    elif query_length == 1:
        max_seqlen_in_batch_q = 1
        cu_seqlens_q = torch.arange(
            batch_size + 1, dtype=torch.int32, device=query_layer.device
        )  # There is a memcpy here, that is very bad.
        indices_q = cu_seqlens_q[:-1]
        query_layer = query_layer.squeeze(1)
    else:
        # The -q_len: slice assumes left padding.
        attention_mask = attention_mask[:, -query_length:]
        query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

    return (
        query_layer,
        key_layer,
        value_layer,
        indices_q,
        (cu_seqlens_q, cu_seqlens_k),
        (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
    )



In [None]:
transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input = _upad_input
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B", device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16, use_flash_attention_2=True)

In [None]:
def mask_concat_samples(batch_data, eos_token_id, reset_position_ids=False):
    '''
    change attention mask
    '''
    input_ids = batch_data['input_ids']
    labels = batch_data['labels'].clone()
    micro_batch_size, seq_length = input_ids.shape

    position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

    inner_sample_lengths = torch.zeros((micro_batch_size, seq_length), dtype=torch.int)
    for b in range(micro_batch_size):
        # Find indecies where EOD token is.
        eod_index = position_ids[b, input_ids[b] == eos_token_id]
        # Detach indecies from positions if going to modify positions.
        if reset_position_ids:
            eod_index = eod_index.clone()

        prev_index = -1
        for j in range(len(eod_index)):
            inner_sample_lengths[b, j] = eod_index[j] - prev_index
            prev_index = eod_index[j]
            if eod_index[j] < seq_length - 1:
                labels[b, eod_index[j]+1] = -100

        if prev_index < seq_length - 1:
            inner_sample_lengths[b, len(eod_index)] = seq_length - 1 - prev_index

        #print(len(input_ids[b]), sum(inner_sample_lengths[b]))
        assert len(input_ids[b]) == sum(inner_sample_lengths[b]).item()

        if reset_position_ids and len(eod_index) > 1:
            for j in range(1, len(eod_index)):
                i = eod_index[j]
                prev_len = eod_index[j-1]
                position_ids[b, i:] -= (i - prev_len)

    batch_data['labels'] = labels
    batch_data['attention_mask'] = inner_sample_lengths

    if reset_position_ids:
        batch_data['position_ids'] = position_ids

In [None]:
## I guess it would be better to keep position increasing, since most of samples are short.
mask_concat_samples(batch, eos_token_id, reset_position_ids=False)

outputs = model(**batch, use_cache=False)
loss = outputs.loss