Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_dummy_mha_with_nt_cuda fails on sm70, sm75 #129523

Open
eqy opened this issue Jun 25, 2024 · 4 comments
Open

test_dummy_mha_with_nt_cuda fails on sm70, sm75 #129523

eqy opened this issue Jun 25, 2024 · 4 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: multi-headed-attention module: nestedtensor NestedTensor tag see issue #25032 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@eqy
Copy link
Collaborator

eqy commented Jun 25, 2024

🐛 Describe the bug

Looks like it's dispatching to efficient attention backward and failing one of the shape checks (

TORCH_CHECK(
        max_seqlen_k <= key.size(1), "Invalid max_seqlen_k:", max_seqlen_k);

)

failing call:

    buf0 = aten._efficient_attention_backward.default(reinterpret_tensor(tangents_1, (1, s1, 2, 3), (6*s1, 6, 3, 1), 0), unsqueeze, unsqueeze_1, unsqueeze_2, None, getitem, convert_element_type, convert_element_type_1, s2, s5, getitem_1, 0.0, getitem_2, getitem_3, 0, False)

Printing k.sizes() here shows: [1, 6, 2, 3] when max_seqlen_k is 10.

Doesn't seem to happen on sm80+ as they seem to be able to dispatch to FA instead?
Interestingly fixing the backend on sm80+ with a decorator to run on efficient-attention only gives:

W0625 22:21:47.487000 140137837507200 torch/nested/_internal/sdpa.py:293] Memory efficient kernel not used because:
W0625 22:21:47.488000 140137837507200 torch/nested/_internal/sdpa.py:296] Flash attention kernel not used because:
W0625 22:21:47.488000 140137837507200 torch/nested/_internal/sdpa.py:101] For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256. Got Query.size(-1): 3, Key.size(-1): 3, Value.size(-1): 3 instead.
W0625 22:21:47.488000 140137837507200 torch/nested/_internal/sdpa.py:299] Math attention kernel not used because:
/workspace/pytorch/test/test_nestedtensor.py:5317: UserWarning: Mem efficient attention requires last dimension of inputs to be divisible by 4. Got Query.size(-1): 3, Key.size(-1): 3, Value.size(-1): 3 instead. (Triggered internally at /workspace/pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:164.)
  attn_output = torch.nn.functional.scaled_dot_product_attention(
/workspace/pytorch/test/test_nestedtensor.py:5317: UserWarning: Flash attention has been runtime disabled. (Triggered internally at /workspace/pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h:494.)
  attn_output = torch.nn.functional.scaled_dot_product_attention(
E

Simply removing the max_seqlen_k <= k.size(1) shape check allows for test to pass but I'm not sure that's correct---is there some special inductor/symbolic tracing accounting for shapes that needs to be done here?

CC @drisspg

Versions

Current 2024/06/25 source build

cc @ptrblck @msaroufim @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer

@eqy eqy added module: multi-headed-attention module: nestedtensor NestedTensor tag see issue #25032 module: cuda Related to torch.cuda, and CUDA support in general labels Jun 25, 2024
@malfet malfet added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 26, 2024
@drisspg
Copy link
Contributor

drisspg commented Jul 2, 2024

this is weird, is this only happening under some compile context?

@jbschlosser
Copy link
Contributor

@drisspg yes this test is for torch.compile() behavior with NJT + SDPA in a way that emulates what FIRST is doing

@drisspg
Copy link
Contributor

drisspg commented Jul 10, 2024

Hmm @danthe3rd do you know if when max_seq_len > sum(seq_len) is it possible to iterate into bad memory? I think the max_seq_len sets a max iteration bound and but there is still checks to ensure that current token indexes are valid right?

@danthe3rd
Copy link
Contributor

Probably you can remove that check. It does not make sense to have max_seq_len > sum(seq_len) tho, as it's always bounded by the sum, but it should be supported by the kernel I guess?
This is code I wrote some time ago so I don't have the context from the top of my head.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: multi-headed-attention module: nestedtensor NestedTensor tag see issue #25032 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants