diff --git a/extension/audio/TARGETS b/extension/audio/TARGETS new file mode 100644 index 00000000000..fe8d35faf82 --- /dev/null +++ b/extension/audio/TARGETS @@ -0,0 +1,28 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") +load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +python_library( + name = "mel_spectrogram_lib", + srcs = ["mel_spectrogram.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/backend_debug:delegation_info", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/runtime:runtime", + "fbsource//third-party/pypi/datasets:datasets", + "fbsource//third-party/pypi/transformers:transformers", + "fbsource//third-party/pypi/librosa:librosa", + "fbsource//third-party/pypi/soundfile:soundfile" + ] +) + +python_binary( + name = "mel_spectrogram", + main_module = "executorch.extension.audio.mel_spectrogram", + deps = [ + ":mel_spectrogram_lib", + ], +) diff --git a/extension/audio/mel_spectrogram.py b/extension/audio/mel_spectrogram.py new file mode 100644 index 00000000000..bafa3a088ac --- /dev/null +++ b/extension/audio/mel_spectrogram.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner + +from executorch.exir import ( + EdgeCompileConfig, + EdgeProgramManager, + to_edge_transform_and_lower, +) + +from torch.export import Dim, export, ExportedProgram + + +class WhisperAudioProcessor(nn.Module): + """ + Computes Mel spectrograms from mono audio input. + Same as HuggingFace WhisperFeatureExtractor, but implemented in PyTorch + """ + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + hop_length=160, + chunk_length=30, + n_fft=400, + padding_value=0.0, + ): + super().__init__() + self.feature_size = feature_size + self.sampling_rate = sampling_rate + self.padding_value = padding_value + + self.n_fft = n_fft + self.hop_length = hop_length + self.chunk_length = chunk_length + self.n_samples = chunk_length * sampling_rate + self.nb_max_frames = self.n_samples // hop_length + self.sampling_rate = sampling_rate + self.mel_filters = self.get_mel_filters( + sampling_rate, n_fft, n_mels=feature_size + ) + + def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=torch.float32): + # Initialize the weights + n_mels = int(n_mels) + weights = torch.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) + + # Center freqs of each FFT bin + fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr, dtype=dtype) + + # 'Center freqs' of mel bands - uniformly spaced between limits + min_mel = 0.0 + max_mel = 45.245640471924965 + + mels = torch.linspace(min_mel, max_mel, n_mels + 2, dtype=dtype) + + # Fill in the linear scale + f_min = 0.0 + f_sp = 200.0 / 3 + freqs = f_min + f_sp * mels + + # And now the nonlinear scale + min_log_hz = 1000.0 # beginning of log region (Hz) + min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) + logstep = ( + torch.log(torch.tensor(6.4, dtype=dtype)) / 27.0 + ) # step size for log region + + # If we have vector data, vectorize + log_t = mels >= min_log_mel + freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) + + mel_f = freqs + + fdiff = torch.diff(mel_f) + ramps = torch.subtract(mel_f.unsqueeze(1), fftfreqs.unsqueeze(0)) + + for i in range(n_mels): + # lower and upper slopes for all bins + lower = -ramps[i] / fdiff[i] + upper = ramps[i + 2] / fdiff[i + 1] + + # .. then intersect them with each other and zero + weights[i] = torch.maximum( + torch.tensor(0.0, dtype=dtype), torch.minimum(lower, upper) + ) + + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) + weights *= enorm[:, None] + + return weights + + def forward(self, waveform): + waveform = F.pad( + waveform, + (0, self.n_samples - waveform.shape[0] - 1), + mode="constant", + value=0, + ) + window = 0.5 * ( + 1 + - torch.cos( + 2 + * torch.pi + * torch.linspace(0, self.n_fft - 1, self.n_fft, dtype=torch.float32) + / self.n_fft + ) + ) + # Ideally we should do instead + # window = torch.hann_window(self.n_fft) + # but this is not currently supported when lowering + # torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4) + stft = torch.stft( + waveform, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + return_complex=True, + ) + magnitudes = torch.abs(stft) ** 2 + + mel_spec = self.mel_filters @ magnitudes + + log_spec = torch.log10(torch.clamp(mel_spec, min=1e-10)) + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + + return log_spec.unsqueeze(0) + + +def export_processor(): + model = WhisperAudioProcessor() + audio_tensor = torch.randn(480000) + chunk_tensor = audio_tensor[:93680] + with torch.no_grad(): + # export. What is the min of waveforms? + dim = Dim("waveform", min=1600, max=audio_tensor.size(0)) + ep: ExportedProgram = export( + model, (chunk_tensor,), dynamic_shapes={"waveform": {0: dim}}, strict=True + ) + logging.debug(ep) + + # to edge + edge: EdgeProgramManager = to_edge_transform_and_lower( + ep, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + ), + ) + logging.debug(edge.exported_program()) + + # to executorch + exec_prog = edge.to_executorch() + output_file = "whisper_preprocess.pte" + with open(output_file, "wb") as file: + exec_prog.write_to_file(file) + + logging.debug("Done") + + +def main(): + export_processor() + + +if __name__ == "__main__": + main()