diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 0bf04e0e7e2f..0b62fbd40b07 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiDashengLM model compatible with HuggingFace weights.""" + import collections import collections.abc from collections.abc import Iterable, Mapping, Sequence @@ -30,10 +31,10 @@ import numpy as np import torch import torch.nn as nn -import torchaudio.transforms as audio_transforms +import torchaudio.functional as F +from torch.nn.functional import scaled_dot_product_attention from transformers import BatchFeature -from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -41,7 +42,6 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -147,15 +147,19 @@ def __init__( super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = ColumnParallelLinear(input_size=in_features, - output_size=hidden_features, - quant_config=quant_config, - prefix=f"{prefix}.fc1") + self.fc1 = ColumnParallelLinear( + input_size=in_features, + output_size=hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) self.act = get_act_fn("gelu") - self.fc2 = RowParallelLinear(input_size=hidden_features, - output_size=out_features, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc2 = RowParallelLinear( + input_size=hidden_features, + output_size=out_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.fc1(x) @@ -171,7 +175,6 @@ def __init__( dim: int, num_heads: int = 8, qkv_bias: bool = False, - causal: bool = False, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -205,33 +208,30 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.qkv", ) - self.attn = MultiHeadAttention( - self.num_heads, - self.head_dim, - self.scale, - num_kv_heads=self.num_kv_heads, - ) self.proj = RowParallelLinear( input_size=dim, output_size=dim, quant_config=quant_config, prefix=f"{prefix}.proj", ) - self.causal = causal def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): B, N, C = x.shape - qkv_out, _ = self.qkv(x) - q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) + qkv, _ = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) - attn_out = self.attn(q, k, v) - C_local = attn_out.numel() // (B * N) # C_local for parallel - attn_out = attn_out.view(B, N, C_local) - - x, _ = self.proj(attn_out) + x = scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask[:, None, None, :] if mask is not None else None, + ) + x = x.transpose(1, 2).reshape(B, N, C) + x, _ = self.proj(x) return x @@ -280,6 +280,63 @@ def forward( return x +class DashengFrontend(nn.Module): + + def __init__(self, config: DashengConfig): + super().__init__() + self.config = config + + spectrogram_window = torch.hann_window(self.config.win_length) + self.register_buffer( + "spectrogram_window", + spectrogram_window, + persistent=False, + ) + self.spectrogram_window: torch.Tensor + + melscale_fbanks = F.melscale_fbanks( + n_freqs=self.config.n_fft // 2 + 1, + f_min=self.config.f_min, + f_max=self.config.f_max, + n_mels=self.config.n_mels, + sample_rate=self.config.sample_rate, + ) + self.register_buffer("melscale_fbanks", + melscale_fbanks, + persistent=False) + self.melscale_fbanks: torch.Tensor + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + spectrogram = F.spectrogram( + waveform=waveform.to(torch.float32), + pad=0, + window=self.spectrogram_window, + n_fft=self.config.n_fft, + hop_length=self.config.hop_length, + win_length=self.config.win_length, + power=2, + normalized=False, + center=self.config.center, + ) + mel_spectrogram = ( + spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT + # x has shape [batch, freq, time]. + # F.amplitude_to_DB accepts inputs shaped as: + # - [freq, time] + # - [channel, freq, time] + # - [..., channel, freq, time] + # Here we insert a channel dimension of size 1 before calling it, + # then remove that extra dimension afterward. + log_mel_spectrogram = F.amplitude_to_DB( + mel_spectrogram.unsqueeze(1), + multiplier=10, + amin=1e-10, + db_multiplier=0, + top_db=120, + ).squeeze(1) + return log_mel_spectrogram.to(waveform.dtype) + + class DashengAudioTransformer(nn.Module): def __init__( @@ -293,7 +350,7 @@ def __init__( self.target_length = config.target_length self.hop_length = config.hop_length - self._init_front_end(config) + self.front_end = DashengFrontend(config) self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01) @@ -318,34 +375,10 @@ def __init__( qkv_bias=config.qkv_bias, init_values=config.init_values, quant_config=quant_config, - prefix=f"{prefix}.block{i}", + prefix=f"{prefix}.blocks.{i}", ) for i in range(config.depth)) self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6) - def _init_front_end(self, config): - with set_default_torch_dtype(torch.float32): - self.front_end = nn.Sequential( - audio_transforms.MelSpectrogram( - f_min=config.f_min, - f_max=config.f_max, - center=config.center, - win_length=config.win_length, - hop_length=config.hop_length, - sample_rate=config.sample_rate, - n_fft=config.n_fft, - n_mels=config.n_mels, - ), - audio_transforms.AmplitudeToDB(top_db=120), - ) - - mel_spectrogram = self.front_end[0] - fb = mel_spectrogram.mel_scale.fb - win = mel_spectrogram.spectrogram.window - mel_spectrogram.mel_scale.fb = fb.to(torch.bfloat16).to( - torch.float32) - mel_spectrogram.spectrogram.window = win.to(torch.bfloat16).to( - torch.float32) - def forward_features( self, x: torch.Tensor, @@ -430,14 +463,16 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.net.0", return_bias=False, - ), get_act_fn("gelu"), + ), + get_act_fn("gelu"), RowParallelLinear( input_size=out_dim, output_size=out_dim, quant_config=quant_config, prefix=f"{prefix}.net.2", return_bias=False, - )) + ), + ) def forward(self, x, mask=None): batch_size, seq_len, dim = x.shape @@ -534,9 +569,12 @@ def _call_hf_processor( # + Padding min_audio_len = self.info.get_min_audio_len() processed_audios = [ - np.pad(audio, (0, min_audio_len - audio.shape[-1]), - mode='constant', - constant_values=0) if isinstance(audio, np.ndarray) + np.pad( + audio, + (0, min_audio_len - audio.shape[-1]), + mode="constant", + constant_values=0, + ) if isinstance(audio, np.ndarray) and audio.shape[-1] < min_audio_len else audio for audio in audios ] @@ -585,8 +623,8 @@ def _get_prompt_updates( if audio_length is None: audio_output_lengths = [] else: - audio_length_np = audio_length.cpu().numpy() if isinstance( - audio_length, torch.Tensor) else audio_length + audio_length_np = (audio_length.cpu().numpy() if isinstance( + audio_length, torch.Tensor) else audio_length) audio_output_lengths = [ max(1, calculate_mel_frames_dasheng( int(length))) # at least one frame @@ -617,6 +655,17 @@ def get_replacement_midashenglm(item_idx: int): dummy_inputs=MiDashengLMDummyInputsBuilder, ) class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -660,8 +709,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") + raise ValueError( + f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): return mm_input.reshape(-1, *mm_input.shape[2:]) @@ -710,8 +759,8 @@ def _process_audio_input( audio_input["input_values"].dtype) batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape - audio_length_np = audio_length.cpu().numpy() if isinstance( - audio_length, torch.Tensor) else audio_length + audio_length_np = (audio_length.cpu().numpy() if isinstance( + audio_length, torch.Tensor) else audio_length) audio_output_lengths = [ max(1, calculate_mel_frames_dasheng( int(length))) # at least one frame @@ -720,11 +769,11 @@ def _process_audio_input( audio_output_lengths = torch.tensor(audio_output_lengths).to( audio_embeddings.device) - audio_feature_mask = (torch.arange( + audio_feature_mask = torch.arange( max_audio_tokens, device=audio_embeddings.device).unsqueeze(0).expand( - batch_size, max_audio_tokens) - < audio_output_lengths.unsqueeze(1)) + batch_size, + max_audio_tokens) < audio_output_lengths.unsqueeze(1) masked_audio_features = audio_embeddings[audio_feature_mask].view( -1, embed_dim) @@ -762,10 +811,12 @@ def forward( ) input_ids = None - return self.decoder.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + return self.decoder.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) def compute_logits( self,