diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md index a7b54f015c2d..d7f5d2f311a3 100644 --- a/docs/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -146,6 +146,7 @@ We use "mamba-like" to refer to layers that posses a state that is updated in-pl For implementing new custom mamba-like layers, one should inherit from `MambaBase` and implement the methods `get_state_dtype`, `get_state_shape` to calculate the data types and state shapes at runtime, as well as `mamba_type` and `get_attn_backend`. It is also necessary to implement the "attention meta-data" class which handles the meta-data that is common across all layers. Please see [`LinearAttentionMetadata`](../../../vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](../../../vllm/v1/attention/backends/short_conv_attn.py) for examples of this. +It is also worth noting that we should update `MAMBA_TYPE_TO_BACKEND_MAP` and `MambaAttentionBackendEnum` in [`registry.py`](../../../vllm/attention/backends/registry.py) when adding a new mamba backend. Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it. Please see the calls to `direct_register_custom_op` in [vllm/model_executor/models/minimax_text_01.py](../../../vllm/model_executor/models/minimax_text_01.py) or [vllm/model_executor/layers/mamba/short_conv.py](../../../vllm/model_executor/layers/mamba/short_conv.py) for examples of this. The new custom op should then be added to the list `_attention_ops` in [vllm/config/compilation.py](../../../vllm/config/compilation.py) to ensure that piecewise CUDA graphs works as intended. diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index dd35165d5415..8b4dc4013362 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -7,7 +7,7 @@ AttentionType, ) from vllm.attention.layer import Attention -from vllm.attention.selector import get_attn_backend +from vllm.attention.selector import get_attn_backend, get_mamba_attn_backend __all__ = [ "Attention", @@ -15,4 +15,5 @@ "AttentionMetadata", "AttentionType", "get_attn_backend", + "get_mamba_attn_backend", ] diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index f07a6059be37..51899b023591 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention backend registry""" -import enum from collections.abc import Callable +from enum import Enum, EnumMeta from typing import TYPE_CHECKING, cast from vllm.logger import init_logger @@ -15,7 +15,7 @@ logger = init_logger(__name__) -class _AttentionBackendEnumMeta(enum.EnumMeta): +class _AttentionBackendEnumMeta(EnumMeta): """Metaclass for AttentionBackendEnum to provide better error messages.""" def __getitem__(cls, name: str): @@ -23,15 +23,15 @@ def __getitem__(cls, name: str): try: return super().__getitem__(name) except KeyError: - members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values() - valid_backends = ", ".join(m.name for m in members) + members = cast("dict[str, Enum]", cls.__members__).keys() + valid_backends = ", ".join(members) raise ValueError( f"Unknown attention backend: '{name}'. " f"Valid options are: {valid_backends}" ) from None -class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta): +class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): """Enumeration of all supported attention backends. The enum value is the default class path, but this can be overridden @@ -83,7 +83,7 @@ def get_path(self, include_classname: bool = True) -> str: Raises: ValueError: If Backend.CUSTOM is used without being registered """ - path = _OVERRIDES.get(self, self.value) + path = _ATTN_OVERRIDES.get(self, self.value) if not path: raise ValueError( f"Backend {self.name} must be registered before use. " @@ -111,18 +111,93 @@ def is_overridden(self) -> bool: Returns: True if the backend has a registered override """ - return self in _OVERRIDES + return self in _ATTN_OVERRIDES def clear_override(self) -> None: """Clear any override for this backend, reverting to the default.""" - _OVERRIDES.pop(self, None) + _ATTN_OVERRIDES.pop(self, None) -_OVERRIDES: dict[AttentionBackendEnum, str] = {} +class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): + """Enumeration of all supported mamba attention backends. + + The enum value is the default class path, but this can be overridden + at runtime using register_backend(). + + To get the actual backend class (respecting overrides), use: + backend.get_class() + """ + + MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend" + MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend" + SHORT_CONV = "vllm.v1.attention.backends.short_conv_attn.ShortConvAttentionBackend" + LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend" + GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend" + # Placeholder for third-party/custom backends - must be registered before use + CUSTOM = "" + + def get_path(self, include_classname: bool = True) -> str: + """Get the class path for this backend (respects overrides). + + Returns: + The fully qualified class path string + + Raises: + ValueError: If Backend.CUSTOM is used without being registered + """ + path = _MAMBA_ATTN_OVERRIDES.get(self, self.value) + if not path: + raise ValueError( + f"Backend {self.name} must be registered before use. " + f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')" + ) + if not include_classname: + path = path.rsplit(".", 1)[0] + return path + + def get_class(self) -> "type[AttentionBackend]": + """Get the backend class (respects overrides). + + Returns: + The backend class + + Raises: + ImportError: If the backend class cannot be imported + ValueError: If Backend.CUSTOM is used without being registered + """ + return resolve_obj_by_qualname(self.get_path()) + + def is_overridden(self) -> bool: + """Check if this backend has been overridden. + + Returns: + True if the backend has a registered override + """ + return self in _MAMBA_ATTN_OVERRIDES + + def clear_override(self) -> None: + """Clear any override for this backend, reverting to the default.""" + _MAMBA_ATTN_OVERRIDES.pop(self, None) + + +MAMBA_TYPE_TO_BACKEND_MAP = { + "mamba1": MambaAttentionBackendEnum.MAMBA1.name, + "mamba2": MambaAttentionBackendEnum.MAMBA2.name, + "short_conv": MambaAttentionBackendEnum.SHORT_CONV.name, + "linear_attention": MambaAttentionBackendEnum.LINEAR.name, + "gdn_attention": MambaAttentionBackendEnum.GDN_ATTN.name, + "custom": MambaAttentionBackendEnum.CUSTOM.name, +} + + +_ATTN_OVERRIDES: dict[AttentionBackendEnum, str] = {} +_MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {} def register_backend( - backend: AttentionBackendEnum, class_path: str | None = None + backend: AttentionBackendEnum | MambaAttentionBackendEnum, + is_mamba: bool = False, + class_path: str | None = None, ) -> Callable[[type], type]: """Register or override a backend implementation. @@ -135,12 +210,17 @@ def register_backend( Decorator function if class_path is None, otherwise a no-op Examples: - # Override an existing backend + # Override an existing attention backend @register_backend(AttentionBackendEnum.FLASH_ATTN) class MyCustomFlashAttn: ... - # Register a custom third-party backend + # Override an existing mamba attention backend + @register_backend(MambaAttentionBackendEnum.LINEAR, is_mamba=True) + class MyCustomMambaAttn: + ... + + # Register a custom third-party attention backend @register_backend(AttentionBackendEnum.CUSTOM) class MyCustomBackend: ... @@ -153,11 +233,17 @@ class MyCustomBackend: """ def decorator(cls: type) -> type: - _OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" + if is_mamba: + _MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index] + else: + _ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index] return cls if class_path is not None: - _OVERRIDES[backend] = class_path + if is_mamba: + _MAMBA_ATTN_OVERRIDES[backend] = class_path # type: ignore[index] + else: + _ATTN_OVERRIDES[backend] = class_path # type: ignore[index] return lambda x: x return decorator diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 1a092db9ce37..e9af08b2316d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -12,7 +12,11 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.backends.registry import ( + MAMBA_TYPE_TO_BACKEND_MAP, + AttentionBackendEnum, + MambaAttentionBackendEnum, +) from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.utils import STR_BACKEND_ENV_VAR @@ -197,6 +201,33 @@ def _cached_get_attn_backend( return backend +def get_mamba_attn_backend( + mamba_type: str, +) -> type[AttentionBackend]: + """Select which mamba attention backend to use and lazily import it.""" + return _cached_get_mamba_attn_backend(mamba_type) + + +@cache +def _cached_get_mamba_attn_backend( + mamba_type: str, +) -> type[AttentionBackend]: + assert mamba_type and isinstance(mamba_type, str) + + selected_backend = None + try: + backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type] + selected_backend = MambaAttentionBackendEnum[backend_name] + except KeyError as e: + raise ValueError( + f"Invalid mamba attention backend type: '{backend_name}'. Valid " + f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}" + ) from e + + mamba_attn_backend = selected_backend.get_class() + return mamba_attn_backend + + @contextmanager def global_force_attn_backend_context_manager( attn_backend: AttentionBackendEnum, diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index 2e7500bac718..27cc3884517f 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -5,7 +5,6 @@ from einops import rearrange from torch import nn -from vllm.attention import AttentionBackend from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed import ( @@ -83,12 +82,7 @@ def kda_attention_fake( class KimiDeltaAttention(nn.Module, MambaBase): @property def mamba_type(self) -> str: - return "linear_attention" - - def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend - - return GDNAttentionBackend + return "gdn_attention" def get_state_dtype( self, diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index e68b09b4d81f..aa919d6fdc35 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -6,6 +6,7 @@ import torch +from vllm.attention.selector import get_mamba_attn_backend from vllm.config import VllmConfig from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec @@ -38,11 +39,6 @@ def get_state_shape(self) -> Iterable[tuple[int, ...]]: def mamba_type(self) -> str: pass - @abstractmethod - def get_attn_backend(self) -> type["AttentionBackend"]: - """Get the attention backend class for this Mamba layer.""" - pass - @abstractmethod def get_state_dtype(self) -> tuple[torch.dtype, ...]: pass @@ -69,3 +65,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: else 0 ), ) + + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this Mamba layer.""" + return get_mamba_attn_backend(self.mamba_type) diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 0a2742ff49a4..d85b3e61c5d6 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -2,12 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -from typing import TYPE_CHECKING import torch import torch.nn.functional as F @@ -37,9 +31,6 @@ from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - class MiniMaxText01RMSNormTP(CustomOp): name = "MiniMaxText01RMSNormTP" @@ -123,11 +114,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): def mamba_type(self) -> str: return "linear_attention" - def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend - - return LinearAttentionBackend - def get_state_dtype(self) -> tuple[torch.dtype]: assert self.model_config is not None assert self.cache_config is not None diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index b6345b8af7f0..90e520e24441 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,10 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, NamedTuple - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend +from typing import NamedTuple import torch from torch import nn @@ -452,11 +449,6 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: def mamba_type(self) -> str: return "mamba1" - def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend - - return Mamba1AttentionBackend - def _time_proj_bias(self) -> torch.Tensor | None: if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: return self.dt_proj.bias.float() diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 57313990b820..900701c46348 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,10 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn @@ -908,11 +904,6 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: def mamba_type(self) -> str: return "mamba2" - def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend - - return Mamba2AttentionBackend - def mamba_mixer2( projected_states: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 04efa8a8b373..0bbad17d7ebc 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -1,10 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend import torch @@ -232,11 +228,6 @@ def get_state_shape(self) -> tuple[tuple[int, ...]]: def mamba_type(self) -> str: return "short_conv" - def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend - - return ShortConvAttentionBackend - def short_conv( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 0c87f5000ff4..52c9755e0e0e 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -4,10 +4,6 @@ from collections.abc import Iterable from itertools import islice -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn @@ -467,11 +463,6 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: def mamba_type(self) -> str: return "mamba2" - def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend - - return Mamba2AttentionBackend - def plamo2_mamba_mixer( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 0415c8e00fdf..ad631f61e4b9 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -10,7 +10,7 @@ from torch import nn from transformers.activations import ACT2FN -from vllm.attention import Attention, AttentionBackend, AttentionMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CacheConfig, @@ -216,12 +216,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): @property def mamba_type(self) -> str: - return "linear_attention" - - def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend - - return GDNAttentionBackend + return "gdn_attention" def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype(