Skip to content

Commit 072d7e5

Browse files
authored
[PERF] Add conv1d metadata to GDN attn (#25105)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
1 parent 01a583f commit 072d7e5

File tree

5 files changed

+24
-8
lines changed

5 files changed

+24
-8
lines changed

vllm/model_executor/layers/mamba/mamba2_metadata.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
PlaceholderAttentionMetadata)
1212
from vllm.attention.backends.utils import PAD_SLOT_ID
1313
from vllm.platforms import current_platform
14+
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
1415
from vllm.v1.attention.backends.mamba2_attn import (
1516
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
1617

@@ -45,8 +46,8 @@ class Mamba2Metadata:
4546
"""
4647
nums_dict: Optional[dict] = None
4748
cu_seqlen: Optional[int] = None
48-
batch_ptr: Optional[torch.tensor] = None
49-
token_chunk_offset_ptr: Optional[torch.tensor] = None
49+
batch_ptr: Optional[torch.Tensor] = None
50+
token_chunk_offset_ptr: Optional[torch.Tensor] = None
5051

5152

5253
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
@@ -117,7 +118,8 @@ def prepare_mamba2_metadata(
117118

118119
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
119120
mamba2_metadata: Union[Mamba2Metadata,
120-
Mamba2AttentionMetadata]):
121+
Mamba2AttentionMetadata,
122+
GDNAttentionMetadata]):
121123
"""
122124
this is triggered upon handling a new input at the first layer
123125
"""

vllm/model_executor/models/qwen3_next.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
RowParallelLinear)
3636
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3737
from vllm.model_executor.layers.mamba.abstract import MambaBase
38+
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
3839
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
3940
mamba_v2_sharded_weight_loader)
4041
from vllm.model_executor.layers.mamba.mamba_utils import (
@@ -414,6 +415,7 @@ def _forward(
414415

415416
assert isinstance(attn_metadata, dict)
416417
attn_metadata = attn_metadata[self.prefix]
418+
conv_metadata = attn_metadata
417419
assert isinstance(attn_metadata, GDNAttentionMetadata)
418420
has_initial_state = attn_metadata.has_initial_state
419421
spec_query_start_loc = attn_metadata.spec_query_start_loc
@@ -475,17 +477,23 @@ def _forward(
475477

476478
# 2.2: process the remaining part
477479
if attn_metadata.num_prefills > 0:
480+
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
481+
if conv_metadata.cu_seqlen is None:
482+
conv_metadata = update_metadata(mixed_qkv_non_spec_T,
483+
non_spec_query_start_loc,
484+
conv_metadata)
478485
# - "cache_indices" updates the conv_state cache in positions
479486
# pointed to by "mamba_cache_params.state_indices_tensor"
480487
mixed_qkv_non_spec = causal_conv1d_fn(
481-
mixed_qkv_non_spec.transpose(0, 1),
488+
mixed_qkv_non_spec_T,
482489
conv_weights,
483490
self.conv1d.bias,
484491
activation=self.activation,
485492
conv_states=conv_state,
486493
has_initial_state=has_initial_state,
487494
cache_indices=non_spec_state_indices_tensor,
488495
query_start_loc=non_spec_query_start_loc,
496+
metadata=conv_metadata,
489497
).transpose(0, 1)
490498
elif attn_metadata.num_decodes > 0:
491499
mixed_qkv_non_spec = causal_conv1d_update(

vllm/v1/attention/backends/gdn_attn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ class GDNAttentionMetadata:
5050
Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,]
5151
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
5252

53+
# The following attributes are for triton implementation of causal_conv1d
54+
nums_dict: Optional[dict] = None
55+
cu_seqlen: Optional[int] = None
56+
batch_ptr: Optional[torch.Tensor] = None
57+
token_chunk_offset_ptr: Optional[torch.Tensor] = None
58+
5359

5460
class GDNAttentionMetadataBuilder(
5561
AttentionMetadataBuilder[GDNAttentionMetadata]):

vllm/v1/attention/backends/mamba2_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ class Mamba2AttentionMetadata:
132132
# The following attributes are for triton implementation of causal_conv1d
133133
nums_dict: Optional[dict] = None
134134
cu_seqlen: Optional[int] = None
135-
batch_ptr: Optional[torch.tensor] = None
136-
token_chunk_offset_ptr: Optional[torch.tensor] = None
135+
batch_ptr: Optional[torch.Tensor] = None
136+
token_chunk_offset_ptr: Optional[torch.Tensor] = None
137137

138138

139139
class Mamba2AttentionMetadataBuilder(

vllm/v1/attention/backends/short_conv_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class ShortConvAttentionMetadata:
3434
# For causal_conv1d
3535
nums_dict: Optional[dict] = None
3636
cu_seqlen: Optional[int] = None
37-
batch_ptr: Optional[torch.tensor] = None
38-
token_chunk_offset_ptr: Optional[torch.tensor] = None
37+
batch_ptr: Optional[torch.Tensor] = None
38+
token_chunk_offset_ptr: Optional[torch.Tensor] = None
3939

4040

4141
class ShortConvAttentionMetadataBuilder(

0 commit comments

Comments
 (0)