Skip to content

[Flex Attention] Errors with Dynamic Shapes (Cannot determine truth value of Relational) #146745

@ChenlongDeng

Description

@ChenlongDeng

🐛 Describe the bug

Thanks for the team's great work! But it seems that the latest version (torch==2.6.0) still hasn't resolved the issue with dynamic shape inputs. I can easily reproduce this problem with a few lines of chunked-prefill code. I am curious if this is the same issue reported in #139064 and how to solve it?

I have narrowed down the issue to whether the create_block_mask function is compiled or not. If this function is not compiled, the program runs normally. However, for longer sequence masks (e.g., 64K*64K), not compiling create_block_mask will lead to huge GPU memory overhead, causing OOM. I'm not sure if this is because a whole bf16 data type mask tensor is created in the background? But if I compile this function, the same LoweringException: TypeError: cannot determine truth value of Relational error as in #139064 occurs.

You can easily reproduce this with the following code:

from torch.nn.attention.flex_attention import flex_attention, create_block_mask, _DEFAULT_SPARSE_BLOCK_SIZE
import torch
import argparse
import math
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--seq_len", type=int, default=32*1024)
parser.add_argument("--head_num", type=int, default=32)
parser.add_argument("--head_dim", type=int, default=128)
parser.add_argument("--chunk_size", type=int, default=2*1024)
args = parser.parse_args()

flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")

def get_dynamic_mod(recent_token_num):
    def get_mask(b, h, q_idx, kv_idx):
        recent_mask = kv_idx < recent_token_num
        real_kv_idx = kv_idx - recent_token_num
        casual_mask = q_idx >= real_kv_idx

        return recent_mask | casual_mask

    return get_mask

@torch.no_grad
def main():
    q = torch.randn(1, args.head_num, args.seq_len, args.head_dim, dtype=torch.bfloat16).cuda()
    k = torch.randn(1, args.head_num, args.seq_len, args.head_dim, dtype=torch.bfloat16).cuda()
    v = torch.randn(1, args.head_num, args.seq_len, args.head_dim, dtype=torch.bfloat16).cuda()
    iter_num = math.ceil(args.seq_len / args.chunk_size)
    num_past_tokens = 0
    for i in tqdm(range(iter_num)):
        query_states = q[:, :, i*args.chunk_size:(i+1)*args.chunk_size, :]
        key_states = k[:, :, i*args.chunk_size-num_past_tokens:(i+1)*args.chunk_size, :]
        value_states = v[:, :, i*args.chunk_size-num_past_tokens:(i+1)*args.chunk_size, :]

        print(query_states.shape, key_states.shape, value_states.shape)

        mask_mod = get_dynamic_mod(num_past_tokens)

        # wheter to use `_compile=True` here is important!
        block_mask = create_block_mask(mask_mod, 1, 1, args.chunk_size, args.chunk_size+num_past_tokens, device="cuda", BLOCK_SIZE=(128, 64), _compile=True)

        attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)

        num_past_tokens = args.chunk_size * (i+1)
        # num_past_tokens = 0


if __name__ == "__main__":
    main()

Versions

torch==2.6.0
GPU: Nvidia A100-40G SXM

cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @yf225 @Chillee @drisspg @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions