Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ Text Generation
* - :code:`MambaForCausalLM`
- Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
- ✅︎
-
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM
Expand Down
113 changes: 10 additions & 103 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# coding=utf-8
"""PyTorch MAMBA model."""
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple

import torch
Expand All @@ -10,7 +9,6 @@
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
Expand Down Expand Up @@ -39,13 +37,6 @@
KVCache = Tuple[torch.Tensor, torch.Tensor]


@dataclass
class MambaCacheParams:
is_prompt: bool = False
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()


# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class MambaMixer(nn.Module):
"""
Expand Down Expand Up @@ -209,37 +200,6 @@ def forward(self, hidden_states: torch.Tensor,
return contextualized_states


class MambaMLP(nn.Module):

def __init__(
self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
hidden_act = config.hidden_act
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x


class MambaDecoderLayer(nn.Module):

def __init__(self,
Expand All @@ -252,7 +212,6 @@ def __init__(self,
self.config = config
self.mixer = MambaMixer(config, layer_idx)

self.feed_forward = MambaMLP(config, quant_config=quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
Expand All @@ -274,10 +233,6 @@ def forward(

hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
ssm_state)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual


Expand Down Expand Up @@ -319,7 +274,6 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
Expand All @@ -346,26 +300,6 @@ def forward(


class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embeddings": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]

def __init__(
self,
Expand Down Expand Up @@ -416,8 +350,8 @@ def forward(self,
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)

hidden_states = self.backbone(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_tensors[0],
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_tensors[0],
mamba_cache_tensors[1])

return hidden_states
Expand Down Expand Up @@ -457,43 +391,16 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue

if "A_log" in name:
name = name.replace("A_log", "A")

if ".self_attn." in name:
name = name.replace(".self_attn", "")

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)