Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scaled_dot_product_attention with alibi bias issue on long sequence length #107959

Closed
nlpcat opened this issue Aug 25, 2023 · 3 comments
Closed
Labels
oncall: transformer/mha Issues related to Transformers and MultiheadAttention

Comments

@nlpcat
Copy link

nlpcat commented Aug 25, 2023

🐛 Describe the bug

A100 40GB.

scaled_dot_product_attention + alibi_attention_mask + enable_mem_efficient + seq_length 8192: throw an error bias_4d_view.stride(0) overflows

scaled_dot_product_attention + alibi_attention_mask + enable_math + seq_length 8192: no error

scaled_dot_product_attention + alibi_attention_mask + enable_mem_efficient + seq_length 4096: no error

import torch
dtype = torch.bfloat16
seq_length = 8192
num_heads = 32
query_layer = torch.rand([1, num_heads, seq_length, 64], dtype=dtype, device="cuda:0")
key_layer = torch.rand([1, num_heads, seq_length, 64], dtype=dtype, device="cuda:0")
value_layer = torch.rand([1, num_heads, seq_length, 64], dtype=dtype, device="cuda:0")
alibi = torch.rand([1, num_heads, seq_length, seq_length], dtype=dtype, device="cuda:0")

# throw an error
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
     context_layer = torch.nn.functional.scaled_dot_product_attention(
         query_layer, key_layer, value_layer, attn_mask=alibi, dropout_p=0.0
     )

# no error
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
     context_layer = torch.nn.functional.scaled_dot_product_attention(
         query_layer, key_layer, value_layer, attn_mask=alibi, dropout_p=0.0
     )

Versions

nightly version
torch==2.1.0-dev20230825+cu118

cc @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg

@nlpcat
Copy link
Author

nlpcat commented Aug 25, 2023

follow up issues on this #96099 and #104310

@nlpcat
Copy link
Author

nlpcat commented Aug 25, 2023

@drisspg

@awgu awgu added the oncall: transformer/mha Issues related to Transformers and MultiheadAttention label Aug 25, 2023
@drisspg
Copy link
Contributor

drisspg commented Aug 25, 2023

Ill take a look

voznesenskym pushed a commit that referenced this issue Aug 27, 2023
Fixes #107959
This should have been fixed here #103201
Edit:
Looking at git blame it appears the dropout revet squashed the changes from this PR
Pull Request resolved: #107968
Approved by: https://github.com/cpuhrsch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: transformer/mha Issues related to Transformers and MultiheadAttention
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants