From a56734de4ec1571923194e494fc2616898087f00 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Thu, 9 Oct 2025 15:09:03 -0700 Subject: [PATCH] Add a wav loader (#14923) Summary: This pull request adds support for loading and processing `.wav` audio files in the multimodal runner, alongside existing `.bin` file support. It introduces a dedicated WAV loader utility, updates the runner to dispatch audio file processing based on file type, and adds comprehensive tests for WAV file parsing and normalization. These changes improve flexibility and robustness when handling audio inputs. **WAV file support and audio processing:** * Added a new utility `wav_loader.h` that provides functions to parse WAV file headers and load normalized PCM audio data from `.wav` files, supporting 16-bit and 32-bit PCM formats. * Updated `multimodal.cpp` to support loading audio from both `.bin` and `.wav` files, including input validation and error handling for unsupported formats. The runner now uses the processor for both file types and enforces processor requirements for `.wav` files. [[1]](diffhunk://#diff-0ac16dbe4eaefa08e21fbda582fe2cd2b482f43aaedfc1bf2f31becf5e7bb843L138-R149) [[2]](diffhunk://#diff-0ac16dbe4eaefa08e21fbda582fe2cd2b482f43aaedfc1bf2f31becf5e7bb843L166-R191) [[3]](diffhunk://#diff-0ac16dbe4eaefa08e21fbda582fe2cd2b482f43aaedfc1bf2f31becf5e7bb843R247-L255) * Added a new command-line flag `data_path` and passed it to the multimodal runner to facilitate data file handling. [[1]](diffhunk://#diff-0ac16dbe4eaefa08e21fbda582fe2cd2b482f43aaedfc1bf2f31becf5e7bb843R38) [[2]](diffhunk://#diff-0ac16dbe4eaefa08e21fbda582fe2cd2b482f43aaedfc1bf2f31becf5e7bb843R294) [[3]](diffhunk://#diff-0ac16dbe4eaefa08e21fbda582fe2cd2b482f43aaedfc1bf2f31becf5e7bb843L297-R322) **Testing and build integration:** * Introduced `test_wav_loader.cpp`, which provides unit tests for WAV header parsing, sample normalization, error handling, and unsupported format detection. * Registered the new utility and tests in build configuration files, ensuring proper header exports and test coverage. [[1]](diffhunk://#diff-8a73187dfda9c5479db6911bee649164ff4434d36e8f4eb881cc1f049c4e3271R108) [[2]](diffhunk://#diff-24b61cfeb7f1fc9a646df385ece0c31ea2ab18b3c7e34fc62117c62538e111ffL22-R22) [[3]](diffhunk://#diff-c8ef93f128805fc48fe2d7c1dadb9ff5d2f4dc5ee7c00b638fd193d3dfb1f06cR47-R56) [[4]](diffhunk://#diff-d755455ed59da7a902bb5a5c1e540a1924f63e8f70a9dc78b455f2c569a19db6R17) Reviewed By: mergennachin Differential Revision: D84214903 Pulled By: larryliu0820 --- examples/models/voxtral/README.md | 21 +- examples/models/voxtral/multimodal.cpp | 105 +++++---- extension/llm/runner/targets.bzl | 1 + extension/llm/runner/test/CMakeLists.txt | 2 +- extension/llm/runner/test/targets.bzl | 10 + extension/llm/runner/test/test_wav_loader.cpp | 155 +++++++++++++ extension/llm/runner/wav_loader.h | 210 ++++++++++++++++++ extension/testing_util/targets.bzl | 1 + 8 files changed, 460 insertions(+), 45 deletions(-) create mode 100644 extension/llm/runner/test/test_wav_loader.cpp create mode 100644 extension/llm/runner/wav_loader.h diff --git a/examples/models/voxtral/README.md b/examples/models/voxtral/README.md index 8cac4264bba..4e9ddcf34a4 100644 --- a/examples/models/voxtral/README.md +++ b/examples/models/voxtral/README.md @@ -41,8 +41,8 @@ To run the model, we will use the Voxtral runner, which utilizes ExecuTorch's Mu The Voxtral runner will do the following things: - Audio Input: - - Option A: Pass the raw audio tensor into exported preprocessor to produce a mel spectrogram tensor. - - Option B: If starting directly with an already processed audio input tensor, format the inputs to the multimodal runner (metadata tokens, audio tokens, text tokens, etc.). + - Option A: Pass raw audio data from a `.wav` file into the exported preprocessor to produce a mel spectrogram tensor. + - Option B: If starting directly with an already processed audio input tensor (preprocessed mel spectrogram), format the inputs to the multimodal runner (metadata tokens, audio tokens, text tokens, etc.). - Feed the formatted inputs to the multimodal modal runner. @@ -66,13 +66,26 @@ cmake -DCMAKE_INSTALL_PREFIX=cmake-out -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Re ## Running the model You can download the `tekken.json` tokenizer from [Voxtral's HuggingFace repo](https://huggingface.co/mistralai/Voxtral-Mini-3B-2507). + +### Running with raw audio (.wav file) +For raw audio files (`.wav`), you must provide a preprocessor to convert the audio into mel spectrogram format: +``` +./cmake-out/examples/models/voxtral/voxtral_runner \ + --model_path path/to/model.pte \ + --tokenizer_path path/to/tekken.json \ + --prompt "What can you tell me about this audio?" \ + --audio_path path/to/audio_input.wav \ + --processor_path path/to/voxtral_preprocessor.pte +``` + +### Running with preprocessed audio (.bin file) +If you already have a preprocessed mel spectrogram saved as a `.bin` file, you can skip the preprocessor: ``` ./cmake-out/examples/models/voxtral/voxtral_runner \ --model_path path/to/model.pte \ --tokenizer_path path/to/tekken.json \ --prompt "What can you tell me about this audio?" \ - --audio_path path/to/audio_input.bin \ - --processor_path path/to/voxtral_preprocessor.pte # If you're passing raw audio file in audio_path + --audio_path path/to/preprocessed_audio.bin ``` Example output: diff --git a/examples/models/voxtral/multimodal.cpp b/examples/models/voxtral/multimodal.cpp index 081df27cd67..b3dd5e3ab68 100644 --- a/examples/models/voxtral/multimodal.cpp +++ b/examples/models/voxtral/multimodal.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -34,6 +35,7 @@ DEFINE_string( "multimodal.pte", "Model serialized in flatbuffer format."); +DEFINE_string(data_path, "", "Path to data file."); DEFINE_string(tokenizer_path, "tekken.json", "Tokenizer stuff."); DEFINE_string(prompt, "What is happening in this audio?", "Text prompt."); @@ -113,15 +115,15 @@ MultimodalInput loadPreprocessedAudio(const std::string& audio_path) { } /** - * @brief Loads a .bin file into a tensor and processes it using a .pte - * processor + * @brief Loads raw audio from a .bin or .wav file and processes it using a + * .pte processor * - * This function loads raw audio data from a .bin file (similar to - * loadPreprocessedAudio), creates a tensor from it, and then passes it through - * a processor module loaded from a .pte file to generate processed audio - * features. + * This function loads raw audio data from either a .bin file (raw float array) + * or a .wav file (WAV format with headers), creates a tensor from it, and then + * passes it through a processor module loaded from a .pte file to generate + * processed audio features. * - * @param audio_path Path to the .bin audio file + * @param audio_path Path to the .bin or .wav audio file * @param processor_path Path to the .pte processor file * @return MultimodalInput containing the processed audio data * @throws std::runtime_error if file loading or processing fails @@ -135,6 +137,41 @@ MultimodalInput processRawAudioFile( "Processor path is required for raw audio processing"); } + // Load the audio data from file (.bin or .wav) + std::vector audio_data; + if (ends_with(audio_path, ".wav")) { + audio_data = ::executorch::extension::llm::load_wav_audio_data(audio_path); + ET_LOG( + Info, + "Loaded WAV file: %s, %zu samples", + audio_path.c_str(), + audio_data.size()); + } else if (ends_with(audio_path, ".bin")) { + std::ifstream f(audio_path, std::ios::binary | std::ios::ate); + if (!f.is_open()) { + ET_LOG(Error, "Failed to open audio file: %s", audio_path.c_str()); + throw std::runtime_error("Failed to open audio file"); + } + + std::size_t n_floats = f.tellg() / sizeof(float); + f.seekg(0, std::ios::beg); + + audio_data.resize(n_floats); + f.read( + reinterpret_cast(audio_data.data()), + audio_data.size() * sizeof(float)); + f.close(); + + ET_LOG( + Info, "Loaded .bin file: %s, %zu floats", audio_path.c_str(), n_floats); + } else { + ET_LOG( + Error, + "Unsupported audio file format: %s (only .bin and .wav files are supported)", + audio_path.c_str()); + throw std::runtime_error("Unsupported audio file format"); + } + // Load the audio processor .pte. std::unique_ptr processor_module; try { @@ -153,25 +190,6 @@ MultimodalInput processRawAudioFile( throw std::runtime_error("Exception while loading processor module"); } - // Load the audio data from file. - std::ifstream f(audio_path, std::ios::binary | std::ios::ate); - if (!f.is_open()) { - ET_LOG(Error, "Failed to open audio file: %s", audio_path.c_str()); - throw std::runtime_error("Failed to open audio file"); - } - - std::size_t n_floats = f.tellg() / sizeof(float); - f.seekg(0, std::ios::beg); - - std::vector audio_data(n_floats); - f.read( - reinterpret_cast(audio_data.data()), - audio_data.size() * sizeof(float)); - f.close(); - - ET_LOG( - Info, "Loaded .bin file: %s, %zu floats", audio_path.c_str(), n_floats); - // Execute the processor std::vector tensor_shape = { static_cast(audio_data.size())}; @@ -226,33 +244,39 @@ MultimodalInput processRawAudioFile( * * Dispatches audio file processing based on file extension and processor * availability: + * - .wav files: Requires processor, processes raw audio through processor * - .bin files with processor: Loads raw audio from .bin and processes through * processor * - .bin files without processor: Loads preprocessed mel spectrogram features * directly * - * @param audio_path Path to the audio file (.bin) - * @param processor_path Path to the processor .pte file (optional) + * @param audio_path Path to the audio file (.bin or .wav) + * @param processor_path Path to the processor .pte file (optional for .bin, + * required for .wav) * @return MultimodalInput containing the processed audio data * @throws std::runtime_error if file format is unsupported or processing fails */ MultimodalInput processAudioFile( const std::string& audio_path, const std::string& processor_path = "") { - if (ends_with(audio_path, ".bin")) { - if (!processor_path.empty()) { - // Process raw audio from .bin file through the processor - return processRawAudioFile(audio_path, processor_path); - } else { - // Load preprocessed audio stored as a binary file (existing behavior) - return loadPreprocessedAudio(audio_path); + if (ends_with(audio_path, ".wav") || ends_with(audio_path, ".bin")) { + if (processor_path.empty()) { + if (ends_with(audio_path, ".wav")) { + ET_CHECK_MSG( + false, + "Processor path is required for .wav file processing: %s", + audio_path.c_str()); + } else { + // Load preprocessed audio stored as a binary file (existing behavior) + return loadPreprocessedAudio(audio_path); + } } + return processRawAudioFile(audio_path, processor_path); } else { - ET_LOG( - Error, - "Unsupported audio file format: %s (only .bin files are supported)", + ET_CHECK_MSG( + false, + "Unsupported audio file format: %s (only .bin and .wav files are supported)", audio_path.c_str()); - throw std::runtime_error("Unsupported audio file format"); } } @@ -267,6 +291,7 @@ int32_t main(int32_t argc, char** argv) { const char* prompt = FLAGS_prompt.c_str(); const char* audio_path = FLAGS_audio_path.c_str(); const char* processor_path = FLAGS_processor_path.c_str(); + const char* data_path = FLAGS_data_path.c_str(); float temperature = FLAGS_temperature; int32_t cpu_threads = FLAGS_cpu_threads; bool warmup = FLAGS_warmup; @@ -294,7 +319,7 @@ int32_t main(int32_t argc, char** argv) { // Create multimodal runner std::unique_ptr<::executorch::extension::llm::MultimodalRunner> runner = ::executorch::extension::llm::create_multimodal_runner( - model_path, std::move(tokenizer)); + model_path, std::move(tokenizer), data_path); if (runner == nullptr) { ET_LOG(Error, "Failed to create multimodal runner"); return 1; diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 242860a195a..e001e8fc154 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -105,6 +105,7 @@ def define_common_targets(): exported_headers = [ "audio.h", "image.h", + "wav_loader.h", "multimodal_input.h", "multimodal_runner.h", "multimodal_prefiller.h", diff --git a/extension/llm/runner/test/CMakeLists.txt b/extension/llm/runner/test/CMakeLists.txt index 2aa18000831..934a5797da1 100644 --- a/extension/llm/runner/test/CMakeLists.txt +++ b/extension/llm/runner/test/CMakeLists.txt @@ -19,7 +19,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp test_text_prefiller.cpp - test_text_decoder_runner.cpp test_multimodal_input.cpp + test_text_decoder_runner.cpp test_multimodal_input.cpp test_wav_loader.cpp ) # Add LSan stub for Apple platforms diff --git a/extension/llm/runner/test/targets.bzl b/extension/llm/runner/test/targets.bzl index 3339b3b8584..0571b39ccdb 100644 --- a/extension/llm/runner/test/targets.bzl +++ b/extension/llm/runner/test/targets.bzl @@ -44,3 +44,13 @@ def define_common_targets(): "//executorch/extension/llm/runner:multimodal_runner_lib", ], ) + + runtime.cxx_test( + name = "test_wav_loader", + srcs = ["test_wav_loader.cpp"], + deps = [ + "//executorch/extension/testing_util:temp_file", + "//executorch/extension/llm/runner:multimodal_runner_lib", + "//executorch/runtime/platform:platform", + ], + ) diff --git a/extension/llm/runner/test/test_wav_loader.cpp b/extension/llm/runner/test/test_wav_loader.cpp new file mode 100644 index 00000000000..bc3ac0ff324 --- /dev/null +++ b/extension/llm/runner/test/test_wav_loader.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include + +#include + +using executorch::extension::llm::kOneOverIntMax; +using executorch::extension::llm::kOneOverShortMax; +using executorch::extension::llm::load_wav_audio_data; +using executorch::extension::llm::load_wav_header; +using executorch::extension::llm::WavHeader; +using executorch::extension::testing::TempFile; + +namespace { + +// Test fixture to ensure PAL initialization +class WavLoaderTest : public ::testing::Test { + protected: + void SetUp() override { + // Ensure PAL is initialized before tests run + executorch::runtime::runtime_init(); + } +}; + +void append_bytes(std::vector& out, const char* literal) { + out.insert(out.end(), literal, literal + 4); +} + +void append_le16(std::vector& out, uint16_t value) { + out.push_back(static_cast(value & 0xFF)); + out.push_back(static_cast((value >> 8) & 0xFF)); +} + +void append_le32(std::vector& out, uint32_t value) { + out.push_back(static_cast(value & 0xFF)); + out.push_back(static_cast((value >> 8) & 0xFF)); + out.push_back(static_cast((value >> 16) & 0xFF)); + out.push_back(static_cast((value >> 24) & 0xFF)); +} + +std::vector make_pcm_wav_bytes( + int bits_per_sample, + const std::vector& samples, + uint16_t num_channels = 1, + uint32_t sample_rate = 16000) { + const size_t bytes_per_sample = static_cast(bits_per_sample / 8); + const uint32_t subchunk2_size = + static_cast(samples.size() * bytes_per_sample); + const uint32_t byte_rate = sample_rate * num_channels * bytes_per_sample; + const uint16_t block_align = num_channels * bytes_per_sample; + const uint32_t chunk_size = 36 + subchunk2_size; + + std::vector bytes; + bytes.reserve(44 + subchunk2_size); + + append_bytes(bytes, "RIFF"); + append_le32(bytes, chunk_size); + append_bytes(bytes, "WAVE"); + append_bytes(bytes, "fmt "); + append_le32(bytes, 16); // PCM + append_le16(bytes, 1); // AudioFormat PCM + append_le16(bytes, num_channels); + append_le32(bytes, sample_rate); + append_le32(bytes, byte_rate); + append_le16(bytes, block_align); + append_le16(bytes, static_cast(bits_per_sample)); + append_bytes(bytes, "data"); + append_le32(bytes, subchunk2_size); + + for (int32_t sample : samples) { + const uint32_t encoded = + static_cast(static_cast(sample)); + for (size_t byte_idx = 0; byte_idx < bytes_per_sample; ++byte_idx) { + bytes.push_back(static_cast((encoded >> (8 * byte_idx)) & 0xFF)); + } + } + + return bytes; +} + +} // namespace + +TEST_F(WavLoaderTest, LoadHeaderParsesPcmMetadata) { + const std::vector wav_bytes = + make_pcm_wav_bytes(16, {0, 32767, -32768}); + TempFile file(wav_bytes.data(), wav_bytes.size()); + + std::unique_ptr header = load_wav_header(file.path()); + ASSERT_NE(header, nullptr); + + EXPECT_EQ(header->AudioFormat, 1); + EXPECT_EQ(header->NumOfChan, 1); + EXPECT_EQ(header->SamplesPerSec, 16000); + EXPECT_EQ(header->bitsPerSample, 16); + EXPECT_EQ(header->blockAlign, 2); + EXPECT_EQ(header->bytesPerSec, 32000); + EXPECT_EQ(header->dataOffset, 44); + EXPECT_EQ(header->Subchunk2Size, 6); +} + +TEST_F(WavLoaderTest, LoadAudioData16BitNormalizesSamples) { + const std::vector samples = {0, 32767, -32768}; + const std::vector wav_bytes = make_pcm_wav_bytes(16, samples); + TempFile file(wav_bytes.data(), wav_bytes.size()); + + std::vector audio = load_wav_audio_data(file.path()); + ASSERT_EQ(audio.size(), samples.size()); + + EXPECT_NEAR(audio[0], 0.0f, 1e-6f); + EXPECT_NEAR(audio[1], 32767.0f * kOneOverShortMax, 1e-6f); + EXPECT_NEAR(audio[2], -32768.0f * kOneOverShortMax, 1e-6f); +} + +TEST_F(WavLoaderTest, LoadAudioData32BitNormalizesSamples) { + const std::vector samples = { + 0, + std::numeric_limits::max(), + std::numeric_limits::min()}; + const std::vector wav_bytes = make_pcm_wav_bytes(32, samples); + TempFile file(wav_bytes.data(), wav_bytes.size()); + + std::vector audio = load_wav_audio_data(file.path()); + ASSERT_EQ(audio.size(), samples.size()); + + EXPECT_NEAR(audio[0], 0.0f, 1e-8f); + EXPECT_NEAR( + audio[1], + static_cast(static_cast(samples[1]) * kOneOverIntMax), + 1e-6f); + EXPECT_NEAR( + audio[2], + static_cast(static_cast(samples[2]) * kOneOverIntMax), + 1e-6f); +} + +TEST_F(WavLoaderTest, LoadHeaderReturnsNullWhenMagicMissing) { + const std::string bogus_contents = "not a wav file"; + TempFile file(bogus_contents); + + std::unique_ptr header = load_wav_header(file.path()); + EXPECT_EQ(header, nullptr); +} diff --git a/extension/llm/runner/wav_loader.h b/extension/llm/runner/wav_loader.h new file mode 100644 index 00000000000..f49a4d1723e --- /dev/null +++ b/extension/llm/runner/wav_loader.h @@ -0,0 +1,210 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple WAV file loader. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace executorch::extension::llm { + +constexpr float kOneOverIntMax = 1 / static_cast(INT32_MAX); +constexpr float kOneOverShortMax = 1 / static_cast(INT16_MAX); + +struct WavHeader { + /* RIFF Chunk Descriptor */ + uint8_t RIFF[4]; + uint32_t ChunkSize; + uint8_t WAVE[4]; + /* "fmt" sub-chunk */ + uint8_t fmt[4]; + uint32_t Subchunk1Size; + uint16_t AudioFormat; + uint16_t NumOfChan; + uint32_t SamplesPerSec; + uint32_t bytesPerSec; + uint16_t blockAlign; + uint16_t bitsPerSample; + /* "data" sub-chunk */ + uint32_t dataOffset; + uint32_t Subchunk2Size; +}; + +inline std::unique_ptr load_wav_header(const std::string& fp) { + std::ifstream file(fp, std::ios::binary); + if (!file.is_open()) { + ET_CHECK_MSG(false, "Failed to open WAV file: %s", fp.c_str()); + } + + file.seekg(0, std::ios::end); + size_t file_size = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector buffer(file_size); + file.read(buffer.data(), file_size); + file.close(); + + const char* data = buffer.data(); + size_t data_size = buffer.size(); + + bool has_riff = false; + bool has_wave = false; + + if (data_size >= 4 && std::memcmp(data, "RIFF", 4) == 0) { + has_riff = true; + } + + if (data_size >= 12 && std::memcmp(data + 8, "WAVE", 4) == 0) { + has_wave = true; + } + + bool is_wav_file = has_riff && has_wave; + std::unique_ptr header; + + if (is_wav_file) { + header = std::make_unique(); + size_t default_header_size = sizeof(WavHeader); + + size_t data_offset = 0; + for (size_t i = 0; i + 4 < data_size; i++) { + if (std::memcmp(data + i, "data", 4) == 0) { + data_offset = i; + break; + } + } + + if (data_size >= default_header_size) { + std::memcpy( + reinterpret_cast(header.get()), data, default_header_size); + + ET_LOG(Info, "WAV header detected, getting raw audio data."); + ET_LOG( + Info, + "RIFF Header: %c%c%c%c", + header->RIFF[0], + header->RIFF[1], + header->RIFF[2], + header->RIFF[3]); + ET_LOG(Info, "Chunk Size: %d", header->ChunkSize); + ET_LOG( + Info, + "WAVE Header: %c%c%c%c", + header->WAVE[0], + header->WAVE[1], + header->WAVE[2], + header->WAVE[3]); + ET_LOG( + Info, + "Format Header: %c%c%c%c", + header->fmt[0], + header->fmt[1], + header->fmt[2], + header->fmt[3]); + ET_LOG(Info, "Format Chunk Size: %d", header->Subchunk1Size); + ET_LOG(Info, "Audio Format: %d", header->AudioFormat); + ET_LOG(Info, "Number of Channels: %d", header->NumOfChan); + ET_LOG(Info, "Sample Rate: %d", header->SamplesPerSec); + ET_LOG(Info, "Byte Rate: %d", header->bytesPerSec); + ET_LOG(Info, "Block Align: %d", header->blockAlign); + ET_LOG(Info, "Bits per Sample: %d", header->bitsPerSample); + + if (data_offset != 0) { + header->Subchunk2Size = + *reinterpret_cast(data + data_offset + 4); + ET_LOG(Info, "Subchunk2Size: %d", header->Subchunk2Size); + header->dataOffset = static_cast(data_offset + 8); + } else { + ET_LOG( + Error, + "WAV file structure is invalid, missing Subchunk2ID 'data' field."); + throw std::runtime_error("Invalid WAV file structure"); + } + } else { + ET_CHECK_MSG( + false, + "WAV header detected but file is too small to contain a complete header"); + } + } + + return header; +} + +inline std::vector load_wav_audio_data(const std::string& fp) { + std::ifstream file(fp, std::ios::binary); + if (!file.is_open()) { + ET_CHECK_MSG(false, "Failed to open WAV file: %s", fp.c_str()); + } + + file.seekg(0, std::ios::end); + size_t file_size = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector buffer(file_size); + file.read(buffer.data(), file_size); + file.close(); + + auto header = load_wav_header(fp); + + if (header.get() == nullptr) { + ET_CHECK_MSG(false, "WAV header not detected in file: %s", fp.c_str()); + } + + const char* data = buffer.data(); + size_t data_offset = header->dataOffset; + size_t data_size = header->Subchunk2Size; + int bits_per_sample = header->bitsPerSample; + + std::vector audio_data; + + if (bits_per_sample == 32) { + size_t num_samples = data_size / 4; + audio_data.resize(num_samples); + const int32_t* input_buffer = + reinterpret_cast(data + data_offset); + + for (size_t i = 0; i < num_samples; ++i) { + audio_data[i] = static_cast( + static_cast(input_buffer[i]) * kOneOverIntMax); + } + } else if (bits_per_sample == 16) { + size_t num_samples = data_size / 2; + audio_data.resize(num_samples); + const int16_t* input_buffer = + reinterpret_cast(data + data_offset); + + for (size_t i = 0; i < num_samples; ++i) { + audio_data[i] = static_cast( + static_cast(input_buffer[i]) * kOneOverShortMax); + } + } else { + ET_CHECK_MSG( + false, + "Unsupported bits per sample: %d. Only support 32 and 16.", + bits_per_sample); + } + + ET_LOG( + Info, + "Loaded %zu audio samples from WAV file: %s", + audio_data.size(), + fp.c_str()); + + return audio_data; +} + +} // namespace executorch::extension::llm diff --git a/extension/testing_util/targets.bzl b/extension/testing_util/targets.bzl index 05b825645e8..a5ad1fb9b8c 100644 --- a/extension/testing_util/targets.bzl +++ b/extension/testing_util/targets.bzl @@ -14,6 +14,7 @@ def define_common_targets(): visibility = [ "//executorch/devtools/etdump/tests/...", "//executorch/extension/data_loader/test/...", + "//executorch/extension/llm/runner/test/...", "//executorch/extension/testing_util/test/...", "//executorch/extension/fb/ptez/decompression_methods/test/...", "//executorch/extension/fb/ptez/test/...",