Skip to content

[Inductor] Support scaled mm on inductor #2411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

shiyang-weng
Copy link
Contributor

@shiyang-weng shiyang-weng commented Jun 19, 2025

Fuse following pattern to scaled_mm

    #   + - - - - | - - - - - -  | - - - -  +
    #   |    dq_per_tensor  dq_per_tensor   |
    #   |         |              |          |
    #   |    OPT(to_bf16)    OPT(to_bf16)   |
    #   |         |             |           |
    #   |    OPT(reshape)     permute       |
    #   |          \           /            |
    #   |             addmm/mm              |
    #   |                |                  |
    #   |      OPT(quant_per_tensor)        |
    #   |                |                  |
    #   |          OPT(reshape)             |

Copy link

pytorch-bot bot commented Jun 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2411

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0c7f8ea with merge base 8b57afe (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@shiyang-weng shiyang-weng marked this pull request as draft June 19, 2025 01:48
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 19, 2025
@@ -392,5 +392,59 @@ def test_dynamic_scale_numeric_parity(
assert torch.equal(float8_eager._data, float8_compile._data)


@pytest.mark.parametrize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is the training float8 test file, float8 inference is using https://github.com/pytorch/ao/blob/main/test/dtypes/test_affine_quantized_float.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is the training float8 test file, float8 inference is using https://github.com/pytorch/ao/blob/main/test/dtypes/test_affine_quantized_float.py

Ok. I change the ut path on last pr #2379

Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. nit: This PR adds a fusion pass for fp8 q-dq-linear, not scaled_mm. scaled_mm is the fusion result. Please update the PR title.

@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("input_dim_exceeds_two", [True, False])
@parametrize("check_reuse_input", [True, False])
def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two, check_reuse_input):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to call it test_fp8_qlinear

return dequant_fp8_linear_bias_pattern, dequant_fp8_linear_no_bias_pattern


def _is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern is fp8 qlinear, not scaled_mm. scaled_mm is the fusion result. So, better we call it fp8_qlinear_pattern

return _inner


def _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. scaled_mm -> fp8_qlinear.

counters["inductor"]["scaled_mm_matcher_nodes"] += len(match.nodes)


def _register_scaled_mm():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. scaled_mm -> fp8_qlinear.

@Xia-Weiwen Xia-Weiwen added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jun 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants