diff --git a/examples/models/voxtral/multimodal.cpp b/examples/models/voxtral/multimodal.cpp index d7183f3c662..b086a04363c 100644 --- a/examples/models/voxtral/multimodal.cpp +++ b/examples/models/voxtral/multimodal.cpp @@ -12,6 +12,10 @@ #include +#include +#include +#include + #include #include #include @@ -36,6 +40,11 @@ DEFINE_string(prompt, "What is happening in this audio?", "Text prompt."); DEFINE_string(audio_path, "", "Path to input audio file."); +DEFINE_string( + processor_path, + "", + "Path to processor .pte file for raw audio processing."); + DEFINE_double( temperature, 0.8f, @@ -50,16 +59,48 @@ DEFINE_bool(warmup, false, "Whether to run a warmup run."); namespace { +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; using ::executorch::extension::llm::Image; using ::executorch::extension::llm::make_image_input; using ::executorch::extension::llm::make_text_input; using ::executorch::extension::llm::MultimodalInput; +using ::executorch::runtime::EValue; bool ends_with(const std::string& str, const std::string& suffix) { return str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; } +/** + * @brief Loads float data from a binary file + * + * @param audio_path Path to the binary audio file (.bin) + * @return Vector of float data loaded from the file + * @throws std::runtime_error if file loading fails + */ +std::vector loadBinaryFloatData(const std::string& audio_path) { + std::ifstream f(audio_path, std::ios::binary | std::ios::ate); + if (!f.is_open()) { + ET_LOG(Error, "Failed to open audio file: %s", audio_path.c_str()); + throw std::runtime_error("Failed to open audio file"); + } + + std::size_t n_floats = + f.tellg() / sizeof(float); // Number of floats in the audio file + f.seekg(0, std::ios::beg); + + std::vector audio_data(n_floats); + f.read( + reinterpret_cast(audio_data.data()), + audio_data.size() * sizeof(float)); + f.close(); + + ET_LOG( + Info, "Loaded .bin file: %s, %zu floats", audio_path.c_str(), n_floats); + return audio_data; +} + /** * @brief Loads preprocessed audio data from a binary file * @@ -73,22 +114,19 @@ bool ends_with(const std::string& str, const std::string& suffix) { * @return MultimodalInput containing the loaded audio data */ MultimodalInput loadPreprocessedAudio(const std::string& audio_path) { - std::ifstream f(audio_path, std::ios::binary | std::ios::ate); + std::vector audio_data = loadBinaryFloatData(audio_path); + int32_t n_bins = 128; int32_t n_frames = 3000; - std::size_t n_floats = - f.tellg() / sizeof(float); // Number of floats in the audio file. - f.seekg(0, std::ios::beg); + + std::size_t n_floats = audio_data.size(); int32_t batch_size = ceil( n_floats / (n_bins * n_frames)); // Batch in increments of n_frames, rounding up. - std::vector audio_data(batch_size * n_bins * n_frames); - f.read( - reinterpret_cast(audio_data.data()), - audio_data.size() * sizeof(float)); ET_LOG(Info, "audio_data len = %d", audio_data.size()); + // Create Audio multimodal input auto audio = std::make_unique<::executorch::extension::llm::Audio>(); audio->batch_size = batch_size; audio->n_bins = n_bins; @@ -100,29 +138,140 @@ MultimodalInput loadPreprocessedAudio(const std::string& audio_path) { } /** - * @brief Processes audio files for multimodal input + * @brief Loads a .bin file into a tensor and processes it using a .pte + * processor * - * Dispatches audio file processing based on file extension: - * - .bin files: Loads preprocessed mel spectrogram features directly - * - .wav/.mp3 files: Currently unsupported, throws runtime_error + * This function loads raw audio data from a .bin file (similar to + * loadPreprocessedAudio), creates a tensor from it, and then passes it through + * a processor module loaded from a .pte file to generate processed audio + * features. + * + * @param audio_path Path to the .bin audio file + * @param processor_path Path to the .pte processor file + * @return MultimodalInput containing the processed audio data + * @throws std::runtime_error if file loading or processing fails + */ +MultimodalInput processRawAudioFile( + const std::string& audio_path, + const std::string& processor_path) { + if (processor_path.empty()) { + ET_LOG(Error, "Processor path is required for raw audio processing"); + throw std::runtime_error( + "Processor path is required for raw audio processing"); + } + + // Load the audio processor .pte. + std::unique_ptr processor_module; + try { + processor_module = + std::make_unique(processor_path, Module::LoadMode::File); + auto load_error = processor_module->load(); + if (load_error != ::executorch::runtime::Error::Ok) { + ET_LOG( + Error, + "Failed to load processor module from: %s", + processor_path.c_str()); + throw std::runtime_error("Failed to load processor module"); + } + } catch (const std::exception& e) { + ET_LOG(Error, "Exception while loading processor module: %s", e.what()); + throw std::runtime_error("Exception while loading processor module"); + } + + // Load the audio data from file. + std::vector audio_data = loadBinaryFloatData(audio_path); + + // Execute the processor + std::vector tensor_shape = { + static_cast(audio_data.size())}; + auto input_tensor = from_blob( + audio_data.data(), tensor_shape, ::executorch::aten::ScalarType::Float); + + ET_LOG(Info, "Processing audio through processor module..."); + auto result = processor_module->execute("forward", input_tensor); + if (!result.ok()) { + ET_LOG(Error, "Failed to execute processor's forward method"); + throw std::runtime_error("Failed to execute processor forward method"); + } + + auto outputs = result.get(); + if (outputs.empty()) { + ET_LOG(Error, "Processor returned no outputs"); + throw std::runtime_error("Processor returned no outputs"); + } + + // Extract processed audio features + const auto& processed_tensor = outputs[0].toTensor(); + const float* processed_data = processed_tensor.const_data_ptr(); + const auto& sizes = processed_tensor.sizes(); + + ET_LOG( + Info, + "Processed audio tensor shape: [%d, %d, %d]", + static_cast(sizes[0]), + static_cast(sizes[1]), + static_cast(sizes[2])); + + // Create Audio multimodal input from processed features + auto processed_audio = + std::make_unique<::executorch::extension::llm::Audio>(); + processed_audio->batch_size = + static_cast(sizes[0]); // Note: batching for s > 30 doesn't work + // yet, so this will just be = 1. + processed_audio->n_bins = static_cast(sizes[1]); + processed_audio->n_frames = + static_cast(sizes[2]); // And this will just be = 3000. + + size_t total_elements = processed_audio->batch_size * + processed_audio->n_bins * processed_audio->n_frames; + processed_audio->data.resize(total_elements * sizeof(float)); + std::memcpy( + processed_audio->data.data(), + processed_data, + total_elements * sizeof(float)); + + ET_LOG( + Info, + "Created processed Audio: batch_size=%d, n_bins=%d, n_frames=%d", + processed_audio->batch_size, + processed_audio->n_bins, + processed_audio->n_frames); + + return ::executorch::extension::llm::make_audio_input( + std::move(*processed_audio)); +} + +/** + * @brief Processes audio files for multimodal input * - * This function provides a interface for different audio input formats - * and can be extended to support raw audio processing in the future. + * Dispatches audio file processing based on file extension and processor + * availability: + * - .bin files with processor: Loads raw audio from .bin and processes through + * processor + * - .bin files without processor: Loads preprocessed mel spectrogram features + * directly * - * @param audio_path Path to the audio file + * @param audio_path Path to the audio file (.bin) + * @param processor_path Path to the processor .pte file (optional) * @return MultimodalInput containing the processed audio data * @throws std::runtime_error if file format is unsupported or processing fails */ -MultimodalInput processAudioFile(const std::string& audio_path) { +MultimodalInput processAudioFile( + const std::string& audio_path, + const std::string& processor_path = "") { if (ends_with(audio_path, ".bin")) { - // Current behavior - load preprocessed audio stored as a binary file. - return loadPreprocessedAudio(audio_path); - } else if (ends_with(audio_path, ".wav") || ends_with(audio_path, ".mp3")) { - // New: Process raw audio files - unsupported for now - ET_LOG(Error, "Raw audio file processing (.wav/.mp3) is not yet supported"); - throw std::runtime_error("Raw audio file processing not supported"); + if (!processor_path.empty()) { + // Process raw audio from .bin file through the processor + return processRawAudioFile(audio_path, processor_path); + } else { + // Load preprocessed audio stored as a binary file (existing behavior) + return loadPreprocessedAudio(audio_path); + } } else { - ET_LOG(Error, "Unsupported audio file format: %s", audio_path.c_str()); + ET_LOG( + Error, + "Unsupported audio file format: %s (only .bin files are supported)", + audio_path.c_str()); throw std::runtime_error("Unsupported audio file format"); } } @@ -137,6 +286,7 @@ int32_t main(int32_t argc, char** argv) { const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); const char* prompt = FLAGS_prompt.c_str(); const char* audio_path = FLAGS_audio_path.c_str(); + const char* processor_path = FLAGS_processor_path.c_str(); float temperature = FLAGS_temperature; int32_t cpu_threads = FLAGS_cpu_threads; bool warmup = FLAGS_warmup; @@ -184,7 +334,7 @@ int32_t main(int32_t argc, char** argv) { inputs.emplace_back(make_text_input("[INST][BEGIN_AUDIO]")); // 2. Add audio input - inputs.emplace_back(processAudioFile(audio_path)); + inputs.emplace_back(processAudioFile(audio_path, processor_path)); // 3. Add text input (the actual user-submitted prompt) inputs.emplace_back(make_text_input(std::string(prompt) + "[/INST]"));