New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SDPA: frontend for BSR masks #104042
SDPA: frontend for BSR masks #104042
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104042
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8eb7cfd: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@drisspg , could you please also have a look? I might have missed some checks or conditions I am not yet aware of... |
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`. This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`. cc alexsamardzic pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`. This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`. cc alexsamardzic pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`. This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`. cc alexsamardzic pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`. This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`. cc alexsamardzic pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`. This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`. cc alexsamardzic pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`. This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`. cc alexsamardzic pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`. This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`. cc alexsamardzic pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments but overall I think it looks good
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`. This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`. cc alexsamardzic pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
ghstack-source-id: 36cb8e987528fd03bfe3dbf2381a4ce95e11be97 Pull Request resolved: #104042
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`. This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`. cc alexsamardzic pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
ghstack-source-id: e8140aaae957af0f8d05b95ed701048c993de4e7 Pull Request resolved: #104042
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`. This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`. cc alexsamardzic pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
ghstack-source-id: 45b273290feb14e57f2242a13e11079e4d3c7a30 Pull Request resolved: #104042
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`. This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`. cc alexsamardzic pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
ghstack-source-id: d086f557ae3e60062167872ddc849cc5efeb46fd Pull Request resolved: #104042
@pytorchbot merge |
Merge failedReason: Approval needed from one of the following: |
@cpuhrsch, could you please comment/approve? |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hi, This feature is really useful, thanks! Apologies if I'm completely missing something, but I am trying to use sparse tensor masks of arbitrary shape (i.e. not triangular or any common pattern) with However, I encounter the following error when the code gets to the
I am not sure if this is a problem on the sparse tensor representation/generation side or within I personally do not have any reason to use the BSR flavour of sparse tensors, and I am wondering if this feature could be supported for other sparse types, as they are more intuitive and easier to create? Many thanks! |
@davidbuterez , this function is not yet tied to the public API. One would need to call |
@nikitaved Thanks, this makes sense. However, I am encountering a new error which seems to be related to Triton. The offending line in
which eventually gets to I am also getting the following error:
My Q, K, V tensors have shape
I guess the |
@davidbuterez , could you please provide some min reproduction code? |
@nikitaved Absolutely, here is a minimal example: import torch
from torch.sparse._triton_ops import _scaled_dot_product_attention
qkv_size = (256, 16, 600, 16)
attn_mask_size = (256, 16, 600, 600)
Q = torch.rand(size=qkv_size, device='cuda', dtype=torch.bfloat16)
K = torch.rand(size=qkv_size, device='cuda', dtype=torch.bfloat16)
V = torch.rand(size=qkv_size, device='cuda', dtype=torch.bfloat16)
attn_mask = torch.randint(size=attn_mask_size, low=0, high=2, device='cuda', dtype=torch.bool)
blocksize = attn_mask.shape[-1] // 2
attn_mask_bsr = attn_mask.to_sparse_bsr((blocksize, blocksize))
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
O = _scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask_bsr, dropout_p=0.2) If it helps, I am using PyTorch 2.2.0.dev20231001 and CUDA 11.8 on an Ampere GPU. Also, a full stack trace (output from a Jupyter notebook):
|
@nikitaved I was wondering if there are any plans to fix this? Thanks. |
I'm just chiming in to confirm your suspicion @davidbuterez. When using sparse masks with Triton, the seq length needs to be a power of 2. I suppose if @nikitaved wanted to fix this, he'd have to wrap the function with some temporary padding? This is what I'll be doing myself with my Q and K before passing it into the triton sdpa. Also interested in a fix :) PS: I also notice I can't go higher than blocksize 64 (ex: to 128), so maybe the power of 2 thing isn't the full picture. I get the same run failed error. |
@davidbuterez , @ghwatson , sorry guys, but I am no longer involved in this work anymore. Check with @amjames , @pearu , they could be of help maybe. |
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR
attn_mask
.This function is directly comparable (with suitable masks) with
torch.nn.functional.scaled_dot_product_attention
onceattn_mask.dtype == torch.bool
, but it's behavior is different whenattn_mask.dtype != torch.bool
. This is becausetorch.nn.functional.scaled_dot_product_attention
assumes that irrelevant values are supposed to be filled with-inf
, while the selected ones should be0
.Stack from ghstack (oldest at bottom):
cc @alexsamardzic @pearu @cpuhrsch @amjames @bhosmer