diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 572307052b48..659a2d4ee5b3 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -6,22 +6,7 @@ from torch.nn import functional as F from vllm import _custom_ops as ops - - -def silu_and_mul(x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - return F.silu(x[..., :d]) * x[..., d:] - - -def swigluoai_and_mul( - x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0 -) -> torch.Tensor: - d = x.shape[-1] // 2 - gate, up = x[..., :d], x[..., d:] - gate = gate.clamp(max=limit) - up = up.clamp(min=-limit, max=limit) - glu = gate * torch.sigmoid(alpha * gate) - return (up + 1) * glu +from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul def grouped_topk( @@ -227,6 +212,11 @@ def __init__(self, layer: torch.nn.Module) -> None: layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) + self.act_to_impl = { + "silu": SiluAndMul(), + "swigluoai": SwigluOAIAndMul(), + } + def __call__( self, layer: torch.nn.Module, @@ -246,7 +236,7 @@ def __call__( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - assert activation in {"silu", "swigluoai"}, f"{activation} is not supported." + assert activation in self.act_to_impl, f"{activation} is not supported." assert not apply_router_weight_on_input topk_weights, topk_ids = select_experts( hidden_states=x, @@ -283,10 +273,7 @@ def __call__( tokens_for_this_expert = sorted_tokens[start_idx:end_idx] gate_up = layer.gate_up_linear[i](tokens_for_this_expert) - if activation == "swigluoai": - gate_up = swigluoai_and_mul(gate_up) - else: - gate_up = silu_and_mul(gate_up) + gate_up = self.act_to_impl[activation].forward_native(gate_up) expert_out = layer.down_linear[i](gate_up) outputs.append(expert_out) start_idx = end_idx