Skip to content

[CPU] Add concat-linear fusion pass for da8w4 #2476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from

Conversation

Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen commented Jul 2, 2025

Summary
This PR adds a concat-linear fusion pass for da8w4 on CPU. The pass fuses the following pattern

    da8w4_linear_cpu(x, ..., w1, ...) -- y1
  /
x --da8w4_linear_cpu(x, ..., w2, ...) -- y2
  \...
    da8w4_linear_cpu(x, ..., wN, ...) -- yN

to

x -- da8w4_linear_cpu(x, ..., w_concat, ...) -- y_concat -- split -- (y1, y2, yN)

The fusion pass is registered as a custom post_grad pass in Inductor. The pass takes effect only when torch._inductor.config.cpp.enable_concat_linear is true.

Benchmarks show that total CPU time of linear is reduced by >5% with concat linear when running Llama3.1-8B with 32 cores on a 6th gen of Intel(R) Xeon(R).

Test plan

pytest test/quantization/test_da8w4_cpu.py -k test_8da4w_concat_linear_cpu

Copy link

pytorch-bot bot commented Jul 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2476

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit e125d05 with merge base 64c1ce3 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 2, 2025
@Xia-Weiwen Xia-Weiwen added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jul 2, 2025
@Xia-Weiwen
Copy link
Collaborator Author

Hi @CaoE Could you please also review this PR? I cannot add you as a reviewer.

@Xia-Weiwen Xia-Weiwen added the cpu label Jul 2, 2025
@Xia-Weiwen Xia-Weiwen requested a review from jerryzh168 July 4, 2025 09:27
@Xia-Weiwen
Copy link
Collaborator Author

Hi @jerryzh168 Could you please review this PR? This PR adds new Inductor passes and we would like to hear your suggestions on where to put the code. Thanks.

def register_da8w4_concat_linear_cpu_pass():
from torch._inductor import config as inductor_config

inductor_config.post_grad_custom_post_pass = _concat_linear_dq8w4_cpu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that we can get gm by graph.owning_module so we can use post_grad_custom_post_pass to apply the pass. Then we don't need to use the register_backend_for_device API. Thanks.

But this one may silently cause conflict right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. We need to extend the passes to lists in PyTorch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. We need to extend the passes to lists in PyTorch.

Still feels extend current design to avoid this conflict will be a better solution. Let's add some notes for the potential conflict at least.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jansel @eellison I think we need to extend the custom passes in Inductor (to make them lists). I will probably submit a PR for it. The custom passes added by pytorch/pytorch#154841 does not meet our needs. Do you have comments? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List registration makes sense to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List registration makes sense to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@jerryzh168
Copy link
Contributor

LGTM, @Xia-Weiwen I had a message in slack, wondering if you want to migrate the cpu related stuff to experimental folder to be more consistent with the rest of the CPU kernels: https://github.com/pytorch/ao/tree/main/torchao/experimental

@Xia-Weiwen
Copy link
Collaborator Author

LGTM, @Xia-Weiwen I had a message in slack, wondering if you want to migrate the cpu related stuff to experimental folder to be more consistent with the rest of the CPU kernels: https://github.com/pytorch/ao/tree/main/torchao/experimental

Thanks for reviewing and sorry for the late reply.

@Xia-Weiwen Xia-Weiwen requested a review from jerryzh168 July 8, 2025 03:42
@Xia-Weiwen
Copy link
Collaborator Author

@pytorchbot merge

Copy link

pytorch-bot bot commented Jul 8, 2025

This PR has pending changes requested. Please address the comments and update the PR before merging.

@Xia-Weiwen
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: Run TorchAO Experimental Tests / test-mps-ops (macos-m1-stable)

Details for Dev Infra team Raised by workflow job

@Xia-Weiwen
Copy link
Collaborator Author

@pytorchbot merge -f "CI failures are unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. cpu Merged topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants