-
Couldn't load subscription status.
- Fork 25.7k
Description
🐛 Describe the bug
Thanks for the team's great work! But it seems that the latest version (torch==2.6.0) still hasn't resolved the issue with dynamic shape inputs. I can easily reproduce this problem with a few lines of chunked-prefill code. I am curious if this is the same issue reported in #139064 and how to solve it?
I have narrowed down the issue to whether the create_block_mask function is compiled or not. If this function is not compiled, the program runs normally. However, for longer sequence masks (e.g., 64K*64K), not compiling create_block_mask will lead to huge GPU memory overhead, causing OOM. I'm not sure if this is because a whole bf16 data type mask tensor is created in the background? But if I compile this function, the same LoweringException: TypeError: cannot determine truth value of Relational error as in #139064 occurs.
You can easily reproduce this with the following code:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, _DEFAULT_SPARSE_BLOCK_SIZE
import torch
import argparse
import math
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument("--seq_len", type=int, default=32*1024)
parser.add_argument("--head_num", type=int, default=32)
parser.add_argument("--head_dim", type=int, default=128)
parser.add_argument("--chunk_size", type=int, default=2*1024)
args = parser.parse_args()
flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
def get_dynamic_mod(recent_token_num):
def get_mask(b, h, q_idx, kv_idx):
recent_mask = kv_idx < recent_token_num
real_kv_idx = kv_idx - recent_token_num
casual_mask = q_idx >= real_kv_idx
return recent_mask | casual_mask
return get_mask
@torch.no_grad
def main():
q = torch.randn(1, args.head_num, args.seq_len, args.head_dim, dtype=torch.bfloat16).cuda()
k = torch.randn(1, args.head_num, args.seq_len, args.head_dim, dtype=torch.bfloat16).cuda()
v = torch.randn(1, args.head_num, args.seq_len, args.head_dim, dtype=torch.bfloat16).cuda()
iter_num = math.ceil(args.seq_len / args.chunk_size)
num_past_tokens = 0
for i in tqdm(range(iter_num)):
query_states = q[:, :, i*args.chunk_size:(i+1)*args.chunk_size, :]
key_states = k[:, :, i*args.chunk_size-num_past_tokens:(i+1)*args.chunk_size, :]
value_states = v[:, :, i*args.chunk_size-num_past_tokens:(i+1)*args.chunk_size, :]
print(query_states.shape, key_states.shape, value_states.shape)
mask_mod = get_dynamic_mod(num_past_tokens)
# wheter to use `_compile=True` here is important!
block_mask = create_block_mask(mask_mod, 1, 1, args.chunk_size, args.chunk_size+num_past_tokens, device="cuda", BLOCK_SIZE=(128, 64), _compile=True)
attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
num_past_tokens = args.chunk_size * (i+1)
# num_past_tokens = 0
if __name__ == "__main__":
main()Versions
torch==2.6.0
GPU: Nvidia A100-40G SXM
cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @yf225 @Chillee @drisspg @yanboliang @BoyuanFeng