Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDPA] Update SDPA API and make function Public #92189

Closed
wants to merge 15 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Jan 14, 2023

Summary

In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function.

Changes

API

Previously the the function signature was:
scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)
Updated signature:
scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor

This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor.

Reasoning:

The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g. (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated.

The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing.

Discussed with folks at FAIR/Xformers and +1 this API change.

Make function Public

In preparation for the pt 2.0 launch we make the function public to start to generate user feedback

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92189

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0587d66:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 20, 2023
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -13943,21 +13943,27 @@
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
autogen: _native_multi_head_attention.out

# TODO: THIS NEEDS TO BE REMOVED BUT PEOPLE HAVE TRAINED THEIR MODELS WITH THIS OP BUILTIN
- func: _scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch added this back in since your review, appears some models may have been packaged with this builtin aten op

@drisspg
Copy link
Contributor Author

drisspg commented Jan 23, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@drisspg drisspg added the release notes: nn release notes category label Feb 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: amp (automated mixed precision) autocast module: inductor release notes: nn release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants