Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDPA] implement ordering #91362

Closed
wants to merge 1 commit into from
Closed

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Dec 23, 2022

Summary

In some cases, dependent on input, flash-attention is not the fastest fused kernel and memory-efficient attention is better. This implements a simple heuristic function for deciding the ordering of kernel functions. This was based off of the xformer function found here: https://github.com/fairinternal/xformers/blob/15bff4986c3a4376176a4e6fa3dc0f2a120fa0bb/xformers/ops/fmha/dispatch.py#L13

cc @ngimel

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 23, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91362

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 6b22fa1:

FLAKY - The following jobs failed but were likely due to flakiness present on master:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg
Copy link
Contributor Author

drisspg commented Dec 26, 2022

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased heuristic_ordering onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout heuristic_ordering && git pull --rebase)

@drisspg
Copy link
Contributor Author

drisspg commented Dec 29, 2022

Profiling the composite MHA with head_dim = 128

Time for flash_attention:
Time: 34762.24639452994 us

Time for efficient_attention:
Time: 32185.50537712872 us

@drisspg
Copy link
Contributor Author

drisspg commented Dec 29, 2022

In an extreme case with:
batch_size = 1
num_heads = 8
max_seq_len = 2048
head_dim =128

Flash time:
Time: 3734.933268278837 us

Efficient-attention time:
Time: 495.5342309549451 us

namespace sdp {

constexpr int32_t num_backends = 3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a disturbingly generic name even if it's within the sdp namespace.

@@ -29,6 +30,46 @@ struct sdp_params {
bool is_causal;
};

inline std::array<SDPBackend, num_backends> priority_order(sdp_params params) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way to implement this (and I think it's kind of what this is), is to modify use_flash_attention and use_mem_efficient_attention to return an integer or to create new functions that return integers.

These integers then are the estimated number of operations performed by the respective fused kernel. This is similar to estimate_matmul_time.

You then pick the one that returns the lowest number of operations. And if the number of operations is negative, well then the kernel doesn't apply.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine for now, but I think generalizing this to a cost model would be more future proof.

@drisspg
Copy link
Contributor Author

drisspg commented Jan 3, 2023

@pytorchbot merge -l

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 3, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

The -l land checks flag is deprecated and no longer needed. Instead we now automatically add the ciflow\trunk label to your PR once it's approved

Your change will be merged once all checks on your PR pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. You can rebase by leaving the following comment on this PR:
@pytorchbot rebase

Details for Dev Infra team Raised by workflow job

@drisspg
Copy link
Contributor Author

drisspg commented Jan 3, 2023

@pytorchbot rebase

@drisspg
Copy link
Contributor Author

drisspg commented Jan 3, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. You can rebase by leaving the following comment on this PR:
@pytorchbot rebase

Details for Dev Infra team Raised by workflow job

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased heuristic_ordering onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout heuristic_ordering && git pull --rebase)

@drisspg
Copy link
Contributor Author

drisspg commented Jan 3, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 additional jobs have failed, first few of them are: trunk ,trunk / linux-focal-rocm5.3-py3.8 / test (default, 2, 2, linux.rocm.gpu)

Details for Dev Infra team Raised by workflow job

@drisspg
Copy link
Contributor Author

drisspg commented Jan 3, 2023

@pytorchbot merge -f "unrelated failure"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@drisspg drisspg changed the title implement ordering [SDPA] implement ordering Jan 10, 2023
@drisspg drisspg added module: performance Issues related to performance, either of kernel code or framework glue module: multi-headed-attention labels Feb 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: multi-headed-attention module: performance Issues related to performance, either of kernel code or framework glue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants