diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6c2a5bda7cba..0e334fdf2404 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1017,6 +1017,79 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def eplb_map_to_physical_and_record( + topk_ids: torch.Tensor, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + indices_type: Optional[torch.dtype] = None) -> torch.Tensor: + ''' + Map the logical expert ids to physical expert ids + and record the expert load metrics. + + This will select a pseudo-random replica for each logical expert. + Only used for EPLB. + + Args: + topk_ids: The logical expert ids. + expert_load_view: The expert load view. + logical_to_physical_map: The logical to physical map. + logical_replica_count: The logical replica count. + indices_type: The indices type. + + Returns: + The physical expert ids. + ''' + + # 1. Convert the logical expert ids to physical expert ids + # Directly select a random replica for each logical expert + + # In case `indices_type` is not `torch.long` or `torch.int`, + # e.g. `torch.uint32` as required by dispatch/combine kernels + topk_ids_long = topk_ids.long() + # Use (token position) modulo (replica count) + # to deterministically choose a replica + replica_count = logical_replica_count[topk_ids_long] + # Flatten-position based index, reshaped back to `topk_ids` shape + pos_indices = torch.arange(topk_ids.numel(), + device=topk_ids.device, + dtype=torch.long).reshape_as(topk_ids) + # Compute pseudo-random indices by modulo + replica_indices = (pos_indices % replica_count).unsqueeze(-1) + physical_ids = logical_to_physical_map[topk_ids_long].gather( + -1, replica_indices).squeeze(-1) + + topk_ids = physical_ids + + # 2. Record expert load metrics. + + # TODO(bowen): When using `FusedMoEModularKernel`, this + # can be done in a more unified way, since + # `FusedMoEPrepareAndFinalize` will return the expert + # token count, in some cases directly from the kernel. + # However, now there are many code paths not using + # the modular kernel, e.g. calling `fused_experts`, + # so we decide to keep the logic here. + # + # If later refactor moved all the MoE kernel calls + # to the modular kernel, we can move this logic there + # to achieve better efficiency. + + # `expert_load_view`: (num_physical_experts,) + + # `torch.bincount` is not compilable, so use `scatter_add_` instead. + topk_ids_flatten = topk_ids.flatten() + expert_load_view.scatter_add_( + dim=0, + index=topk_ids_flatten.long(), + src=torch.ones_like(topk_ids_flatten).to(expert_load_view)) + + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) + return topk_ids + + def fused_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index da513d75da4d..17ad75584a3f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -43,7 +43,8 @@ if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts - from .fused_moe import TritonExperts, fused_experts + from .fused_moe import (TritonExperts, eplb_map_to_physical_and_record, + fused_experts) if has_pplx(): from .pplx_prepare_finalize import (PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) @@ -55,6 +56,16 @@ fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore FusedMoEPrepareAndFinalize = None # type: ignore + + def eplb_map_to_physical_and_record( + topk_ids: torch.Tensor, expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + indices_type: Optional[torch.dtype]) -> torch.Tensor: + # CPU fallback: no EPLB so just return as is + return topk_ids + + if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_grouped_topk as grouped_topk) @@ -1616,55 +1627,13 @@ def select_experts( assert logical_to_physical_map is not None assert logical_replica_count is not None - # 1. Convert the logical expert ids to physical expert ids - # Directly select a random replica for each logical expert - - # TODO: maybe optimize this by using specified kernels, - # or compute pseudo-random indices by modulo - - # In case `indices_type` is not `torch.long` or `torch.int`, - # e.g. `torch.uint32` as required by dispatch/combine kernels - topk_ids_long = topk_ids.long() - replica_indices = ( - torch.rand_like(topk_ids, dtype=torch.float) * - logical_replica_count[topk_ids_long]).long().unsqueeze(-1) - physical_ids = logical_to_physical_map[topk_ids_long].gather( - -1, replica_indices).squeeze(-1) - - topk_ids = physical_ids - - # 2. Record expert load metrics. - - # TODO(bowen): When using `FusedMoEModularKernel`, this - # can be done in a more unified way, since - # `FusedMoEPrepareAndFinalize` will return the expert - # token count, in some cases directly from the kernel. - # However, now there are many code paths not using - # the modular kernel, e.g. calling `fused_experts`, - # so we decide to keep the logic here. - # - # If later refactor moved all the MoE kernel calls - # to the modular kernel, we can move this logic there - # to achieve better efficiency. - - # `expert_load_view`: (num_physical_experts,) - - topk_ids_flatten = topk_ids.flatten() - - # Performance optimization: - # `masked_fill` is significantly faster than `masked_select` - invalid_mask = topk_ids_flatten < 0 - # Replace invalid expert ids with 0 (just a dummy position) - # to avoid out-of-bounds errors in scatter_add_ - index = topk_ids_flatten.masked_fill_(invalid_mask, 0) - # `src` is the valid mask, which is 1 for valid and 0 for invalid - src = ~invalid_mask - - expert_load_view.scatter_add_(dim=0, - index=index.long(), - src=src.to(expert_load_view)) - - topk_ids = topk_ids.to(dtype=indices_type) + topk_ids = eplb_map_to_physical_and_record( + topk_ids=topk_ids, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + indices_type=indices_type, + ) assert topk_ids.dtype == indices_type or indices_type is None