From 8f412ef1918b4973b4c91ee6661ad61e04a5dc79 Mon Sep 17 00:00:00 2001 From: Rohan Joshi Date: Tue, 26 Aug 2025 08:42:59 -0700 Subject: [PATCH] Added documentation and typing to WhisperAudioProcessor (#13661) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/13661 Reviewed By: cccclai Differential Revision: D80907444 --- extension/audio/README.md | 11 ++++++ extension/audio/mel_spectrogram.py | 55 ++++++++++++++++++++++-------- 2 files changed, 52 insertions(+), 14 deletions(-) create mode 100644 extension/audio/README.md diff --git a/extension/audio/README.md b/extension/audio/README.md new file mode 100644 index 00000000000..ba2b87c73f2 --- /dev/null +++ b/extension/audio/README.md @@ -0,0 +1,11 @@ +# Audio Processing with ExecuTorch + +The file `mel_spectrogram.py` contains the class `WhisperAudioProcessor`, a module which converts a mono waveform audio input (as a 1D tensor) into Mel spectrograms. It applies a Short-Time Fourier Transform (via torch.stft) and a Mel filterbank. It is equivalent to the `WhisperFeatureExtractor` class in HuggingFace Transformers, but is implemented in PyTorch instead of NumPy. `WhisperFeatureExtractor` is used for Whisper, Voxtral, Qwen2 audio and Qwen2.5 omni. For example, the output Mel spectrograms can be fed directly into the Whisper model (encoder+decoder) exported from HF Transformers. + +Since `WhisperAudioProcessor` is written in PyTorch, we can export it with ExecuTorch and run it on device. The defaults for `WhisperAudioProcessor` are 16kHz audio and 80 Mel spectrogram bins and audio chunks of 30 sec. + +Run it as a script + +``` python mel_spectrogram.py ``` + +to export `WhisperFeatureExtractor` (with default constructor arguments) as `whisper_preprocess.pte`, which can run on device (on CPU). diff --git a/extension/audio/mel_spectrogram.py b/extension/audio/mel_spectrogram.py index bafa3a088ac..1befd2ec031 100644 --- a/extension/audio/mel_spectrogram.py +++ b/extension/audio/mel_spectrogram.py @@ -22,20 +22,35 @@ class WhisperAudioProcessor(nn.Module): - """ + r""" Computes Mel spectrograms from mono audio input. Same as HuggingFace WhisperFeatureExtractor, but implemented in PyTorch + + Args: + feature_size (`int`, defaults to 80): + The feature dimension of the extracted features. + sampling_rate (`int`, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + hop_length (`int`, defaults to 160): + Length of the overlaping windows for the STFT used to obtain the Mel Frequency coefficients. + chunk_length (`int`, defaults to 30): + The maximum number of chuncks of `sampling_rate` samples used to trim and pad longer or shorter audio + sequences. + n_fft (`int`, defaults to 400): + Size of the Fourier transform. + padding_value (`float`, *optional*, defaults to 0.0): + Padding value used to pad the audio. Should correspond to silences. """ def __init__( self, - feature_size=80, - sampling_rate=16000, - hop_length=160, - chunk_length=30, - n_fft=400, - padding_value=0.0, - ): + feature_size: int = 80, + sampling_rate: int = 16000, + hop_length: int = 160, + chunk_length: int = 30, + n_fft: int = 400, + padding_value: float = 0.0, + ) -> None: super().__init__() self.feature_size = feature_size self.sampling_rate = sampling_rate @@ -51,7 +66,9 @@ def __init__( sampling_rate, n_fft, n_mels=feature_size ) - def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=torch.float32): + def get_mel_filters( + self, sr: int, n_fft: int, n_mels: int = 128, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: # Initialize the weights n_mels = int(n_mels) weights = torch.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) @@ -97,17 +114,27 @@ def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=torch.float32): ) # Slaney-style mel is scaled to be approx constant energy per channel - enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) - weights *= enorm[:, None] + enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) # pyre-ignore[58] + weights *= enorm[:, None] # pyre-ignore[16] return weights - def forward(self, waveform): + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + r""" + Args: + waveform (`torch.Tensor`): Mono waveform input, tensor of (dynamic) shape [num_samples], + where num_samples < n_samples. n_samples is 480000 for 16kHz and chunk length 30 + + Returns: + torch.Tensor: Output of fixed shape [1, feature_size, nb_max_frames] + [1, 80, 3000] with default options + """ + # TODO: pad up to multiples of chunk_length (currently 1 chunk of 30 sec) waveform = F.pad( waveform, (0, self.n_samples - waveform.shape[0] - 1), mode="constant", - value=0, + value=self.padding_value, ) window = 0.5 * ( 1 @@ -130,7 +157,7 @@ def forward(self, waveform): center=True, return_complex=True, ) - magnitudes = torch.abs(stft) ** 2 + magnitudes = torch.abs(stft) ** 2 # pyre-ignore[58] mel_spec = self.mel_filters @ magnitudes