From 0d20070bb8457f4a2447d88faf139b03174535a1 Mon Sep 17 00:00:00 2001 From: Rohan Joshi Date: Wed, 27 Aug 2025 13:41:43 -0700 Subject: [PATCH] Enable long audio for WhisperAudioProcessor Summary: Enabled WhisperAudioProcessor for audio which is longer than one chunk (i.e. > 30 sec) Reviewed By: jackzhxng Differential Revision: D81093558 --- extension/audio/mel_spectrogram.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/extension/audio/mel_spectrogram.py b/extension/audio/mel_spectrogram.py index 1befd2ec031..b94e39c1fe1 100644 --- a/extension/audio/mel_spectrogram.py +++ b/extension/audio/mel_spectrogram.py @@ -123,19 +123,23 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: r""" Args: waveform (`torch.Tensor`): Mono waveform input, tensor of (dynamic) shape [num_samples], - where num_samples < n_samples. n_samples is 480000 for 16kHz and chunk length 30 Returns: - torch.Tensor: Output of fixed shape [1, feature_size, nb_max_frames] - [1, 80, 3000] with default options + torch.Tensor: Output of shape [1, feature_size, nb_max_frames * n_chunks] + n_chunks is the number of chunks of `sampling_rate` samples in the input waveform. + [1, 80, 3000] with default options and 1 chunk """ - # TODO: pad up to multiples of chunk_length (currently 1 chunk of 30 sec) + n_chunks = (waveform.shape[0] - 1) // self.n_samples + 1 waveform = F.pad( waveform, - (0, self.n_samples - waveform.shape[0] - 1), + (0, self.n_samples * n_chunks - waveform.shape[0]), mode="constant", value=self.padding_value, ) + # Ideally we should do: + # window = torch.hann_window(self.n_fft) + # but this is not currently supported when lowering. + # torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4) window = 0.5 * ( 1 - torch.cos( @@ -145,10 +149,6 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: / self.n_fft ) ) - # Ideally we should do instead - # window = torch.hann_window(self.n_fft) - # but this is not currently supported when lowering - # torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4) stft = torch.stft( waveform, n_fft=self.n_fft, @@ -157,7 +157,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: center=True, return_complex=True, ) - magnitudes = torch.abs(stft) ** 2 # pyre-ignore[58] + magnitudes = torch.abs(stft)[..., :-1] ** 2 # pyre-ignore[58] mel_spec = self.mel_filters @ magnitudes @@ -173,8 +173,7 @@ def export_processor(): audio_tensor = torch.randn(480000) chunk_tensor = audio_tensor[:93680] with torch.no_grad(): - # export. What is the min of waveforms? - dim = Dim("waveform", min=1600, max=audio_tensor.size(0)) + dim = Dim("waveform", min=1600, max=audio_tensor.size(0) * 10) # 10 chunks max ep: ExportedProgram = export( model, (chunk_tensor,), dynamic_shapes={"waveform": {0: dim}}, strict=True )