-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SDPA] Update SDPA API and make function Public #92189
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92189
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0587d66: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
c6d2e44
to
13112a6
Compare
13112a6
to
1128121
Compare
b8fd513
to
19d219c
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@@ -13943,21 +13943,27 @@ | |||
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda | |||
autogen: _native_multi_head_attention.out | |||
|
|||
# TODO: THIS NEEDS TO BE REMOVED BUT PEOPLE HAVE TRAINED THEIR MODELS WITH THIS OP BUILTIN | |||
- func: _scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpuhrsch added this back in since your review, appears some models may have been packaged with this builtin aten op
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary
In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function.
Changes
API
Previously the the function signature was:
scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)
Updated signature:
scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor
This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor.
Reasoning:
The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g. (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated.
The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing.
Discussed with folks at FAIR/Xformers and +1 this API change.
Make function Public
In preparation for the pt 2.0 launch we make the function public to start to generate user feedback
cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire