-
Notifications
You must be signed in to change notification settings - Fork 322
Fix Float8Tensor quantize op kernrel preference dispatch #2883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2883
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4516f6e with merge base 2a53216 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
6935cc8
to
bacbe8c
Compare
6cf26bd
to
815a964
Compare
test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Outdated
Show resolved
Hide resolved
torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Outdated
Show resolved
Hide resolved
torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Outdated
Show resolved
Hide resolved
815a964
to
a78fc11
Compare
torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Outdated
Show resolved
Hide resolved
be69537
to
5f6ec32
Compare
@@ -256,6 +303,9 @@ def _(func, types, args, kwargs): | |||
kernel_choice = "fbgemm" | |||
elif weight_tensor.kernel_preference == KernelPreference.FBGEMM: | |||
kernel_choice = "fbgemm" | |||
elif weight_tensor.kernel_preference == KernelPreference.TRITON: | |||
# no triton gemm op is available, so we'll fallback to torch | |||
kernel_choice = "torch" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also weird, your kernel choice
is doing double duty and now a recipe. that recipe is also not very clear from your initial doc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean kernel choice is used for both quantize and gemm?
what is the initial doc you are referring to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah exactly, The doc is just the code block and reading on the kernel choice doc string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel choice is used for both quantize and gemm
yeah that's a decision we made before, according to Josh there is no need to have kernel level choice for now, just to keep things simple.
we did mention this in the KernelPreference doc I think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it should be:
- fbgemm - use all fbgemm kernels, error out if something is not supported
- torch - use all torch kernels, error out if something is not supported
- auto - torchao decides what to do
we should not use torch kernels in the fbgemm setting, as that is not honoring what the user asked for
8ea051d
to
74dd7dd
Compare
@@ -30,5 +30,9 @@ class KernelPreference(str, Enum): | |||
""" | |||
FBGEMM = "fbgemm" | |||
|
|||
"""Use triton quantize and quantized mm kernels (if available), requires fbgemm_gpu_genai library, if no triton kernel for the quantize op or mm kernel is available, we'll fallback to torch ops | |||
""" | |||
TRITON = "triton" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this name isn't coherent with the rest of the enum. We already have an FBGEMM
option which does not say anything about cutlass vs triton and therefore already includes these kernels. I think you have two options:
- have the fbgemm option pick the best kernel (cutlass vs triton) for the user. I prefer this one.
- make it clear that "FBGEMM" does not mean "FBGEMM", but really means "FBGEMM_CUTLASS", and also add "FBGEMM_TRITON". I don't really like this option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK thanks, yeah 1 seems easiest for now, will update to that. unless there is request to distinguish these in the future
74dd7dd
to
0b2ab3e
Compare
if ( | ||
isinstance(granularity, PerTensor) | ||
and kernel_preference == KernelPreference.FBGEMM | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: lets Xfail this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are using unittest, seems like we can't do return unittest.expectedFailure("...")
?
.../ao/test/quantization/quantize_/workflows/float8/test_float8_tensor.py", line 92, in test_fp8_linear_variants
return unittest.expectedFailure(
File ".../python3.10/unittest/case.py", line 148, in expectedFailure
test_item.__unittest_expecting_failure__ = True
AttributeError: 'str' object has no attribute '__unittest_expecting_failure__'
but let me know if there is an example to do expectedFailure conditionally instead of skipping entire test
Can you explain specifically what did not work, and why it works after this PR? It would also be good to have a test which fails before this PR and passes after this PR. |
a1f4504
to
a73fa51
Compare
@vkuzo sure, updated the PR summary and added a test for this one and next PR as well |
note that I don't think this is true, all values of |
@@ -399,6 +413,29 @@ def test_moe_weight_reshape_ops(self): | |||
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) | |||
self._test_moe_weight_reshape_ops(config) | |||
|
|||
def test_expected_gpu_kernel_fbgemm(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this test should be together with the other tests we have which check the same thing for other settings of this config, currently in test_affine_quantized_float.py
. Can we add a TODO to unify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think we can put everything here after we deprecate the AQT path in 9 months
a73fa51
to
f685f8b
Compare
makes sense, it is user facing |
f685f8b
to
8bea88e
Compare
Summary: Previously if user specifies kernel_preference == "fbgemm", we'll use torch ops like `_choose_scale_float8` and `_quantize_affine_float8` to quantize the high precision Tensor into a float8 Tensor this PR makes sure we use fbgemm kernels when kernel_preference is "fbgemm", meaning: `torch.ops.triton.quantize_fp8_row` for per row, and `torch.ops.fbgemm.quantize_fp8_per_tensor` for per tensor (while `torch.ops.fbgemm.quantize_fp8_per_tensor` has some issues right now and we'll enable later when it's fixed) This doesn't have impact on BC, meaning old serialized model can still be loaded and run, only thing is fixing the kernel choice for fbgemm kernel preference means users who requested FBGEMM kernelpreference now actually run fbgemm quantize op instead of torch op Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_expected_gpu_kernel_fbgemm Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2883, branch: jerryzh168/stack/59
8bea88e
to
4516f6e
Compare
Stacked PRs:
Fix Float8Tensor quantize op kernrel preference dispatch
Summary:
Previously if user specifies kernel_preference == "fbgemm", we'll use torch ops like
_choose_scale_float8
and_quantize_affine_float8
to quantize the high precision Tensorinto a float8 Tensor
this PR makes sure we use fbgemm kernels when kernel_preference is "fbgemm", meaning:
torch.ops.triton.quantize_fp8_row
for per row, andtorch.ops.fbgemm.quantize_fp8_per_tensor
for per tensor (while
torch.ops.fbgemm.quantize_fp8_per_tensor
has some issues right now and we'll enable later when it's fixed)This doesn't have impact on BC, meaning old serialized model can still be loaded and run, only thing is fixing the kernel choice for fbgemm kernel preference
means users who requested FBGEMM kernelpreference now actually run fbgemm quantize op instead of torch op
Test Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_expected_gpu_kernel_fbgemm
Reviewers:
Subscribers:
Tasks:
Tags: