Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDPA: frontend for BSR masks #104042

Closed
wants to merge 15 commits into from
Closed

Conversation

nikitaved
Copy link
Collaborator

@nikitaved nikitaved commented Jun 22, 2023

This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR attn_mask.

This function is directly comparable (with suitable masks) with torch.nn.functional.scaled_dot_product_attention once attn_mask.dtype == torch.bool, but it's behavior is different when attn_mask.dtype != torch.bool. This is because torch.nn.functional.scaled_dot_product_attention assumes that irrelevant values are supposed to be filled with -inf, while the selected ones should be 0.

Stack from ghstack (oldest at bottom):

cc @alexsamardzic @pearu @cpuhrsch @amjames @bhosmer

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jun 22, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104042

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8eb7cfd:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@nikitaved
Copy link
Collaborator Author

nikitaved commented Jun 22, 2023

@drisspg , could you please also have a look? I might have missed some checks or conditions I am not yet aware of...

test/test_sparse_csr.py Outdated Show resolved Hide resolved
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`.




cc alexsamardzic pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`.




cc alexsamardzic pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`.




cc alexsamardzic pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`.




cc alexsamardzic pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`.




cc alexsamardzic pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`.




cc alexsamardzic pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
@nikitaved nikitaved added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 27, 2023
@nikitaved
Copy link
Collaborator Author

@amjames , @drisspg , unless you have any objections, I think this could be shipped.

torch/sparse/_triton_ops.py Outdated Show resolved Hide resolved
torch/sparse/_triton_ops.py Show resolved Hide resolved
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`.




cc alexsamardzic pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
test/test_sparse_csr.py Outdated Show resolved Hide resolved
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments but overall I think it looks good

This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`.




cc alexsamardzic pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
nikitaved added a commit that referenced this pull request Jul 11, 2023
ghstack-source-id: 36cb8e987528fd03bfe3dbf2381a4ce95e11be97
Pull Request resolved: #104042
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`.




cc alexsamardzic pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
nikitaved added a commit that referenced this pull request Jul 11, 2023
ghstack-source-id: e8140aaae957af0f8d05b95ed701048c993de4e7
Pull Request resolved: #104042
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`.




cc alexsamardzic pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
nikitaved added a commit that referenced this pull request Jul 12, 2023
ghstack-source-id: 45b273290feb14e57f2242a13e11079e4d3c7a30
Pull Request resolved: #104042
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`.




cc alexsamardzic pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
nikitaved added a commit that referenced this pull request Jul 12, 2023
ghstack-source-id: d086f557ae3e60062167872ddc849cc5efeb46fd
Pull Request resolved: #104042
@nikitaved
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approval needed from one of the following:
nikitaved, IvanYashchuk, cpuhrsch, pearu

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@nikitaved
Copy link
Collaborator Author

nikitaved commented Jul 12, 2023

@cpuhrsch, could you please comment/approve?

@nikitaved
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/nikitaved/58/head branch July 17, 2023 14:18
@davidbuterez
Copy link

Hi,

This feature is really useful, thanks! Apologies if I'm completely missing something, but I am trying to use sparse tensor masks of arbitrary shape (i.e. not triangular or any common pattern) with torch.nn.functional.scaled_dot_product_attention. My intended use is to save memory for highly sparse tensors. I am running this within the Memory-Efficient Attention context manager.

However, I encounter the following error when the code gets to the scaled_dot_product_attention line:

RuntimeError: Sparse BSR tensors do not have strides. 

I am not sure if this is a problem on the sparse tensor representation/generation side or within scaled_dot_product_attention.

I personally do not have any reason to use the BSR flavour of sparse tensors, and I am wondering if this feature could be supported for other sparse types, as they are more intuitive and easier to create?

Many thanks!

@nikitaved
Copy link
Collaborator Author

nikitaved commented Oct 5, 2023

@davidbuterez , this function is not yet tied to the public API. One would need to call torch.sparse._triton_ops.scaled_dot_product_attention directly. And, unfortunately, we do not have support for other sparse formats yet.

@davidbuterez
Copy link

@nikitaved Thanks, this makes sense. However, I am encountering a new error which seems to be related to Triton. The offending line in _scaled_dot_product_attention is

sdpa = sampled_addmm(attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False)

which eventually gets to inline_triton_ir and pm.run(mod) where it crashes with error RuntimeError: PassManager::run failed.

I am also getting the following error:

loc(".../torch-2-nightly-oct-2023/lib/python3.10/site-packages/triton/language/standard.py":93:11): error: Number of elements must be power-of-two, but "tt.return"(%0) : (tensor<300x300xf32>) -> () doesn't follow the rule (90000) elements

My Q, K, V tensors have shape [256, 16, 600, 16] and the attn_mask is [256, 16, 600, 600]. Before SDPA, I am converting attn_mask using

attn_mask = attn_mask.to_sparse_bsr((attn_mask.shape[-1] // 2, attn_mask.shape[-1] // 2))

I guess the RuntimeError is caused by the mask not being a power of two in size?

@nikitaved
Copy link
Collaborator Author

@davidbuterez , could you please provide some min reproduction code?

@davidbuterez
Copy link

@nikitaved Absolutely, here is a minimal example:

import torch
from torch.sparse._triton_ops import _scaled_dot_product_attention

qkv_size = (256, 16, 600, 16)
attn_mask_size = (256, 16, 600, 600)

Q = torch.rand(size=qkv_size, device='cuda', dtype=torch.bfloat16)
K = torch.rand(size=qkv_size, device='cuda', dtype=torch.bfloat16)
V = torch.rand(size=qkv_size, device='cuda', dtype=torch.bfloat16)

attn_mask = torch.randint(size=attn_mask_size, low=0, high=2, device='cuda', dtype=torch.bool)
blocksize = attn_mask.shape[-1] // 2
attn_mask_bsr = attn_mask.to_sparse_bsr((blocksize, blocksize))

with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
    O = _scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask_bsr, dropout_p=0.2)

If it helps, I am using PyTorch 2.2.0.dev20231001 and CUDA 11.8 on an Ampere GPU.

Also, a full stack trace (output from a Jupyter notebook):

loc(".../miniconda3/envs/torch-2-nightly-oct-2023/lib/python3.10/site-packages/triton/language/standard.py":93:11): error: Number of elements must be power-of-two, but "tt.return"(%0) : (tensor<300x300xf32>) -> () doesn't follow the rule (90000) elements
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[14], line 2
      1 with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
----> 2     O = _scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask_bsr, dropout_p=0.2)

File ~/miniconda3/envs/torch-2-nightly-oct-2023/lib/python3.10/site-packages/torch/sparse/_triton_ops.py:877, in _scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
    874 if attn_mask.dtype is not torch.bool:
    875     check_dtype(f_name, attn_mask, query.dtype)
--> 877 sdpa = sampled_addmm(attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False)
    878 if scale is None and query.size(-1) == 0 or scale == 0.0:
    879     check(
    880         False,
    881         f"{f_name}(): current value of scale == {scale} "
    882         "results in division by zero."
    883     )

File ~/miniconda3/envs/torch-2-nightly-oct-2023/lib/python3.10/site-packages/torch/sparse/_triton_ops.py:615, in sampled_addmm(input, mat1, mat2, beta, alpha, out, skip_checks, max_grid)
    612 mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))
    613 tile_k = max(*blocksize)
--> 615 _run_sampled_addmm_kernel(
    616     alpha, beta, beta == 0.0,
    617     blocksize, k, tile_k,
    618     values, crow_indices, col_indices,
    619     mat1, mat2,
    620     max_grid
    621 )
    623 # If nnz x block strides are not the same in out_backup.values and values,
    624 # it means that out_backup.values and values are not the views of each other,
    625 # so we have to copy.
    626 if out_backup.values().stride()[-3:] != values.stride()[-3:]:

File ~/miniconda3/envs/torch-2-nightly-oct-2023/lib/python3.10/site-packages/torch/sparse/_triton_ops.py:544, in _run_sampled_addmm_kernel(alpha, beta, is_beta_zero, blocksize, k, tile_k, values, crow_indices, col_indices, mat1, mat2, max_grid)
    533 def kernel(grid, *sliced_tensors):
    534     _sampled_addmm_kernel[grid](
    535         alpha, beta, is_beta_zero,
    536         *blocksize, k, tile_k,
   (...)
    541         num_warps=4
    542     )
--> 544 launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)

File ~/miniconda3/envs/torch-2-nightly-oct-2023/lib/python3.10/site-packages/torch/sparse/_triton_ops.py:151, in launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
    146     grid_blocks = tuple(
    147         valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid)
    148     )  # type: ignore[assignment]
    150 for grid, *sliced_tensors in grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
--> 151     kernel(grid, *sliced_tensors)

File ~/miniconda3/envs/torch-2-nightly-oct-2023/lib/python3.10/site-packages/torch/sparse/_triton_ops.py:534, in _run_sampled_addmm_kernel..kernel(grid, *sliced_tensors)
    533 def kernel(grid, *sliced_tensors):
--> 534     _sampled_addmm_kernel[grid](
    535         alpha, beta, is_beta_zero,
    536         *blocksize, k, tile_k,
    537         *ptr_stride_extractor(*sliced_tensors),
    538         acc_dtype=acc_dtype,
    539         allow_tf32=allow_tf32,
    540         num_stages=1,
    541         num_warps=4
    542     )

File :74, in _sampled_addmm_kernel(alpha, beta, IS_BETA_ZERO, BLOCKSIZE_ROW, BLOCKSIZE_COL, k, TILE_K, values_ptr, values_batch_stride, values_nnz_stride, values_row_block_stride, values_col_block_stride, crow_indices_ptr, crow_indices_batch_stride, crow_indices_stride, col_indices_ptr, col_indices_batch_stride, col_indices_stride, mat1_ptr, mat1_batch_stride, mat1_tiled_row_stride, mat1_tiled_col_stride, mat1_row_block_stride, mat1_col_block_stride, mat2_ptr, mat2_batch_stride, mat2_tiled_row_stride, mat2_tiled_col_stride, mat2_row_block_stride, mat2_col_block_stride, acc_dtype, allow_tf32, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type)

File ~/miniconda3/envs/torch-2-nightly-oct-2023/lib/python3.10/site-packages/triton/compiler/compiler.py:566, in compile(fn, **kwargs)
    564 path = metadata_group.get(ir_filename)
    565 if path is None:
--> 566     next_module = compile_kernel(module)
    567     if ir == "amdgcn":
    568         extra_file_name = f"{name}.hsaco_path"

File ~/miniconda3/envs/torch-2-nightly-oct-2023/lib/python3.10/site-packages/triton/compiler/compiler.py:466, in compile..(src)
    463 stages = dict()
    464 stages["ast"] = (lambda path: fn, None)
    465 stages["ttir"] = (lambda path: parse_mlir_module(path, context),
--> 466                   lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
    467 stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
    468                    lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
    469 stages["llir"] = (lambda path: Path(path).read_text(),
    470                   lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))

File ~/miniconda3/envs/torch-2-nightly-oct-2023/lib/python3.10/site-packages/triton/compiler/compiler.py:54, in optimize_ttir(mod, arch)
     53 def optimize_ttir(mod, arch):
---> 54     mod = inline_triton_ir(mod)
     55     mod = ttir_compute_capability_rewrite(mod, arch)
     56     pm = ir.pass_manager(mod.context)

File ~/miniconda3/envs/torch-2-nightly-oct-2023/lib/python3.10/site-packages/triton/compiler/compiler.py:38, in inline_triton_ir(mod)
     36 pm.enable_debug()
     37 pm.add_inliner_pass()
---> 38 pm.run(mod)
     39 return mod

RuntimeError: PassManager::run failed

@davidbuterez
Copy link

@nikitaved I was wondering if there are any plans to fix this? Thanks.

@ghwatson
Copy link

ghwatson commented Nov 27, 2023

I'm just chiming in to confirm your suspicion @davidbuterez. When using sparse masks with Triton, the seq length needs to be a power of 2. I suppose if @nikitaved wanted to fix this, he'd have to wrap the function with some temporary padding? This is what I'll be doing myself with my Q and K before passing it into the triton sdpa.

Also interested in a fix :)

PS: I also notice I can't go higher than blocksize 64 (ex: to 128), so maybe the power of 2 thing isn't the full picture. I get the same run failed error.

@nikitaved
Copy link
Collaborator Author

@davidbuterez , @ghwatson , sorry guys, but I am no longer involved in this work anymore. Check with @amjames , @pearu , they could be of help maybe.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: sparse Related to torch.sparse open source release notes: sparse release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants