-
Notifications
You must be signed in to change notification settings - Fork 318
[CPU] Add concat-linear fusion pass for da8w4 #2476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2476
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit e125d05 with merge base 64c1ce3 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @CaoE Could you please also review this PR? I cannot add you as a reviewer. |
Hi @jerryzh168 Could you please review this PR? This PR adds new Inductor passes and we would like to hear your suggestions on where to put the code. Thanks. |
def register_da8w4_concat_linear_cpu_pass(): | ||
from torch._inductor import config as inductor_config | ||
|
||
inductor_config.post_grad_custom_post_pass = _concat_linear_dq8w4_cpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found that we can get gm by graph.owning_module so we can use post_grad_custom_post_pass to apply the pass. Then we don't need to use the register_backend_for_device API. Thanks.
But this one may silently cause conflict right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. We need to extend the passes to lists in PyTorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. We need to extend the passes to lists in PyTorch.
Still feels extend current design to avoid this conflict will be a better solution. Let's add some notes for the potential conflict at least.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jansel @eellison I think we need to extend the custom passes in Inductor (to make them lists). I will probably submit a PR for it. The custom passes added by pytorch/pytorch#154841 does not meet our needs. Do you have comments? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
List registration makes sense to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
List registration makes sense to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
LGTM, @Xia-Weiwen I had a message in slack, wondering if you want to migrate the cpu related stuff to experimental folder to be more consistent with the rest of the CPU kernels: https://github.com/pytorch/ao/tree/main/torchao/experimental |
Thanks for reviewing and sorry for the late reply. |
@pytorchbot merge |
This PR has pending changes requested. Please address the comments and update the PR before merging. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: Run TorchAO Experimental Tests / test-mps-ops (macos-m1-stable) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -f "CI failures are unrelated" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary
This PR adds a concat-linear fusion pass for da8w4 on CPU. The pass fuses the following pattern
to
The fusion pass is registered as a custom post_grad pass in Inductor. The pass takes effect only when
torch._inductor.config.cpp.enable_concat_linear
is true.Benchmarks show that total CPU time of linear is reduced by >5% with concat linear when running Llama3.1-8B with 32 cores on a 6th gen of Intel(R) Xeon(R).
Test plan