diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 2dd625054339..86c50f39f007 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -67,6 +67,7 @@ def maybe_roundup_layer_hidden_size( def maybe_make_prepare_finalize( moe: FusedMoEConfig, quant_config: FusedMoEQuantConfig | None, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> FusedMoEPrepareAndFinalize | None: if not moe.moe_parallel_config.use_all2all_kernels: return None @@ -134,6 +135,13 @@ def maybe_make_prepare_finalize( elif moe.use_deepep_ll_kernels: assert quant_config is not None + global_to_physical = physical_to_global = local_expert_global_ids = None + if routing_tables is not None: + ( + global_to_physical, + physical_to_global, + local_expert_global_ids, + ) = routing_tables all_to_all_args = dict( max_num_tokens_per_dp_rank=moe.max_num_tokens, token_hidden_size=moe.hidden_dim, @@ -155,6 +163,9 @@ def maybe_make_prepare_finalize( max_tokens_per_rank=moe.max_num_tokens, num_dispatchers=all2all_manager.world_size, use_fp8_dispatch=use_fp8_dispatch, + global_to_physical=global_to_physical, + physical_to_global=physical_to_global, + local_expert_global_ids=local_expert_global_ids, ) return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 06c9df317f7c..e0db248958b4 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -85,6 +85,9 @@ def __init__( max_tokens_per_rank: int, num_dispatchers: int, use_fp8_dispatch: bool = False, + global_to_physical: torch.Tensor | None = None, + physical_to_global: torch.Tensor | None = None, + local_expert_global_ids: torch.Tensor | None = None, ): super().__init__() @@ -97,6 +100,17 @@ def __init__( self.handles: list[tuple | None] = [None, None] self.num_dispatchers_ = num_dispatchers + topk_indices_dtype = self.topk_indices_dtype() + + def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None: + if tensor is None or topk_indices_dtype is None: + return tensor + return tensor.to(dtype=topk_indices_dtype) + + self.global_to_physical = _maybe_cast(global_to_physical) + self.physical_to_global = _maybe_cast(physical_to_global) + self.local_expert_global_ids = _maybe_cast(local_expert_global_ids) + # We don't have enough information to determine if we should dispatch # activation scales in a packed ue8m0 format during object construction # time. This setting is handled by post_init_setup. @@ -136,6 +150,16 @@ def max_num_tokens_per_rank(self) -> int | None: def topk_indices_dtype(self) -> torch.dtype | None: return torch.int64 + def _map_global_to_physical_ids(self, topk_ids: torch.Tensor) -> torch.Tensor: + if self.global_to_physical is None: + return topk_ids + return self.global_to_physical[topk_ids] + + def _map_local_to_global_ids(self, expert_topk_ids: torch.Tensor) -> torch.Tensor: + if self.local_expert_global_ids is None: + return expert_topk_ids + return self.local_expert_global_ids[expert_topk_ids] + def _do_quant( self, x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], @@ -226,9 +250,10 @@ def prepare_async( a1 = a1 * topk_weights.to(a1.dtype) # Dispatch + dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids) expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch( a1, - topk_ids, + dispatch_topk_ids, self.max_tokens_per_rank, num_experts, use_fp8=self.use_fp8_dispatch, @@ -313,11 +338,12 @@ def _finalize( # weights have already been applied. combine_topk_weights = torch.ones_like(topk_weights) + combine_topk_ids = self._map_global_to_physical_ids(topk_ids) # TODO (varun) : Enable zero copy mode dbo_maybe_run_recv_hook() _, _, recv_hook = self.buffer.low_latency_combine( fused_expert_output, - topk_ids, + combine_topk_ids, combine_topk_weights, handle, async_finish=False, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 87f8c8d75a9b..073e90a4e680 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -50,10 +50,15 @@ def uses_weight_scale_2_pattern(self) -> bool: """ return False - def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> FusedMoEPrepareAndFinalize | None: from .all2all_utils import maybe_make_prepare_finalize - return maybe_make_prepare_finalize(self.moe, self.moe_quant_config) + return maybe_make_prepare_finalize( + self.moe, self.moe_quant_config, routing_tables + ) def select_gemm_impl( self, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 023132acfed3..c41995e4a913 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -5,7 +5,7 @@ from contextlib import nullcontext from enum import Enum from functools import partial -from typing import Literal, get_args, overload +from typing import Literal, cast, get_args, overload import torch import torch.nn.functional as F @@ -192,6 +192,42 @@ def determine_expert_map( return (local_num_experts, expert_map, expert_mask) +def determine_expert_placement_strategy( + expert_placement_strategy: ExpertPlacementStrategy, + moe_parallel_config: FusedMoEParallelConfig, + num_expert_group: int | None, + num_redundant_experts: int, + enable_eplb: bool, +) -> ExpertPlacementStrategy: + if expert_placement_strategy == "round_robin": + round_robin_supported = ( + (num_expert_group is not None and num_expert_group > 1) + and num_redundant_experts == 0 + and not enable_eplb + ) + + if not round_robin_supported: + logger.warning( + "Round-robin expert placement is only supported for " + "models with multiple expert groups and no redundant " + "experts. Falling back to linear expert placement." + ) + return "linear" + if ( + moe_parallel_config.use_all2all_kernels + and not moe_parallel_config.use_deepep_ll_kernels + ): + logger.warning( + "Round-robin expert placement currently only supports " + "the DeepEP low-latency backend, but '%s' was configured. " + "Falling back to linear expert placement.", + moe_parallel_config.all2all_backend, + ) + return "linear" + + return expert_placement_strategy + + def get_compressed_expert_map(expert_map: torch.Tensor) -> str: """ Compresses the expert map by removing any -1 entries. @@ -400,6 +436,9 @@ def __init__( self.expert_load_view: torch.Tensor | None = None self.logical_to_physical_map: torch.Tensor | None = None self.logical_replica_count: torch.Tensor | None = None + self.expert_placement_strategy: ExpertPlacementStrategy = ( + vllm_config.parallel_config.expert_placement_strategy + ) # ROCm aiter shared experts fusion self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled() @@ -433,38 +472,27 @@ def __init__( "Redundant experts are only supported with EPLB." ) - expert_placement_strategy = ( - vllm_config.parallel_config.expert_placement_strategy + self.expert_placement_strategy = determine_expert_placement_strategy( + expert_placement_strategy=self.expert_placement_strategy, + moe_parallel_config=self.moe_parallel_config, + num_expert_group=num_expert_group, + num_redundant_experts=num_redundant_experts, + enable_eplb=self.enable_eplb, ) - if expert_placement_strategy == "round_robin": - # TODO(Bruce): will support round robin expert placement with - # EPLB enabled in the future. - round_robin_supported = ( - (num_expert_group is not None and num_expert_group > 1) - and num_redundant_experts == 0 - and not self.enable_eplb - ) - - if not round_robin_supported: - logger.warning( - "Round-robin expert placement is only supported for " - "models with multiple expert groups and no redundant " - "experts. Falling back to linear expert placement." - ) - expert_placement_strategy = "linear" self.expert_map: torch.Tensor | None local_num_experts, expert_map, expert_mask = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, - expert_placement_strategy=expert_placement_strategy, + expert_placement_strategy=self.expert_placement_strategy, num_fused_shared_experts=self.num_fused_shared_experts, return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) self.register_buffer("expert_mask", expert_mask) + self._maybe_init_expert_routing_tables() logger.info_once( "[EP Rank %s/%s] Expert parallelism is enabled. Expert " "placement strategy: %s. Local/global" @@ -472,7 +500,7 @@ def __init__( " %s.", self.ep_rank, self.ep_size, - expert_placement_strategy, + self.expert_placement_strategy, self.local_num_experts, self.global_num_experts, get_compressed_expert_map(self.expert_map), @@ -621,7 +649,12 @@ def _get_quant_method() -> FusedMoEMethodBase: # should be safe to swap out the quant_method. def maybe_init_modular_kernel(self) -> None: self.ensure_moe_quant_config_init() - prepare_finalize = self.quant_method.maybe_make_prepare_finalize() + # routing_tables only needed for round-robin expert placement with + # DeepEP all2all backend. + routing_tables = self._maybe_init_expert_routing_tables() + prepare_finalize = self.quant_method.maybe_make_prepare_finalize( + routing_tables=routing_tables + ) if prepare_finalize is not None: logger.debug( "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) @@ -703,6 +736,84 @@ def is_internal_router(self) -> bool: # By default, router/gate is called before FusedMoE forward pass return False + def _maybe_init_expert_routing_tables( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: + # Currently routing_tables only needed for round-robin expert placement + # with DeepEP-ll all2all backend. + if ( + self.expert_placement_strategy != "round_robin" + or not self.use_deepep_ll_kernels + ): + return None + + if hasattr(self, "expert_global_to_physical"): + return cast( + tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ( + self.expert_global_to_physical, + self.expert_physical_to_global, + self.expert_local_to_global, + ), + ) + + if self.expert_map is None: + return None + + routing_tables = self.ensure_round_robin_expert_routing_tables( + global_num_experts=self.global_num_experts, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + local_num_experts=self.local_num_experts, + device=self.expert_map.device, + ) + + global_to_physical, physical_to_global, local_global = routing_tables + self.register_buffer("expert_global_to_physical", global_to_physical) + self.register_buffer("expert_physical_to_global", physical_to_global) + self.register_buffer("expert_local_to_global", local_global) + + return routing_tables + + @staticmethod + def ensure_round_robin_expert_routing_tables( + global_num_experts: int, + ep_size: int, + ep_rank: int, + local_num_experts: int, + device: torch.device | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device_kwargs = {"device": device} if device is not None else {} + global_indices = torch.arange( + global_num_experts, dtype=torch.long, **device_kwargs + ) + owner = torch.remainder(global_indices, ep_size) + local_index = torch.div(global_indices, ep_size, rounding_mode="floor") + base = global_num_experts // ep_size + remainder = global_num_experts % ep_size + physical_offset = owner * base + if remainder > 0: + remainder_tensor = torch.tensor( + remainder, dtype=torch.long, **device_kwargs + ) + physical_offset = physical_offset + torch.minimum(owner, remainder_tensor) + + global_to_physical = physical_offset + local_index + physical_to_global = torch.empty_like(global_to_physical) + physical_to_global[global_to_physical] = global_indices + + local_global = torch.arange( + ep_rank, + global_num_experts, + ep_size, + dtype=torch.long, + **device_kwargs, + ) + if local_global.numel() != local_num_experts: + local_global = local_global[:local_num_experts] + + return (global_to_physical, physical_to_global, local_global) + def update_expert_map(self): # ep_size and ep_rank should already be updated assert self.expert_map is not None @@ -711,12 +822,14 @@ def update_expert_map(self): ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, + expert_placement_strategy=self.expert_placement_strategy, num_fused_shared_experts=self.num_fused_shared_experts, return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) self.register_buffer("expert_mask", expert_mask) + self._maybe_init_expert_routing_tables() if self.aiter_fmoe_shared_expert_enabled: self._init_aiter_shared_experts_topK_buffer( vllm_config=get_current_vllm_config(), diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 2e0376553b91..63b0e6f573d6 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -108,11 +108,14 @@ def supports_eplb(self) -> bool: def allow_inplace(self) -> bool: return True - def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> FusedMoEPrepareAndFinalize | None: if self.rocm_aiter_moe_enabled: return None else: - return super().maybe_make_prepare_finalize() + return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( self, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 06ee96d55419..22b3c477f420 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -380,11 +380,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: (layer.w2_input_global_scale), requires_grad=False ) - def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalize | None: if self.use_marlin: return None elif not self.allow_flashinfer: - return super().maybe_make_prepare_finalize() + return super().maybe_make_prepare_finalize(routing_tables) prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe) logger.debug_once("%s", prepare_finalize.__class__.__name__) @@ -890,11 +893,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight_scale ) - def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalize | None: if self.use_marlin or self.rocm_aiter_moe_enabled: return None else: - return super().maybe_make_prepare_finalize() + return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( self, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0479bec33840..92fbdd709348 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1018,7 +1018,10 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w13_input_scale del layer.w2_input_scale - def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalize | None: if ( self.rocm_aiter_moe_enabled or self.use_marlin @@ -1039,7 +1042,7 @@ def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize else: - return super().maybe_make_prepare_finalize() + return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( self, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 476521813f46..38ab7cd4f115 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -373,6 +373,7 @@ def __init__( def maybe_make_prepare_finalize( self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: # TRT LLM not supported with all2all yet. if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: @@ -384,7 +385,7 @@ def maybe_make_prepare_finalize( logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize else: - return super().maybe_make_prepare_finalize() + return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( self, @@ -1179,7 +1180,10 @@ def __init__( " for ModelOptNvFp4FusedMoE." ) - def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalize | None: if self.use_marlin or ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM @@ -1196,7 +1200,7 @@ def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize else: - return super().maybe_make_prepare_finalize() + return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( self,