Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions extension/audio/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

python_library(
name = "mel_spectrogram_lib",
srcs = ["mel_spectrogram.py"],
deps = [
"//caffe2:torch",
"//executorch/devtools/backend_debug:delegation_info",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/runtime:runtime",
"fbsource//third-party/pypi/datasets:datasets",
"fbsource//third-party/pypi/transformers:transformers",
"fbsource//third-party/pypi/librosa:librosa",
"fbsource//third-party/pypi/soundfile:soundfile"
]
)

python_binary(
name = "mel_spectrogram",
main_module = "executorch.extension.audio.mel_spectrogram",
deps = [
":mel_spectrogram_lib",
],
)
180 changes: 180 additions & 0 deletions extension/audio/mel_spectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
to_edge_transform_and_lower,
)

from torch.export import Dim, export, ExportedProgram


class WhisperAudioProcessor(nn.Module):
"""
Computes Mel spectrograms from mono audio input.
Same as HuggingFace WhisperFeatureExtractor, but implemented in PyTorch
"""

def __init__(
self,
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
padding_value=0.0,
):
super().__init__()
self.feature_size = feature_size
self.sampling_rate = sampling_rate
self.padding_value = padding_value

self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(
sampling_rate, n_fft, n_mels=feature_size
)

def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=torch.float32):
# Initialize the weights
n_mels = int(n_mels)
weights = torch.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)

# Center freqs of each FFT bin
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr, dtype=dtype)

# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965

mels = torch.linspace(min_mel, max_mel, n_mels + 2, dtype=dtype)

# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels

# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = (
torch.log(torch.tensor(6.4, dtype=dtype)) / 27.0
) # step size for log region

# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

mel_f = freqs

fdiff = torch.diff(mel_f)
ramps = torch.subtract(mel_f.unsqueeze(1), fftfreqs.unsqueeze(0))

for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]

# .. then intersect them with each other and zero
weights[i] = torch.maximum(
torch.tensor(0.0, dtype=dtype), torch.minimum(lower, upper)
)

# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, None]

return weights

def forward(self, waveform):
waveform = F.pad(
waveform,
(0, self.n_samples - waveform.shape[0] - 1),
mode="constant",
value=0,
)
window = 0.5 * (
1
- torch.cos(
2
* torch.pi
* torch.linspace(0, self.n_fft - 1, self.n_fft, dtype=torch.float32)
/ self.n_fft
)
)
# Ideally we should do instead
# window = torch.hann_window(self.n_fft)
# but this is not currently supported when lowering
# torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4)
stft = torch.stft(
waveform,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=window,
center=True,
return_complex=True,
)
magnitudes = torch.abs(stft) ** 2

mel_spec = self.mel_filters @ magnitudes

log_spec = torch.log10(torch.clamp(mel_spec, min=1e-10))
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

return log_spec.unsqueeze(0)


def export_processor():
model = WhisperAudioProcessor()
audio_tensor = torch.randn(480000)
chunk_tensor = audio_tensor[:93680]
with torch.no_grad():
# export. What is the min of waveforms?
dim = Dim("waveform", min=1600, max=audio_tensor.size(0))
ep: ExportedProgram = export(
model, (chunk_tensor,), dynamic_shapes={"waveform": {0: dim}}, strict=True
)
logging.debug(ep)

# to edge
edge: EdgeProgramManager = to_edge_transform_and_lower(
ep,
partitioner=[XnnpackPartitioner()],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)
logging.debug(edge.exported_program())

# to executorch
exec_prog = edge.to_executorch()
output_file = "whisper_preprocess.pte"
with open(output_file, "wb") as file:
exec_prog.write_to_file(file)

logging.debug("Done")


def main():
export_processor()


if __name__ == "__main__":
main()
Loading