Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions extension/audio/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Audio Processing with ExecuTorch

The file `mel_spectrogram.py` contains the class `WhisperAudioProcessor`, a module which converts a mono waveform audio input (as a 1D tensor) into Mel spectrograms. It applies a Short-Time Fourier Transform (via torch.stft) and a Mel filterbank. It is equivalent to the `WhisperFeatureExtractor` class in HuggingFace Transformers, but is implemented in PyTorch instead of NumPy. `WhisperFeatureExtractor` is used for Whisper, Voxtral, Qwen2 audio and Qwen2.5 omni. For example, the output Mel spectrograms can be fed directly into the Whisper model (encoder+decoder) exported from HF Transformers.

Since `WhisperAudioProcessor` is written in PyTorch, we can export it with ExecuTorch and run it on device. The defaults for `WhisperAudioProcessor` are 16kHz audio and 80 Mel spectrogram bins and audio chunks of 30 sec.

Run it as a script

``` python mel_spectrogram.py ```

to export `WhisperFeatureExtractor` (with default constructor arguments) as `whisper_preprocess.pte`, which can run on device (on CPU).
55 changes: 41 additions & 14 deletions extension/audio/mel_spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,35 @@


class WhisperAudioProcessor(nn.Module):
"""
r"""
Computes Mel spectrograms from mono audio input.
Same as HuggingFace WhisperFeatureExtractor, but implemented in PyTorch

Args:
feature_size (`int`, defaults to 80):
The feature dimension of the extracted features.
sampling_rate (`int`, defaults to 16000):
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
hop_length (`int`, defaults to 160):
Length of the overlaping windows for the STFT used to obtain the Mel Frequency coefficients.
chunk_length (`int`, defaults to 30):
The maximum number of chuncks of `sampling_rate` samples used to trim and pad longer or shorter audio
sequences.
n_fft (`int`, defaults to 400):
Size of the Fourier transform.
padding_value (`float`, *optional*, defaults to 0.0):
Padding value used to pad the audio. Should correspond to silences.
"""

def __init__(
self,
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
padding_value=0.0,
):
feature_size: int = 80,
sampling_rate: int = 16000,
hop_length: int = 160,
chunk_length: int = 30,
n_fft: int = 400,
padding_value: float = 0.0,
) -> None:
super().__init__()
self.feature_size = feature_size
self.sampling_rate = sampling_rate
Expand All @@ -51,7 +66,9 @@ def __init__(
sampling_rate, n_fft, n_mels=feature_size
)

def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=torch.float32):
def get_mel_filters(
self, sr: int, n_fft: int, n_mels: int = 128, dtype: torch.dtype = torch.float32
) -> torch.Tensor:
# Initialize the weights
n_mels = int(n_mels)
weights = torch.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
Expand Down Expand Up @@ -97,17 +114,27 @@ def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=torch.float32):
)

# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, None]
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) # pyre-ignore[58]
weights *= enorm[:, None] # pyre-ignore[16]

return weights

def forward(self, waveform):
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
r"""
Args:
waveform (`torch.Tensor`): Mono waveform input, tensor of (dynamic) shape [num_samples],
where num_samples < n_samples. n_samples is 480000 for 16kHz and chunk length 30

Returns:
torch.Tensor: Output of fixed shape [1, feature_size, nb_max_frames]
[1, 80, 3000] with default options
"""
# TODO: pad up to multiples of chunk_length (currently 1 chunk of 30 sec)
waveform = F.pad(
waveform,
(0, self.n_samples - waveform.shape[0] - 1),
mode="constant",
value=0,
value=self.padding_value,
)
window = 0.5 * (
1
Expand All @@ -130,7 +157,7 @@ def forward(self, waveform):
center=True,
return_complex=True,
)
magnitudes = torch.abs(stft) ** 2
magnitudes = torch.abs(stft) ** 2 # pyre-ignore[58]

mel_spec = self.mel_filters @ magnitudes

Expand Down
Loading