From c9ca1024a13246b488008f6de29d8b05930cbf48 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 11 Sep 2025 13:38:11 +0000 Subject: [PATCH 01/39] Move apply_w8a8_block_fp8_linear to an op class Signed-off-by: ElizaWszola --- .../model_executor/layers/quantization/fp8.py | 12 +- .../layers/quantization/utils/fp8_utils.py | 167 +++++++++--------- 2 files changed, 87 insertions(+), 92 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 65e0b7062153..6b38eb97ad0a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -30,7 +30,8 @@ register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) + W8A8BlockFp8LinearOp, get_col_major_tma_aligned_tensor, + requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin) @@ -249,6 +250,11 @@ def __init__(self, quant_config: Fp8Config): act_quant_static=self.act_q_static, act_quant_group_shape=self.act_q_group_shape) + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + self.cutlass_block_fp8_supported, + self.use_aiter_and_is_supported, + ) + def create_weights( self, layer: torch.nn.Module, @@ -480,15 +486,13 @@ def apply(self, if self.block_quant: assert self.quant_config.weight_block_size is not None - return torch.ops.vllm.apply_w8a8_block_fp8_linear( + return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, block_size=self.quant_config.weight_block_size, weight_scale=layer.weight_scale_inv, input_scale=layer.input_scale, bias=bias, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) return self.fp8_linear.apply(input=x, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7b324dce3c36..394b46d50153 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -109,108 +109,99 @@ def dispatch_w8a8_blockscale_func( return w8a8_block_fp8_matmul -# TODO fix ROCm->Triton custom path: -# https://github.com/vllm-project/vllm/issues/14397 -def apply_w8a8_block_fp8_linear( - input: torch.Tensor, - weight: torch.Tensor, - block_size: list[int], - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, -) -> torch.Tensor: - assert input_scale is None - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[0]] - output_dtype = input.dtype - - if should_use_deepgemm_for_fp8_linear(output_dtype, weight): +class W8A8BlockFp8LinearOp: + """ + This class executes a Blocked FP8 linear layer using cutlass if supported and + torch.scaled_mm otherwise. + """ + # TODO where to put + # cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + # use_aiter_and_is_supported: bool = False, + def __init__( + self, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, + ): + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported + self.use_aiter_and_is_supported = use_aiter_and_is_supported + + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + block_size: list[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] + output_dtype = input.dtype - q_input, x_scale = per_token_group_quant_fp8( - input_2d, - block_size[1], - column_major_scales=True, - ) - - # ensure DeepGEMM-backed custom op is registered before use - import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 + if should_use_deepgemm_for_fp8_linear(output_dtype, weight): - output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( - q_input, - weight, - x_scale, - weight_scale, - block_size, - output_dtype=output_dtype) - if bias is not None: - output += bias - return output.to(dtype=output_dtype).view(*output_shape) + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] - if current_platform.is_cuda(): - if current_platform.has_device_capability(100): + q_input, x_scale = per_token_group_quant_fp8( + input_2d, + block_size[1], + column_major_scales=True, + ) - use_cutlass = cutlass_block_fp8_supported and ( - cdiv(weight.shape[0], 128) == weight_scale.shape[0] - and cdiv(weight.shape[1], 128) == weight_scale.shape[1]) + # ensure DeepGEMM-backed custom op is registered before use + import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 + + output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( + q_input, + weight, + x_scale, + weight_scale, + block_size, + output_dtype=output_dtype) + if bias is not None: + output += bias + return output.to(dtype=output_dtype).view(*output_shape) + + if current_platform.is_cuda(): + if current_platform.has_device_capability(100): + + use_cutlass = self.cutlass_block_fp8_supported and ( + cdiv(weight.shape[0], 128) == weight_scale.shape[0] + and cdiv(weight.shape[1], 128) == weight_scale.shape[1]) + else: + # TODO: update this after switching to public sm90 block scale gemm + # as it also supports weight.shape % 128 != 0 + use_cutlass = self.cutlass_block_fp8_supported and ( + weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) else: - # TODO: update this after switching to public sm90 block scale gemm - # as it also supports weight.shape % 128 != 0 - use_cutlass = cutlass_block_fp8_supported and ( - weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) - else: - use_cutlass = False + use_cutlass = False - w8a8_blockscale_func = dispatch_w8a8_blockscale_func( - use_cutlass, use_aiter_and_is_supported) - if use_cutlass: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass) - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) - - else: - if use_aiter_and_is_supported: - q_input, x_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) - else: + w8a8_blockscale_func = dispatch_w8a8_blockscale_func( + use_cutlass, self.use_aiter_and_is_supported) + if use_cutlass: q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass) + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) - - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) - - -def apply_w8a8_block_fp8_linear_fake( - input: torch.Tensor, - weight: torch.Tensor, - block_size: list[int], - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, -) -> torch.Tensor: - output_shape = [*input.shape[:-1], weight.shape[0]] - return torch.empty(output_shape, dtype=input.dtype, device=input.device) + else: + if self.use_aiter_and_is_supported: + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=use_cutlass) + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) -if not current_platform.is_cpu(): - direct_register_custom_op( - op_name="apply_w8a8_block_fp8_linear", - op_func=apply_w8a8_block_fp8_linear, - mutates_args=[], - fake_impl=apply_w8a8_block_fp8_linear_fake, - ) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) def input_to_float8( From eef43494adb2062b8e5634e0583db81d73096a5c Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 11 Sep 2025 13:43:59 +0000 Subject: [PATCH 02/39] Remove TODO, bring back old one Signed-off-by: ElizaWszola --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 394b46d50153..6bd703aad339 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -109,15 +109,14 @@ def dispatch_w8a8_blockscale_func( return w8a8_block_fp8_matmul +# TODO fix ROCm->Triton custom path: +# https://github.com/vllm-project/vllm/issues/14397 class W8A8BlockFp8LinearOp: """ This class executes a Blocked FP8 linear layer using cutlass if supported and torch.scaled_mm otherwise. """ - # TODO where to put - # cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - # use_aiter_and_is_supported: bool = False, def __init__( self, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, From dd53183e381037abbe36cf64daa0853ce9d10b7e Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 11 Sep 2025 15:02:43 +0000 Subject: [PATCH 03/39] CUDA graphs fix Signed-off-by: ElizaWszola --- .../model_executor/layers/quantization/fp8.py | 13 ++++++++--- .../layers/quantization/utils/fp8_utils.py | 23 +++++++++++++++---- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 6b38eb97ad0a..a0f5596e11c4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -49,7 +49,9 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported +from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, + is_deep_gemm_supported, + should_use_deepgemm_for_fp8_linear) from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: @@ -251,8 +253,10 @@ def __init__(self, quant_config: Fp8Config): act_quant_group_shape=self.act_q_group_shape) self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - self.cutlass_block_fp8_supported, - self.use_aiter_and_is_supported, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ue8m0_deepgemm_supported=is_deep_gemm_e8m0_used(), + is_blackwell=current_platform.has_device_capability(100), ) def create_weights( @@ -365,6 +369,9 @@ def create_weights( else: layer.register_parameter("input_scale", None) + self.w8a8_block_fp8_linear.set_should_use_deepgemm( + should_use_deepgemm_for_fp8_linear(self.out_dtype, weight)) + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which # can benefit from tensors located far enough from one another in memory diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 6bd703aad339..2e6a3157a1a1 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -121,9 +121,20 @@ def __init__( self, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, use_aiter_and_is_supported: bool = False, + ue8m0_deepgemm_supported: bool = False, + is_blackwell: bool = False, ): self.cutlass_block_fp8_supported = cutlass_block_fp8_supported self.use_aiter_and_is_supported = use_aiter_and_is_supported + self.ue8m0_deepgemm_supported = ue8m0_deepgemm_supported + self.is_blackwell = is_blackwell + self.should_use_deepgemm = False + + def set_should_use_deepgemm( + self, + should_use_deepgemm: bool, + ): + self.should_use_deepgemm = should_use_deepgemm def apply( self, @@ -140,7 +151,7 @@ def apply( output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_deepgemm_for_fp8_linear(output_dtype, weight): + if self.should_use_deepgemm: input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] @@ -149,6 +160,7 @@ def apply( input_2d, block_size[1], column_major_scales=True, + use_ue8m0=self.ue8m0_deepgemm_supported, ) # ensure DeepGEMM-backed custom op is registered before use @@ -166,8 +178,7 @@ def apply( return output.to(dtype=output_dtype).view(*output_shape) if current_platform.is_cuda(): - if current_platform.has_device_capability(100): - + if self.is_blackwell: use_cutlass = self.cutlass_block_fp8_supported and ( cdiv(weight.shape[0], 128) == weight_scale.shape[0] and cdiv(weight.shape[1], 128) == weight_scale.shape[1]) @@ -183,7 +194,8 @@ def apply( use_cutlass, self.use_aiter_and_is_supported) if use_cutlass: q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass) + input_2d, block_size[1], column_major_scales=use_cutlass, + use_ue8m0=self.ue8m0_deepgemm_supported) output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, block_size, input.dtype) @@ -193,7 +205,8 @@ def apply( input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) else: q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass) + input_2d, block_size[1], column_major_scales=use_cutlass, + use_ue8m0=self.ue8m0_deepgemm_supported) output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, block_size, input.dtype) From bb248819ce5e0b84b147d97adae552a82965a858 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 11 Sep 2025 15:13:27 +0000 Subject: [PATCH 04/39] Clean up Signed-off-by: ElizaWszola --- vllm/model_executor/layers/quantization/fp8.py | 7 ++----- .../layers/quantization/utils/fp8_utils.py | 12 ++++-------- vllm/utils/deep_gemm.py | 5 +++-- 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a0f5596e11c4..69a78119a0e0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -50,8 +50,7 @@ from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, - is_deep_gemm_supported, - should_use_deepgemm_for_fp8_linear) + is_deep_gemm_supported) from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: @@ -255,6 +254,7 @@ def __init__(self, quant_config: Fp8Config): self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, + is_deep_gemm_supported=is_deep_gemm_supported(), ue8m0_deepgemm_supported=is_deep_gemm_e8m0_used(), is_blackwell=current_platform.has_device_capability(100), ) @@ -369,9 +369,6 @@ def create_weights( else: layer.register_parameter("input_scale", None) - self.w8a8_block_fp8_linear.set_should_use_deepgemm( - should_use_deepgemm_for_fp8_linear(self.out_dtype, weight)) - def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which # can benefit from tensors located far enough from one another in memory diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 2e6a3157a1a1..9ea1ac71b667 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -121,20 +121,15 @@ def __init__( self, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, use_aiter_and_is_supported: bool = False, + is_deep_gemm_supported: bool = False, ue8m0_deepgemm_supported: bool = False, is_blackwell: bool = False, ): self.cutlass_block_fp8_supported = cutlass_block_fp8_supported self.use_aiter_and_is_supported = use_aiter_and_is_supported + self.is_deep_gemm_supported = is_deep_gemm_supported self.ue8m0_deepgemm_supported = ue8m0_deepgemm_supported self.is_blackwell = is_blackwell - self.should_use_deepgemm = False - - def set_should_use_deepgemm( - self, - should_use_deepgemm: bool, - ): - self.should_use_deepgemm = should_use_deepgemm def apply( self, @@ -151,7 +146,8 @@ def apply( output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if self.should_use_deepgemm: + if should_use_deepgemm_for_fp8_linear(self.is_deep_gemm_supported, + output_dtype, weight): input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 90cdd396209c..3a5c6af5eaee 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -194,9 +194,10 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return 1 - sim -def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, +def should_use_deepgemm_for_fp8_linear(is_deep_gemm_supported: bool, + output_dtype: torch.dtype, weight: torch.Tensor): - return (is_deep_gemm_supported() and output_dtype == torch.bfloat16 + return (is_deep_gemm_supported and output_dtype == torch.bfloat16 and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) From 1ba47cd51b821b26022f6857f9bb2a4531206632 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 11 Sep 2025 15:39:32 +0000 Subject: [PATCH 05/39] Create linear op objects conditionally, move some arch checks to blocked op constructor Signed-off-by: ElizaWszola --- .../model_executor/layers/quantization/fp8.py | 20 +++++++++---------- .../layers/quantization/utils/fp8_utils.py | 14 ++++++------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 69a78119a0e0..24d002b689c9 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -247,17 +247,15 @@ def __init__(self, quant_config: Fp8Config): else: self.act_q_group_shape = GroupShape.PER_TENSOR - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape) - - self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported, - is_deep_gemm_supported=is_deep_gemm_supported(), - ue8m0_deepgemm_supported=is_deep_gemm_e8m0_used(), - is_blackwell=current_platform.has_device_capability(100), - ) + if self.block_quant: + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + else: + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.act_q_static, + act_quant_group_shape=self.act_q_group_shape) def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9ea1ac71b667..6103fe0cb9bb 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -21,6 +21,7 @@ from vllm.triton_utils import tl, triton from vllm.utils import cdiv, direct_register_custom_op from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, + is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -113,23 +114,20 @@ def dispatch_w8a8_blockscale_func( # https://github.com/vllm-project/vllm/issues/14397 class W8A8BlockFp8LinearOp: """ - This class executes a Blocked FP8 linear layer using cutlass if supported and - torch.scaled_mm otherwise. + This class executes a Blocked FP8 linear layer using cutlass if supported + and torch.scaled_mm otherwise. """ def __init__( self, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, use_aiter_and_is_supported: bool = False, - is_deep_gemm_supported: bool = False, - ue8m0_deepgemm_supported: bool = False, - is_blackwell: bool = False, ): self.cutlass_block_fp8_supported = cutlass_block_fp8_supported self.use_aiter_and_is_supported = use_aiter_and_is_supported - self.is_deep_gemm_supported = is_deep_gemm_supported - self.ue8m0_deepgemm_supported = ue8m0_deepgemm_supported - self.is_blackwell = is_blackwell + self.is_deep_gemm_supported = is_deep_gemm_supported() + self.ue8m0_deepgemm_supported = is_deep_gemm_e8m0_used() + self.is_blackwell = current_platform.has_device_capability(100) def apply( self, From 02793b96943589e5e8bc0c65f61d158ef97ddae7 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 11 Sep 2025 17:00:18 +0000 Subject: [PATCH 06/39] format Signed-off-by: ElizaWszola --- .../model_executor/layers/quantization/fp8.py | 7 +++--- .../layers/quantization/utils/fp8_utils.py | 22 ++++++++++++------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 24d002b689c9..f1924100fbf8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -49,8 +49,7 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: @@ -254,8 +253,8 @@ def __init__(self, quant_config: Fp8Config): ) else: self.fp8_linear = Fp8LinearOp( - act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape) + act_quant_static=self.act_q_static, + act_quant_group_shape=self.act_q_group_shape) def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 6103fe0cb9bb..9e7662dcb066 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -177,8 +177,8 @@ def apply( cdiv(weight.shape[0], 128) == weight_scale.shape[0] and cdiv(weight.shape[1], 128) == weight_scale.shape[1]) else: - # TODO: update this after switching to public sm90 block scale gemm - # as it also supports weight.shape % 128 != 0 + # TODO: update this after switching to public sm90 block scale + # gemm as it also supports weight.shape % 128 != 0 use_cutlass = self.cutlass_block_fp8_supported and ( weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) else: @@ -188,10 +188,13 @@ def apply( use_cutlass, self.use_aiter_and_is_supported) if use_cutlass: q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass, + input_2d, + block_size[1], + column_major_scales=use_cutlass, use_ue8m0=self.ue8m0_deepgemm_supported) - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) + output = w8a8_blockscale_func(q_input, weight, x_scale, + weight_scale, block_size, + input.dtype) else: if self.use_aiter_and_is_supported: @@ -199,11 +202,14 @@ def apply( input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) else: q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass, + input_2d, + block_size[1], + column_major_scales=use_cutlass, use_ue8m0=self.ue8m0_deepgemm_supported) - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) + output = w8a8_blockscale_func(q_input, weight, x_scale, + weight_scale, block_size, + input.dtype) if bias is not None: output = output + bias From b72c9f2d2290d6e72e3413ec3ac91a0758541f03 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 12 Sep 2025 04:48:23 +0000 Subject: [PATCH 07/39] clean up repetitive code Signed-off-by: ElizaWszola --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9e7662dcb066..f5ce4ce1ac77 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -146,10 +146,6 @@ def apply( if should_use_deepgemm_for_fp8_linear(self.is_deep_gemm_supported, output_dtype, weight): - - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[0]] - q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], From d51f35c1a948329fee1d034b89851d650011a4fd Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 12 Sep 2025 12:40:11 +0000 Subject: [PATCH 08/39] More aggressive dispatch of blockscale ops Signed-off-by: ElizaWszola --- .../layers/quantization/utils/fp8_utils.py | 131 ++++++++++++------ 1 file changed, 87 insertions(+), 44 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index f5ce4ce1ac77..5307e73d47fd 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -93,7 +93,7 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) -def dispatch_w8a8_blockscale_func( +def dispatch_w8a8_blockscale_op( use_cutlass: bool, use_aiter_and_is_supported: bool ) -> Callable[[ torch.Tensor, @@ -146,26 +146,12 @@ def apply( if should_use_deepgemm_for_fp8_linear(self.is_deep_gemm_supported, output_dtype, weight): - q_input, x_scale = per_token_group_quant_fp8( - input_2d, - block_size[1], - column_major_scales=True, - use_ue8m0=self.ue8m0_deepgemm_supported, - ) - - # ensure DeepGEMM-backed custom op is registered before use - import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 + output = self._run_deepgemm(input, weight, block_size, + weight_scale, input_scale, bias) - output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( - q_input, - weight, - x_scale, - weight_scale, - block_size, - output_dtype=output_dtype) if bias is not None: - output += bias - return output.to(dtype=output_dtype).view(*output_shape) + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) if current_platform.is_cuda(): if self.is_blackwell: @@ -180,37 +166,94 @@ def apply( else: use_cutlass = False - w8a8_blockscale_func = dispatch_w8a8_blockscale_func( + w8a8_blockscale_op = self._dispatch_w8a8_blockscale_op( use_cutlass, self.use_aiter_and_is_supported) - if use_cutlass: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, - block_size[1], - column_major_scales=use_cutlass, - use_ue8m0=self.ue8m0_deepgemm_supported) - output = w8a8_blockscale_func(q_input, weight, x_scale, - weight_scale, block_size, - input.dtype) - - else: - if self.use_aiter_and_is_supported: - q_input, x_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) - else: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, - block_size[1], - column_major_scales=use_cutlass, - use_ue8m0=self.ue8m0_deepgemm_supported) - - output = w8a8_blockscale_func(q_input, weight, x_scale, - weight_scale, block_size, - input.dtype) + output = w8a8_blockscale_op(input_2d, weight, block_size, weight_scale) if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) + def _run_deepgemm( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + block_size: list[int], + weight_scale: torch.Tensor, + ) -> torch.Tensor: + # ensure DeepGEMM-backed custom op is registered before use + import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 + + q_input, x_scale = per_token_group_quant_fp8( + input_2d, + block_size[1], + column_major_scales=True, + use_ue8m0=self.ue8m0_deepgemm_supported, + ) + return torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( + q_input, + weight, + x_scale, + weight_scale, + block_size, + output_dtype=input_2d.dtype) + + def _run_cutlass( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + block_size: list[int], + weight_scale: torch.Tensor, + ) -> torch.Tensor: + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=True, + use_ue8m0=False) + return cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, + block_size, input_2d.dtype) + + def _run_aiter( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + block_size: list[int], + weight_scale: torch.Tensor, + ) -> torch.Tensor: + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( + q_input, weight, x_scale, weight_scale, block_size, input_2d.dtype) + + def _w8a8_block_fp8_matmul( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + block_size: list[int], + weight_scale: torch.Tensor, + ) -> torch.Tensor: + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=False, + use_ue8m0=False) + return w8a8_block_fp8_matmul(q_input, weight, x_scale, weight_scale, + block_size, input_2d.dtype) + + def _dispatch_w8a8_blockscale_op( + self, use_cutlass: bool, use_aiter_and_is_supported: bool + ) -> Callable[[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + list[int], + torch.dtype, + ], torch.Tensor]: + if use_cutlass: + return self._run_cutlass + if (use_aiter_and_is_supported): + return self._run_aiter + return self._run_w8a8_block_fp8_matmul + def input_to_float8( x: torch.Tensor, From a6ae6893988db66871ae076deb145a30dbe0d5aa Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 12 Sep 2025 12:57:48 +0000 Subject: [PATCH 09/39] fix Signed-off-by: ElizaWszola --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 5307e73d47fd..60cda6e8d3b6 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -93,7 +93,7 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) -def dispatch_w8a8_blockscale_op( +def dispatch_w8a8_blockscale_func( use_cutlass: bool, use_aiter_and_is_supported: bool ) -> Callable[[ torch.Tensor, From 3238ff687c908ce43fe9e39cf0ec2b7d15faeca2 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 12 Sep 2025 09:51:38 -0400 Subject: [PATCH 10/39] Deep_gemm fix Signed-off-by: ElizaWszola --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 60cda6e8d3b6..a8596c68a150 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -147,7 +147,7 @@ def apply( if should_use_deepgemm_for_fp8_linear(self.is_deep_gemm_supported, output_dtype, weight): output = self._run_deepgemm(input, weight, block_size, - weight_scale, input_scale, bias) + weight_scale) if bias is not None: output = output + bias From 9b09b60f2967158fab4cc4d5c3f5b9cf7db98182 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 12 Sep 2025 10:52:23 -0400 Subject: [PATCH 11/39] Post-merge fixes, better dispatch Signed-off-by: ElizaWszola --- .../layers/quantization/utils/fp8_utils.py | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ce03edb69ff6..5e35d833aeb4 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -40,15 +40,17 @@ def cutlass_scaled_mm( Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, + is_hopper: Optional[bool] = None, ) -> torch.Tensor: + if is_hopper is None: + is_hopper = current_platform.is_device_capability(90) return ops.cutlass_scaled_mm( A, B.T, out_dtype=output_dtype, scale_a=As, # SM90 block FP8 requires row-major scale_b, which we do ahead of time - scale_b=Bs if block_size is not None - and current_platform.is_device_capability(90) else Bs.T) + scale_b=Bs if block_size is not None and is_hopper else Bs.T) def rocm_aiter_gemm_w8a8_blockscale_impl( @@ -126,13 +128,11 @@ def __init__( cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, use_aiter_and_is_supported: bool = False, ): - self.cutlass_block_fp8_supported = cutlass_block_fp8_supported - self.use_aiter_and_is_supported = use_aiter_and_is_supported self.is_deep_gemm_supported = is_deep_gemm_supported() self.ue8m0_deepgemm_supported = is_deep_gemm_e8m0_used() - self.is_hopper = current_platform.has_device_capability(90) + self.is_hopper = current_platform.is_device_capability(90) self.w8a8_blockscale_op = self._dispatch_w8a8_blockscale_op( - self.cutlass_block_fp8_supported, self.use_aiter_and_is_supported) + cutlass_block_fp8_supported, use_aiter_and_is_supported) def apply( self, @@ -157,22 +157,8 @@ def apply( output = output + bias return output.to(dtype=input.dtype).view(*output_shape) - num_pad = 0 - if cutlass_block_fp8_supported and self.is_hopper: - # pad first dimension to be divisible by 4 due to - # cutlass blockwise gemm limitation for hopper - num_pad = 4 - (input_2d.shape[0] % 4) - if num_pad > 0: - input_2d = torch.nn.functional.pad(input_2d, - (0, 0, 0, num_pad), - "constant", 0) - output = self.w8a8_blockscale_op(input_2d, weight, block_size, weight_scale) - - if num_pad > 0: - output = output[:-num_pad] - if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) @@ -209,7 +195,7 @@ def _run_cutlass( weight_scale: torch.Tensor, ) -> torch.Tensor: num_pad = 0 - if current_platform.is_device_capability(90): + if self.is_hopper: # pad first dimension to be divisible by 4 due to # cutlass blockwise gemm limitation for hopper num_pad = 4 - (input_2d.shape[0] % 4) @@ -221,8 +207,11 @@ def _run_cutlass( block_size[1], column_major_scales=True, use_ue8m0=False) - return cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, - block_size, input_2d.dtype) + output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, + block_size, input_2d.dtype, self.is_hopper) + if num_pad > 0: + output = output[:-num_pad] + return output def _run_aiter( self, @@ -236,7 +225,7 @@ def _run_aiter( return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( q_input, weight, x_scale, weight_scale, block_size, input_2d.dtype) - def _w8a8_block_fp8_matmul( + def _run_w8a8_block_fp8_matmul( self, input_2d: torch.Tensor, weight: torch.Tensor, @@ -250,6 +239,24 @@ def _w8a8_block_fp8_matmul( return w8a8_block_fp8_matmul(q_input, weight, x_scale, weight_scale, block_size, input_2d.dtype) + def _dispatch_w8a8_blockscale_op( + self, + use_cutlass: bool, + use_aiter_and_is_supported: bool, + ) -> Callable[[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + list[int], + torch.dtype, + ], torch.Tensor]: + if use_cutlass: + return self._run_cutlass + if use_aiter_and_is_supported: + return self._run_aiter + return self._run_w8a8_block_fp8_matmul + def input_to_float8( x: torch.Tensor, From e6b0028764a0bfd32e1833668c0cd9373f1f36fd Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 12 Sep 2025 11:25:04 -0400 Subject: [PATCH 12/39] small fixes Signed-off-by: ElizaWszola --- vllm/model_executor/layers/quantization/fp8.py | 2 +- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8ed98296d639..0d38ce00ed13 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -473,7 +473,7 @@ def process_weights_after_loading(self, layer: Module) -> None: if (self.block_quant and current_platform.is_device_capability(90) and self.cutlass_block_fp8_supported and not should_use_deepgemm_for_fp8_linear( - torch.bfloat16, layer.weight)): + is_deep_gemm_supported(), torch.bfloat16, layer.weight)): layer.weight_scale_inv = Parameter( layer.weight_scale_inv.data.T.contiguous(), requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 5e35d833aeb4..f4b12bde1f8a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -148,7 +148,6 @@ def apply( input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_deepgemm_for_fp8_linear(self.is_deep_gemm_supported, output_dtype, weight): output = self._run_deepgemm(input, weight, block_size, @@ -244,12 +243,10 @@ def _dispatch_w8a8_blockscale_op( use_cutlass: bool, use_aiter_and_is_supported: bool, ) -> Callable[[ - torch.Tensor, - torch.Tensor, torch.Tensor, torch.Tensor, list[int], - torch.dtype, + torch.Tensor, ], torch.Tensor]: if use_cutlass: return self._run_cutlass From ef6f1e2d3b75ee66a020ba525a1ff9f202448206 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 17 Sep 2025 12:02:57 +0000 Subject: [PATCH 13/39] Fix cutlass compilation issue on Hopper Signed-off-by: ElizaWszola --- .../layers/quantization/utils/fp8_utils.py | 76 +++++++++++++++---- 1 file changed, 61 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index f4b12bde1f8a..4c2fab07addb 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -98,6 +98,56 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) +def _padded_cutlass( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + pad_multiple = 4 + dim = qx.shape[0] + padded = dim if dim % pad_multiple == 0 else dim + pad_multiple - ( + dim % pad_multiple) + + padded_shape = [padded, *qx.shape[1:]] + padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype) + padded_qx[0:qx.shape[0], ...].copy_(qx) + + padded_x_scale_shape = [*x_scale.shape[1:], padded] + padded_x_scale = torch.ones(padded_x_scale_shape, + device=x_scale.device, + dtype=x_scale.dtype).permute(-1, -2) + padded_x_scale[0:x_scale.shape[0], ...].copy_(x_scale) + + output = cutlass_scaled_mm(padded_qx, weight, padded_x_scale, weight_scale, + block_size, output_dtype, True) + return output[0:qx.shape[0], ...] + + +def _padded_cutlass_fake( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty((qx.size(0), weight.size(0)), + dtype=output_dtype, + device=qx.device) + + +direct_register_custom_op( + "padded_cutlass", + _padded_cutlass, + mutates_args=[], + fake_impl=_padded_cutlass_fake, + dispatch_key="CUDA", +) + + def dispatch_w8a8_blockscale_func( use_cutlass: bool, use_aiter_and_is_supported: bool ) -> Callable[[ @@ -148,6 +198,7 @@ def apply( input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype + if should_use_deepgemm_for_fp8_linear(self.is_deep_gemm_supported, output_dtype, weight): output = self._run_deepgemm(input, weight, block_size, @@ -193,23 +244,17 @@ def _run_cutlass( block_size: list[int], weight_scale: torch.Tensor, ) -> torch.Tensor: - num_pad = 0 - if self.is_hopper: - # pad first dimension to be divisible by 4 due to - # cutlass blockwise gemm limitation for hopper - num_pad = 4 - (input_2d.shape[0] % 4) - if num_pad > 0: - input_2d = torch.nn.functional.pad(input_2d, - (0, 0, 0, num_pad), - "constant", 0) q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1], column_major_scales=True, use_ue8m0=False) - output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, - block_size, input_2d.dtype, self.is_hopper) - if num_pad > 0: - output = output[:-num_pad] + if self.is_hopper: + output = torch.ops.vllm.padded_cutlass(q_input, weight, x_scale, + weight_scale, block_size, + input_2d.dtype) + else: + output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, + block_size, input_2d.dtype, False) return output def _run_aiter( @@ -235,8 +280,9 @@ def _run_w8a8_block_fp8_matmul( block_size[1], column_major_scales=False, use_ue8m0=False) - return w8a8_block_fp8_matmul(q_input, weight, x_scale, weight_scale, - block_size, input_2d.dtype) + return w8a8_block_fp8_matmul(q_input, weight, x_scale, + weight_scale.t(), block_size, + input_2d.dtype) def _dispatch_w8a8_blockscale_op( self, From 77335de6cc466971865dcea5fe4cb9b7f32b1ec6 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 17 Sep 2025 12:24:13 +0000 Subject: [PATCH 14/39] Cleanup bad transpose Signed-off-by: ElizaWszola --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 4c2fab07addb..383bbc88ddad 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -280,9 +280,8 @@ def _run_w8a8_block_fp8_matmul( block_size[1], column_major_scales=False, use_ue8m0=False) - return w8a8_block_fp8_matmul(q_input, weight, x_scale, - weight_scale.t(), block_size, - input_2d.dtype) + return w8a8_block_fp8_matmul(q_input, weight, x_scale, weight_scale, + block_size, input_2d.dtype) def _dispatch_w8a8_blockscale_op( self, From e036dac2b12ccdb0016d6e2f5322f980989f22b6 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 17 Sep 2025 12:37:16 +0000 Subject: [PATCH 15/39] Wrap w8a8_block_fp8_matmul Signed-off-by: ElizaWszola --- .../layers/quantization/utils/fp8_utils.py | 38 ++++++++++++++++++- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 383bbc88ddad..477393d5494d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -148,6 +148,40 @@ def _padded_cutlass_fake( ) +def _w8a8_block_fp8_matmul_func( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return w8a8_block_fp8_matmul(qx, weight, x_scale, weight_scale, block_size, + output_dtype) + + +def _w8a8_block_fp8_matmul_func_fake( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty((qx.size(0), weight.size(0)), + dtype=output_dtype, + device=qx.device) + + +direct_register_custom_op( + "w8a8_block_fp8_matmul_func", + _w8a8_block_fp8_matmul_func, + mutates_args=[], + fake_impl=_w8a8_block_fp8_matmul_func_fake, + dispatch_key="CUDA", +) + + def dispatch_w8a8_blockscale_func( use_cutlass: bool, use_aiter_and_is_supported: bool ) -> Callable[[ @@ -280,8 +314,8 @@ def _run_w8a8_block_fp8_matmul( block_size[1], column_major_scales=False, use_ue8m0=False) - return w8a8_block_fp8_matmul(q_input, weight, x_scale, weight_scale, - block_size, input_2d.dtype) + return torch.ops.vllm.w8a8_block_fp8_matmul_func( + q_input, weight, x_scale, weight_scale, block_size, input_2d.dtype) def _dispatch_w8a8_blockscale_op( self, From 233e874ea11fbe8331537bf78ba527157f81018f Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 17 Sep 2025 11:50:44 -0400 Subject: [PATCH 16/39] Rename padded_cutlass to padded_cutlass_scaled_mm, add todo Signed-off-by: ElizaWszola --- .../layers/quantization/utils/fp8_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 477393d5494d..d2db324cdab8 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -98,7 +98,10 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) -def _padded_cutlass( +# TODO: ideally, we would like to wrap only the padding computation +# and unwrap the rest. This should be possible after solving +# https://github.com/vllm-project/vllm/issues/25080 +def _padded_cutlass_scaled_mm( qx: torch.Tensor, weight: torch.Tensor, x_scale: torch.Tensor, @@ -126,7 +129,7 @@ def _padded_cutlass( return output[0:qx.shape[0], ...] -def _padded_cutlass_fake( +def _padded_cutlass_scaled_mm_fake( qx: torch.Tensor, weight: torch.Tensor, x_scale: torch.Tensor, @@ -140,10 +143,10 @@ def _padded_cutlass_fake( direct_register_custom_op( - "padded_cutlass", - _padded_cutlass, + "padded_cutlass_scaled_mm", + _padded_cutlass_scaled_mm, mutates_args=[], - fake_impl=_padded_cutlass_fake, + fake_impl=_padded_cutlass_scaled_mm_fake, dispatch_key="CUDA", ) @@ -283,9 +286,9 @@ def _run_cutlass( column_major_scales=True, use_ue8m0=False) if self.is_hopper: - output = torch.ops.vllm.padded_cutlass(q_input, weight, x_scale, - weight_scale, block_size, - input_2d.dtype) + output = torch.ops.vllm.padded_cutlass_scaled_mm( + q_input, weight, x_scale, weight_scale, block_size, + input_2d.dtype) else: output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, block_size, input_2d.dtype, False) From 1edfedc50592daca9c5786347782de5287fb5631 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 17 Sep 2025 14:00:16 -0400 Subject: [PATCH 17/39] Cleanup dispatch_w8a8_blockscale_func Signed-off-by: ElizaWszola --- .../model_executor/test_enabled_custom_ops.py | 30 ------------------- .../layers/quantization/utils/fp8_utils.py | 17 ----------- 2 files changed, 47 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 86139d598582..343c0db4086c 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -16,8 +16,6 @@ from vllm.model_executor.layers.layernorm import (RMSNorm, dispatch_rocm_rmsnorm_func, fused_add_rms_norm, rms_norm) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] @@ -109,34 +107,6 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() -@pytest.mark.skipif( - not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), - reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") -@pytest.mark.parametrize("use_cutlass", [True, False]) -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) -def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, - use_rocm_aiter_gemm_w8a8_blockscale: str, - monkeypatch): - - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", - use_rocm_aiter_gemm_w8a8_blockscale) - - use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( - int(use_rocm_aiter_gemm_w8a8_blockscale))) - block_scale_func = dispatch_w8a8_blockscale_func( - use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) - if use_cutlass: - assert block_scale_func == cutlass_scaled_mm - elif current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_gemm_w8a8_blockscale): - assert block_scale_func == ( - torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) - else: - assert block_scale_func == w8a8_block_fp8_matmul - - @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index d2db324cdab8..817133370775 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -185,23 +185,6 @@ def _w8a8_block_fp8_matmul_func_fake( ) -def dispatch_w8a8_blockscale_func( - use_cutlass: bool, use_aiter_and_is_supported: bool -) -> Callable[[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - list[int], - torch.dtype, -], torch.Tensor]: - if use_cutlass: - return cutlass_scaled_mm - if (use_aiter_and_is_supported): - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale - return w8a8_block_fp8_matmul - - # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 class W8A8BlockFp8LinearOp: From 0ac3a1e1c8e9af833f31f9f7017bc206323dc311 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 18 Sep 2025 10:19:09 -0400 Subject: [PATCH 18/39] Deep gemm warmup fix Signed-off-by: ElizaWszola --- vllm/model_executor/warmup/deep_gemm_warmup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index a636a714145c..4d1829cd228c 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -36,7 +36,7 @@ def _extract_data_from_linear_base_module( assert m.quant_method.quant_config is not None w = m.weight - ws = m.weight_scale_inv + ws = m.weight_scale quant_block_size = m.quant_method.quant_config.weight_block_size assert isinstance(w, torch.Tensor) From 9a4810016053e16c33a105e0a501cfa6c0ad94db Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 18 Sep 2025 10:25:31 -0400 Subject: [PATCH 19/39] Fix deep gemm support function Signed-off-by: ElizaWszola --- vllm/utils/deep_gemm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 882b41be4341..c50107cc223d 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -175,10 +175,10 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): def should_use_deepgemm_for_fp8_linear( output_dtype: torch.dtype, weight: torch.Tensor, - is_deep_gemm_supported: Optional[bool] = None): - if is_deep_gemm_supported is None: - is_deep_gemm_supported = is_deep_gemm_supported() - return (is_deep_gemm_supported and output_dtype == torch.bfloat16 + supports_deep_gemm: Optional[bool] = None): + if supports_deep_gemm is None: + supports_deep_gemm = is_deep_gemm_supported() + return (supports_deep_gemm and output_dtype == torch.bfloat16 and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) From b6a8fb85a63ba8d7526d1b179cbb929e51216d2b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 19 Sep 2025 04:08:39 -0400 Subject: [PATCH 20/39] Feedback Signed-off-by: ElizaWszola --- .../schemes/compressed_tensors_w8a8_fp8.py | 7 +- .../model_executor/layers/quantization/fp8.py | 7 +- .../layers/quantization/input_quant_fp8.py | 17 ++-- .../layers/quantization/utils/fp8_utils.py | 83 ++++++++++--------- 4 files changed, 66 insertions(+), 48 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 103c24ba41e8..a470a2a31562 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -49,7 +49,10 @@ def __init__(self, weight_quant: QuantizationArgs, self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() if self.weight_block_size is not None: + assert not self.is_static_input_scheme self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=self.weight_block_size, + act_quant_group_shape=GroupShape(1, self.weight_block_size[1]), cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) @@ -148,11 +151,11 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if layer.weight_block_size is not None: + if self.weight_block_size is not None: return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, - block_size=layer.weight_block_size, + block_size=self.weight_block_size, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ad04a163aecf..804935797deb 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -247,7 +247,10 @@ def __init__(self, quant_config: Fp8Config): self.act_q_group_shape = GroupShape.PER_TENSOR if self.block_quant: + assert not self.act_q_static self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=self.weight_block_size, + act_quant_group_shape=GroupShape(1, self.weight_block_size[1]), cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) @@ -403,12 +406,12 @@ def apply(self, bias=bias) if self.block_quant: - assert layer.weight_block_size is not None + assert self.weight_block_size is not None return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, - block_size=layer.weight_block_size, + block_size=self.weight_block_size, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 31182f40b48f..95d5a8d22581 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -27,11 +27,14 @@ class QuantFP8(CustomOp): This CustomOp supports both static and dynamic quantization. """ - def __init__(self, - static: bool, - group_shape: GroupShape, - num_token_padding: Optional[int] = None, - column_major_scales: bool = False): + def __init__( + self, + static: bool, + group_shape: GroupShape, + num_token_padding: Optional[int] = None, + column_major_scales: bool = False, + use_ue8m0: Optional[bool] = None, # for Torch compile + ): """ :param static: static or dynamic quantization :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, @@ -46,6 +49,7 @@ def __init__(self, self.group_shape = group_shape self.num_token_padding = num_token_padding self.column_major_scales = column_major_scales + self.use_ue8m0 = use_ue8m0 self.is_group_quant = group_shape.is_per_group() if self.is_group_quant: @@ -70,7 +74,8 @@ def forward_cuda( x, group_size=self.group_size, column_major_scales=self.column_major_scales, - dtype=_FP8_DTYPE) + dtype=_FP8_DTYPE, + use_ue8m0=self.use_ue8m0) assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index fc0173f10e9e..21e0ac3d067b 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -13,8 +13,9 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( - group_broadcast) + GroupShape, group_broadcast) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.model_executor.parameter import (BlockQuantScaleParameter, @@ -36,6 +37,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz +# We need to pass in the is_hopper flag as argument because the function +# current_platform.is_device_capability() is not supported by Torch compiler. def cutlass_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -198,20 +201,30 @@ class W8A8BlockFp8LinearOp: def __init__( self, + weight_group_shape: GroupShape, + act_quant_group_shape: GroupShape, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, use_aiter_and_is_supported: bool = False, ): + self.weight_group_shape = weight_group_shape + self.act_quant_group_shape = act_quant_group_shape self.is_deep_gemm_supported = is_deep_gemm_supported() - self.ue8m0_deepgemm_supported = is_deep_gemm_e8m0_used() self.is_hopper = current_platform.is_device_capability(90) - self.w8a8_blockscale_op = self._dispatch_w8a8_blockscale_op( + self.w8a8_blockscale_op, self.input_quant_op = \ + self._dispatch_w8a8_blockscale_op( cutlass_block_fp8_supported, use_aiter_and_is_supported) + self.deepgemm_input_quant_op = (QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=is_deep_gemm_e8m0_used()) if self.is_deep_gemm_supported + else None) def apply( self, input: torch.Tensor, weight: torch.Tensor, - block_size: list[int], + block_size: GroupShape, weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, @@ -224,14 +237,12 @@ def apply( if should_use_deepgemm_for_fp8_linear(output_dtype, weight, self.is_deep_gemm_supported): - output = self._run_deepgemm(input, weight, block_size, - weight_scale) + output = self._run_deepgemm(input, weight, weight_scale) if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) - output = self.w8a8_blockscale_op(input_2d, weight, block_size, - weight_scale) + output = self.w8a8_blockscale_op(input_2d, weight, weight_scale) if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) @@ -240,87 +251,83 @@ def _run_deepgemm( self, input_2d: torch.Tensor, weight: torch.Tensor, - block_size: list[int], weight_scale: torch.Tensor, ) -> torch.Tensor: # ensure DeepGEMM-backed custom op is registered before use import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 - q_input, x_scale = per_token_group_quant_fp8( - input_2d, - block_size[1], - column_major_scales=True, - use_ue8m0=self.ue8m0_deepgemm_supported, - ) + assert self.deepgemm_input_quant_op is not None + q_input, x_scale = self.deepgemm_input_quant_op.forward_cuda(input_2d) return torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( q_input, weight, x_scale, weight_scale, - block_size, + self.weight_group_shape, output_dtype=input_2d.dtype) def _run_cutlass( self, input_2d: torch.Tensor, weight: torch.Tensor, - block_size: list[int], weight_scale: torch.Tensor, ) -> torch.Tensor: - q_input, x_scale = per_token_group_quant_fp8(input_2d, - block_size[1], - column_major_scales=True, - use_ue8m0=False) + q_input, x_scale = self.input_quant_op.forward_cuda(input_2d) if self.is_hopper: output = torch.ops.vllm.padded_cutlass_scaled_mm( - q_input, weight, x_scale, weight_scale, block_size, - input_2d.dtype) + q_input, weight, x_scale, weight_scale, + self.weight_group_shape, input_2d.dtype) else: output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, - block_size, input_2d.dtype, False) + self.weight_group_shape, input_2d.dtype, + False) return output def _run_aiter( self, input_2d: torch.Tensor, weight: torch.Tensor, - block_size: list[int], weight_scale: torch.Tensor, ) -> torch.Tensor: + assert self.act_quant_group_shape == GroupShape(1, 128) q_input, x_scale = aiter_per1x128_quant( input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( - q_input, weight, x_scale, weight_scale, block_size, input_2d.dtype) + q_input, weight, x_scale, weight_scale, self.weight_group_shape, + input_2d.dtype) def _run_w8a8_block_fp8_matmul( self, input_2d: torch.Tensor, weight: torch.Tensor, - block_size: list[int], weight_scale: torch.Tensor, ) -> torch.Tensor: - q_input, x_scale = per_token_group_quant_fp8(input_2d, - block_size[1], - column_major_scales=False, - use_ue8m0=False) + q_input, x_scale = self.input_quant_op.forward_cuda(input_2d) return torch.ops.vllm.w8a8_block_fp8_matmul_func( - q_input, weight, x_scale, weight_scale, block_size, input_2d.dtype) + q_input, weight, x_scale, weight_scale, self.weight_group_shape, + input_2d.dtype) def _dispatch_w8a8_blockscale_op( self, use_cutlass: bool, use_aiter_and_is_supported: bool, - ) -> Callable[[ + ) -> tuple[Callable[[ torch.Tensor, torch.Tensor, - list[int], torch.Tensor, - ], torch.Tensor]: + ], torch.Tensor], Optional[QuantFP8]]: if use_cutlass: - return self._run_cutlass + return self._run_cutlass, (QuantFP8(False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=False)) if use_aiter_and_is_supported: - return self._run_aiter - return self._run_w8a8_block_fp8_matmul + return self._run_aiter, None + return self._run_w8a8_block_fp8_matmul, (QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=False, + use_ue8m0=False)) def input_to_float8( From e89ecd852ea9092e63f46aa3c75cd7af4078d528 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 19 Sep 2025 04:25:36 -0400 Subject: [PATCH 21/39] Pre-commit fixes Signed-off-by: ElizaWszola --- .../schemes/compressed_tensors_w8a8_fp8.py | 3 ++- vllm/model_executor/layers/quantization/fp8.py | 3 ++- .../layers/quantization/utils/fp8_utils.py | 12 ++++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index a470a2a31562..acfc3276acf7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -51,7 +51,8 @@ def __init__(self, weight_quant: QuantizationArgs, if self.weight_block_size is not None: assert not self.is_static_input_scheme self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=self.weight_block_size, + weight_group_shape=GroupShape(self.weight_block_size[0], + self.weight_block_size[1]), act_quant_group_shape=GroupShape(1, self.weight_block_size[1]), cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 804935797deb..fd2a54bf8242 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -249,7 +249,8 @@ def __init__(self, quant_config: Fp8Config): if self.block_quant: assert not self.act_q_static self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=self.weight_block_size, + weight_group_shape=GroupShape(self.weight_block_size[0], + self.weight_block_size[1]), act_quant_group_shape=GroupShape(1, self.weight_block_size[1]), cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 21e0ac3d067b..8ff357d8db4a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -272,15 +272,18 @@ def _run_cutlass( weight: torch.Tensor, weight_scale: torch.Tensor, ) -> torch.Tensor: + assert self.input_quant_op is not None q_input, x_scale = self.input_quant_op.forward_cuda(input_2d) if self.is_hopper: output = torch.ops.vllm.padded_cutlass_scaled_mm( q_input, weight, x_scale, weight_scale, - self.weight_group_shape, input_2d.dtype) + [self.weight_group_shape[0], self.weight_group_shape[1]], + input_2d.dtype) else: - output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, - self.weight_group_shape, input_2d.dtype, - False) + output = cutlass_scaled_mm( + q_input, weight, x_scale, weight_scale, + [self.weight_group_shape[0], self.weight_group_shape[1]], + input_2d.dtype, False) return output def _run_aiter( @@ -302,6 +305,7 @@ def _run_w8a8_block_fp8_matmul( weight: torch.Tensor, weight_scale: torch.Tensor, ) -> torch.Tensor: + assert self.input_quant_op is not None q_input, x_scale = self.input_quant_op.forward_cuda(input_2d) return torch.ops.vllm.w8a8_block_fp8_matmul_func( q_input, weight, x_scale, weight_scale, self.weight_group_shape, From 00cb05c39f77f71dbff88783bbd9d3a4822eff57 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 19 Sep 2025 04:39:55 -0400 Subject: [PATCH 22/39] Pre-commit fixes 2 Signed-off-by: ElizaWszola --- .../compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py | 1 - vllm/model_executor/layers/quantization/fp8.py | 2 +- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index acfc3276acf7..b0c566cccfcd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -156,7 +156,6 @@ def apply_weights(self, return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, - block_size=self.weight_block_size, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fd2a54bf8242..512983aa6532 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -248,6 +248,7 @@ def __init__(self, quant_config: Fp8Config): if self.block_quant: assert not self.act_q_static + assert self.weight_block_size is not None self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( weight_group_shape=GroupShape(self.weight_block_size[0], self.weight_block_size[1]), @@ -412,7 +413,6 @@ def apply(self, return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, - block_size=self.weight_block_size, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 8ff357d8db4a..d6355daafba6 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -224,7 +224,6 @@ def apply( self, input: torch.Tensor, weight: torch.Tensor, - block_size: GroupShape, weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, From 66c89e6ee4881635983b4a07f1ae03b4e5bbff74 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 19 Sep 2025 12:03:36 -0400 Subject: [PATCH 23/39] Feedback Signed-off-by: ElizaWszola --- vllm/config/__init__.py | 4 ++++ .../schemes/compressed_tensors_w8a8_fp8.py | 13 +++++++----- .../model_executor/layers/quantization/fp8.py | 16 +++++++------- .../layers/quantization/input_quant_fp8.py | 1 + .../layers/quantization/utils/fp8_utils.py | 21 +++++++++++-------- 5 files changed, 34 insertions(+), 21 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 25daca00c02d..dbecb34369aa 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2824,6 +2824,10 @@ def __post_init__(self): # local attention. self.scheduler_config.disable_hybrid_kv_cache_manager = True + custom_ops = self.compilation_config.custom_ops + if "none" not in custom_ops and "-quant_fp8" not in custom_ops: + custom_ops.append("+quant_fp8") + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index b0c566cccfcd..fa0816959fcd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -41,19 +41,22 @@ def __init__(self, weight_quant: QuantizationArgs, self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.act_q_group_shape = GroupShape.PER_TENSOR \ - if is_static_input_scheme else GroupShape.PER_TOKEN self.weight_block_size = self.weight_quant.block_structure + if self.weight_block_size is not None: + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) + else: + self.act_q_group_shape = GroupShape.PER_TENSOR \ + if is_static_input_scheme else GroupShape.PER_TOKEN + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() if self.weight_block_size is not None: assert not self.is_static_input_scheme self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(self.weight_block_size[0], - self.weight_block_size[1]), - act_quant_group_shape=GroupShape(1, self.weight_block_size[1]), + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 512983aa6532..59cca9ef84b0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -240,19 +240,21 @@ def __init__(self, quant_config: Fp8Config): self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" - # Use per-token quantization for better perf if dynamic and cutlass - if not self.act_q_static and cutlass_fp8_supported(): - self.act_q_group_shape = GroupShape.PER_TOKEN + if self.weight_block_size: + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) else: - self.act_q_group_shape = GroupShape.PER_TENSOR + # Use per-token quantization for better perf if dynamic and cutlass + if not self.act_q_static and cutlass_fp8_supported(): + self.act_q_group_shape = GroupShape.PER_TOKEN + else: + self.act_q_group_shape = GroupShape.PER_TENSOR if self.block_quant: assert not self.act_q_static assert self.weight_block_size is not None self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(self.weight_block_size[0], - self.weight_block_size[1]), - act_quant_group_shape=GroupShape(1, self.weight_block_size[1]), + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 95d5a8d22581..a5efb8723805 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -96,6 +96,7 @@ def forward_native( ): if self.is_group_quant: assert scale is None, "Group quantization is always dynamic" + assert self.use_ue8m0 is None or not self.use_ue8m0 return self._quantize_group_native(x) assert (scale is not None) == self.static diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index d6355daafba6..a06685063457 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -210,6 +210,11 @@ def __init__( self.act_quant_group_shape = act_quant_group_shape self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_hopper = current_platform.is_device_capability(90) + + # Get the correct blockscale mul and input quant operations. + # We can't use _dispatch_w8a8_blockscale_op to figure out if we want + # to use deepgemm because we don't know the shape of weights (and + # whether deepgemm supports it) at the init time. self.w8a8_blockscale_op, self.input_quant_op = \ self._dispatch_w8a8_blockscale_op( cutlass_block_fp8_supported, use_aiter_and_is_supported) @@ -256,7 +261,7 @@ def _run_deepgemm( import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 assert self.deepgemm_input_quant_op is not None - q_input, x_scale = self.deepgemm_input_quant_op.forward_cuda(input_2d) + q_input, x_scale = self.deepgemm_input_quant_op(input_2d) return torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( q_input, weight, @@ -272,17 +277,15 @@ def _run_cutlass( weight_scale: torch.Tensor, ) -> torch.Tensor: assert self.input_quant_op is not None - q_input, x_scale = self.input_quant_op.forward_cuda(input_2d) + q_input, x_scale = self.input_quant_op(input_2d) if self.is_hopper: output = torch.ops.vllm.padded_cutlass_scaled_mm( q_input, weight, x_scale, weight_scale, - [self.weight_group_shape[0], self.weight_group_shape[1]], - input_2d.dtype) + tuple(self.weight_group_shape), input_2d.dtype) else: - output = cutlass_scaled_mm( - q_input, weight, x_scale, weight_scale, - [self.weight_group_shape[0], self.weight_group_shape[1]], - input_2d.dtype, False) + output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, + tuple(self.weight_group_shape), + input_2d.dtype, False) return output def _run_aiter( @@ -305,7 +308,7 @@ def _run_w8a8_block_fp8_matmul( weight_scale: torch.Tensor, ) -> torch.Tensor: assert self.input_quant_op is not None - q_input, x_scale = self.input_quant_op.forward_cuda(input_2d) + q_input, x_scale = self.input_quant_op(input_2d) return torch.ops.vllm.w8a8_block_fp8_matmul_func( q_input, weight, x_scale, weight_scale, self.weight_group_shape, input_2d.dtype) From d9b412136cc62c3d269209019f93fc37c57963e8 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 19 Sep 2025 12:38:17 -0400 Subject: [PATCH 24/39] fix type issue Signed-off-by: ElizaWszola --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index a06685063457..1bc47dde5acb 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -281,10 +281,10 @@ def _run_cutlass( if self.is_hopper: output = torch.ops.vllm.padded_cutlass_scaled_mm( q_input, weight, x_scale, weight_scale, - tuple(self.weight_group_shape), input_2d.dtype) + list(self.weight_group_shape), input_2d.dtype) else: output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, - tuple(self.weight_group_shape), + list(self.weight_group_shape), input_2d.dtype, False) return output From 1bc81a1c7dd639b2f24d4f84a3d4c6a73b86fa55 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 19 Sep 2025 13:47:29 -0400 Subject: [PATCH 25/39] Add use_ue8m0 support to _quantize_group_native Signed-off-by: ElizaWszola --- tests/kernels/quantization/test_fp8_quant_group.py | 9 ++++++--- .../layers/quantization/input_quant_fp8.py | 6 ++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 720eee62760d..16bae6b95221 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -68,8 +68,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, @pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("use_ue8m0", [True, False]) @torch.inference_mode() -def test_quantfp8_group_multidimensional(seed: int) -> None: +def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None: current_platform.seed_everything(seed) group_size = 64 @@ -82,7 +83,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None: group_shape = GroupShape(1, group_size) quant_op = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=False) + column_major_scales=False, + use_ue8m0=use_ue8m0) x_quant, scales = quant_op.forward_native(x_3d.clone()) assert x_quant.shape == x_3d.shape @@ -91,7 +93,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None: # Test column_major_scales with multi-dim quant_op_col = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=True) + column_major_scales=True, + use_ue8m0=use_ue8m0) _, scales_col = quant_op_col.forward_native(x_3d.clone()) assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index a5efb8723805..e500248e4a70 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -96,7 +96,6 @@ def forward_native( ): if self.is_group_quant: assert scale is None, "Group quantization is always dynamic" - assert self.use_ue8m0 is None or not self.use_ue8m0 return self._quantize_group_native(x) assert (scale is not None) == self.static @@ -143,7 +142,10 @@ def _quantize_group_native( x_grouped = x.view(-1, num_groups, self.group_size) absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() - scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + scale_raw = absmax / _FP8_MAX + if self.use_ue8m0: + scale_raw = torch.exp2(torch.ceil(torch.log2(scale_raw))) + scales = (scale_raw).clamp(min=_FP8_MIN_SCALING_FACTOR) x_scaled = x_grouped / scales x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) From ec73268c4e172303173efa1194ca29db928f8c1d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 22 Sep 2025 03:21:55 -0400 Subject: [PATCH 26/39] Fix padding compilation issue Signed-off-by: ElizaWszola --- .../layers/quantization/utils/fp8_utils.py | 65 +++---------------- 1 file changed, 8 insertions(+), 57 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 1bc47dde5acb..7215597286de 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -104,59 +104,6 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) -# TODO: ideally, we would like to wrap only the padding computation -# and unwrap the rest. This should be possible after solving -# https://github.com/vllm-project/vllm/issues/25080 -def _padded_cutlass_scaled_mm( - qx: torch.Tensor, - weight: torch.Tensor, - x_scale: torch.Tensor, - weight_scale: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype, -) -> torch.Tensor: - pad_multiple = 4 - dim = qx.shape[0] - padded = dim if dim % pad_multiple == 0 else dim + pad_multiple - ( - dim % pad_multiple) - - padded_shape = [padded, *qx.shape[1:]] - padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype) - padded_qx[0:qx.shape[0], ...].copy_(qx) - - padded_x_scale_shape = [*x_scale.shape[1:], padded] - padded_x_scale = torch.ones(padded_x_scale_shape, - device=x_scale.device, - dtype=x_scale.dtype).permute(-1, -2) - padded_x_scale[0:x_scale.shape[0], ...].copy_(x_scale) - - output = cutlass_scaled_mm(padded_qx, weight, padded_x_scale, weight_scale, - block_size, output_dtype, True) - return output[0:qx.shape[0], ...] - - -def _padded_cutlass_scaled_mm_fake( - qx: torch.Tensor, - weight: torch.Tensor, - x_scale: torch.Tensor, - weight_scale: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype, -) -> torch.Tensor: - return torch.empty((qx.size(0), weight.size(0)), - dtype=output_dtype, - device=qx.device) - - -direct_register_custom_op( - "padded_cutlass_scaled_mm", - _padded_cutlass_scaled_mm, - mutates_args=[], - fake_impl=_padded_cutlass_scaled_mm_fake, - dispatch_key="CUDA", -) - - def _w8a8_block_fp8_matmul_func( qx: torch.Tensor, weight: torch.Tensor, @@ -277,12 +224,16 @@ def _run_cutlass( weight_scale: torch.Tensor, ) -> torch.Tensor: assert self.input_quant_op is not None - q_input, x_scale = self.input_quant_op(input_2d) if self.is_hopper: - output = torch.ops.vllm.padded_cutlass_scaled_mm( - q_input, weight, x_scale, weight_scale, - list(self.weight_group_shape), input_2d.dtype) + padded_x = torch.nn.functional.pad( + input_2d, (0, 0, 0, -input_2d.shape[0] % 4)) + q_input, x_scale = self.input_quant_op(padded_x) + output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, + list(self.weight_group_shape), + input_2d.dtype, True) + output = output[0:input_2d.shape[0], ...] else: + q_input, x_scale = self.input_quant_op(input_2d) output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, list(self.weight_group_shape), input_2d.dtype, False) From d19bf4b791af890a121c1666ea224fbb95bce3c4 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 22 Sep 2025 08:46:57 -0400 Subject: [PATCH 27/39] Feedback Signed-off-by: ElizaWszola --- .../cutlass_benchmarks/w8a8_benchmarks.py | 4 +- .../benchmark_fp8_block_dense_gemm.py | 4 +- tests/kernels/quantization/test_block_fp8.py | 5 ++- .../quantization/test_fp8_quant_group.py | 10 +++-- vllm/config/__init__.py | 13 +++++-- .../layers/quantization/deepgemm.py | 10 ++--- .../layers/quantization/utils/fp8_utils.py | 38 ++++++++++--------- 7 files changed, 49 insertions(+), 35 deletions(-) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index a5a5b52f6039..02f8c593392c 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -17,7 +17,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - w8a8_block_fp8_matmul, + w8a8_triton_block_scaled_mm, ) from vllm.utils import FlexibleArgumentParser, cdiv @@ -158,7 +158,7 @@ def bench_fp8( "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) ), - "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul( + "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm( a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) ), "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index b99c2099f2c3..b3c3742825de 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, - w8a8_block_fp8_matmul, + w8a8_triton_block_scaled_mm, ) from vllm.triton_utils import triton from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8 @@ -59,7 +59,7 @@ def deepgemm_gemm(): # === vLLM Triton Implementation === def vllm_triton_gemm(): - return w8a8_block_fp8_matmul(A_vllm, + return w8a8_triton_block_scaled_mm(A_vllm, B_vllm, A_scale_vllm, B_scale_vllm, diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c440747316b8..c0b934fc55ae 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -12,7 +12,7 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( cutlass_scaled_mm, get_col_major_tma_aligned_tensor, - per_token_group_quant_fp8, w8a8_block_fp8_matmul) + per_token_group_quant_fp8, w8a8_triton_block_scaled_mm) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8 @@ -90,7 +90,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 16bae6b95221..9150a66818cb 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -20,9 +20,11 @@ (8, 513, 64), # Non-divisible (native only) ]) @pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("use_ue8m0", [True, False]) @torch.inference_mode() def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, - group_size: int, seed: int) -> None: + group_size: int, seed: int, + use_ue8m0: bool) -> None: """Test QuantFP8 group quantization with various configurations. Tests both CUDA and native implementations, column-major scales, @@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, group_shape = GroupShape(1, group_size) quant_op = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=False) + column_major_scales=False, + use_ue8m0=use_ue8m0) # 1. Test native implementation (always available) x_quant_native, scales_native = quant_op.forward_native(x.clone()) @@ -48,7 +51,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, # 2. Test column-major scales configuration quant_op_col = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=True) + column_major_scales=True, + use_ue8m0=use_ue8m0) _, scales_col = quant_op_col.forward_native(x.clone()) assert scales_col.shape == (expected_num_groups, batch_size) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index dbecb34369aa..65c06e69304c 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2824,9 +2824,16 @@ def __post_init__(self): # local attention. self.scheduler_config.disable_hybrid_kv_cache_manager = True - custom_ops = self.compilation_config.custom_ops - if "none" not in custom_ops and "-quant_fp8" not in custom_ops: - custom_ops.append("+quant_fp8") + # Enable quant_fp8 CUDA ops when we do block-wise quantization + # due to an incorrect group shape issue when compiling native + # + # Also, on H100 since it's faster than native implementation + # https://github.com/vllm-project/vllm/issues/25094 + if (self.quant_config is not None + and self.quant_config.weight_block_size is not None): + custom_ops = self.compilation_config.custom_ops + if "none" not in custom_ops and "-quant_fp8" not in custom_ops: + custom_ops.append("+quant_fp8") def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index d26a932eddb2..c2b3ccf19fca 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -43,7 +43,7 @@ def prepare_block_fp8_matmul_inputs( return M, N, K, C -def w8a8_block_fp8_matmul_deepgemm( +def w8a8_deepgemm_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -59,7 +59,7 @@ def w8a8_block_fp8_matmul_deepgemm( return C -def w8a8_block_fp8_matmul_deepgemm_fake( +def w8a8_deepgemm_block_scaled_mm_fake( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -73,9 +73,9 @@ def w8a8_block_fp8_matmul_deepgemm_fake( direct_register_custom_op( - op_name="w8a8_block_fp8_matmul_deepgemm", - op_func=w8a8_block_fp8_matmul_deepgemm, + op_name="w8a8_deepgemm_block_scaled_mm", + op_func=w8a8_deepgemm_block_scaled_mm, mutates_args=[], - fake_impl=w8a8_block_fp8_matmul_deepgemm_fake, + fake_impl=w8a8_deepgemm_block_scaled_mm_fake, dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7215597286de..8d3d2b7f0263 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -104,7 +104,10 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) -def _w8a8_block_fp8_matmul_func( +# TODO we should be able to change the type of block_size to GroupShape +# after we resolve GroupShape compilation issue +# https://github.com/vllm-project/vllm/issues/25270 +def _w8a8_triton_block_scaled_mm_func( qx: torch.Tensor, weight: torch.Tensor, x_scale: torch.Tensor, @@ -112,11 +115,11 @@ def _w8a8_block_fp8_matmul_func( block_size: list[int], output_dtype: torch.dtype, ) -> torch.Tensor: - return w8a8_block_fp8_matmul(qx, weight, x_scale, weight_scale, block_size, - output_dtype) + return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale, + block_size, output_dtype) -def _w8a8_block_fp8_matmul_func_fake( +def _w8a8_triton_block_scaled_mm_fake( qx: torch.Tensor, weight: torch.Tensor, x_scale: torch.Tensor, @@ -130,10 +133,10 @@ def _w8a8_block_fp8_matmul_func_fake( direct_register_custom_op( - "w8a8_block_fp8_matmul_func", - _w8a8_block_fp8_matmul_func, + "w8a8_triton_block_scaled_mm_func", + _w8a8_triton_block_scaled_mm_func, mutates_args=[], - fake_impl=_w8a8_block_fp8_matmul_func_fake, + fake_impl=_w8a8_triton_block_scaled_mm_fake, dispatch_key="CUDA", ) @@ -209,7 +212,7 @@ def _run_deepgemm( assert self.deepgemm_input_quant_op is not None q_input, x_scale = self.deepgemm_input_quant_op(input_2d) - return torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( + return torch.ops.vllm.w8a8_deepgemm_block_scaled_mm( q_input, weight, x_scale, @@ -252,7 +255,7 @@ def _run_aiter( q_input, weight, x_scale, weight_scale, self.weight_group_shape, input_2d.dtype) - def _run_w8a8_block_fp8_matmul( + def _run_triton( self, input_2d: torch.Tensor, weight: torch.Tensor, @@ -260,7 +263,7 @@ def _run_w8a8_block_fp8_matmul( ) -> torch.Tensor: assert self.input_quant_op is not None q_input, x_scale = self.input_quant_op(input_2d) - return torch.ops.vllm.w8a8_block_fp8_matmul_func( + return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( q_input, weight, x_scale, weight_scale, self.weight_group_shape, input_2d.dtype) @@ -280,11 +283,10 @@ def _dispatch_w8a8_blockscale_op( use_ue8m0=False)) if use_aiter_and_is_supported: return self._run_aiter, None - return self._run_w8a8_block_fp8_matmul, (QuantFP8( - False, - self.act_quant_group_shape, - column_major_scales=False, - use_ue8m0=False)) + return self._run_triton, (QuantFP8(False, + self.act_quant_group_shape, + column_major_scales=False, + use_ue8m0=False)) def input_to_float8( @@ -536,7 +538,7 @@ def per_token_group_quant_fp8( @triton.jit -def _w8a8_block_fp8_matmul( +def _w8a8_triton_block_scaled_mm( # Pointers to inputs and output A, B, @@ -661,7 +663,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, return None -def w8a8_block_fp8_matmul( +def w8a8_triton_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -721,7 +723,7 @@ def grid(META): return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - _w8a8_block_fp8_matmul[grid]( + _w8a8_triton_block_scaled_mm[grid]( A, B, C, From 1f895e9e46be556dba829bc956fe0560ad7c8845 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 22 Sep 2025 14:47:44 +0200 Subject: [PATCH 28/39] Update vllm/model_executor/layers/quantization/utils/fp8_utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Luka Govedič Signed-off-by: ElizaWszola Signed-off-by: ElizaWszola --- .../layers/quantization/utils/fp8_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 8d3d2b7f0263..a89c3e5b0681 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -228,18 +228,18 @@ def _run_cutlass( ) -> torch.Tensor: assert self.input_quant_op is not None if self.is_hopper: - padded_x = torch.nn.functional.pad( + # We pad unconditionally (even if shape is already divisible by 4) + # to support dynamic shape for input_2d.shape[0] in torch.compile + x = torch.nn.functional.pad( input_2d, (0, 0, 0, -input_2d.shape[0] % 4)) - q_input, x_scale = self.input_quant_op(padded_x) - output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, - list(self.weight_group_shape), - input_2d.dtype, True) - output = output[0:input_2d.shape[0], ...] else: - q_input, x_scale = self.input_quant_op(input_2d) - output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, + x = input_2d + + q_input, x_scale = self.input_quant_op(x) + output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, False) + input_2d.dtype, self.is_hopper) + output = output[0:input_2d.shape[0], ...] return output def _run_aiter( From be3ac58015cc84023ef2072a88c49535bef093b8 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 22 Sep 2025 08:57:14 -0400 Subject: [PATCH 29/39] Link bad group shape issue Signed-off-by: ElizaWszola --- vllm/config/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 65c06e69304c..a60478d4b9f0 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2826,6 +2826,7 @@ def __post_init__(self): # Enable quant_fp8 CUDA ops when we do block-wise quantization # due to an incorrect group shape issue when compiling native + # https://github.com/vllm-project/vllm/issues/25382 # # Also, on H100 since it's faster than native implementation # https://github.com/vllm-project/vllm/issues/25094 From 3772f2f2fd7f647234073cdfb19064ccad0a9aa9 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 22 Sep 2025 09:11:25 -0400 Subject: [PATCH 30/39] format Signed-off-by: ElizaWszola --- .../layers/quantization/utils/fp8_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index a89c3e5b0681..b80ccbee886a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -230,15 +230,15 @@ def _run_cutlass( if self.is_hopper: # We pad unconditionally (even if shape is already divisible by 4) # to support dynamic shape for input_2d.shape[0] in torch.compile - x = torch.nn.functional.pad( - input_2d, (0, 0, 0, -input_2d.shape[0] % 4)) + x = torch.nn.functional.pad(input_2d, + (0, 0, 0, -input_2d.shape[0] % 4)) else: x = input_2d - + q_input, x_scale = self.input_quant_op(x) output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, - list(self.weight_group_shape), - input_2d.dtype, self.is_hopper) + list(self.weight_group_shape), + input_2d.dtype, self.is_hopper) output = output[0:input_2d.shape[0], ...] return output From 2a87a3b035d4d6c85d7caf4d09fc3bc591e5cb36 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 22 Sep 2025 09:56:06 -0400 Subject: [PATCH 31/39] fix quant config condition Signed-off-by: ElizaWszola --- vllm/config/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index aafb188e9ec4..9251d0cc9ab7 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -692,6 +692,7 @@ def __post_init__(self): # Also, on H100 since it's faster than native implementation # https://github.com/vllm-project/vllm/issues/25094 if (self.quant_config is not None + and hasattr(self.quant_config, "weight_block_size") and self.quant_config.weight_block_size is not None): custom_ops = self.compilation_config.custom_ops if "none" not in custom_ops and "-quant_fp8" not in custom_ops: From e7f6ec92be1f47e57a2e4457dc03fe4f5378a286 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 22 Sep 2025 14:55:03 -0700 Subject: [PATCH 32/39] fix quant issue (TODO test) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/model_executor/layers/quantization/input_quant_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index e500248e4a70..caee7a84b82c 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -159,6 +159,6 @@ def _quantize_group_native( scales = scales.reshape(orig_shape[:-1] + (num_groups, )) if self.column_major_scales: - scales = scales.transpose(-2, -1).contiguous() + scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2) return x_quant, scales From 10829d39f3536acff67b72a9e510eaf9493ede69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 22 Sep 2025 15:40:52 -0700 Subject: [PATCH 33/39] fix custom op test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .../model_executor/test_enabled_custom_ops.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 343c0db4086c..52bfc0283e33 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional import pytest import torch @@ -32,15 +33,15 @@ class Relu3(ReLUSquaredActivation): [ # Default values based on compile level # - All by default (no Inductor compilation) - ("", 0, False, [True] * 4, True), - ("", 1, True, [True] * 4, True), - ("", 2, False, [True] * 4, True), + (None, 0, False, [True] * 4, True), + (None, 1, True, [True] * 4, True), + (None, 2, False, [True] * 4, True), # - None by default (with Inductor) - ("", 3, True, [False] * 4, False), - ("", 4, True, [False] * 4, False), + (None, 3, True, [False] * 4, False), + (None, 4, True, [False] * 4, False), # - All by default (without Inductor) - ("", 3, False, [True] * 4, True), - ("", 4, False, [True] * 4, True), + (None, 3, False, [True] * 4, True), + (None, 4, False, [True] * 4, True), # Explicitly enabling/disabling # # Default: all @@ -52,7 +53,7 @@ class Relu3(ReLUSquaredActivation): # All but SiluAndMul ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True), # All but ReLU3 (even if ReLU2 is on) - ("-relu3,relu2", 3, False, [1, 1, 1, 0], True), + ("-relu3,+relu2", 3, False, [1, 1, 1, 0], True), # RMSNorm and SiluAndMul ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False), # All but RMSNorm @@ -65,12 +66,13 @@ class Relu3(ReLUSquaredActivation): # All but RMSNorm ("all,-rms_norm", 4, True, [0, 1, 1, 1], True), ]) -def test_enabled_ops(env: str, torch_level: int, use_inductor: bool, +def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool, ops_enabled: list[int], default_on: bool): + custom_ops = [] if env is None else env.split(",") vllm_config = VllmConfig( compilation_config=CompilationConfig(use_inductor=bool(use_inductor), level=torch_level, - custom_ops=env.split(","))) + custom_ops=custom_ops)) with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on From ebdcb103c4e00352c3a43c2d436c230f7603c7a4 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 23 Sep 2025 01:59:48 -0400 Subject: [PATCH 34/39] CUDA condition for compressed tensors and H100 Signed-off-by: ElizaWszola --- vllm/config/__init__.py | 20 +++++++++++-------- .../compressed_tensors/compressed_tensors.py | 8 ++++++++ .../layers/quantization/input_quant_fp8.py | 6 +++--- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 61b6e1579214..4cb5c0af47a5 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -685,15 +685,19 @@ def __post_init__(self): # local attention. self.scheduler_config.disable_hybrid_kv_cache_manager = True - # Enable quant_fp8 CUDA ops when we do block-wise quantization - # due to an incorrect group shape issue when compiling native - # https://github.com/vllm-project/vllm/issues/25382 - # - # Also, on H100 since it's faster than native implementation + def has_blocked_weights(): + if self.quant_config is not None: + if hasattr(self.quant_config, "weight_block_size"): + return self.quant_config.weight_block_size is not None + elif hasattr(self.quant_config, "has_blocked_weights"): + return self.quant_config.has_blocked_weights() + return False + + # Enable quant_fp8 CUDA ops on H100 since it's faster than + # native implementation # https://github.com/vllm-project/vllm/issues/25094 - if (self.quant_config is not None - and hasattr(self.quant_config, "weight_block_size") - and self.quant_config.weight_block_size is not None): + if current_platform.has_device_capability( + 90) and has_blocked_weights(): custom_ops = self.compilation_config.custom_ops if "none" not in custom_ops and "-quant_fp8" not in custom_ops: custom_ops.append("+quant_fp8") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d6550dd16892..3f771ea2abd1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -644,6 +644,14 @@ def get_cache_scale(self, name: str) -> Optional[str]: # If no matches, return None return None + def has_blocked_weights(self) -> bool: + for scheme in self.target_scheme_map.values(): + weight_quant = scheme.get("weights") + if (weight_quant is not None + and weight_quant.strategy == QuantizationStrategy.BLOCK): + return True + return False + @staticmethod def supports_cutlass_24( weight_quant: Optional[QuantizationArgs], diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index caee7a84b82c..ece3e5817116 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -142,10 +142,10 @@ def _quantize_group_native( x_grouped = x.view(-1, num_groups, self.group_size) absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() - scale_raw = absmax / _FP8_MAX + scales_raw = absmax / _FP8_MAX if self.use_ue8m0: - scale_raw = torch.exp2(torch.ceil(torch.log2(scale_raw))) - scales = (scale_raw).clamp(min=_FP8_MIN_SCALING_FACTOR) + scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw))) + scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR) x_scaled = x_grouped / scales x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) From 2e3d206d383b150655a1f9abf8128b780887e113 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 23 Sep 2025 02:47:01 -0400 Subject: [PATCH 35/39] Fix quantfp8 test Signed-off-by: ElizaWszola --- tests/kernels/quantization/test_fp8_quant_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 9150a66818cb..9f64a7ea09ad 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -54,7 +54,7 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, column_major_scales=True, use_ue8m0=use_ue8m0) _, scales_col = quant_op_col.forward_native(x.clone()) - assert scales_col.shape == (expected_num_groups, batch_size) + assert scales_col.shape == (batch_size, expected_num_groups) # 3. Test CUDA implementation (only for divisible dimensions) if is_divisible: From bd32cb932adab7e99232b723c615df7fc463e392 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 23 Sep 2025 02:52:59 -0400 Subject: [PATCH 36/39] Test scales_col vs. scales_native Signed-off-by: ElizaWszola --- tests/kernels/quantization/test_fp8_quant_group.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 9f64a7ea09ad..53e01b10d605 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -56,6 +56,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, _, scales_col = quant_op_col.forward_native(x.clone()) assert scales_col.shape == (batch_size, expected_num_groups) + # Test column-major scales consistency + assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8) + # 3. Test CUDA implementation (only for divisible dimensions) if is_divisible: x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) From 1f0080475e183267931e191428be2d8348334415 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 23 Sep 2025 10:20:41 -0400 Subject: [PATCH 37/39] Add compressed tensors model test Signed-off-by: ElizaWszola --- tests/quantization/test_compressed_tensors.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index c0ab3fbb1062..5ea981d21749 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -18,6 +18,9 @@ CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp) from vllm.model_executor.layers.quantization.utils.quant_utils import ( cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -742,3 +745,30 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt, perplexity = llm.generate_prompt_perplexity([prompt])[0] print(perplexity) assert perplexity <= exp_perplexity + + +def test_compressed_tensors_fp8_block_enabled(vllm_runner): + model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK" + with vllm_runner(model_path) as llm: + + is_sm90 = current_platform.has_device_capability(90) + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8) + assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear, + W8A8BlockFp8LinearOp) + + input_quant_op = \ + qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op + assert isinstance(input_quant_op, QuantFP8) + assert input_quant_op.enabled() == is_sm90 + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output From e895df60611320a3aac2b294791355cd128bd983 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 23 Sep 2025 10:46:50 -0400 Subject: [PATCH 38/39] Extra asserts, don't use enabled() Signed-off-by: ElizaWszola --- tests/kernels/quantization/test_fp8_quant_group.py | 2 ++ tests/quantization/test_compressed_tensors.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 53e01b10d605..3d4c851a9b88 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -55,6 +55,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, use_ue8m0=use_ue8m0) _, scales_col = quant_op_col.forward_native(x.clone()) assert scales_col.shape == (batch_size, expected_num_groups) + assert scales_col.stride(0) == 1 + assert scales_col.stride(1) == batch_size # Test column-major scales consistency assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 5ea981d21749..ad89b8a27234 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -752,6 +752,7 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner): with vllm_runner(model_path) as llm: is_sm90 = current_platform.has_device_capability(90) + fp8_dtype = current_platform.fp8_dtype() def check_model(model): layer = model.model.layers[0] @@ -763,10 +764,17 @@ def check_model(model): assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear, W8A8BlockFp8LinearOp) + assert qkv_proj.weight.dtype is fp8_dtype + assert qkv_proj.weight_scale.dtype is torch.float32 + assert len(qkv_proj.weight.shape) == 2 + assert len(qkv_proj.weight_scale.shape) == 2 + input_quant_op = \ qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op assert isinstance(input_quant_op, QuantFP8) - assert input_quant_op.enabled() == is_sm90 + quant_enabled = \ + input_quant_op._forward_method == input_quant_op.forward_cuda + assert quant_enabled == is_sm90 llm.apply_model(check_model) From 9806cf851cb4669ce114f34b4f324b4abc4b7a07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 23 Sep 2025 11:48:04 -0400 Subject: [PATCH 39/39] CUDA path for quant MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/quantization/test_compressed_tensors.py | 5 +---- vllm/config/__init__.py | 6 +++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index ad89b8a27234..af8c7ec3b482 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -751,7 +751,6 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner): model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK" with vllm_runner(model_path) as llm: - is_sm90 = current_platform.has_device_capability(90) fp8_dtype = current_platform.fp8_dtype() def check_model(model): @@ -772,9 +771,7 @@ def check_model(model): input_quant_op = \ qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op assert isinstance(input_quant_op, QuantFP8) - quant_enabled = \ - input_quant_op._forward_method == input_quant_op.forward_cuda - assert quant_enabled == is_sm90 + assert input_quant_op._forward_method == input_quant_op.forward_cuda llm.apply_model(check_model) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 4cb5c0af47a5..1561df2a0fcf 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -693,11 +693,11 @@ def has_blocked_weights(): return self.quant_config.has_blocked_weights() return False - # Enable quant_fp8 CUDA ops on H100 since it's faster than + # Enable quant_fp8 CUDA ops (TODO disable in follow up) + # On H100 the CUDA kernel is faster than # native implementation # https://github.com/vllm-project/vllm/issues/25094 - if current_platform.has_device_capability( - 90) and has_blocked_weights(): + if has_blocked_weights(): custom_ops = self.compilation_config.custom_ops if "none" not in custom_ops and "-quant_fp8" not in custom_ops: custom_ops.append("+quant_fp8")