diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 923d678f1f2d..9170361e974b 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -2,14 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from typing import Callable +from unittest.mock import patch +import pandas as pd import torch -from vllm import _custom_ops as ops -from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.triton_utils import triton +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + + +def with_triton_mode(fn): + """Temporarily force the Triton fallback path""" + + def wrapped(*args, **kwargs): + with patch("vllm.platforms.current_platform.is_cuda", return_value=False): + return fn(*args, **kwargs) + + return wrapped # TODO(luka): use standalone_compile utility @@ -21,78 +32,236 @@ def inner(*args): return inner -torch._dynamo.config.recompile_limit = 8888 -compilation_config = CompilationConfig(custom_ops=["none"]) -with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)): - torch_per_token_quant_fp8 = torch.compile( - QuantFP8(False, GroupShape.PER_TOKEN), - fullgraph=True, - dynamic=False, # recompile for different shapes - ) +def bench_compile(fn: Callable): + # recompile for different shapes + fwd = torch.compile(fn, fullgraph=True, dynamic=False) # First dim is explicitly dynamic to simulate vLLM usage - torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0) + return with_dyn_arg(fwd, 0, 0) -def cuda_per_token_quant_fp8( - input: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - return ops.scaled_fp8_quant(input) +torch._dynamo.config.recompile_limit = 8888 -def calculate_diff(batch_size: int, seq_len: int): - """Calculate difference between Triton and CUDA implementations.""" +def calculate_diff( + batch_size: int, + hidden_size: int, + group_shape: GroupShape, + dtype: torch.dtype, +): + """Calculate the difference between Inductor and CUDA implementations.""" device = torch.device("cuda") - x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device) + x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device) + + quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False) - torch_out, torch_scale = torch_per_token_quant_fp8(x) - cuda_out, cuda_scale = cuda_per_token_quant_fp8(x) + torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x) + torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x) + cuda_out, cuda_scale = quant_fp8.forward_cuda(x) - if torch.allclose( - cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5 - ) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5): + out_allclose = lambda o1, o2: torch.allclose( + o1.to(torch.float32), + o2.to(torch.float32), + rtol=1e-3, + atol=1e-5, + ) + scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5) + + if ( + out_allclose(cuda_out, torch_out) + and scale_allclose(cuda_scale, torch_scale) + and out_allclose(cuda_out, torch_eager_out) + and scale_allclose(cuda_scale, torch_eager_scale) + ): print("✅ All implementations match") else: print("❌ Implementations differ") -batch_size_range = [1, 16, 32, 64, 128] -seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] - -configs = list(itertools.product(batch_size_range, seq_len_range)) +configs = [] -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size", "seq_len"], - x_vals=configs, - line_arg="provider", - line_vals=["torch", "cuda"], - line_names=["Torch", "CUDA"], - styles=[("blue", "-"), ("green", "-")], - ylabel="us", - plot_name="per-token-dynamic-quant-fp8-performance", - args={}, - ) -) -def benchmark_quantization(batch_size, seq_len, provider): - dtype = torch.float16 +def benchmark_quantization( + batch_size, + hidden_size, + provider, + group_shape: GroupShape, + col_major: bool, + dtype: torch.dtype, +): device = torch.device("cuda") - x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) + x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] + quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major) if provider == "torch": - fn = lambda: torch_per_token_quant_fp8(x.clone()) + fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone()) elif provider == "cuda": - fn = lambda: cuda_per_token_quant_fp8(x.clone()) + fn = lambda: quant_fp8.forward_cuda(x.clone()) + elif provider == "triton": + if not group_shape.is_per_group(): + # Triton only supported for per-group + return 0, 0, 0 + + fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone()) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) return 1000 * ms, 1000 * max_ms, 1000 * min_ms +# TODO(luka) extract to utils +def compute_geomean_speedups( + df: pd.DataFrame, + baseline_col: str, + speedup_cols: list[str], + groupby_cols: list[str] | None = None, +) -> pd.DataFrame: + """ + Compute geometric mean speedups over a baseline column. + + Args: + df: Input dataframe + baseline_col: Column to use as baseline + speedup_cols: Columns to compute speedups for + groupby_cols: Columns to group by. If None, compute over entire df. + + Returns: + pd.DataFrame with geometric mean speedups + """ + from scipy.stats import gmean + + def geo_speedup(group: pd.DataFrame) -> pd.Series: + ratios = { + col: (group[baseline_col] / group[col]).values for col in speedup_cols + } + return pd.Series({col: gmean(vals) for col, vals in ratios.items()}) + + if groupby_cols is None: + result = geo_speedup(df).to_frame().T + else: + result = ( + df.groupby(groupby_cols) + .apply(geo_speedup, include_groups=False) + .reset_index() + ) + + return result + + if __name__ == "__main__": - calculate_diff(batch_size=4, seq_len=4096) - benchmark_quantization.run(print_data=True) + parser = FlexibleArgumentParser( + description="Benchmark the various implementations of QuantFP8 (dynamic-only)" + ) + parser.add_argument("-c", "--check", action="store_true") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) + parser.add_argument( + "--hidden-sizes", + type=int, + nargs="+", + default=None, + help="Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)", + ) + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=None, + help="Batch sizes to benchmark (default: 1,16,32,64,128)", + ) + parser.add_argument( + "--group-sizes", + type=int, + nargs="+", + default=None, + help="Group sizes for GroupShape(1,N) to benchmark. " + "Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)", + ) + parser.add_argument( + "--no-column-major", + action="store_true", + help="Disable column-major scales testing", + ) + + args = parser.parse_args() + assert args + + dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + + hidden_sizes = args.hidden_sizes or [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] + batch_sizes = args.batch_sizes or [1, 16, 32, 64, 128] + + if args.group_sizes is not None: + group_shapes = [] + for size in args.group_sizes: + if size == 0: + group_shapes.append(GroupShape.PER_TENSOR) + elif size == -1: + group_shapes.append(GroupShape.PER_TOKEN) + else: + group_shapes.append(GroupShape(1, size)) + else: + group_shapes = [ + GroupShape.PER_TENSOR, + GroupShape.PER_TOKEN, + GroupShape(1, 64), + GroupShape(1, 128), + ] + + column_major_scales = [False] if args.no_column_major else [True, False] + + config_gen = itertools.product( + group_shapes, + column_major_scales, + batch_sizes, + hidden_sizes, + ) + + # filter out column-major scales for non-group, reverse order + configs.extend(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1])) + + print(f"Running {len(configs)} configurations:") + print(f" Hidden sizes: {hidden_sizes}") + print(f" Batch sizes: {batch_sizes}") + print(f" Group shapes: {[str(g) for g in group_shapes]}") + print(f" Column major scales: {column_major_scales}") + print() + + if args.check: + for group_shape in group_shapes: + group_size = group_shape[1] + print(f"{group_size=}") + calculate_diff( + batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype + ) + + benchmark = triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden_size", "batch_size", "col_major", "group_shape"], + x_vals=configs, + line_arg="provider", + line_vals=["torch", "cuda", "triton"], + line_names=["Torch (Compiled)", "CUDA", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("black", "-")], + ylabel="us", + plot_name="QuantFP8 performance", + args={}, + ) + )(benchmark_quantization) + + df = benchmark.run(print_data=True, dtype=dtype, return_df=True) + + # Print geomean speedups + geo_table_grouped = compute_geomean_speedups( + df, + baseline_col="Torch (Compiled)", + speedup_cols=["CUDA", "Triton"], + groupby_cols=["col_major", "group_shape"], + ) + + print("Speedup over Torch (Compiled)") + print(geo_table_grouped.to_string(index=False)) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py new file mode 100644 index 000000000000..720eee62760d --- /dev/null +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for QuantFP8 Group Quantization implementation.""" + +import pytest +import torch + +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.platforms import current_platform + + +@pytest.mark.parametrize( + "batch_size,hidden_dim,group_size", + [ + (16, 256, 32), # Small + (64, 1024, 64), # Medium + (128, 2048, 128), # Large + (8, 513, 64), # Non-divisible (native only) + ]) +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, + group_size: int, seed: int) -> None: + """Test QuantFP8 group quantization with various configurations. + + Tests both CUDA and native implementations, column-major scales, + and verifies consistency between implementations. + """ + current_platform.seed_everything(seed) + + x = torch.randn( + (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + expected_num_groups = (hidden_dim + group_size - 1) // group_size + is_divisible = hidden_dim % group_size == 0 + + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=False) + + # 1. Test native implementation (always available) + x_quant_native, scales_native = quant_op.forward_native(x.clone()) + assert x_quant_native.shape == x.shape + assert scales_native.shape == (batch_size, expected_num_groups) + + # 2. Test column-major scales configuration + quant_op_col = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=True) + _, scales_col = quant_op_col.forward_native(x.clone()) + assert scales_col.shape == (expected_num_groups, batch_size) + + # 3. Test CUDA implementation (only for divisible dimensions) + if is_divisible: + x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) + assert x_quant_cuda.shape == x.shape + assert scales_cuda.shape == (batch_size, expected_num_groups) + + # Verify CUDA/native consistency + assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) + + # Quantized values should mostly match + diff_count = (x_quant_cuda != x_quant_native).sum().item() + diff_ratio = diff_count / x_quant_cuda.numel() + assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}" + + +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_multidimensional(seed: int) -> None: + current_platform.seed_everything(seed) + + group_size = 64 + + # Test with 3D input + batch1, batch2, hidden_dim = 4, 8, 512 + x_3d = torch.randn( + (batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=False) + + x_quant, scales = quant_op.forward_native(x_3d.clone()) + assert x_quant.shape == x_3d.shape + assert scales.shape == (batch1, batch2, hidden_dim // group_size) + + # Test column_major_scales with multi-dim + quant_op_col = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=True) + _, scales_col = quant_op_col.forward_native(x_3d.clone()) + assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) + + # Test with 4D input + batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256 + x_4d = torch.randn((batch1, batch2, batch3, hidden_dim), + dtype=torch.bfloat16, + device="cuda") * 8 + + x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone()) + assert x_quant_4d.shape == x_4d.shape + assert scales_4d.shape == (batch1, batch2, batch3, + hidden_dim // group_size) + + _, scales_4d_col = quant_op_col.forward_native(x_4d.clone()) + assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, + batch3) + + +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_edge_cases(seed: int) -> None: + current_platform.seed_everything(seed) + + batch_size = 16 + group_size = 64 + + # Test with single group (group_size >= hidden_dim) + x_small = torch.randn( + (batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8 + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=False) + + x_quant_small, scales_small = quant_op.forward_native(x_small.clone()) + assert x_quant_small.shape == x_small.shape + assert scales_small.shape == (batch_size, 1) + + # Test with zero inputs + x_zero = torch.zeros((batch_size, 256), + dtype=torch.bfloat16, + device="cuda") + x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone()) + assert x_quant_zero.shape == x_zero.shape + assert (scales_zero > 0).all(), "Scales should be clamped to minimum" + + # Test very large values + x_large = torch.full((batch_size, 256), + 1000.0, + dtype=torch.bfloat16, + device="cuda") + x_quant_large, scales_large = quant_op.forward_native(x_large.clone()) + assert x_quant_large.shape == x_large.shape + # FP8 max is typically 448 or 224, so scales should be > 1 + assert (scales_large > 1.0).all(), "Large values should have scales > 1" diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index eb3e14180ecf..42d75aa3f1eb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -32,9 +32,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8) + _resize_cache, moe_kernel_quantize_input) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( dequant_mxfp4) from vllm.platforms import current_platform diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index e1a9bdde9334..31182f40b48f 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -23,28 +23,39 @@ @CustomOp.register("quant_fp8") class QuantFP8(CustomOp): """ - Quantize input tensor to per-tensor or per-token FP8. + Quantize input tensor to FP8 (per-tensor, per-token, or per-group). This CustomOp supports both static and dynamic quantization. """ def __init__(self, static: bool, group_shape: GroupShape, - num_token_padding: Optional[int] = None): + num_token_padding: Optional[int] = None, + column_major_scales: bool = False): """ - :param static: static or dynamic quantization - :param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR) - :param num_token_padding: Pad the token dimension of output to this size + :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, + or arbitrary block size) + :param num_token_padding: Pad the token dimension of output to this + size + :param column_major_scales: For group quantization, output scales in + column major format """ super().__init__() - self.num_token_padding = num_token_padding - assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} - assert not static or group_shape == GroupShape.PER_TENSOR, \ - "Only per-tensor scales supported for static quantization." self.static = static self.group_shape = group_shape - self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN + self.num_token_padding = num_token_padding + self.column_major_scales = column_major_scales + + self.is_group_quant = group_shape.is_per_group() + if self.is_group_quant: + assert not static, "Group quantization only supports dynamic mode" + self.group_size = group_shape.col + else: + assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} + assert not static or group_shape == GroupShape.PER_TENSOR, \ + "Only per-tensor scales supported for static quantization." + self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN def forward_cuda( self, @@ -52,11 +63,19 @@ def forward_cuda( scale: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if self.is_group_quant: + assert scale is None, "Group quantization is always dynamic" + from vllm.model_executor.layers.quantization.utils import fp8_utils + return fp8_utils.per_token_group_quant_fp8( + x, + group_size=self.group_size, + column_major_scales=self.column_major_scales, + dtype=_FP8_DTYPE) + assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape == GroupShape.PER_TOKEN and scale_ub.numel() == 1) - return ops.scaled_fp8_quant( x, scale, @@ -70,6 +89,10 @@ def forward_native( scale: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None, ): + if self.is_group_quant: + assert scale is None, "Group quantization is always dynamic" + return self._quantize_group_native(x) + assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape == GroupShape.PER_TOKEN @@ -84,8 +107,7 @@ def forward_native( else: x_max = x.abs().max().unsqueeze(-1).to(torch.float32) - scale = x_max / _FP8_MAX - scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR) + scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) # Even for dynamic per-token scales, # reciprocal performs slightly better than division @@ -101,3 +123,34 @@ def forward_native( out = F.pad(out, (0, 0, 0, padding), "constant", 0.0) return out, scale + + def _quantize_group_native( + self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + orig_shape = x.shape + hidden_dim = x.shape[-1] + num_groups = (hidden_dim + self.group_size - 1) // self.group_size + padded_dim = num_groups * self.group_size + + if padded_dim != hidden_dim: + padding = padded_dim - hidden_dim + x = F.pad(x, (0, padding), mode='constant', value=0.0) + + x_grouped = x.view(-1, num_groups, self.group_size) + absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() + scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + + x_scaled = x_grouped / scales + x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + + x_quant = x_quant.view(-1, padded_dim) + if padded_dim != hidden_dim: + x_quant = x_quant[..., :hidden_dim] + x_quant = x_quant.view(orig_shape) + + scales = scales.squeeze(-1) + scales = scales.reshape(orig_shape[:-1] + (num_groups, )) + + if self.column_major_scales: + scales = scales.transpose(-2, -1).contiguous() + + return x_quant, scales diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index f4ff875adb21..5339c6043cc1 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -34,6 +34,15 @@ class GroupShape(_GroupShape): PER_TENSOR: ClassVar['GroupShape'] PER_TOKEN: ClassVar['GroupShape'] + def is_per_tensor(self) -> bool: + return self.row == -1 and self.col == -1 + + def is_per_token(self) -> bool: + return self.row == 1 and self.col == -1 + + def is_per_group(self) -> bool: + return self.row == 1 and self.col >= 1 + GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1)