Skip to content

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Aug 26, 2025

Stacked PRs:


Fix Float8Tensor quantize op kernrel preference dispatch

Summary:
Previously if user specifies kernel_preference == "fbgemm", we'll use torch ops like _choose_scale_float8 and _quantize_affine_float8 to quantize the high precision Tensor
into a float8 Tensor

this PR makes sure we use fbgemm kernels when kernel_preference is "fbgemm", meaning: torch.ops.triton.quantize_fp8_row for per row, and torch.ops.fbgemm.quantize_fp8_per_tensor
for per tensor (while torch.ops.fbgemm.quantize_fp8_per_tensor has some issues right now and we'll enable later when it's fixed)

This doesn't have impact on BC, meaning old serialized model can still be loaded and run, only thing is fixing the kernel choice for fbgemm kernel preference
means users who requested FBGEMM kernelpreference now actually run fbgemm quantize op instead of torch op

Test Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_expected_gpu_kernel_fbgemm

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Aug 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2883

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4516f6e with merge base 2a53216 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 26, 2025
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/59 branch from 6935cc8 to bacbe8c Compare August 26, 2025 22:22
@jerryzh168 jerryzh168 requested review from vkuzo and drisspg August 26, 2025 22:23
@jerryzh168 jerryzh168 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Aug 26, 2025
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/59 branch 2 times, most recently from 6cf26bd to 815a964 Compare August 27, 2025 23:47
@jerryzh168 jerryzh168 changed the title Fix Float8Tensor quantize op kernrel preference dispatch Fix Float8Tensor quantize op kernel preference dispatch Aug 28, 2025
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/59 branch from 815a964 to a78fc11 Compare August 28, 2025 00:19
@jerryzh168 jerryzh168 changed the title Fix Float8Tensor quantize op kernel preference dispatch Fix Float8Tensor quantize op kernrel preference dispatch Aug 28, 2025
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/59 branch 4 times, most recently from be69537 to 5f6ec32 Compare August 28, 2025 00:46
@jerryzh168 jerryzh168 requested review from vkuzo and drisspg August 28, 2025 01:43
@@ -256,6 +303,9 @@ def _(func, types, args, kwargs):
kernel_choice = "fbgemm"
elif weight_tensor.kernel_preference == KernelPreference.FBGEMM:
kernel_choice = "fbgemm"
elif weight_tensor.kernel_preference == KernelPreference.TRITON:
# no triton gemm op is available, so we'll fallback to torch
kernel_choice = "torch"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also weird, your kernel choice is doing double duty and now a recipe. that recipe is also not very clear from your initial doc

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean kernel choice is used for both quantize and gemm?

what is the initial doc you are referring to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly, The doc is just the code block and reading on the kernel choice doc string

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel choice is used for both quantize and gemm

yeah that's a decision we made before, according to Josh there is no need to have kernel level choice for now, just to keep things simple.

we did mention this in the KernelPreference doc I think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it should be:

  • fbgemm - use all fbgemm kernels, error out if something is not supported
  • torch - use all torch kernels, error out if something is not supported
  • auto - torchao decides what to do

we should not use torch kernels in the fbgemm setting, as that is not honoring what the user asked for

@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/59 branch 2 times, most recently from 8ea051d to 74dd7dd Compare August 28, 2025 04:15
@@ -30,5 +30,9 @@ class KernelPreference(str, Enum):
"""
FBGEMM = "fbgemm"

"""Use triton quantize and quantized mm kernels (if available), requires fbgemm_gpu_genai library, if no triton kernel for the quantize op or mm kernel is available, we'll fallback to torch ops
"""
TRITON = "triton"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this name isn't coherent with the rest of the enum. We already have an FBGEMM option which does not say anything about cutlass vs triton and therefore already includes these kernels. I think you have two options:

  1. have the fbgemm option pick the best kernel (cutlass vs triton) for the user. I prefer this one.
  2. make it clear that "FBGEMM" does not mean "FBGEMM", but really means "FBGEMM_CUTLASS", and also add "FBGEMM_TRITON". I don't really like this option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK thanks, yeah 1 seems easiest for now, will update to that. unless there is request to distinguish these in the future

@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/59 branch from 74dd7dd to 0b2ab3e Compare August 28, 2025 17:13
@jerryzh168 jerryzh168 requested a review from vkuzo August 28, 2025 18:52
@jerryzh168 jerryzh168 requested a review from drisspg August 28, 2025 18:52
if (
isinstance(granularity, PerTensor)
and kernel_preference == KernelPreference.FBGEMM
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: lets Xfail this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are using unittest, seems like we can't do return unittest.expectedFailure("...")?

.../ao/test/quantization/quantize_/workflows/float8/test_float8_tensor.py", line 92, in test_fp8_linear_variants
    return unittest.expectedFailure(
  File ".../python3.10/unittest/case.py", line 148, in expectedFailure
    test_item.__unittest_expecting_failure__ = True
AttributeError: 'str' object has no attribute '__unittest_expecting_failure__'

but let me know if there is an example to do expectedFailure conditionally instead of skipping entire test

@vkuzo
Copy link
Contributor

vkuzo commented Aug 28, 2025

Previously we didn't handle kernel_preference == "fbgemm" properly for the quantize op,
this PR makes sure we dispatch to fbgemm kernels when kernel_preference is fbgemm

Can you explain specifically what did not work, and why it works after this PR? It would also be good to have a test which fails before this PR and passes after this PR.

@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/59 branch 3 times, most recently from a1f4504 to a73fa51 Compare August 28, 2025 20:28
@jerryzh168
Copy link
Contributor Author

@vkuzo sure, updated the PR summary and added a test for this one and next PR as well

@vkuzo
Copy link
Contributor

vkuzo commented Aug 28, 2025

fixing the kernel choice for fbgemm kernel preference, which is supposed to be a developer facing API

note that I don't think this is true, all values of KernelPreference are user facing

@@ -399,6 +413,29 @@ def test_moe_weight_reshape_ops(self):
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
self._test_moe_weight_reshape_ops(config)

def test_expected_gpu_kernel_fbgemm(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test should be together with the other tests we have which check the same thing for other settings of this config, currently in test_affine_quantized_float.py. Can we add a TODO to unify?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think we can put everything here after we deprecate the AQT path in 9 months

@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/59 branch from a73fa51 to f685f8b Compare August 28, 2025 22:17
@jerryzh168
Copy link
Contributor Author

fixing the kernel choice for fbgemm kernel preference, which is supposed to be a developer facing API

note that I don't think this is true, all values of KernelPreference are user facing

makes sense, it is user facing

@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/59 branch from f685f8b to 8bea88e Compare August 28, 2025 22:35
Summary:
Previously if user specifies kernel_preference == "fbgemm", we'll use torch ops like `_choose_scale_float8` and `_quantize_affine_float8` to quantize the high precision Tensor
into a float8 Tensor

this PR makes sure we use fbgemm kernels when kernel_preference is "fbgemm", meaning: `torch.ops.triton.quantize_fp8_row` for per row, and `torch.ops.fbgemm.quantize_fp8_per_tensor`
for per tensor (while `torch.ops.fbgemm.quantize_fp8_per_tensor` has some issues right now and we'll enable later when it's fixed)

This doesn't have impact on BC, meaning old serialized model can still be loaded and run, only thing is fixing the kernel choice for fbgemm kernel preference
means users who requested FBGEMM kernelpreference now actually run fbgemm quantize op instead of torch op

Test Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_expected_gpu_kernel_fbgemm

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2883, branch: jerryzh168/stack/59
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/59 branch from 8bea88e to 4516f6e Compare August 28, 2025 22:43
@jerryzh168 jerryzh168 requested a review from vkuzo August 29, 2025 16:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants