Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 122 additions & 71 deletions vllm/model_executor/models/midashenglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""

import collections
import collections.abc
from collections.abc import Iterable, Mapping, Sequence
Expand All @@ -30,18 +31,17 @@
import numpy as np
import torch
import torch.nn as nn
import torchaudio.transforms as audio_transforms
import torchaudio.functional as F
from torch.nn.functional import scaled_dot_product_attention
from transformers import BatchFeature

from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
Expand Down Expand Up @@ -147,15 +147,19 @@ def __init__(
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = ColumnParallelLinear(input_size=in_features,
output_size=hidden_features,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc1 = ColumnParallelLinear(
input_size=in_features,
output_size=hidden_features,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.act = get_act_fn("gelu")
self.fc2 = RowParallelLinear(input_size=hidden_features,
output_size=out_features,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
self.fc2 = RowParallelLinear(
input_size=hidden_features,
output_size=out_features,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x)
Expand All @@ -171,7 +175,6 @@ def __init__(
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
causal: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
Expand Down Expand Up @@ -205,33 +208,30 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.attn = MultiHeadAttention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
)
self.proj = RowParallelLinear(
input_size=dim,
output_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
self.causal = causal

def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
B, N, C = x.shape

qkv_out, _ = self.qkv(x)
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
qkv, _ = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)

attn_out = self.attn(q, k, v)
C_local = attn_out.numel() // (B * N) # C_local for parallel
attn_out = attn_out.view(B, N, C_local)

x, _ = self.proj(attn_out)
x = scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask[:, None, None, :] if mask is not None else None,
)

x = x.transpose(1, 2).reshape(B, N, C)
x, _ = self.proj(x)
return x


Expand Down Expand Up @@ -280,6 +280,63 @@ def forward(
return x


class DashengFrontend(nn.Module):

def __init__(self, config: DashengConfig):
super().__init__()
self.config = config

spectrogram_window = torch.hann_window(self.config.win_length)
self.register_buffer(
"spectrogram_window",
spectrogram_window,
persistent=False,
)
self.spectrogram_window: torch.Tensor

melscale_fbanks = F.melscale_fbanks(
n_freqs=self.config.n_fft // 2 + 1,
f_min=self.config.f_min,
f_max=self.config.f_max,
n_mels=self.config.n_mels,
sample_rate=self.config.sample_rate,
)
self.register_buffer("melscale_fbanks",
melscale_fbanks,
persistent=False)
self.melscale_fbanks: torch.Tensor

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
spectrogram = F.spectrogram(
waveform=waveform.to(torch.float32),
pad=0,
window=self.spectrogram_window,
n_fft=self.config.n_fft,
hop_length=self.config.hop_length,
win_length=self.config.win_length,
power=2,
normalized=False,
center=self.config.center,
)
mel_spectrogram = (
spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
# x has shape [batch, freq, time].
# F.amplitude_to_DB accepts inputs shaped as:
# - [freq, time]
# - [channel, freq, time]
# - [..., channel, freq, time]
# Here we insert a channel dimension of size 1 before calling it,
# then remove that extra dimension afterward.
log_mel_spectrogram = F.amplitude_to_DB(
mel_spectrogram.unsqueeze(1),
multiplier=10,
amin=1e-10,
db_multiplier=0,
top_db=120,
).squeeze(1)
return log_mel_spectrogram.to(waveform.dtype)


class DashengAudioTransformer(nn.Module):

def __init__(
Expand All @@ -293,7 +350,7 @@ def __init__(
self.target_length = config.target_length
self.hop_length = config.hop_length

self._init_front_end(config)
self.front_end = DashengFrontend(config)

self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)

Expand All @@ -318,34 +375,10 @@ def __init__(
qkv_bias=config.qkv_bias,
init_values=config.init_values,
quant_config=quant_config,
prefix=f"{prefix}.block{i}",
prefix=f"{prefix}.blocks.{i}",
) for i in range(config.depth))
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)

def _init_front_end(self, config):
with set_default_torch_dtype(torch.float32):
self.front_end = nn.Sequential(
audio_transforms.MelSpectrogram(
f_min=config.f_min,
f_max=config.f_max,
center=config.center,
win_length=config.win_length,
hop_length=config.hop_length,
sample_rate=config.sample_rate,
n_fft=config.n_fft,
n_mels=config.n_mels,
),
audio_transforms.AmplitudeToDB(top_db=120),
)

mel_spectrogram = self.front_end[0]
fb = mel_spectrogram.mel_scale.fb
win = mel_spectrogram.spectrogram.window
mel_spectrogram.mel_scale.fb = fb.to(torch.bfloat16).to(
torch.float32)
mel_spectrogram.spectrogram.window = win.to(torch.bfloat16).to(
torch.float32)

def forward_features(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -430,14 +463,16 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.net.0",
return_bias=False,
), get_act_fn("gelu"),
),
get_act_fn("gelu"),
RowParallelLinear(
input_size=out_dim,
output_size=out_dim,
quant_config=quant_config,
prefix=f"{prefix}.net.2",
return_bias=False,
))
),
)

def forward(self, x, mask=None):
batch_size, seq_len, dim = x.shape
Expand Down Expand Up @@ -534,9 +569,12 @@ def _call_hf_processor(
# + Padding
min_audio_len = self.info.get_min_audio_len()
processed_audios = [
np.pad(audio, (0, min_audio_len - audio.shape[-1]),
mode='constant',
constant_values=0) if isinstance(audio, np.ndarray)
np.pad(
audio,
(0, min_audio_len - audio.shape[-1]),
mode="constant",
constant_values=0,
) if isinstance(audio, np.ndarray)
and audio.shape[-1] < min_audio_len else audio for audio in audios
]

Expand Down Expand Up @@ -585,8 +623,8 @@ def _get_prompt_updates(
if audio_length is None:
audio_output_lengths = []
else:
audio_length_np = audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length
audio_length_np = (audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length)
audio_output_lengths = [
max(1, calculate_mel_frames_dasheng(
int(length))) # at least one frame
Expand Down Expand Up @@ -617,6 +655,17 @@ def get_replacement_midashenglm(item_idx: int):
dummy_inputs=MiDashengLMDummyInputsBuilder,
)
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
Expand Down Expand Up @@ -660,8 +709,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
raise ValueError(
f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
return mm_input.reshape(-1, *mm_input.shape[2:])

Expand Down Expand Up @@ -710,8 +759,8 @@ def _process_audio_input(
audio_input["input_values"].dtype)
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape

audio_length_np = audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length
audio_length_np = (audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length)
audio_output_lengths = [
max(1, calculate_mel_frames_dasheng(
int(length))) # at least one frame
Expand All @@ -720,11 +769,11 @@ def _process_audio_input(
audio_output_lengths = torch.tensor(audio_output_lengths).to(
audio_embeddings.device)

audio_feature_mask = (torch.arange(
audio_feature_mask = torch.arange(
max_audio_tokens,
device=audio_embeddings.device).unsqueeze(0).expand(
batch_size, max_audio_tokens)
< audio_output_lengths.unsqueeze(1))
batch_size,
max_audio_tokens) < audio_output_lengths.unsqueeze(1)

masked_audio_features = audio_embeddings[audio_feature_mask].view(
-1, embed_dim)
Expand Down Expand Up @@ -762,10 +811,12 @@ def forward(
)
input_ids = None

return self.decoder.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return self.decoder.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)

def compute_logits(
self,
Expand Down