Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions extension/audio/mel_spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
import torch.nn.functional as F

from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
to_edge_transform_and_lower,
)

from torch.export import Dim, export, ExportedProgram
from torch.export import Dim


class WhisperAudioProcessor(nn.Module):
Expand Down Expand Up @@ -51,6 +49,8 @@ def __init__(
chunk_length: int = 30,
n_fft: int = 400,
padding_value: float = 0.0,
max_audio_len: int = 600,
stack_output: bool = False,
) -> None:
super().__init__()
self.feature_size = feature_size
Expand All @@ -66,6 +66,8 @@ def __init__(
self.mel_filters = self.get_mel_filters(
sampling_rate, n_fft, n_mels=feature_size
)
self.max_audio_len = max_audio_len
self.stack_output = stack_output

def get_mel_filters(
self, sr: int, n_fft: int, n_mels: int = 128, dtype: torch.dtype = torch.float32
Expand Down Expand Up @@ -137,6 +139,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
mode="constant",
value=self.padding_value,
)

# Ideally we should do:
# window = torch.hann_window(self.n_fft)
# but this is not currently supported when lowering.
Expand Down Expand Up @@ -166,18 +169,27 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

return log_spec.unsqueeze(0)
if self.stack_output:
log_spec = log_spec.reshape(self.feature_size, -1, self.nb_max_frames)
log_spec = log_spec.transpose(0, 1)
return log_spec
else:
return log_spec.unsqueeze(0)


def export_processor(model=None, output_file="whisper_preprocess.pte"):
if model is None:
model = WhisperAudioProcessor()
audio_tensor = torch.randn(480000)
chunk_tensor = audio_tensor[:93680]
with torch.no_grad():
dim = Dim("waveform", min=1600, max=audio_tensor.size(0) * 10) # 10 chunks max
ep: ExportedProgram = export(
model, (chunk_tensor,), dynamic_shapes={"waveform": {0: dim}}, strict=True

audio_tensor = torch.randn(93680)
shapes_collection = torch.export.ShapesCollection()
max_n_chunks = int(model.max_audio_len * model.n_samples)
shapes_collection[audio_tensor] = {0: Dim.DYNAMIC(max=max_n_chunks)}
with torch.no_grad(), torch.fx.experimental._config.patch(
backed_size_oblivious=True
):
ep = torch.export.export(
model, (audio_tensor,), dynamic_shapes=shapes_collection, strict=True
)
logging.debug(ep)

Expand Down Expand Up @@ -236,6 +248,17 @@ def main():
default="whisper_preprocess.pte",
help="Output file path for the exported model",
)
parser.add_argument(
"--max_audio_len",
type=int,
default=600,
help="Max audio length that can be processed, in seconds.",
)
parser.add_argument(
"--stack_output",
action="store_true",
help="Whether to stack output along the batch dimension, one per chunk. Used by models such as Voxtral, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/voxtral/processing_voxtral.py#L94 for more information.",
)

args = parser.parse_args()

Expand All @@ -245,6 +268,8 @@ def main():
hop_length=args.hop_length,
chunk_length=args.chunk_length,
n_fft=args.n_fft,
max_audio_len=args.max_audio_len,
stack_output=args.stack_output,
)

export_processor(model, args.output_file)
Expand Down
Loading