-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SDPA] implement ordering #91362
[SDPA] implement ordering #91362
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91362
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 6b22fa1: FLAKY - The following jobs failed but were likely due to flakiness present on master:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
b281936
to
44f8382
Compare
Profiling the composite MHA with head_dim = 128 Time for flash_attention: Time for efficient_attention: |
In an extreme case with: Flash time: Efficient-attention time: |
namespace sdp { | ||
|
||
constexpr int32_t num_backends = 3; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a disturbingly generic name even if it's within the sdp namespace.
@@ -29,6 +30,46 @@ struct sdp_params { | |||
bool is_causal; | |||
}; | |||
|
|||
inline std::array<SDPBackend, num_backends> priority_order(sdp_params params) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another way to implement this (and I think it's kind of what this is), is to modify use_flash_attention
and use_mem_efficient_attention
to return an integer or to create new functions that return integers.
These integers then are the estimated number of operations performed by the respective fused kernel. This is similar to estimate_matmul_time
.
You then pick the one that returns the lowest number of operations. And if the number of operations is negative, well then the kernel doesn't apply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine for now, but I think generalizing this to a cost model would be more future proof.
@pytorchbot merge -l |
Merge startedThe Your change will be merged once all checks on your PR pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. You can rebase by leaving the following comment on this PR: Details for Dev Infra teamRaised by workflow job |
@pytorchbot rebase |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. You can rebase by leaving the following comment on this PR: Details for Dev Infra teamRaised by workflow job |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
44f8382
to
6b22fa1
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 additional jobs have failed, first few of them are: trunk ,trunk / linux-focal-rocm5.3-py3.8 / test (default, 2, 2, linux.rocm.gpu) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -f "unrelated failure" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary
In some cases, dependent on input, flash-attention is not the fastest fused kernel and memory-efficient attention is better. This implements a simple heuristic function for deciding the ordering of kernel functions. This was based off of the xformer function found here: https://github.com/fairinternal/xformers/blob/15bff4986c3a4376176a4e6fa3dc0f2a120fa0bb/xformers/ops/fmha/dispatch.py#L13
cc @ngimel