-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[FP8] Extend per-token-group quantization support to QuantFP8 #24342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully refactors FP8 quantization by extending QuantFP8
to support per-token-group quantization, unifying the quantization paths. The changes are well-structured, improving code clarity and maintainability. I've identified one potential issue where a group size of 1 is not handled correctly, which could lead to a runtime error. My review includes a suggestion to fix this.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @ProExpertProg , it's not ready for review yet. Mainly opened this draft PR for early feedback. I'll address the comments and also add the torch implementation |
Okay please keep us posted, really looking forward to this work! |
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is moving in the right direction, but we should still try to move towards using QuantFP8
instead of a bunch of free function calls around. If you want to address that in a separate PR that's fine but then we shouldn't touch the MoE layers in this one at all and just add support for group quant to QuantFP8 here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sadly this refactor is going to be a massive can of worms but it has to be done. All of these free functions need to become classes. We can't pass QuantFP8 instances through.
@ProExpertProg understood. Then I’ll just revert the latest changes and keep only the torch impl in this PR. The refactoring seems more involved than I had imagined. This will also keep the PR short and easy to review. |
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
db3dc5e
to
2662be1
Compare
@tahsintunan can you post performance numbers for this? We might need to enable the custom path by default if the torch path is not fast enough |
B200 results
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just a few nits! I think we can actually extend the current benchmark instead, I can either push it to your branch or open a new PR - either way you should remove the benchmark you added
|
||
if padded_dim != hidden_dim: | ||
padding = padded_dim - hidden_dim | ||
x = F.pad(x, (0, padding), mode='constant', value=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there's a way to do this without padding - I worry the generated Triton kernel won't be able to eliminate the copy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm realizing the padding is unlikely anyway as group shape will likely divide the hidden_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, padding will only be used for non-standard dimensions, which should be rare.
@tahsintunan do you have a timeline for implementing the feedback? Based on the performance numbers we want to merge this as soon as feasible. I'd also be happy to finish your work and merge (you'll still be a coauthor on the commit) - let me know! Thanks for working on this |
@ProExpertProg I can take care of the remaining items by Monday. If you need this merged sooner, feel free to wrap it up. If not, I'll get it done by Monday. |
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
@tahsintunan yeah today or early tomorrow is totally fine! Thanks for taking this on. |
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
I ran the benchmark on H100 and found it was slower for 1x128 (the case we care about)
|
# recompile for different shapes | ||
fwd = torch.compile(fn, fullgraph=True, dynamic=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why aren't we compiling with dynamic=True? I don't think we should be targeting shape specialization since we won't use that in practice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice we specialize on all shapes except the first dim (num_tokens
). The with_dyn_arg
marks that shape as dynamic to fully simulate vLLM usage 👍
batch_size_range = [1, 16, 32, 64, 128] | ||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] | ||
hidden_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] | ||
batch_sizes = [1, 16, 32, 64, 128] | ||
group_shapes = [ | ||
GroupShape.PER_TENSOR, | ||
GroupShape.PER_TOKEN, | ||
GroupShape(1, 64), | ||
GroupShape(1, 128), | ||
] | ||
column_major_scales = [True, False] | ||
|
||
config_gen = itertools.product( | ||
group_shapes, | ||
column_major_scales, | ||
batch_sizes, | ||
hidden_sizes, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make some of these args? It takes a really long time by default to run all of this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can that be a follow-up it requires reworking the structure a lot (because currently this is passed to the function annotation).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed. Now you should be able to use something like this:
python3 benchmarks/kernels/bench_per_token_quant_fp8.py --hidden-sizes 1024 2048 4096 --batch-sizes 32 --group-sizes 128 --no-column-major
assert not static, "Group quantization only supports dynamic mode" | ||
self.group_size = group_shape.col | ||
else: | ||
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add an assert that column_major_scales is False if non group?
|
||
x_quant = x_quant.view(-1, padded_dim) | ||
if padded_dim != hidden_dim: | ||
x_quant = x_quant[..., :hidden_dim] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make sure this is contiguous after stripping padding?
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
Ran this on 5090
|
Merging for now since it is valid and isn't on by default, thanks for the nice work! |
…roject#24342) Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com>
…roject#24342) Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: charlifu <charlifu@amd.com>
Purpose
Extends
QuantFP8
to support per-token-group quantization and adds a torch implementation of the QuantFP8 group quantization.Addresses #24185
Changes
is_per_tensor()
,is_per_token()
,is_per_group()
helper methods toGroupShape
QuantFP8
to support arbitrary group sizes likeGroupShape(1, 128)
test_fp8_quant_group.py
andbenchmark_quantfp8_group.py
Test Plan
Tested with existing test suite:
tests/kernels/quantization/test_fp8_quant.py
tests/kernels/quantization/test_fp8_quant_group.py