From bc08785ae8fc9ff9d655f0d5a2e773be2d384aee Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 14 Oct 2024 14:17:13 +0000 Subject: [PATCH 1/4] [Bugfix] Clean up some cruft in mamba.py --- vllm/model_executor/models/mamba.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 1112a2181135..34cb4117ea4b 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,6 +1,5 @@ # coding=utf-8 """PyTorch MAMBA model.""" -from dataclasses import dataclass from typing import Iterable, List, Optional, Tuple import torch @@ -39,13 +38,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -@dataclass -class MambaCacheParams: - is_prompt: bool = False - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - - # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class MambaMixer(nn.Module): """ @@ -346,18 +338,8 @@ def forward( class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - } - # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", - "o_proj", "embed_tokens", "lm_head", ] @@ -459,9 +441,6 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] @@ -474,9 +453,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if "A_log" in name: name = name.replace("A_log", "A") - if ".self_attn." in name: - name = name.replace(".self_attn", "") - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue From 104281c72ac7ad68d183726f36dea8fd5c219b2c Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 14 Oct 2024 14:20:01 +0000 Subject: [PATCH 2/4] Another one --- vllm/model_executor/models/mamba.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 34cb4117ea4b..2f1dcc9a7eb0 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -447,9 +447,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "A_log" in name: name = name.replace("A_log", "A") From e7a5bb7c02c999ca50717517ecddca8366ebaa9e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 14 Oct 2024 18:44:44 +0000 Subject: [PATCH 3/4] =?UTF-8?q?=F0=9F=9A=AA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/models/supported_models.rst | 2 +- vllm/model_executor/models/mamba.py | 102 ++++-------------------- 2 files changed, 17 insertions(+), 87 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index f5d53edcebd3..8d6e5b7649ba 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -155,7 +155,7 @@ Text Generation * - :code:`MambaForCausalLM` - Mamba - :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. - - ✅︎ + - - * - :code:`MiniCPMForCausalLM` - MiniCPM diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 2f1dcc9a7eb0..48ea1e34054c 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -9,7 +9,6 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -201,37 +200,6 @@ def forward(self, hidden_states: torch.Tensor, return contextualized_states -class MambaMLP(nn.Module): - - def __init__( - self, - config: MambaConfig, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - hidden_size = config.hidden_size - intermediate_size = config.intermediate_size - hidden_act = config.hidden_act - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - class MambaDecoderLayer(nn.Module): def __init__(self, @@ -244,7 +212,6 @@ def __init__(self, self.config = config self.mixer = MambaMixer(config, layer_idx) - self.feed_forward = MambaMLP(config, quant_config=quant_config) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -253,24 +220,16 @@ def forward( self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], conv_state: torch.Tensor, ssm_state: torch.Tensor, **kwargs, ): - if residual is None: - residual = hidden_states - hidden_states = self.norm(hidden_states) - else: - hidden_states, residual = self.norm(hidden_states, residual) - + residual = hidden_states + hidden_states = self.norm(hidden_states) hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, ssm_state) - # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) - hidden_states = self.feed_forward(hidden_states) - return hidden_states, residual + hidden_states = hidden_states + residual + return hidden_states class MambaModel(nn.Module): @@ -311,7 +270,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor, @@ -324,7 +282,7 @@ def forward( current_ssm_state = ssm_state[i] current_conv_state = conv_state[i] - hidden_states, residual = layer( + hidden_states = layer( positions=positions, hidden_states=hidden_states, attn_metadata=attn_metadata, @@ -332,22 +290,12 @@ def forward( conv_state=current_conv_state, ssm_state=current_ssm_state, ) - hidden_states, _ = self.norm_f(hidden_states, residual) + hidden_states = self.norm_f(hidden_states) return hidden_states class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): - # LoRA specific attributes - supported_lora_modules = [ - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embeddings": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] def __init__( self, @@ -398,8 +346,8 @@ def forward(self, mamba_cache_tensors = self.mamba_cache.current_run_tensors( input_ids, attn_metadata, **kwargs) - hidden_states = self.backbone(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_tensors[0], + hidden_states = self.backbone(input_ids, positions, attn_metadata, + mamba_cache_tensors[0], mamba_cache_tensors[1]) return hidden_states @@ -439,34 +387,16 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 5f220370c100121c5c5cf4fd9aa9286836421497 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 14 Oct 2024 19:09:32 +0000 Subject: [PATCH 4/4] Revert to previous residual style --- vllm/model_executor/models/mamba.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 48ea1e34054c..b86b687a9c36 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -220,16 +220,20 @@ def forward( self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], conv_state: torch.Tensor, ssm_state: torch.Tensor, **kwargs, ): - residual = hidden_states - hidden_states = self.norm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, ssm_state) - hidden_states = hidden_states + residual - return hidden_states + return hidden_states, residual class MambaModel(nn.Module): @@ -282,7 +286,7 @@ def forward( current_ssm_state = ssm_state[i] current_conv_state = conv_state[i] - hidden_states = layer( + hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, attn_metadata=attn_metadata, @@ -290,7 +294,7 @@ def forward( conv_state=current_conv_state, ssm_state=current_ssm_state, ) - hidden_states = self.norm_f(hidden_states) + hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states