Skip to content

Conversation

@shen-shanshan
Copy link
Contributor

@shen-shanshan shen-shanshan commented Oct 9, 2025

Purpose

Motivation:

Like attention and MLA attention modules, we want to use some device-specific kernels for mamba layers and customize the proccessing of mamba attn backend, i.e., this is significant for running some mamba-like models (e.g., Qwen3-Next) on Ascend platform.

Main changes:

  • Add get_mamba_attn_backend() to attention selector, and force all mamba layers to get their attention backend by calling this method.
  • Add get_mamba_attn_backend_cls() to platform, thus other device besides GPU can custom their mamba attention backend.

Backend select priority:

  1. Select device-specific mamba backend acording to mamba_type. If no customization here, then comes to the default logic.
  2. Get default backend according to mamba_type from the mamba_type_to_backend_map.

Mamba layer:

Pass mamba_type to get_mamba_attn_backend() method to get its backend when initialization.


Update 2025/10/28:

Add _MambaBackend and MAMBA_BACKEND_MAP in registry.py.


Update 2025/11/10:

Add another argument linear_attn_type, together with mamba_type, to specify some special linear attention backend, e.g., GDNAttention.

linear_attn_type is optional:

  • None: use normal linear attention or other mamba backend.
  • not None: use mamba_type + linear_attn_type to get related backend.

Other changes:

  • Refactor KimiDeltaAttention to use get_mamba_attn_backend().
  • Update the doc about adding a new mamba type layer.

Update 2025/11/17:

Refactor follow #24794.


Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@shen-shanshan shen-shanshan marked this pull request as draft October 9, 2025 12:40
@mergify mergify bot added the qwen Related to Qwen models label Oct 9, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a pluggable selector for Mamba attention backends, refactoring the model layers to use this new centralized mechanism instead of hardcoded imports. This is a good architectural improvement for modularity. My review focuses on the robustness of the new selector. I've identified a potential issue in the error handling within vllm/attention/selector.py where the check for a valid backend class could be more robust and the error message more informative. I've provided a suggestion to address this.

@shen-shanshan shen-shanshan changed the title [Refactor][Mamba] Add selector for mamba attention backend and make it pluggable for other device [Mamba] Add selector for mamba attention backend and make it pluggable for other device Oct 10, 2025
@shen-shanshan shen-shanshan marked this pull request as ready for review October 10, 2025 07:47
mamba_type: str = "",
) -> str:
"""Get mamba attention backend class of a device."""
mamba_type_to_backend_map = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a _MambaBackend enum and MAMBA_BACKEND_MAP in registry.py and use these instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MatthewBonanni Hello, could you please help ask someone of the maintainers if anything else need updated or if this PR can be merged?

@mergify mergify bot added the new-model Requests to new models label Oct 28, 2025
Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in reviewing. I think this change looks fine - have some minor questions + suggestions.

MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend"
MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend"
LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend"
GDN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to handle the Kimi Linear case here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to handle the Kimi Linear case here.

Thanks for your suggestion! I will add it later.

@mergify
Copy link

mergify bot commented Nov 10, 2025

Documentation preview: https://vllm--26487.org.readthedocs.build/en/26487/

@mergify mergify bot added the documentation Improvements or additions to documentation label Nov 10, 2025
@shen-shanshan
Copy link
Contributor Author

CC @tdoublep @MatthewBonanni

I have updated this PR with updated descriptions in the purpose part of this PR.

@LucasWilkinson
Copy link
Collaborator

LucasWilkinson commented Nov 10, 2025

not super familiar with this area but is _MambaBackend the best name here? given the diverse set of implementation being added? would something like _LinearBackend, _LinearAttentionBackend, or _SSMAttentionBackend be better? cc @tdoublep ?

@tdoublep
Copy link
Member

@LucasWilkinson I believe _MambaBackend is consistent with the rest of the codebase right now. We use "mamba" as a catch-all for mamba + linear attention mechnaisms.

@shen-shanshan shen-shanshan changed the title [Mamba] Add selector for mamba attention backend and make it pluggable for other device [Model][Mamba] Add selector for mamba attention backend and make it pluggable for other device Nov 11, 2025
@MatthewBonanni
Copy link
Contributor

MatthewBonanni commented Nov 11, 2025

Now that #24794 has landed, could we refactor this PR to use the pattern from registry.py? i.e. _MambaBackendEnumMeta and MambaBackendEnum instead of _MambaBackend? You can make a separate _MAMBA_OVERRIDES too.

@shen-shanshan
Copy link
Contributor Author

Now that #24794 has landed, could we refactor this PR to use the pattern from registry.py? i.e. _MambaBackendEnumMeta and MambaBackendEnum instead of _MambaBackend? You can make a separate _MAMBA_OVERRIDES too.

OK, I will update it soon.

mamba_type: str,
linear_attn_type: str | None,
) -> str:
"""Get mamba attention backend class of a device."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add more docstring here to describe the args usage. Thanks.

from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend

return LinearAttentionBackend
return self.mamba_attn_backend
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice all the get_attn_backend implementation is the same, why note implement it in MambaBase class?

raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self

self.mamba_attn_backend = get_mamba_attn_backend(self.mamba_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.mamba_attn_backend is only used by get_attn_backend, why not let get_attn_backend call get_mamba_attn_backend directly? this self.mamba_attn_backend is looks unnecessary

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
@shen-shanshan
Copy link
Contributor Author

CC @tdoublep @MatthewBonanni I have updated the code following #24794.

@MatthewBonanni
Copy link
Contributor

LGTM!

@Yikun Yikun added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 17, 2025
Copy link
Collaborator

@jikunshang jikunshang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for refactor!

@shen-shanshan
Copy link
Contributor Author

@jikunshang Hello, the CI are broken may due to something irrelevant to this PR. Could you please help retrigger it? Thanks.

@jikunshang
Copy link
Collaborator

failed case is irrelevant. I think it's ncessary to retrigger full test to avoid CI resource.
we can request a force merge. cc @DarkLight1337 Please also take a look, thanks!

@gcanlin
Copy link
Contributor

gcanlin commented Nov 19, 2025

The breaking of CI has been fixed by #28908. Please try merging main branch into this PR.

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) November 19, 2025 05:46
@shen-shanshan
Copy link
Contributor Author

shen-shanshan commented Nov 19, 2025

All the CI failures are due to the same error shown below:

RuntimeError: This flash attention build does not support headdim not being a multiple of 32.

@gcanlin
Copy link
Contributor

gcanlin commented Nov 19, 2025

All the CI failures are due to the same error shown below:

RuntimeError: This flash attention build does not support headdim not being a multiple of 32.

I also met this issue locally.

@tdoublep tdoublep self-requested a review November 19, 2025 13:57
Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work - thank you

@DarkLight1337 DarkLight1337 merged commit d44e9df into vllm-project:main Nov 19, 2025
54 checks passed
Victor49152 pushed a commit to Victor49152/vllm that referenced this pull request Nov 20, 2025
…luggable for other device (vllm-project#26487)

Signed-off-by: shen-shanshan <467638484@qq.com>
LuminolT pushed a commit to LuminolT/vllm that referenced this pull request Nov 21, 2025
…luggable for other device (vllm-project#26487)

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: LuminolT <lumischen01@gmail.com>
bigPYJ1151 pushed a commit that referenced this pull request Nov 25, 2025
…luggable for other device (#26487)

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
…luggable for other device (vllm-project#26487)

Signed-off-by: shen-shanshan <467638484@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation new-model Requests to new models qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants