From 44743bdeaf1e1672c95bfb37d9bd4908d22494cf Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Fri, 12 Sep 2025 12:42:56 -0700 Subject: [PATCH] Mel spectrogram output stacking along batch dim (#14275) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/14275 Differential Revision: D81798729 --- extension/audio/mel_spectrogram.py | 45 +++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/extension/audio/mel_spectrogram.py b/extension/audio/mel_spectrogram.py index b5d8763d67a..d8577829ffc 100644 --- a/extension/audio/mel_spectrogram.py +++ b/extension/audio/mel_spectrogram.py @@ -12,14 +12,12 @@ import torch.nn.functional as F from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner - from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, to_edge_transform_and_lower, ) - -from torch.export import Dim, export, ExportedProgram +from torch.export import Dim class WhisperAudioProcessor(nn.Module): @@ -51,6 +49,8 @@ def __init__( chunk_length: int = 30, n_fft: int = 400, padding_value: float = 0.0, + max_audio_len: int = 600, + stack_output: bool = False, ) -> None: super().__init__() self.feature_size = feature_size @@ -66,6 +66,8 @@ def __init__( self.mel_filters = self.get_mel_filters( sampling_rate, n_fft, n_mels=feature_size ) + self.max_audio_len = max_audio_len + self.stack_output = stack_output def get_mel_filters( self, sr: int, n_fft: int, n_mels: int = 128, dtype: torch.dtype = torch.float32 @@ -137,6 +139,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: mode="constant", value=self.padding_value, ) + # Ideally we should do: # window = torch.hann_window(self.n_fft) # but this is not currently supported when lowering. @@ -166,18 +169,27 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 - return log_spec.unsqueeze(0) + if self.stack_output: + log_spec = log_spec.reshape(self.feature_size, -1, self.nb_max_frames) + log_spec = log_spec.transpose(0, 1) + return log_spec + else: + return log_spec.unsqueeze(0) def export_processor(model=None, output_file="whisper_preprocess.pte"): if model is None: model = WhisperAudioProcessor() - audio_tensor = torch.randn(480000) - chunk_tensor = audio_tensor[:93680] - with torch.no_grad(): - dim = Dim("waveform", min=1600, max=audio_tensor.size(0) * 10) # 10 chunks max - ep: ExportedProgram = export( - model, (chunk_tensor,), dynamic_shapes={"waveform": {0: dim}}, strict=True + + audio_tensor = torch.randn(93680) + shapes_collection = torch.export.ShapesCollection() + max_n_chunks = int(model.max_audio_len * model.n_samples) + shapes_collection[audio_tensor] = {0: Dim.DYNAMIC(max=max_n_chunks)} + with torch.no_grad(), torch.fx.experimental._config.patch( + backed_size_oblivious=True + ): + ep = torch.export.export( + model, (audio_tensor,), dynamic_shapes=shapes_collection, strict=True ) logging.debug(ep) @@ -236,6 +248,17 @@ def main(): default="whisper_preprocess.pte", help="Output file path for the exported model", ) + parser.add_argument( + "--max_audio_len", + type=int, + default=600, + help="Max audio length that can be processed, in seconds.", + ) + parser.add_argument( + "--stack_output", + action="store_true", + help="Whether to stack output along the batch dimension, one per chunk. Used by models such as Voxtral, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/voxtral/processing_voxtral.py#L94 for more information.", + ) args = parser.parse_args() @@ -245,6 +268,8 @@ def main(): hop_length=args.hop_length, chunk_length=args.chunk_length, n_fft=args.n_fft, + max_audio_len=args.max_audio_len, + stack_output=args.stack_output, ) export_processor(model, args.output_file)