Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,79 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)


@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
indices_type: Optional[torch.dtype] = None) -> torch.Tensor:
'''
Map the logical expert ids to physical expert ids
and record the expert load metrics.

This will select a pseudo-random replica for each logical expert.
Only used for EPLB.

Args:
topk_ids: The logical expert ids.
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
indices_type: The indices type.

Returns:
The physical expert ids.
'''

# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert

# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long = topk_ids.long()
# Use (token position) modulo (replica count)
# to deterministically choose a replica
replica_count = logical_replica_count[topk_ids_long]
# Flatten-position based index, reshaped back to `topk_ids` shape
pos_indices = torch.arange(topk_ids.numel(),
device=topk_ids.device,
dtype=torch.long).reshape_as(topk_ids)
# Compute pseudo-random indices by modulo
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
physical_ids = logical_to_physical_map[topk_ids_long].gather(
-1, replica_indices).squeeze(-1)

topk_ids = physical_ids

# 2. Record expert load metrics.

# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.

# `expert_load_view`: (num_physical_experts,)

# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view))

if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
return topk_ids


def fused_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
Expand Down
69 changes: 19 additions & 50 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@

if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
from .fused_moe import TritonExperts, fused_experts
from .fused_moe import (TritonExperts, eplb_map_to_physical_and_record,
fused_experts)
if has_pplx():
from .pplx_prepare_finalize import (PplxPrepareAndFinalize,
pplx_hidden_dim_scale_bytes)
Expand All @@ -55,6 +56,16 @@
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
FusedMoEPrepareAndFinalize = None # type: ignore

def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor, expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
indices_type: Optional[torch.dtype]) -> torch.Tensor:
# CPU fallback: no EPLB so just return as is
return topk_ids


if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk)
Expand Down Expand Up @@ -1616,55 +1627,13 @@ def select_experts(
assert logical_to_physical_map is not None
assert logical_replica_count is not None

# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert

# TODO: maybe optimize this by using specified kernels,
# or compute pseudo-random indices by modulo

# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long = topk_ids.long()
replica_indices = (
torch.rand_like(topk_ids, dtype=torch.float) *
logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
physical_ids = logical_to_physical_map[topk_ids_long].gather(
-1, replica_indices).squeeze(-1)

topk_ids = physical_ids

# 2. Record expert load metrics.

# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.

# `expert_load_view`: (num_physical_experts,)

topk_ids_flatten = topk_ids.flatten()

# Performance optimization:
# `masked_fill` is significantly faster than `masked_select`
invalid_mask = topk_ids_flatten < 0
# Replace invalid expert ids with 0 (just a dummy position)
# to avoid out-of-bounds errors in scatter_add_
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
# `src` is the valid mask, which is 1 for valid and 0 for invalid
src = ~invalid_mask

expert_load_view.scatter_add_(dim=0,
index=index.long(),
src=src.to(expert_load_view))

topk_ids = topk_ids.to(dtype=indices_type)
topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
indices_type=indices_type,
)

assert topk_ids.dtype == indices_type or indices_type is None

Expand Down