From b0b9d48b3c2c39a1d689f2dc54bb5e6f210b5a80 Mon Sep 17 00:00:00 2001 From: Tahsin Tunan Date: Sat, 6 Sep 2025 00:30:36 +0600 Subject: [PATCH 1/8] add per-token-group quantization support to QuantFP8 Signed-off-by: Tahsin Tunan --- vllm/model_executor/layers/fused_moe/utils.py | 27 +++++-- .../layers/quantization/input_quant_fp8.py | 75 +++++++++++++------ .../layers/quantization/utils/quant_utils.py | 9 +++ 3 files changed, 82 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 1aeb3f92bc3e..138045a88ac7 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -5,15 +5,15 @@ import torch -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( quant_dequant_mxfp4) from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( mxfp8_quantize) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv @@ -122,15 +122,26 @@ def _fp8_quantize( is provided, the output will be blocked. """ if block_shape is None: - # TODO(luka): use QuantFP8 custom op - # https://github.com/vllm-project/vllm/issues/20711 - A, A_scale = ops.scaled_fp8_quant( - A, A_scale, use_per_token_if_dynamic=per_act_token) + if per_act_token: + group_shape = GroupShape.PER_TOKEN + else: + group_shape = GroupShape.PER_TENSOR + + quant_op = QuantFP8(static=(A_scale is not None), + group_shape=group_shape) + A, A_scale = quant_op(A, A_scale) else: assert not per_act_token assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_fp8(A, block_k) + + group_shape = GroupShape(1, block_k) + quant_op = QuantFP8( + static=False, # Group quantization is always dynamic + group_shape=group_shape, + column_major_scales=False # Use row-major for MoE + ) + A, A_scale = quant_op(A) assert cdiv(A.size(-1), block_k) == A_scale.size(-1) return A, A_scale diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index e1a9bdde9334..f104f656de43 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -23,28 +23,63 @@ @CustomOp.register("quant_fp8") class QuantFP8(CustomOp): """ - Quantize input tensor to per-tensor or per-token FP8. + Quantize input tensor to FP8 (per-tensor, per-token, or per-group). This CustomOp supports both static and dynamic quantization. """ def __init__(self, static: bool, group_shape: GroupShape, - num_token_padding: Optional[int] = None): + num_token_padding: Optional[int] = None, + column_major_scales: bool = False): """ - :param static: static or dynamic quantization - :param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR) - :param num_token_padding: Pad the token dimension of output to this size + :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, + or arbitrary block size) + :param num_token_padding: Pad the token dimension of output to this + size + :param column_major_scales: For group quantization, output scales in + column major format """ super().__init__() - self.num_token_padding = num_token_padding - assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} - assert not static or group_shape == GroupShape.PER_TENSOR, \ - "Only per-tensor scales supported for static quantization." self.static = static self.group_shape = group_shape - self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN + self.num_token_padding = num_token_padding + self.column_major_scales = column_major_scales + + self.is_group_quant = group_shape.is_per_group() + if self.is_group_quant: + assert not static, "Group quantization only supports dynamic mode" + self.group_size = group_shape.col + else: + assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} + assert not static or group_shape == GroupShape.PER_TENSOR, \ + "Only per-tensor scales supported for static quantization." + self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN + + def _quantize_group(self, + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) + return per_token_group_quant_fp8( + x, + group_size=self.group_size, + column_major_scales=self.column_major_scales, + dtype=_FP8_DTYPE) + + def _compute_dynamic_scale( + self, x: torch.Tensor, + scale_ub: Optional[torch.Tensor]) -> torch.Tensor: + if self.group_shape == GroupShape.PER_TOKEN: + x_max, _ = x.abs().max(dim=-1) + x_max = x_max.unsqueeze(-1).to(torch.float32) + if scale_ub is not None: + x_max = x_max.clamp(max=scale_ub) + else: + x_max = x.abs().max().unsqueeze(-1).to(torch.float32) + + scale = x_max / _FP8_MAX + return scale.clamp(min=_FP8_MIN_SCALING_FACTOR) def forward_cuda( self, @@ -52,11 +87,14 @@ def forward_cuda( scale: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if self.is_group_quant: + assert scale is None, "Group quantization is always dynamic" + return self._quantize_group(x) + assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape == GroupShape.PER_TOKEN and scale_ub.numel() == 1) - return ops.scaled_fp8_quant( x, scale, @@ -70,22 +108,17 @@ def forward_native( scale: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None, ): + if self.is_group_quant: + assert scale is None, "Group quantization is always dynamic" + return self._quantize_group(x) + assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape == GroupShape.PER_TOKEN and scale_ub.numel() == 1) if scale is None: - if self.group_shape == GroupShape.PER_TOKEN: - x_max, _ = x.abs().max(dim=-1) - x_max = x_max.unsqueeze(-1).to(torch.float32) - if scale_ub is not None: - x_max = x_max.clamp(max=scale_ub) - else: - x_max = x.abs().max().unsqueeze(-1).to(torch.float32) - - scale = x_max / _FP8_MAX - scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR) + scale = self._compute_dynamic_scale(x, scale_ub) # Even for dynamic per-token scales, # reciprocal performs slightly better than division diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index f4ff875adb21..bea9c8d51580 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -34,6 +34,15 @@ class GroupShape(_GroupShape): PER_TENSOR: ClassVar['GroupShape'] PER_TOKEN: ClassVar['GroupShape'] + def is_per_tensor(self) -> bool: + return self.row == -1 and self.col == -1 + + def is_per_token(self) -> bool: + return self.row == 1 and self.col == -1 + + def is_per_group(self) -> bool: + return self.row == 1 and self.col > 1 + GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1) From 74bd08458c2d5862d9c9bbd41fc7f349bdb97654 Mon Sep 17 00:00:00 2001 From: Tahsin Tunan Date: Sat, 6 Sep 2025 00:39:33 +0600 Subject: [PATCH 2/8] Update vllm/model_executor/layers/quantization/utils/quant_utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Tahsin Tunan --- vllm/model_executor/layers/quantization/utils/quant_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index bea9c8d51580..5339c6043cc1 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -41,7 +41,7 @@ def is_per_token(self) -> bool: return self.row == 1 and self.col == -1 def is_per_group(self) -> bool: - return self.row == 1 and self.col > 1 + return self.row == 1 and self.col >= 1 GroupShape.PER_TENSOR = GroupShape(-1, -1) From b50d1633de776d3abe94d3b20b7a0e8512dfb2fe Mon Sep 17 00:00:00 2001 From: Tahsin Tunan Date: Mon, 8 Sep 2025 05:41:13 +0600 Subject: [PATCH 3/8] Add PyTorch implementation for QuantFP8 group quantization Signed-off-by: Tahsin Tunan --- .../kernels/benchmark_quantfp8_group.py | 148 +++++++++++++ .../quantization/test_fp8_quant_group.py | 206 ++++++++++++++++++ .../layers/fused_moe/fused_moe.py | 4 +- vllm/model_executor/layers/fused_moe/utils.py | 33 +-- .../layers/quantization/input_quant_fp8.py | 60 ++--- .../quantization/utils/fp8_quant_ops.py | 110 ++++++++++ 6 files changed, 501 insertions(+), 60 deletions(-) create mode 100644 benchmarks/kernels/benchmark_quantfp8_group.py create mode 100644 tests/kernels/quantization/test_fp8_quant_group.py create mode 100644 vllm/model_executor/layers/quantization/utils/fp8_quant_ops.py diff --git a/benchmarks/kernels/benchmark_quantfp8_group.py b/benchmarks/kernels/benchmark_quantfp8_group.py new file mode 100644 index 000000000000..d8555a00f824 --- /dev/null +++ b/benchmarks/kernels/benchmark_quantfp8_group.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark for QuantFP8 Group Quantization implementation.""" + +import argparse + +import torch + +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import current_platform + + +def _time_cuda( + fn, + warmup_iters: int, + bench_iters: int, +) -> float: + # warmup + for _ in range(warmup_iters): + fn() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(bench_iters): + fn() + end.record() + torch.cuda.synchronize() + + return start.elapsed_time(end) / bench_iters # ms/iter + + +def run_benchmark( + shape: tuple[int, int], + group_size: int, + column_major: bool, + warmup_iters: int, + bench_iters: int, +) -> None: + """Benchmark QuantFP8 with group quantization using different backends.""" + num_tokens, hidden_dim = shape + + device = torch.device("cuda") + torch.manual_seed(42) + x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8 + + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8( + static=False, group_shape=group_shape, column_major_scales=column_major + ) + + def cuda_impl(): + return quant_op.forward_cuda(x.clone()) + + def native_impl(): + return quant_op.forward_native(x.clone()) + + cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters) + native_ms = _time_cuda(native_impl, warmup_iters, bench_iters) + + speedup = cuda_ms / native_ms if native_ms else 0 + + cfg_desc = f"shape={shape} gs={group_size:<3} col_major={column_major}" + print(f"{cfg_desc:45} | {cuda_ms:7.3f} | {native_ms:7.3f} | {speedup:6.2f}x") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark QuantFP8 group quantization implementation" + ) + parser.add_argument( + "--warmup-iters", type=int, default=10, help="Number of warmup iterations" + ) + parser.add_argument( + "--bench-iters", type=int, default=100, help="Number of benchmark iterations" + ) + parser.add_argument( + "--shapes", + type=str, + default="32,128;64,256;16,512;128,1024;256,2048", + help="Shapes to benchmark as 'tokens,hidden;...' (default: multiple shapes)", + ) + parser.add_argument( + "--group-sizes", + type=str, + default="64,128", + help="Group sizes to benchmark (comma-separated)", + ) + parser.add_argument( + "--no-column-major", + action="store_true", + help="Skip column-major scale benchmarks", + ) + return parser.parse_args() + + +def main(): + if not current_platform.is_cuda(): + raise RuntimeError("CUDA device is required to run this benchmark.") + + args = parse_args() + + shapes = [] + for shape_str in args.shapes.split(";"): + tokens, hidden = map(int, shape_str.split(",")) + shapes.append((tokens, hidden)) + + group_sizes = list(map(int, args.group_sizes.split(","))) + + print("\n" + "=" * 80) + print("QuantFP8 Group Quantization Benchmark (CUDA kernel vs PyTorch native)") + print("=" * 80) + print(f"Device: {torch.cuda.get_device_name()}") + print(f"Warmup iterations: {args.warmup_iters}") + print(f"Benchmark iterations: {args.bench_iters}") + print("=" * 80) + + print(f"{'Configuration':45} | {'CUDA':^9} | {'Native':^9} | {'Speedup':^8}") + print("-" * 80) + + for shape in shapes: + for gs in group_sizes: + run_benchmark( + shape, + gs, + column_major=False, + warmup_iters=args.warmup_iters, + bench_iters=args.bench_iters, + ) + + if not args.no_column_major: + run_benchmark( + shape, + gs, + column_major=True, + warmup_iters=args.warmup_iters, + bench_iters=args.bench_iters, + ) + + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py new file mode 100644 index 000000000000..47c877d22731 --- /dev/null +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for QuantFP8 Group Quantization implementation.""" + +import pytest +import torch + +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.platforms import current_platform + + +@pytest.mark.parametrize("batch_size", [16, 32]) +@pytest.mark.parametrize("hidden_dim", + [256, 512, 513]) # Include non-divisible +@pytest.mark.parametrize("group_size", [32, 64, 128]) +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_basic(batch_size: int, hidden_dim: int, + group_size: int, seed: int) -> None: + current_platform.seed_everything(seed) + + x = torch.randn( + (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + + # Create QuantFP8 with group quantization + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=False) + + expected_num_groups = (hidden_dim + group_size - 1) // group_size + + # Test CUDA implementation (only supports divisible dimensions) + if hidden_dim % group_size == 0: + x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) + assert x_quant_cuda.shape == x.shape + assert scales_cuda.shape == (batch_size, expected_num_groups) + + # Test PyTorch native implementation + x_quant_native, scales_native = quant_op.forward_native(x.clone()) + assert x_quant_native.shape == x.shape + assert scales_native.shape == (batch_size, expected_num_groups) + + # Test column_major_scales + quant_op_col = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=True) + _, scales_col = quant_op_col.forward_native(x.clone()) + assert scales_col.shape == (expected_num_groups, batch_size) + + +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_multidimensional(seed: int) -> None: + current_platform.seed_everything(seed) + + group_size = 64 + + # Test with 3D input + batch1, batch2, hidden_dim = 4, 8, 512 + x_3d = torch.randn( + (batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=False) + + x_quant, scales = quant_op.forward_native(x_3d.clone()) + assert x_quant.shape == x_3d.shape + assert scales.shape == (batch1, batch2, hidden_dim // group_size) + + # Test column_major_scales with multi-dim + quant_op_col = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=True) + _, scales_col = quant_op_col.forward_native(x_3d.clone()) + assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) + + # Test with 4D input + batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256 + x_4d = torch.randn((batch1, batch2, batch3, hidden_dim), + dtype=torch.bfloat16, + device="cuda") * 8 + + x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone()) + assert x_quant_4d.shape == x_4d.shape + assert scales_4d.shape == (batch1, batch2, batch3, + hidden_dim // group_size) + + _, scales_4d_col = quant_op_col.forward_native(x_4d.clone()) + assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, + batch3) + + +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("hidden_dim", [1024]) +@pytest.mark.parametrize("group_size", [128]) +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_cuda_native_consistency(batch_size: int, + hidden_dim: int, + group_size: int, + seed: int) -> None: + """Compare CUDA and native implementations for consistency.""" + current_platform.seed_everything(seed) + + x = torch.randn( + (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=False) + + # Run both implementations + x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) + x_quant_native, scales_native = quant_op.forward_native(x.clone()) + + # Check shapes match + assert x_quant_cuda.shape == x_quant_native.shape + assert scales_cuda.shape == scales_native.shape + + # Scales should match + assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) + + # Quantized values should mostly match, with rare rounding differences + # FP8 rounding at boundaries can differ between CUDA and PyTorch + diff_count = (x_quant_cuda != x_quant_native).sum().item() + diff_ratio = diff_count / x_quant_cuda.numel() + assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}" + + +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_edge_cases(seed: int) -> None: + current_platform.seed_everything(seed) + + batch_size = 16 + group_size = 64 + + # Test with single group (group_size >= hidden_dim) + x_small = torch.randn( + (batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8 + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=False) + + x_quant_small, scales_small = quant_op.forward_native(x_small.clone()) + assert x_quant_small.shape == x_small.shape + assert scales_small.shape == (batch_size, 1) + + # Test with zero inputs + x_zero = torch.zeros((batch_size, 256), + dtype=torch.bfloat16, + device="cuda") + x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone()) + assert x_quant_zero.shape == x_zero.shape + assert (scales_zero > 0).all(), "Scales should be clamped to minimum" + + # Test very large values + x_large = torch.full((batch_size, 256), + 1000.0, + dtype=torch.bfloat16, + device="cuda") + x_quant_large, scales_large = quant_op.forward_native(x_large.clone()) + assert x_quant_large.shape == x_large.shape + # FP8 max is typically 448 or 224, so scales should be > 1 + assert (scales_large > 1.0).all(), "Large values should have scales > 1" + + +@pytest.mark.parametrize( + "batch_size,hidden_dim,group_size", + [ + (16, 256, 16), # Small + (64, 1024, 64), # Medium + (128, 2048, 128), # Large + (8, 513, 64), # Non-divisible (native only) + ]) +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_various_configs(batch_size: int, hidden_dim: int, + group_size: int, seed: int) -> None: + current_platform.seed_everything(seed) + + x = torch.randn( + (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=False) + + expected_num_groups = (hidden_dim + group_size - 1) // group_size + + x_quant_native, scales_native = quant_op.forward_native(x.clone()) + assert x_quant_native.shape == x.shape + assert scales_native.shape == (batch_size, expected_num_groups) + + if hidden_dim % group_size == 0: + x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) + assert x_quant_cuda.shape == x.shape + assert scales_cuda.shape == (batch_size, expected_num_groups) + assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index eb3e14180ecf..42d75aa3f1eb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -32,9 +32,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8) + _resize_cache, moe_kernel_quantize_input) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( dequant_mxfp4) from vllm.platforms import current_platform diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 138045a88ac7..feeac0e042cf 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -5,15 +5,14 @@ import torch -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.fp8_quant_ops import ( + quantize_fp8_per_group, quantize_fp8_per_tensor, quantize_fp8_per_token) from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( quant_dequant_mxfp4) from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( mxfp8_quantize) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv @@ -123,28 +122,18 @@ def _fp8_quantize( """ if block_shape is None: if per_act_token: - group_shape = GroupShape.PER_TOKEN + return quantize_fp8_per_token(A, A_scale) else: - group_shape = GroupShape.PER_TENSOR - - quant_op = QuantFP8(static=(A_scale is not None), - group_shape=group_shape) - A, A_scale = quant_op(A, A_scale) + return quantize_fp8_per_tensor(A, A_scale) else: - assert not per_act_token - assert len(block_shape) == 2 + assert not per_act_token, \ + "per_act_token not supported with block_shape" + assert A_scale is None, \ + "Group quantization doesn't support static scales" + assert len(block_shape) == 2, "block_shape must be [m, k]" _, block_k = block_shape[0], block_shape[1] - - group_shape = GroupShape(1, block_k) - quant_op = QuantFP8( - static=False, # Group quantization is always dynamic - group_shape=group_shape, - column_major_scales=False # Use row-major for MoE - ) - A, A_scale = quant_op(A) - assert cdiv(A.size(-1), block_k) == A_scale.size(-1) - - return A, A_scale + return quantize_fp8_per_group( + A, block_k, column_major_scales=False) # Use row-major for MoE def _int8_quantize( diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index f104f656de43..fe5761ad549b 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -7,6 +7,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.utils.fp8_quant_ops import ( + quantize_fp8_per_group, quantize_fp8_per_tensor, quantize_fp8_per_token) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform @@ -14,10 +16,6 @@ # Using the default value (240.0) from pytorch will cause accuracy # issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm. _FP8_DTYPE = current_platform.fp8_dtype() -_FP8_FINFO = torch.finfo(_FP8_DTYPE) -_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max -_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min -_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) @CustomOp.register("quant_fp8") @@ -57,30 +55,6 @@ def __init__(self, "Only per-tensor scales supported for static quantization." self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN - def _quantize_group(self, - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) - return per_token_group_quant_fp8( - x, - group_size=self.group_size, - column_major_scales=self.column_major_scales, - dtype=_FP8_DTYPE) - - def _compute_dynamic_scale( - self, x: torch.Tensor, - scale_ub: Optional[torch.Tensor]) -> torch.Tensor: - if self.group_shape == GroupShape.PER_TOKEN: - x_max, _ = x.abs().max(dim=-1) - x_max = x_max.unsqueeze(-1).to(torch.float32) - if scale_ub is not None: - x_max = x_max.clamp(max=scale_ub) - else: - x_max = x.abs().max().unsqueeze(-1).to(torch.float32) - - scale = x_max / _FP8_MAX - return scale.clamp(min=_FP8_MIN_SCALING_FACTOR) - def forward_cuda( self, x: torch.Tensor, @@ -89,7 +63,7 @@ def forward_cuda( ) -> tuple[torch.Tensor, torch.Tensor]: if self.is_group_quant: assert scale is None, "Group quantization is always dynamic" - return self._quantize_group(x) + return self._quantize_group_cuda(x) assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape @@ -110,20 +84,17 @@ def forward_native( ): if self.is_group_quant: assert scale is None, "Group quantization is always dynamic" - return self._quantize_group(x) + return self._quantize_group_native(x) assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape == GroupShape.PER_TOKEN and scale_ub.numel() == 1) - if scale is None: - scale = self._compute_dynamic_scale(x, scale_ub) - - # Even for dynamic per-token scales, - # reciprocal performs slightly better than division - out = x.to(torch.float32) * scale.reciprocal() - out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + if self.use_per_token_if_dynamic and scale is None: + out, scale = quantize_fp8_per_token(x, scale, scale_ub) + else: + out, scale = quantize_fp8_per_tensor(x, scale) # This currently generates an extra Triton kernel in compilation. # Fortunately, we don't use padding if compiling. @@ -134,3 +105,18 @@ def forward_native( out = F.pad(out, (0, 0, 0, padding), "constant", 0.0) return out, scale + + def _quantize_group_cuda( + self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) + return per_token_group_quant_fp8( + x, + group_size=self.group_size, + column_major_scales=self.column_major_scales, + dtype=_FP8_DTYPE) + + def _quantize_group_native( + self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_fp8_per_group(x, self.group_size, + self.column_major_scales) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_quant_ops.py b/vllm/model_executor/layers/quantization/utils/fp8_quant_ops.py new file mode 100644 index 000000000000..be1cf3f85fc3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/fp8_quant_ops.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +import torch.nn.functional as F + +from vllm.platforms import current_platform + +_FP8_DTYPE = current_platform.fp8_dtype() +_FP8_FINFO = torch.finfo(_FP8_DTYPE) +_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max +_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min +_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) + + +def quantize_fp8_per_tensor( + x: torch.Tensor, + scale: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute per-tensor FP8 quantization. + + Args: + x: Input tensor to quantize + scale: Optional pre-computed scale (for static quantization) + + Returns: + Quantized tensor and scale + """ + if scale is None: + x_max = x.abs().max().unsqueeze(-1).to(torch.float32) + scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + + # Even for dynamic per-token scales, + # reciprocal performs slightly better than division + out = x.to(torch.float32) * scale.reciprocal() + out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + return out, scale + + +def quantize_fp8_per_token( + x: torch.Tensor, + scale: Optional[torch.Tensor] = None, + scale_ub: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute per-token FP8 quantization. + + Args: + x: Input tensor to quantize + scale: Optional pre-computed scale (for static quantization) + scale_ub: Optional upper bound for scale + + Returns: + Quantized tensor and scale + """ + if scale is None: + x_max, _ = x.abs().max(dim=-1) + x_max = x_max.unsqueeze(-1).to(torch.float32) + if scale_ub is not None: + x_max = x_max.clamp(max=scale_ub) + scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + + out = x.to(torch.float32) * scale.reciprocal() + out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + return out, scale + + +def quantize_fp8_per_group(x: torch.Tensor, + group_size: int, + column_major_scales: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute per-group FP8 quantization. + + Args: + x: Input tensor to quantize + group_size: Size of quantization groups + column_major_scales: If True, output scales in column-major format + + Returns: + Quantized tensor and per-group scales + """ + orig_shape = x.shape + hidden_dim = x.shape[-1] + num_groups = (hidden_dim + group_size - 1) // group_size + padded_dim = num_groups * group_size + + if padded_dim != hidden_dim: + padding = padded_dim - hidden_dim + x = F.pad(x, (0, padding), mode='constant', value=0.0) + + x_grouped = x.view(-1, num_groups, group_size) + absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() + scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + + x_scaled = x_grouped / scales + x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + + x_quant = x_quant.view(-1, padded_dim) + if padded_dim != hidden_dim: + x_quant = x_quant[..., :hidden_dim] + x_quant = x_quant.view(orig_shape) + + scales = scales.squeeze(-1) + scales = scales.reshape(orig_shape[:-1] + (num_groups, )) + + if column_major_scales: + scales = scales.transpose(-2, -1).contiguous() + + return x_quant, scales From 2662be15d60e2efcf7e89af9887b9b77ea4ecdff Mon Sep 17 00:00:00 2001 From: Tahsin Tunan Date: Wed, 10 Sep 2025 19:31:14 +0600 Subject: [PATCH 4/8] refactor: move FP8 quantization functions into QuantFP8 Signed-off-by: Tahsin Tunan --- vllm/model_executor/layers/fused_moe/utils.py | 26 ++--- .../layers/quantization/input_quant_fp8.py | 56 ++++++++- .../quantization/utils/fp8_quant_ops.py | 110 ------------------ 3 files changed, 63 insertions(+), 129 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/utils/fp8_quant_ops.py diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index feeac0e042cf..1aeb3f92bc3e 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -5,8 +5,9 @@ import torch -from vllm.model_executor.layers.quantization.utils.fp8_quant_ops import ( - quantize_fp8_per_group, quantize_fp8_per_tensor, quantize_fp8_per_token) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( @@ -121,19 +122,18 @@ def _fp8_quantize( is provided, the output will be blocked. """ if block_shape is None: - if per_act_token: - return quantize_fp8_per_token(A, A_scale) - else: - return quantize_fp8_per_tensor(A, A_scale) + # TODO(luka): use QuantFP8 custom op + # https://github.com/vllm-project/vllm/issues/20711 + A, A_scale = ops.scaled_fp8_quant( + A, A_scale, use_per_token_if_dynamic=per_act_token) else: - assert not per_act_token, \ - "per_act_token not supported with block_shape" - assert A_scale is None, \ - "Group quantization doesn't support static scales" - assert len(block_shape) == 2, "block_shape must be [m, k]" + assert not per_act_token + assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] - return quantize_fp8_per_group( - A, block_k, column_major_scales=False) # Use row-major for MoE + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert cdiv(A.size(-1), block_k) == A_scale.size(-1) + + return A, A_scale def _int8_quantize( diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index fe5761ad549b..ab71b41c3264 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -7,8 +7,6 @@ from vllm import _custom_ops as ops from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization.utils.fp8_quant_ops import ( - quantize_fp8_per_group, quantize_fp8_per_tensor, quantize_fp8_per_token) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform @@ -16,6 +14,10 @@ # Using the default value (240.0) from pytorch will cause accuracy # issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm. _FP8_DTYPE = current_platform.fp8_dtype() +_FP8_FINFO = torch.finfo(_FP8_DTYPE) +_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max +_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min +_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) @CustomOp.register("quant_fp8") @@ -92,9 +94,25 @@ def forward_native( and scale_ub.numel() == 1) if self.use_per_token_if_dynamic and scale is None: - out, scale = quantize_fp8_per_token(x, scale, scale_ub) + # Per-token quantization logic + x_max, _ = x.abs().max(dim=-1) + x_max = x_max.unsqueeze(-1).to(torch.float32) + if scale_ub is not None: + x_max = x_max.clamp(max=scale_ub) + scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + + out = x.to(torch.float32) * scale.reciprocal() + out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) else: - out, scale = quantize_fp8_per_tensor(x, scale) + # Per-tensor quantization logic + if scale is None: + x_max = x.abs().max().unsqueeze(-1).to(torch.float32) + scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + + # Even for dynamic per-token scales, + # reciprocal performs slightly better than division + out = x.to(torch.float32) * scale.reciprocal() + out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) # This currently generates an extra Triton kernel in compilation. # Fortunately, we don't use padding if compiling. @@ -118,5 +136,31 @@ def _quantize_group_cuda( def _quantize_group_native( self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - return quantize_fp8_per_group(x, self.group_size, - self.column_major_scales) + orig_shape = x.shape + hidden_dim = x.shape[-1] + num_groups = (hidden_dim + self.group_size - 1) // self.group_size + padded_dim = num_groups * self.group_size + + if padded_dim != hidden_dim: + padding = padded_dim - hidden_dim + x = F.pad(x, (0, padding), mode='constant', value=0.0) + + x_grouped = x.view(-1, num_groups, self.group_size) + absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() + scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + + x_scaled = x_grouped / scales + x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + + x_quant = x_quant.view(-1, padded_dim) + if padded_dim != hidden_dim: + x_quant = x_quant[..., :hidden_dim] + x_quant = x_quant.view(orig_shape) + + scales = scales.squeeze(-1) + scales = scales.reshape(orig_shape[:-1] + (num_groups, )) + + if self.column_major_scales: + scales = scales.transpose(-2, -1).contiguous() + + return x_quant, scales diff --git a/vllm/model_executor/layers/quantization/utils/fp8_quant_ops.py b/vllm/model_executor/layers/quantization/utils/fp8_quant_ops.py deleted file mode 100644 index be1cf3f85fc3..000000000000 --- a/vllm/model_executor/layers/quantization/utils/fp8_quant_ops.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import torch -import torch.nn.functional as F - -from vllm.platforms import current_platform - -_FP8_DTYPE = current_platform.fp8_dtype() -_FP8_FINFO = torch.finfo(_FP8_DTYPE) -_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max -_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min -_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) - - -def quantize_fp8_per_tensor( - x: torch.Tensor, - scale: Optional[torch.Tensor] = None -) -> tuple[torch.Tensor, torch.Tensor]: - """Compute per-tensor FP8 quantization. - - Args: - x: Input tensor to quantize - scale: Optional pre-computed scale (for static quantization) - - Returns: - Quantized tensor and scale - """ - if scale is None: - x_max = x.abs().max().unsqueeze(-1).to(torch.float32) - scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) - - # Even for dynamic per-token scales, - # reciprocal performs slightly better than division - out = x.to(torch.float32) * scale.reciprocal() - out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) - return out, scale - - -def quantize_fp8_per_token( - x: torch.Tensor, - scale: Optional[torch.Tensor] = None, - scale_ub: Optional[torch.Tensor] = None -) -> tuple[torch.Tensor, torch.Tensor]: - """Compute per-token FP8 quantization. - - Args: - x: Input tensor to quantize - scale: Optional pre-computed scale (for static quantization) - scale_ub: Optional upper bound for scale - - Returns: - Quantized tensor and scale - """ - if scale is None: - x_max, _ = x.abs().max(dim=-1) - x_max = x_max.unsqueeze(-1).to(torch.float32) - if scale_ub is not None: - x_max = x_max.clamp(max=scale_ub) - scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) - - out = x.to(torch.float32) * scale.reciprocal() - out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) - return out, scale - - -def quantize_fp8_per_group(x: torch.Tensor, - group_size: int, - column_major_scales: bool = False - ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute per-group FP8 quantization. - - Args: - x: Input tensor to quantize - group_size: Size of quantization groups - column_major_scales: If True, output scales in column-major format - - Returns: - Quantized tensor and per-group scales - """ - orig_shape = x.shape - hidden_dim = x.shape[-1] - num_groups = (hidden_dim + group_size - 1) // group_size - padded_dim = num_groups * group_size - - if padded_dim != hidden_dim: - padding = padded_dim - hidden_dim - x = F.pad(x, (0, padding), mode='constant', value=0.0) - - x_grouped = x.view(-1, num_groups, group_size) - absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() - scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) - - x_scaled = x_grouped / scales - x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) - - x_quant = x_quant.view(-1, padded_dim) - if padded_dim != hidden_dim: - x_quant = x_quant[..., :hidden_dim] - x_quant = x_quant.view(orig_shape) - - scales = scales.squeeze(-1) - scales = scales.reshape(orig_shape[:-1] + (num_groups, )) - - if column_major_scales: - scales = scales.transpose(-2, -1).contiguous() - - return x_quant, scales From 4fe4578514f6501269924265cac7c3f1e0c016d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 11 Sep 2025 14:51:27 -0700 Subject: [PATCH 5/8] Refactor benchmark to support all group shapes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .../kernels/bench_per_token_quant_fp8.py | 192 ++++++++++++++---- 1 file changed, 154 insertions(+), 38 deletions(-) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 923d678f1f2d..310853b9d1ba 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -2,14 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from typing import Callable +from unittest.mock import patch +import pandas as pd import torch -from vllm import _custom_ops as ops -from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.triton_utils import triton +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + + +def with_triton_mode(fn): + """Temporarily force the Triton fallback path""" + + def wrapped(*args, **kwargs): + with patch("vllm.platforms.current_platform.is_cuda", return_value=False): + return fn(*args, **kwargs) + + return wrapped # TODO(luka): use standalone_compile utility @@ -21,78 +32,183 @@ def inner(*args): return inner -torch._dynamo.config.recompile_limit = 8888 -compilation_config = CompilationConfig(custom_ops=["none"]) -with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)): - torch_per_token_quant_fp8 = torch.compile( - QuantFP8(False, GroupShape.PER_TOKEN), - fullgraph=True, - dynamic=False, # recompile for different shapes - ) +def bench_compile(fn: Callable): + # recompile for different shapes + fwd = torch.compile(fn, fullgraph=True, dynamic=False) # First dim is explicitly dynamic to simulate vLLM usage - torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0) + return with_dyn_arg(fwd, 0, 0) -def cuda_per_token_quant_fp8( - input: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - return ops.scaled_fp8_quant(input) +torch._dynamo.config.recompile_limit = 8888 -def calculate_diff(batch_size: int, seq_len: int): - """Calculate difference between Triton and CUDA implementations.""" +def calculate_diff( + batch_size: int, + hidden_size: int, + group_shape: GroupShape, + dtype: torch.dtype, +): + """Calculate the difference between Inductor and CUDA implementations.""" device = torch.device("cuda") - x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device) + x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device) + + quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False) - torch_out, torch_scale = torch_per_token_quant_fp8(x) - cuda_out, cuda_scale = cuda_per_token_quant_fp8(x) + torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x) + torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x) + cuda_out, cuda_scale = quant_fp8.forward_cuda(x) - if torch.allclose( - cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5 - ) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5): + out_allclose = lambda o1, o2: torch.allclose( + o1.to(torch.float32), + o2.to(torch.float32), + rtol=1e-3, + atol=1e-5, + ) + scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5) + + if ( + out_allclose(cuda_out, torch_out) + and scale_allclose(cuda_scale, torch_scale) + and out_allclose(cuda_out, torch_eager_out) + and scale_allclose(cuda_scale, torch_eager_scale) + ): print("✅ All implementations match") else: print("❌ Implementations differ") -batch_size_range = [1, 16, 32, 64, 128] -seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] +hidden_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] +batch_sizes = [1, 16, 32, 64, 128] +group_shapes = [ + GroupShape.PER_TENSOR, + GroupShape.PER_TOKEN, + GroupShape(1, 64), + GroupShape(1, 128), +] +column_major_scales = [True, False] + +config_gen = itertools.product( + group_shapes, + column_major_scales, + batch_sizes, + hidden_sizes, +) -configs = list(itertools.product(batch_size_range, seq_len_range)) +# filter out column-major scales for non-group, reverse order +configs = list(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1])) @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["batch_size", "seq_len"], + x_names=["hidden_size", "batch_size", "col_major", "group_shape"], x_vals=configs, line_arg="provider", - line_vals=["torch", "cuda"], - line_names=["Torch", "CUDA"], - styles=[("blue", "-"), ("green", "-")], + line_vals=["torch", "cuda", "triton"], + line_names=["Torch (Compiled)", "CUDA", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("black", "-")], ylabel="us", - plot_name="per-token-dynamic-quant-fp8-performance", + plot_name="QuantFP8 performance", args={}, ) ) -def benchmark_quantization(batch_size, seq_len, provider): - dtype = torch.float16 +def benchmark_quantization( + batch_size, + hidden_size, + provider, + group_shape: GroupShape, + col_major: bool, + dtype: torch.dtype, +): device = torch.device("cuda") - x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) + x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] + quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major) if provider == "torch": - fn = lambda: torch_per_token_quant_fp8(x.clone()) + fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone()) elif provider == "cuda": - fn = lambda: cuda_per_token_quant_fp8(x.clone()) + fn = lambda: quant_fp8.forward_cuda(x.clone()) + elif provider == "triton": + if not group_shape.is_per_group(): + # Triton only supported for per-group + return 0, 0, 0 + + fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone()) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) return 1000 * ms, 1000 * max_ms, 1000 * min_ms +# TODO(luka) extract to utils +def compute_geomean_speedups( + df: pd.DataFrame, + baseline_col: str, + speedup_cols: list[str], + groupby_cols: list[str] | None = None, +) -> pd.DataFrame: + """ + Compute geometric mean speedups over a baseline column. + + Args: + df: Input dataframe + baseline_col: Column to use as baseline + speedup_cols: Columns to compute speedups for + groupby_cols: Columns to group by. If None, compute over entire df. + + Returns: + pd.DataFrame with geometric mean speedups + """ + from scipy.stats import gmean + + def geo_speedup(group: pd.DataFrame) -> pd.Series: + ratios = { + col: (group[baseline_col] / group[col]).values for col in speedup_cols + } + return pd.Series({col: gmean(vals) for col, vals in ratios.items()}) + + if groupby_cols is None: + result = geo_speedup(df).to_frame().T + else: + result = df.groupby(groupby_cols).apply(geo_speedup).reset_index() + + return result + + if __name__ == "__main__": - calculate_diff(batch_size=4, seq_len=4096) - benchmark_quantization.run(print_data=True) + parser = FlexibleArgumentParser( + description="Benchmark the various implementations of QuantFP8 (dynamic-only)" + ) + parser.add_argument("-c", "--check", action="store_true") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) + + args = parser.parse_args() + assert args + + dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + + if args.check: + for group_shape in group_shapes: + group_size = group_shape[1] + print(f"{group_size=}") + calculate_diff( + batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype + ) + + df = benchmark_quantization.run(print_data=True, dtype=dtype, return_df=True) + + # Print geomean speedups + geo_table_grouped = compute_geomean_speedups( + df, + baseline_col="Torch (Compiled)", + speedup_cols=["CUDA", "Triton"], + groupby_cols=["col_major", "group_shape"], + ) + + print("Speedup over Torch (Compiled)") + print(geo_table_grouped.to_string(index=False)) From 100b11c38d595c7c0e0e6290574b82728ec9ee71 Mon Sep 17 00:00:00 2001 From: Tahsin Tunan Date: Mon, 15 Sep 2025 20:31:31 +0600 Subject: [PATCH 6/8] refactor: clean up QuantFP8 forward methods and consolidate tests Signed-off-by: Tahsin Tunan --- .../kernels/benchmark_quantfp8_group.py | 148 ------------------ .../quantization/test_fp8_quant_group.py | 123 +++++---------- .../layers/quantization/input_quant_fp8.py | 48 +++--- 3 files changed, 54 insertions(+), 265 deletions(-) delete mode 100644 benchmarks/kernels/benchmark_quantfp8_group.py diff --git a/benchmarks/kernels/benchmark_quantfp8_group.py b/benchmarks/kernels/benchmark_quantfp8_group.py deleted file mode 100644 index d8555a00f824..000000000000 --- a/benchmarks/kernels/benchmark_quantfp8_group.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Benchmark for QuantFP8 Group Quantization implementation.""" - -import argparse - -import torch - -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape -from vllm.platforms import current_platform - - -def _time_cuda( - fn, - warmup_iters: int, - bench_iters: int, -) -> float: - # warmup - for _ in range(warmup_iters): - fn() - torch.cuda.synchronize() - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - for _ in range(bench_iters): - fn() - end.record() - torch.cuda.synchronize() - - return start.elapsed_time(end) / bench_iters # ms/iter - - -def run_benchmark( - shape: tuple[int, int], - group_size: int, - column_major: bool, - warmup_iters: int, - bench_iters: int, -) -> None: - """Benchmark QuantFP8 with group quantization using different backends.""" - num_tokens, hidden_dim = shape - - device = torch.device("cuda") - torch.manual_seed(42) - x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8 - - group_shape = GroupShape(1, group_size) - quant_op = QuantFP8( - static=False, group_shape=group_shape, column_major_scales=column_major - ) - - def cuda_impl(): - return quant_op.forward_cuda(x.clone()) - - def native_impl(): - return quant_op.forward_native(x.clone()) - - cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters) - native_ms = _time_cuda(native_impl, warmup_iters, bench_iters) - - speedup = cuda_ms / native_ms if native_ms else 0 - - cfg_desc = f"shape={shape} gs={group_size:<3} col_major={column_major}" - print(f"{cfg_desc:45} | {cuda_ms:7.3f} | {native_ms:7.3f} | {speedup:6.2f}x") - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Benchmark QuantFP8 group quantization implementation" - ) - parser.add_argument( - "--warmup-iters", type=int, default=10, help="Number of warmup iterations" - ) - parser.add_argument( - "--bench-iters", type=int, default=100, help="Number of benchmark iterations" - ) - parser.add_argument( - "--shapes", - type=str, - default="32,128;64,256;16,512;128,1024;256,2048", - help="Shapes to benchmark as 'tokens,hidden;...' (default: multiple shapes)", - ) - parser.add_argument( - "--group-sizes", - type=str, - default="64,128", - help="Group sizes to benchmark (comma-separated)", - ) - parser.add_argument( - "--no-column-major", - action="store_true", - help="Skip column-major scale benchmarks", - ) - return parser.parse_args() - - -def main(): - if not current_platform.is_cuda(): - raise RuntimeError("CUDA device is required to run this benchmark.") - - args = parse_args() - - shapes = [] - for shape_str in args.shapes.split(";"): - tokens, hidden = map(int, shape_str.split(",")) - shapes.append((tokens, hidden)) - - group_sizes = list(map(int, args.group_sizes.split(","))) - - print("\n" + "=" * 80) - print("QuantFP8 Group Quantization Benchmark (CUDA kernel vs PyTorch native)") - print("=" * 80) - print(f"Device: {torch.cuda.get_device_name()}") - print(f"Warmup iterations: {args.warmup_iters}") - print(f"Benchmark iterations: {args.bench_iters}") - print("=" * 80) - - print(f"{'Configuration':45} | {'CUDA':^9} | {'Native':^9} | {'Speedup':^8}") - print("-" * 80) - - for shape in shapes: - for gs in group_sizes: - run_benchmark( - shape, - gs, - column_major=False, - warmup_iters=args.warmup_iters, - bench_iters=args.bench_iters, - ) - - if not args.no_column_major: - run_benchmark( - shape, - gs, - column_major=True, - warmup_iters=args.warmup_iters, - bench_iters=args.bench_iters, - ) - - print("=" * 80) - - -if __name__ == "__main__": - main() diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 47c877d22731..1377148102fc 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -11,45 +11,64 @@ from vllm.platforms import current_platform -@pytest.mark.parametrize("batch_size", [16, 32]) -@pytest.mark.parametrize("hidden_dim", - [256, 512, 513]) # Include non-divisible -@pytest.mark.parametrize("group_size", [32, 64, 128]) +@pytest.mark.parametrize( + "batch_size,hidden_dim,group_size", + [ + (16, 256, 32), # Small + (64, 1024, 64), # Medium + (128, 2048, 128), # Large + (8, 513, 64), # Non-divisible (native only) + ]) @pytest.mark.parametrize("seed", [42]) @torch.inference_mode() -def test_quantfp8_group_basic(batch_size: int, hidden_dim: int, - group_size: int, seed: int) -> None: +def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, + group_size: int, seed: int) -> None: + """Test QuantFP8 group quantization with various configurations. + + Tests both CUDA and native implementations, column-major scales, + and verifies consistency between implementations. + """ current_platform.seed_everything(seed) x = torch.randn( (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + expected_num_groups = (hidden_dim + group_size - 1) // group_size + is_divisible = hidden_dim % group_size == 0 - # Create QuantFP8 with group quantization group_shape = GroupShape(1, group_size) quant_op = QuantFP8(static=False, group_shape=group_shape, column_major_scales=False) - expected_num_groups = (hidden_dim + group_size - 1) // group_size + # 1. Test native implementation (always available) + x_quant_native, scales_native = quant_op.forward_native(x.clone()) + assert x_quant_native.shape == x.shape + assert scales_native.shape == (batch_size, expected_num_groups) - # Test CUDA implementation (only supports divisible dimensions) - if hidden_dim % group_size == 0: + # 2. Test CUDA implementation (only for divisible dimensions) + x_quant_cuda = None + scales_cuda = None + if is_divisible: x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) assert x_quant_cuda.shape == x.shape assert scales_cuda.shape == (batch_size, expected_num_groups) - # Test PyTorch native implementation - x_quant_native, scales_native = quant_op.forward_native(x.clone()) - assert x_quant_native.shape == x.shape - assert scales_native.shape == (batch_size, expected_num_groups) - - # Test column_major_scales + # 3. Test column-major scales configuration quant_op_col = QuantFP8(static=False, group_shape=group_shape, column_major_scales=True) _, scales_col = quant_op_col.forward_native(x.clone()) assert scales_col.shape == (expected_num_groups, batch_size) + # 4. Verify CUDA/native consistency (when CUDA is available) + if is_divisible: + assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) + + # Quantized values should mostly match + diff_count = (x_quant_cuda != x_quant_native).sum().item() + diff_ratio = diff_count / x_quant_cuda.numel() + assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}" + @pytest.mark.parametrize("seed", [42]) @torch.inference_mode() @@ -95,44 +114,6 @@ def test_quantfp8_group_multidimensional(seed: int) -> None: batch3) -@pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("hidden_dim", [1024]) -@pytest.mark.parametrize("group_size", [128]) -@pytest.mark.parametrize("seed", [42]) -@torch.inference_mode() -def test_quantfp8_group_cuda_native_consistency(batch_size: int, - hidden_dim: int, - group_size: int, - seed: int) -> None: - """Compare CUDA and native implementations for consistency.""" - current_platform.seed_everything(seed) - - x = torch.randn( - (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 - - group_shape = GroupShape(1, group_size) - quant_op = QuantFP8(static=False, - group_shape=group_shape, - column_major_scales=False) - - # Run both implementations - x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) - x_quant_native, scales_native = quant_op.forward_native(x.clone()) - - # Check shapes match - assert x_quant_cuda.shape == x_quant_native.shape - assert scales_cuda.shape == scales_native.shape - - # Scales should match - assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) - - # Quantized values should mostly match, with rare rounding differences - # FP8 rounding at boundaries can differ between CUDA and PyTorch - diff_count = (x_quant_cuda != x_quant_native).sum().item() - diff_ratio = diff_count / x_quant_cuda.numel() - assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}" - - @pytest.mark.parametrize("seed", [42]) @torch.inference_mode() def test_quantfp8_group_edge_cases(seed: int) -> None: @@ -170,37 +151,3 @@ def test_quantfp8_group_edge_cases(seed: int) -> None: assert x_quant_large.shape == x_large.shape # FP8 max is typically 448 or 224, so scales should be > 1 assert (scales_large > 1.0).all(), "Large values should have scales > 1" - - -@pytest.mark.parametrize( - "batch_size,hidden_dim,group_size", - [ - (16, 256, 16), # Small - (64, 1024, 64), # Medium - (128, 2048, 128), # Large - (8, 513, 64), # Non-divisible (native only) - ]) -@pytest.mark.parametrize("seed", [42]) -@torch.inference_mode() -def test_quantfp8_group_various_configs(batch_size: int, hidden_dim: int, - group_size: int, seed: int) -> None: - current_platform.seed_everything(seed) - - x = torch.randn( - (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 - group_shape = GroupShape(1, group_size) - quant_op = QuantFP8(static=False, - group_shape=group_shape, - column_major_scales=False) - - expected_num_groups = (hidden_dim + group_size - 1) // group_size - - x_quant_native, scales_native = quant_op.forward_native(x.clone()) - assert x_quant_native.shape == x.shape - assert scales_native.shape == (batch_size, expected_num_groups) - - if hidden_dim % group_size == 0: - x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) - assert x_quant_cuda.shape == x.shape - assert scales_cuda.shape == (batch_size, expected_num_groups) - assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index ab71b41c3264..31182f40b48f 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -65,7 +65,12 @@ def forward_cuda( ) -> tuple[torch.Tensor, torch.Tensor]: if self.is_group_quant: assert scale is None, "Group quantization is always dynamic" - return self._quantize_group_cuda(x) + from vllm.model_executor.layers.quantization.utils import fp8_utils + return fp8_utils.per_token_group_quant_fp8( + x, + group_size=self.group_size, + column_major_scales=self.column_major_scales, + dtype=_FP8_DTYPE) assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape @@ -93,26 +98,21 @@ def forward_native( == GroupShape.PER_TOKEN and scale_ub.numel() == 1) - if self.use_per_token_if_dynamic and scale is None: - # Per-token quantization logic - x_max, _ = x.abs().max(dim=-1) - x_max = x_max.unsqueeze(-1).to(torch.float32) - if scale_ub is not None: - x_max = x_max.clamp(max=scale_ub) - scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) - - out = x.to(torch.float32) * scale.reciprocal() - out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) - else: - # Per-tensor quantization logic - if scale is None: + if scale is None: + if self.group_shape == GroupShape.PER_TOKEN: + x_max, _ = x.abs().max(dim=-1) + x_max = x_max.unsqueeze(-1).to(torch.float32) + if scale_ub is not None: + x_max = x_max.clamp(max=scale_ub) + else: x_max = x.abs().max().unsqueeze(-1).to(torch.float32) - scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) - # Even for dynamic per-token scales, - # reciprocal performs slightly better than division - out = x.to(torch.float32) * scale.reciprocal() - out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + + # Even for dynamic per-token scales, + # reciprocal performs slightly better than division + out = x.to(torch.float32) * scale.reciprocal() + out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) # This currently generates an extra Triton kernel in compilation. # Fortunately, we don't use padding if compiling. @@ -124,16 +124,6 @@ def forward_native( return out, scale - def _quantize_group_cuda( - self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) - return per_token_group_quant_fp8( - x, - group_size=self.group_size, - column_major_scales=self.column_major_scales, - dtype=_FP8_DTYPE) - def _quantize_group_native( self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: orig_shape = x.shape From dd452274ab45653c4c182de01a471522b48e4775 Mon Sep 17 00:00:00 2001 From: Tahsin Tunan Date: Mon, 15 Sep 2025 23:17:00 +0600 Subject: [PATCH 7/8] refactor: test_fp8_quant_group to avoid mypy type errors Signed-off-by: Tahsin Tunan --- .../quantization/test_fp8_quant_group.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 1377148102fc..720eee62760d 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -45,23 +45,20 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, assert x_quant_native.shape == x.shape assert scales_native.shape == (batch_size, expected_num_groups) - # 2. Test CUDA implementation (only for divisible dimensions) - x_quant_cuda = None - scales_cuda = None - if is_divisible: - x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) - assert x_quant_cuda.shape == x.shape - assert scales_cuda.shape == (batch_size, expected_num_groups) - - # 3. Test column-major scales configuration + # 2. Test column-major scales configuration quant_op_col = QuantFP8(static=False, group_shape=group_shape, column_major_scales=True) _, scales_col = quant_op_col.forward_native(x.clone()) assert scales_col.shape == (expected_num_groups, batch_size) - # 4. Verify CUDA/native consistency (when CUDA is available) + # 3. Test CUDA implementation (only for divisible dimensions) if is_divisible: + x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) + assert x_quant_cuda.shape == x.shape + assert scales_cuda.shape == (batch_size, expected_num_groups) + + # Verify CUDA/native consistency assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) # Quantized values should mostly match From ff0855ae0e2bed4e30f305aaee0e1304db8ad7f5 Mon Sep 17 00:00:00 2001 From: Tahsin Tunan Date: Tue, 16 Sep 2025 22:03:37 +0600 Subject: [PATCH 8/8] bench: add CLI args for FP8 benchmark configuration Signed-off-by: Tahsin Tunan --- .../kernels/bench_per_token_quant_fp8.py | 125 +++++++++++++----- 1 file changed, 89 insertions(+), 36 deletions(-) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 310853b9d1ba..9170361e974b 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -78,40 +78,9 @@ def calculate_diff( print("❌ Implementations differ") -hidden_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] -batch_sizes = [1, 16, 32, 64, 128] -group_shapes = [ - GroupShape.PER_TENSOR, - GroupShape.PER_TOKEN, - GroupShape(1, 64), - GroupShape(1, 128), -] -column_major_scales = [True, False] - -config_gen = itertools.product( - group_shapes, - column_major_scales, - batch_sizes, - hidden_sizes, -) - -# filter out column-major scales for non-group, reverse order -configs = list(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1])) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["hidden_size", "batch_size", "col_major", "group_shape"], - x_vals=configs, - line_arg="provider", - line_vals=["torch", "cuda", "triton"], - line_names=["Torch (Compiled)", "CUDA", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("black", "-")], - ylabel="us", - plot_name="QuantFP8 performance", - args={}, - ) -) +configs = [] + + def benchmark_quantization( batch_size, hidden_size, @@ -173,7 +142,11 @@ def geo_speedup(group: pd.DataFrame) -> pd.Series: if groupby_cols is None: result = geo_speedup(df).to_frame().T else: - result = df.groupby(groupby_cols).apply(geo_speedup).reset_index() + result = ( + df.groupby(groupby_cols) + .apply(geo_speedup, include_groups=False) + .reset_index() + ) return result @@ -186,12 +159,78 @@ def geo_speedup(group: pd.DataFrame) -> pd.Series: parser.add_argument( "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" ) + parser.add_argument( + "--hidden-sizes", + type=int, + nargs="+", + default=None, + help="Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)", + ) + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=None, + help="Batch sizes to benchmark (default: 1,16,32,64,128)", + ) + parser.add_argument( + "--group-sizes", + type=int, + nargs="+", + default=None, + help="Group sizes for GroupShape(1,N) to benchmark. " + "Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)", + ) + parser.add_argument( + "--no-column-major", + action="store_true", + help="Disable column-major scales testing", + ) args = parser.parse_args() assert args dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + hidden_sizes = args.hidden_sizes or [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] + batch_sizes = args.batch_sizes or [1, 16, 32, 64, 128] + + if args.group_sizes is not None: + group_shapes = [] + for size in args.group_sizes: + if size == 0: + group_shapes.append(GroupShape.PER_TENSOR) + elif size == -1: + group_shapes.append(GroupShape.PER_TOKEN) + else: + group_shapes.append(GroupShape(1, size)) + else: + group_shapes = [ + GroupShape.PER_TENSOR, + GroupShape.PER_TOKEN, + GroupShape(1, 64), + GroupShape(1, 128), + ] + + column_major_scales = [False] if args.no_column_major else [True, False] + + config_gen = itertools.product( + group_shapes, + column_major_scales, + batch_sizes, + hidden_sizes, + ) + + # filter out column-major scales for non-group, reverse order + configs.extend(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1])) + + print(f"Running {len(configs)} configurations:") + print(f" Hidden sizes: {hidden_sizes}") + print(f" Batch sizes: {batch_sizes}") + print(f" Group shapes: {[str(g) for g in group_shapes]}") + print(f" Column major scales: {column_major_scales}") + print() + if args.check: for group_shape in group_shapes: group_size = group_shape[1] @@ -200,7 +239,21 @@ def geo_speedup(group: pd.DataFrame) -> pd.Series: batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype ) - df = benchmark_quantization.run(print_data=True, dtype=dtype, return_df=True) + benchmark = triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden_size", "batch_size", "col_major", "group_shape"], + x_vals=configs, + line_arg="provider", + line_vals=["torch", "cuda", "triton"], + line_names=["Torch (Compiled)", "CUDA", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("black", "-")], + ylabel="us", + plot_name="QuantFP8 performance", + args={}, + ) + )(benchmark_quantization) + + df = benchmark.run(print_data=True, dtype=dtype, return_df=True) # Print geomean speedups geo_table_grouped = compute_geomean_speedups(