From 971e948adf89f3b2f02d29e91caa4c38f537906d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 19 Sep 2025 01:57:28 +0000 Subject: [PATCH 1/6] seems to work Signed-off-by: Bill Nell --- .../compressed_tensors_moe.py | 131 +++++++++++++++--- .../model_executor/warmup/deep_gemm_warmup.py | 10 +- 2 files changed, 122 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 85adae32f4cd..b222b48a3c98 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -13,6 +13,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, @@ -31,6 +32,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) @@ -45,10 +48,17 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used logger = init_logger(__name__) +def _is_col_major(x: torch.Tensor) -> bool: + assert x.dim() == 3 + b, m, n = x.shape + return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m + + class GPTQMarlinState(Enum): REPACK = enum.auto() READY = enum.auto() @@ -505,10 +515,13 @@ def __init__( self.weight_quant.strategy == QuantizationStrategy.CHANNEL and self.input_quant.strategy == QuantizationStrategy.TOKEN) if not (per_tensor or per_channel): - raise ValueError( - "For FP8 Fused MoE layers, we require per tensor " - "or channelwise, dynamic per token quantization. Found " - f"{self.weight_quant}, {self.input_quant}") + assert self.weight_quant.strategy == QuantizationStrategy.BLOCK + logger.debug("WQ = %s", str(self.weight_quant)) + self.weight_block_size = self.weight_quant.block_structure + # TODO: self.weight_quant.dynamic + else: + self.weight_block_size = None + self.block_quant = self.weight_block_size is not None self.static_input_scales = not self.input_quant.dynamic if self.static_input_scales and per_channel: @@ -519,7 +532,8 @@ def __init__( # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + and not self.block_quant) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False @@ -531,13 +545,20 @@ def __init__( # cutlass path self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( self.weight_quant, self.input_quant) - self.use_cutlass = (quant_config._is_fp8_w8a8_sm90( - self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) + self.use_cutlass = not self.block_quant and ( + quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant) + or self.is_fp8_w8a8_sm100) self.disable_expert_map = False - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size @@ -547,6 +568,31 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype = torch.float8_e4m3fn + if self.block_quant: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.weight_block_size[0], + self.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size_per_partition % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_n = {block_n}.") + if (tp_size > 1 + and intermediate_size_per_partition % block_k != 0): + # Required by row parallel + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}.") + # WEIGHTS w13_weight = torch.nn.Parameter(torch.empty( num_experts, @@ -602,6 +648,29 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * + ((intermediate_size_per_partition + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + # INPUT_SCALES if self.static_input_scales: w13_input_scale = torch.nn.Parameter(torch.ones( @@ -623,6 +692,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: + # TODO(bnell): Is this assert right? assert self.input_quant.strategy == QuantizationStrategy.TENSOR if (layer.w13_input_scale is None or layer.w2_input_scale is None): raise ValueError( @@ -706,6 +776,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: del layer.w2_input_scale if self.use_cutlass: + assert self.weight_quant.strategy != QuantizationStrategy.BLOCK device = layer.w13_weight.device # ab_strides1 and c_strides2 are the same self.ab_strides1_c_strides2 = torch.full( @@ -724,6 +795,30 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device=device, dtype=torch.int64) + # XXXXXXXXXXXXX + if is_deep_gemm_e8m0_used() and self.block_quant: + assert layer.weight_block_size is not None + # Re-quantise the expert weights so their scales are UE8M0. + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.w13_weight.data, + layer.w13_weight_scale.data, + block_sz, + ) + requant_weight_ue8m0_inplace( + layer.w2_weight.data, + layer.w2_weight_scale.data, + block_sz, + ) + + # Ensure column-major TMA alignment expected by DeepGEMM. + if _is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w13_weight_scale) + if _is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w2_weight_scale) + def maybe_make_prepare_finalize( self) -> Optional[mk.FusedMoEPrepareAndFinalize]: if self.use_marlin or self.rocm_aiter_moe_enabled: @@ -777,9 +872,10 @@ def select_gemm_impl( return experts # triton path - from vllm.model_executor.layers.fused_moe import TritonExperts - from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) + from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) assert not self.rocm_aiter_moe_enabled and not self.use_marlin @@ -790,14 +886,16 @@ def select_gemm_impl( assert max_num_tokens_per_rank is not None logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - return BatchedTritonExperts( + return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, ) else: - logger.debug("TritonExperts(%s)", self.__class__.__name__) - return TritonExperts(self.moe_quant_config) + logger.debug("TritonOrDeepGemmExperts(%s)", + self.__class__.__name__) + return TritonOrDeepGemmExperts(self.moe_quant_config, + allow_deep_gemm=True) def get_fused_moe_quant_config( self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: @@ -816,6 +914,7 @@ def get_fused_moe_quant_config( a2_scale=layer.w2_input_scale, per_act_token_quant=per_act_token, per_out_ch_quant=per_channel_quant, + block_shape=layer.weight_block_size, ) def apply( diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 4d1829cd228c..472cce22d8c5 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -75,9 +75,13 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: and module.quant_method.block_quant): return False - w, _, block_sizes = _extract_data_from_linear_base_module(module) - return (block_sizes == deep_gemm_block_shape() and w.ndim == 2 - and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0) + try: + w, _, block_sizes = _extract_data_from_linear_base_module(module) + return (block_sizes == deep_gemm_block_shape() and w.ndim == 2 + and w.shape[0] % block_size == 0 + and w.shape[1] % block_size == 0) + except Exception: + return False def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: From 3ad06568889b0b0739320deda63e0b0628243e0a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 19 Sep 2025 16:02:53 +0000 Subject: [PATCH 2/6] remove _inv Signed-off-by: Bill Nell --- .../compressed_tensors/compressed_tensors_moe.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index b222b48a3c98..da2def13348f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -668,8 +668,6 @@ def create_weights( {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) - layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) - layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) # INPUT_SCALES if self.static_input_scales: @@ -812,11 +810,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Ensure column-major TMA alignment expected by DeepGEMM. - if _is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( + if _is_col_major(layer.w13_weight_scale): + layer.w13_weight_scale = get_col_major_tma_aligned_tensor( layer.w13_weight_scale) - if _is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( + if _is_col_major(layer.w2_weight_scale): + layer.w2_weight_scale = get_col_major_tma_aligned_tensor( layer.w2_weight_scale) def maybe_make_prepare_finalize( From 1ce411492e6eb62323c96415ae974c1ac7f694e0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 22 Sep 2025 19:07:10 +0000 Subject: [PATCH 3/6] cleanup Signed-off-by: Bill Nell --- .../compressed_tensors_moe.py | 26 +++++++------------ .../model_executor/warmup/deep_gemm_warmup.py | 10 +++---- 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index da2def13348f..fd73695c0a18 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -518,7 +518,7 @@ def __init__( assert self.weight_quant.strategy == QuantizationStrategy.BLOCK logger.debug("WQ = %s", str(self.weight_quant)) self.weight_block_size = self.weight_quant.block_structure - # TODO: self.weight_quant.dynamic + assert self.weight_quant.dynamic is not None else: self.weight_block_size = None self.block_quant = self.weight_block_size is not None @@ -550,15 +550,9 @@ def __init__( or self.is_fp8_w8a8_sm100) self.disable_expert_map = False - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size @@ -668,6 +662,8 @@ def create_weights( {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) # INPUT_SCALES if self.static_input_scales: @@ -690,7 +686,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: - # TODO(bnell): Is this assert right? assert self.input_quant.strategy == QuantizationStrategy.TENSOR if (layer.w13_input_scale is None or layer.w2_input_scale is None): raise ValueError( @@ -793,7 +788,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device=device, dtype=torch.int64) - # XXXXXXXXXXXXX if is_deep_gemm_e8m0_used() and self.block_quant: assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. @@ -810,11 +804,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Ensure column-major TMA alignment expected by DeepGEMM. - if _is_col_major(layer.w13_weight_scale): - layer.w13_weight_scale = get_col_major_tma_aligned_tensor( + if _is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( layer.w13_weight_scale) - if _is_col_major(layer.w2_weight_scale): - layer.w2_weight_scale = get_col_major_tma_aligned_tensor( + if _is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( layer.w2_weight_scale) def maybe_make_prepare_finalize( diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 472cce22d8c5..4d1829cd228c 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -75,13 +75,9 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: and module.quant_method.block_quant): return False - try: - w, _, block_sizes = _extract_data_from_linear_base_module(module) - return (block_sizes == deep_gemm_block_shape() and w.ndim == 2 - and w.shape[0] % block_size == 0 - and w.shape[1] % block_size == 0) - except Exception: - return False + w, _, block_sizes = _extract_data_from_linear_base_module(module) + return (block_sizes == deep_gemm_block_shape() and w.ndim == 2 + and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0) def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: From 8250bea8810e2d307bc2715e41864fbc5f986c32 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 22 Sep 2025 22:01:43 +0000 Subject: [PATCH 4/6] remove _inv suffix from weight scales Signed-off-by: Bill Nell --- .../compressed_tensors/compressed_tensors_moe.py | 10 ++++------ vllm/model_executor/warmup/deep_gemm_warmup.py | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index fd73695c0a18..a90e38de3611 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -662,8 +662,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) - layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) - layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) # INPUT_SCALES if self.static_input_scales: @@ -804,11 +802,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Ensure column-major TMA alignment expected by DeepGEMM. - if _is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( + if _is_col_major(layer.w13_weight_scale): + layer.w13_weight_scale = get_col_major_tma_aligned_tensor( layer.w13_weight_scale) - if _is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( + if _is_col_major(layer.w2_weight_scale): + layer.w2_weight_scale = get_col_major_tma_aligned_tensor( layer.w2_weight_scale) def maybe_make_prepare_finalize( diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 4d1829cd228c..f6df85a50238 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -53,9 +53,9 @@ def _extract_data_from_fused_moe_module( """ assert isinstance(m, FusedMoE) w13 = m.w13_weight - w13_s = m.w13_weight_scale_inv + w13_s = getattr(m, "w13_weight_scale_inv", m.w13_weight_scale) w2 = m.w2_weight - w2_s = m.w2_weight_scale_inv + w2_s = getattr(m, "w2_weight_scale_inv", m.w2_weight_scale) num_topk = m.top_k assert isinstance(w13, torch.Tensor) From dc3c2330675da9ccf013e55d896802f105d330d9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 22 Sep 2025 22:32:26 +0000 Subject: [PATCH 5/6] make util Signed-off-by: Bill Nell --- .../compressed_tensors_moe.py | 13 ++++------- .../model_executor/layers/quantization/fp8.py | 22 +++++++------------ .../layers/quantization/utils/fp8_utils.py | 6 +++++ 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index a90e38de3611..9924cc721a2f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -33,7 +33,8 @@ build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) + expert_weight_is_col_major, get_col_major_tma_aligned_tensor, + requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) @@ -53,12 +54,6 @@ logger = init_logger(__name__) -def _is_col_major(x: torch.Tensor) -> bool: - assert x.dim() == 3 - b, m, n = x.shape - return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m - - class GPTQMarlinState(Enum): REPACK = enum.auto() READY = enum.auto() @@ -802,10 +797,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Ensure column-major TMA alignment expected by DeepGEMM. - if _is_col_major(layer.w13_weight_scale): + if expert_weight_is_col_major(layer.w13_weight_scale): layer.w13_weight_scale = get_col_major_tma_aligned_tensor( layer.w13_weight_scale) - if _is_col_major(layer.w2_weight_scale): + if expert_weight_is_col_major(layer.w2_weight_scale): layer.w2_weight_scale = get_col_major_tma_aligned_tensor( layer.w2_weight_scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index aec9c79f1ea8..2b24e052053c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -33,10 +33,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( apply_fp8_block_linear, check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, - create_fp8_weight_parameter, get_col_major_tma_aligned_tensor, - maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, - process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace, - validate_fp8_block_shape) + create_fp8_weight_parameter, expert_weight_is_col_major, + get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block, + process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy, + requant_weight_ue8m0_inplace, validate_fp8_block_shape) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin) @@ -64,12 +64,6 @@ logger = init_logger(__name__) -def _is_col_major(x: torch.Tensor) -> bool: - assert x.dim() == 3 - b, m, n = x.shape - return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m - - class Fp8Config(QuantizationConfig): """Config class for FP8.""" @@ -660,10 +654,10 @@ def process_weights_after_loading(self, layer: Module) -> None: # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): - if _is_col_major(layer.w13_weight_scale_inv): + if expert_weight_is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) - if _is_col_major(layer.w2_weight_scale_inv): + if expert_weight_is_col_major(layer.w2_weight_scale_inv): layer.w2_weight_scale_inv = \ get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) @@ -811,10 +805,10 @@ def process_weights_after_loading(self, layer: Module) -> None: ) # Ensure column-major TMA alignment expected by DeepGEMM. - if _is_col_major(layer.w13_weight_scale_inv): + if expert_weight_is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( layer.w13_weight_scale_inv) - if _is_col_major(layer.w2_weight_scale_inv): + if expert_weight_is_col_major(layer.w2_weight_scale_inv): layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( layer.w2_weight_scale_inv) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index fc12483de0c0..d1d87b7ba12e 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -1014,3 +1014,9 @@ def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor, cutlass_block_fp8_supported=cutlass_block_fp8_supported, use_aiter_and_is_supported=use_aiter_and_is_supported, ) + + +def expert_weight_is_col_major(x: torch.Tensor) -> bool: + assert x.dim() == 3 + b, m, n = x.shape + return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m From a7d60665a404cc9b9c242a3d8db781d26da64a5a Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 22 Sep 2025 19:46:58 -0600 Subject: [PATCH 6/6] Remove debug log for weight quantization --- .../quantization/compressed_tensors/compressed_tensors_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 9924cc721a2f..10f9085be4d1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -511,7 +511,6 @@ def __init__( and self.input_quant.strategy == QuantizationStrategy.TOKEN) if not (per_tensor or per_channel): assert self.weight_quant.strategy == QuantizationStrategy.BLOCK - logger.debug("WQ = %s", str(self.weight_quant)) self.weight_block_size = self.weight_quant.block_structure assert self.weight_quant.dynamic is not None else: