Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/model_executor/layers/fused_moe/all2all_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def maybe_roundup_layer_hidden_size(
def maybe_make_prepare_finalize(
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
if not moe.moe_parallel_config.use_all2all_kernels:
return None
Expand Down Expand Up @@ -134,6 +135,13 @@ def maybe_make_prepare_finalize(

elif moe.use_deepep_ll_kernels:
assert quant_config is not None
global_to_physical = physical_to_global = local_expert_global_ids = None
if routing_tables is not None:
(
global_to_physical,
physical_to_global,
local_expert_global_ids,
) = routing_tables
all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
Expand All @@ -155,6 +163,9 @@ def maybe_make_prepare_finalize(
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
global_to_physical=global_to_physical,
physical_to_global=physical_to_global,
local_expert_global_ids=local_expert_global_ids,
)

return prepare_finalize
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __init__(
max_tokens_per_rank: int,
num_dispatchers: int,
use_fp8_dispatch: bool = False,
global_to_physical: torch.Tensor | None = None,
physical_to_global: torch.Tensor | None = None,
local_expert_global_ids: torch.Tensor | None = None,
):
super().__init__()

Expand All @@ -97,6 +100,17 @@ def __init__(
self.handles: list[tuple | None] = [None, None]
self.num_dispatchers_ = num_dispatchers

topk_indices_dtype = self.topk_indices_dtype()

def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None:
if tensor is None or topk_indices_dtype is None:
return tensor
return tensor.to(dtype=topk_indices_dtype)

self.global_to_physical = _maybe_cast(global_to_physical)
self.physical_to_global = _maybe_cast(physical_to_global)
self.local_expert_global_ids = _maybe_cast(local_expert_global_ids)

# We don't have enough information to determine if we should dispatch
# activation scales in a packed ue8m0 format during object construction
# time. This setting is handled by post_init_setup.
Expand Down Expand Up @@ -136,6 +150,16 @@ def max_num_tokens_per_rank(self) -> int | None:
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int64

def _map_global_to_physical_ids(self, topk_ids: torch.Tensor) -> torch.Tensor:
if self.global_to_physical is None:
return topk_ids
return self.global_to_physical[topk_ids]

def _map_local_to_global_ids(self, expert_topk_ids: torch.Tensor) -> torch.Tensor:
if self.local_expert_global_ids is None:
return expert_topk_ids
return self.local_expert_global_ids[expert_topk_ids]

def _do_quant(
self,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
Expand Down Expand Up @@ -226,9 +250,10 @@ def prepare_async(
a1 = a1 * topk_weights.to(a1.dtype)

# Dispatch
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
a1,
topk_ids,
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
Expand Down Expand Up @@ -313,11 +338,12 @@ def _finalize(
# weights have already been applied.
combine_topk_weights = torch.ones_like(topk_weights)

combine_topk_ids = self._map_global_to_physical_ids(topk_ids)
# TODO (varun) : Enable zero copy mode
dbo_maybe_run_recv_hook()
_, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output,
topk_ids,
combine_topk_ids,
combine_topk_weights,
handle,
async_finish=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,15 @@ def uses_weight_scale_2_pattern(self) -> bool:
"""
return False

def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
from .all2all_utils import maybe_make_prepare_finalize

return maybe_make_prepare_finalize(self.moe, self.moe_quant_config)
return maybe_make_prepare_finalize(
self.moe, self.moe_quant_config, routing_tables
)

def select_gemm_impl(
self,
Expand Down
157 changes: 135 additions & 22 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import nullcontext
from enum import Enum
from functools import partial
from typing import Literal, get_args, overload
from typing import Literal, cast, get_args, overload

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -192,6 +192,42 @@ def determine_expert_map(
return (local_num_experts, expert_map, expert_mask)


def determine_expert_placement_strategy(
expert_placement_strategy: ExpertPlacementStrategy,
moe_parallel_config: FusedMoEParallelConfig,
num_expert_group: int | None,
num_redundant_experts: int,
enable_eplb: bool,
) -> ExpertPlacementStrategy:
if expert_placement_strategy == "round_robin":
round_robin_supported = (
(num_expert_group is not None and num_expert_group > 1)
and num_redundant_experts == 0
and not enable_eplb
)

if not round_robin_supported:
logger.warning(
"Round-robin expert placement is only supported for "
"models with multiple expert groups and no redundant "
"experts. Falling back to linear expert placement."
)
return "linear"
if (
moe_parallel_config.use_all2all_kernels
and not moe_parallel_config.use_deepep_ll_kernels
):
logger.warning(
"Round-robin expert placement currently only supports "
"the DeepEP low-latency backend, but '%s' was configured. "
"Falling back to linear expert placement.",
moe_parallel_config.all2all_backend,
)
return "linear"

return expert_placement_strategy


def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
"""
Compresses the expert map by removing any -1 entries.
Expand Down Expand Up @@ -400,6 +436,9 @@ def __init__(
self.expert_load_view: torch.Tensor | None = None
self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None
self.expert_placement_strategy: ExpertPlacementStrategy = (
vllm_config.parallel_config.expert_placement_strategy
)

# ROCm aiter shared experts fusion
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
Expand Down Expand Up @@ -433,46 +472,35 @@ def __init__(
"Redundant experts are only supported with EPLB."
)

expert_placement_strategy = (
vllm_config.parallel_config.expert_placement_strategy
self.expert_placement_strategy = determine_expert_placement_strategy(
expert_placement_strategy=self.expert_placement_strategy,
moe_parallel_config=self.moe_parallel_config,
num_expert_group=num_expert_group,
num_redundant_experts=num_redundant_experts,
enable_eplb=self.enable_eplb,
)
if expert_placement_strategy == "round_robin":
# TODO(Bruce): will support round robin expert placement with
# EPLB enabled in the future.
round_robin_supported = (
(num_expert_group is not None and num_expert_group > 1)
and num_redundant_experts == 0
and not self.enable_eplb
)

if not round_robin_supported:
logger.warning(
"Round-robin expert placement is only supported for "
"models with multiple expert groups and no redundant "
"experts. Falling back to linear expert placement."
)
expert_placement_strategy = "linear"

self.expert_map: torch.Tensor | None
local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy,
expert_placement_strategy=self.expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables()
logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Expert "
"placement strategy: %s. Local/global"
" number of experts: %s/%s. Experts local to global index map:"
" %s.",
self.ep_rank,
self.ep_size,
expert_placement_strategy,
self.expert_placement_strategy,
self.local_num_experts,
self.global_num_experts,
get_compressed_expert_map(self.expert_map),
Expand Down Expand Up @@ -621,7 +649,12 @@ def _get_quant_method() -> FusedMoEMethodBase:
# should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None:
self.ensure_moe_quant_config_init()
prepare_finalize = self.quant_method.maybe_make_prepare_finalize()
# routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
routing_tables = self._maybe_init_expert_routing_tables()
prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
routing_tables=routing_tables
)
if prepare_finalize is not None:
logger.debug(
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
Expand Down Expand Up @@ -703,6 +736,84 @@ def is_internal_router(self) -> bool:
# By default, router/gate is called before FusedMoE forward pass
return False

def _maybe_init_expert_routing_tables(
self,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
# Currently routing_tables only needed for round-robin expert placement
# with DeepEP-ll all2all backend.
if (
self.expert_placement_strategy != "round_robin"
or not self.use_deepep_ll_kernels
):
return None

if hasattr(self, "expert_global_to_physical"):
return cast(
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
(
self.expert_global_to_physical,
self.expert_physical_to_global,
self.expert_local_to_global,
),
)

if self.expert_map is None:
return None

routing_tables = self.ensure_round_robin_expert_routing_tables(
global_num_experts=self.global_num_experts,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
local_num_experts=self.local_num_experts,
device=self.expert_map.device,
)

global_to_physical, physical_to_global, local_global = routing_tables
self.register_buffer("expert_global_to_physical", global_to_physical)
self.register_buffer("expert_physical_to_global", physical_to_global)
self.register_buffer("expert_local_to_global", local_global)

return routing_tables

@staticmethod
def ensure_round_robin_expert_routing_tables(
global_num_experts: int,
ep_size: int,
ep_rank: int,
local_num_experts: int,
device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
device_kwargs = {"device": device} if device is not None else {}
global_indices = torch.arange(
global_num_experts, dtype=torch.long, **device_kwargs
)
owner = torch.remainder(global_indices, ep_size)
local_index = torch.div(global_indices, ep_size, rounding_mode="floor")
base = global_num_experts // ep_size
remainder = global_num_experts % ep_size
physical_offset = owner * base
if remainder > 0:
remainder_tensor = torch.tensor(
remainder, dtype=torch.long, **device_kwargs
)
physical_offset = physical_offset + torch.minimum(owner, remainder_tensor)

global_to_physical = physical_offset + local_index
physical_to_global = torch.empty_like(global_to_physical)
physical_to_global[global_to_physical] = global_indices

local_global = torch.arange(
ep_rank,
global_num_experts,
ep_size,
dtype=torch.long,
**device_kwargs,
)
if local_global.numel() != local_num_experts:
local_global = local_global[:local_num_experts]

return (global_to_physical, physical_to_global, local_global)

def update_expert_map(self):
# ep_size and ep_rank should already be updated
assert self.expert_map is not None
Expand All @@ -711,12 +822,14 @@ def update_expert_map(self):
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
expert_placement_strategy=self.expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables()
if self.aiter_fmoe_shared_expert_enabled:
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,14 @@ def supports_eplb(self) -> bool:
def allow_inplace(self) -> bool:
return True

def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
if self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize()
return super().maybe_make_prepare_finalize(routing_tables)

def select_gemm_impl(
self,
Expand Down
Loading