Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
Expand All @@ -31,6 +32,9 @@
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
expert_weight_is_col_major, get_col_major_tma_aligned_tensor,
requant_weight_ue8m0_inplace)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales)
Expand All @@ -45,6 +49,7 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used

logger = init_logger(__name__)

Expand Down Expand Up @@ -505,10 +510,12 @@ def __init__(
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not (per_tensor or per_channel):
raise ValueError(
"For FP8 Fused MoE layers, we require per tensor "
"or channelwise, dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
self.weight_block_size = self.weight_quant.block_structure
assert self.weight_quant.dynamic is not None
else:
self.weight_block_size = None
self.block_quant = self.weight_block_size is not None

self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel:
Expand All @@ -519,7 +526,8 @@ def __init__(
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
and not self.block_quant)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
Expand All @@ -531,8 +539,9 @@ def __init__(
# cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant)
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
self.use_cutlass = not self.block_quant and (
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
or self.is_fp8_w8a8_sm100)
self.disable_expert_map = False

def create_weights(self, layer: torch.nn.Module, num_experts: int,
Expand All @@ -547,6 +556,31 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,

params_dtype = torch.float8_e4m3fn

if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = (
self.weight_block_size[0],
self.weight_block_size[1],
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}.")
if (tp_size > 1
and intermediate_size_per_partition % block_k != 0):
# Required by row parallel
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")

# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
Expand Down Expand Up @@ -602,6 +636,27 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 *
((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, (hidden_size + block_n - 1) // block_n,
(intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

# INPUT_SCALES
if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(torch.ones(
Expand Down Expand Up @@ -706,6 +761,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
del layer.w2_input_scale

if self.use_cutlass:
assert self.weight_quant.strategy != QuantizationStrategy.BLOCK
device = layer.w13_weight.device
# ab_strides1 and c_strides2 are the same
self.ab_strides1_c_strides2 = torch.full(
Expand All @@ -724,6 +780,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device=device,
dtype=torch.int64)

if is_deep_gemm_e8m0_used() and self.block_quant:
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale.data,
block_sz,
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale.data,
block_sz,
)

# Ensure column-major TMA alignment expected by DeepGEMM.
if expert_weight_is_col_major(layer.w13_weight_scale):
layer.w13_weight_scale = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale)
if expert_weight_is_col_major(layer.w2_weight_scale):
layer.w2_weight_scale = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale)

def maybe_make_prepare_finalize(
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.use_marlin or self.rocm_aiter_moe_enabled:
Expand Down Expand Up @@ -777,9 +856,10 @@ def select_gemm_impl(
return experts

# triton path
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)

assert not self.rocm_aiter_moe_enabled and not self.use_marlin

Expand All @@ -790,14 +870,16 @@ def select_gemm_impl(
assert max_num_tokens_per_rank is not None

logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
return BatchedTritonExperts(
return BatchedTritonOrDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonExperts(%s)", self.__class__.__name__)
return TritonExperts(self.moe_quant_config)
logger.debug("TritonOrDeepGemmExperts(%s)",
self.__class__.__name__)
return TritonOrDeepGemmExperts(self.moe_quant_config,
allow_deep_gemm=True)

def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
Expand All @@ -816,6 +898,7 @@ def get_fused_moe_quant_config(
a2_scale=layer.w2_input_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_channel_quant,
block_shape=layer.weight_block_size,
)

def apply(
Expand Down
22 changes: 8 additions & 14 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, get_col_major_tma_aligned_tensor,
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
validate_fp8_block_shape)
create_fp8_weight_parameter, expert_weight_is_col_major,
get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy,
requant_weight_ue8m0_inplace, validate_fp8_block_shape)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin)
Expand Down Expand Up @@ -64,12 +64,6 @@
logger = init_logger(__name__)


def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m


class Fp8Config(QuantizationConfig):
"""Config class for FP8."""

Expand Down Expand Up @@ -660,10 +654,10 @@ def process_weights_after_loading(self, layer: Module) -> None:
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
if _is_col_major(layer.w13_weight_scale_inv):
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \
get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
if _is_col_major(layer.w2_weight_scale_inv):
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = \
get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)

Expand Down Expand Up @@ -811,10 +805,10 @@ def process_weights_after_loading(self, layer: Module) -> None:
)

# Ensure column-major TMA alignment expected by DeepGEMM.
if _is_col_major(layer.w13_weight_scale_inv):
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv)
if _is_col_major(layer.w2_weight_scale_inv):
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv)

Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,3 +1014,9 @@ def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)


def expert_weight_is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
4 changes: 2 additions & 2 deletions vllm/model_executor/warmup/deep_gemm_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def _extract_data_from_fused_moe_module(
"""
assert isinstance(m, FusedMoE)
w13 = m.w13_weight
w13_s = m.w13_weight_scale_inv
w13_s = getattr(m, "w13_weight_scale_inv", m.w13_weight_scale)
w2 = m.w2_weight
w2_s = m.w2_weight_scale_inv
w2_s = getattr(m, "w2_weight_scale_inv", m.w2_weight_scale)
num_topk = m.top_k

assert isinstance(w13, torch.Tensor)
Expand Down