diff --git a/CMakeLists.txt b/CMakeLists.txt index fc427d517a9..da71e839990 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -630,7 +630,12 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE) list(APPEND _executorch_extensions extension_module_static) endif() +if(EXECUTORCH_BUILD_EXTENSION_AUDIO) + message(STATUS "Audio/ASR extension enabled") +endif() + if(EXECUTORCH_BUILD_EXTENSION_LLM) + message(STATUS "LLM extension enabled") if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) set(SUPPORT_REGEX_LOOKAHEAD ON) # llama/runner/CMakeLists.txt builds a shared library libllama_runner.so diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index c84911cf851..b05ad2c38ea 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -81,6 +81,7 @@ if [ "$BUILD_AARCH64" = true ]; then -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ -DEXECUTORCH_BUILD_QNN=ON \ -DEXECUTORCH_BUILD_DEVTOOLS=ON \ + -DEXECUTORCH_BUILD_EXTENSION_AUDIO=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ @@ -150,6 +151,7 @@ if [ "$BUILD_X86_64" = true ]; then -DQNN_SDK_ROOT=${QNN_SDK_ROOT} \ -DEXECUTORCH_BUILD_QNN=ON \ -DEXECUTORCH_BUILD_DEVTOOLS=ON \ + -DEXECUTORCH_BUILD_EXTENSION_AUDIO=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ diff --git a/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt b/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt index 8f7d0f9a9be..137318fcce8 100644 --- a/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt @@ -14,6 +14,7 @@ set(_qnn_whisper_runner__srcs ${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp ${CMAKE_CURRENT_LIST_DIR}/runner/runner.h ${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp + ${EXECUTORCH_ROOT}/extension/audio/runner/asr_runner.h ) # build qnn whisper runner diff --git a/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp b/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp index e61b2f444c0..f47923e4884 100644 --- a/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp +++ b/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp @@ -14,6 +14,7 @@ */ #include +#include #include #include #include @@ -97,7 +98,7 @@ std::vector>> parse_input_list_file( int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); // create llama runner - example::Runner runner(FLAGS_model_path, FLAGS_tokenizer_json_path); + example::WhisperRunner runner(FLAGS_model_path, FLAGS_tokenizer_json_path); std::vector>> multi_turns_input_buffers = parse_input_list_file(FLAGS_input_list_path); @@ -110,7 +111,14 @@ int main(int argc, char** argv) { } }; // generate tokens - runner.transcribe(FLAGS_seq_len, multi_turns_input_buffers[iter], callback); + executorch::extension::llm::Audio audio{ + std::vector( + multi_turns_input_buffers[iter][0].begin(), + multi_turns_input_buffers[iter][0].end()), + 1, + 80, + 3000}; + runner.transcribe(FLAGS_seq_len, audio, callback); auto output_file_name = FLAGS_output_folder_path + "/output_" + std::to_string(iter) + ".txt"; std::ofstream fout(output_file_name); diff --git a/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp b/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp index c98326778bf..b55c961912c 100644 --- a/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp @@ -27,7 +27,7 @@ static constexpr auto kDecoderStartTokenId = "decoder_start_token_id"; static constexpr auto kEosId = "get_eos_id"; static constexpr auto kMaxContextLen = "get_max_context_len"; } // namespace -Runner::Runner( +WhisperRunner::WhisperRunner( const std::string& model_path, const std::string& tokenizer_json_path) : tokenizer_json_path_(tokenizer_json_path) { @@ -35,12 +35,12 @@ Runner::Runner( decoder_ = std::make_unique(model_path); tokenizer_ = std::make_unique(); } -bool Runner::is_loaded() const { +bool WhisperRunner::is_loaded() const { return encoder_->is_method_loaded() && decoder_->is_method_loaded() && tokenizer_->is_loaded() && sampler_; } -Error Runner::load() { +Error WhisperRunner::load() { if (is_loaded()) { return Error::Ok; } @@ -108,33 +108,28 @@ Error Runner::load() { return Error::Ok; } -uint64_t Runner::logits_to_token( +uint64_t WhisperRunner::logits_to_token( const executorch::aten::Tensor& logits_tensor) { return sampler_->sample(logits_tensor.data_ptr()); } -/** - * @param inputs: A vector containing one element: a vector of bytes that - * encodes a float tensor in little-endian byte order. - * - */ -Error Runner::transcribe( +Error WhisperRunner::transcribe( int32_t seq_len, - std::vector>& inputs, - std::function token_callback) { + executorch::extension::llm::Audio& audio, + std::function token_callback, + std::function + stats_callback) { if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); stats_.model_load_end_ms = time_in_ms(); } - ET_CHECK_MSG(inputs.size() == 1, "The input size of whisper should be one."); - ET_LOG(Info, "Start Encoding"); stats_.encoder_inference_start_ms = time_in_ms(); auto input_features_tensor_ptr = from_blob( - inputs[0].data(), + audio.data.data(), // (1, processor.feature_extractor.feature_size, // processor.feature_extractor.nb_max_frames) - {1, 80, 3000}, + {audio.batch_size, audio.n_bins, audio.n_frames}, // {1, 80, 3000} ScalarType::Float); Result encoder_out = encoder_->encode(input_features_tensor_ptr); auto encoder_out_tensor_ptr = make_tensor_ptr(encoder_out.get()); @@ -188,7 +183,7 @@ Error Runner::transcribe( return Error::Ok; } -Error Runner::print_performance() { +Error WhisperRunner::print_performance() { ET_LOG(Info, "\tTotal Generated token:\t\t\t\t%ld", num_generated_token_); ET_LOG( diff --git a/examples/qualcomm/oss_scripts/whisper/runner/runner.h b/examples/qualcomm/oss_scripts/whisper/runner/runner.h index de7c38d0e32..9f3b2046f92 100644 --- a/examples/qualcomm/oss_scripts/whisper/runner/runner.h +++ b/examples/qualcomm/oss_scripts/whisper/runner/runner.h @@ -13,6 +13,9 @@ #include #include +#include +#include +#include #include #include #include @@ -24,9 +27,9 @@ namespace example { -class Runner { +class WhisperRunner : public executorch::extension::audio::ASRRunner { public: - explicit Runner( + explicit WhisperRunner( const std::string& model_path, const std::string& tokenizer_json_path); @@ -51,8 +54,10 @@ class Runner { executorch::runtime::Error load(); executorch::runtime::Error transcribe( int32_t seq_len, - std::vector>& inputs, - std::function token_callback = {}); + executorch::extension::llm::Audio& audio, + std::function token_callback = {}, + std::function + stats_callback = {}); private: executorch::runtime::Error print_performance(); diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 2599d202e61..9214576e288 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -166,10 +166,12 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING) ) endif() -if(EXECUTORCH_BUILD_LLAMA_JNI) +if(EXECUTORCH_BUILD_EXTENSION_LLM) target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp) list(APPEND link_libraries extension_llm_runner) - target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1) + target_compile_definitions( + executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_LLM=1 + ) if(QNN_SDK_ROOT) target_sources( @@ -221,6 +223,29 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) endif() endif() +if(EXECUTORCH_BUILD_EXTENSION_AUDIO) + target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp) + target_compile_definitions( + executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_AUDIO=1 + ) + + if(QNN_SDK_ROOT) + target_sources( + executorch_jni + PRIVATE + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/encoder.cpp + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/decoder.cpp + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp + ) + + target_include_directories( + executorch_jni + PRIVATE ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner + ) + target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_QNN=1) + endif() +endif() + target_include_directories( executorch_jni PRIVATE diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRModule.java new file mode 100644 index 00000000000..862f39891bf --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRModule.java @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.audio; +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import java.io.File; +import org.pytorch.executorch.ExecuTorchRuntime; +import org.pytorch.executorch.extension.llm.LlmCallback; +import org.pytorch.executorch.annotations.Experimental; + +/** + * ASRModule is a wrapper around the Executorch ASR runners like Whisper runner. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +public class ASRModule { + + @DoNotStrip private final HybridData mHybridData; + + @DoNotStrip + private static native HybridData initHybrid( + String modulePath, String tokenizerPath); + + public ASRModule( + String modulePath, String tokenizerPath) { + ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime(); + + File modelFile = new File(modulePath); + if (!modelFile.canRead() || !modelFile.isFile()) { + throw new RuntimeException("Cannot load model path " + modulePath); + } + File tokenizerFile = new File(tokenizerPath); + if (!tokenizerFile.canRead() || !tokenizerFile.isFile()) { + throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath); + } + mHybridData = initHybrid(modulePath, tokenizerPath); + } + + public void resetNative() { + mHybridData.resetNative(); + } + + @DoNotStrip + public native int transcribe( + int seqLen, + byte[][] inputs, + LlmCallback callback, + int n_bins, + int n_frames); + + /** Force loading the module. Otherwise the model is loaded during first generate(). */ + @DoNotStrip + public native int load(); +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java new file mode 100644 index 00000000000..8c7e2d27dd5 --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java @@ -0,0 +1,2 @@ +/** Extension for audio and ASR related use cases for ExecuTorch Android Java/JNI package. */ +package org.pytorch.executorch.extension.audio; diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index a6f4fe186cf..f79e0ccb98a 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -101,7 +101,7 @@ non_fbcode_target(_kind = fb_android_cxx_library, srcs = ["jni_layer.cpp", "jni_layer_llama.cpp", "jni_layer_runtime.cpp", "jni_helper.cpp"], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS + [ - "-DEXECUTORCH_BUILD_LLAMA_JNI", + "-DEXECUTORCH_BUILD_EXTENSION_LLM", ], soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index e7ef6e62c74..1d3a6991d5e 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -508,7 +508,13 @@ class ExecuTorchJni : public facebook::jni::HybridClass { }; } // namespace executorch::extension -#ifdef EXECUTORCH_BUILD_LLAMA_JNI +#ifdef EXECUTORCH_BUILD_EXTENSION_AUDIO +extern void register_natives_for_asr(); +#else +void register_natives_for_asr() {} +#endif + +#ifdef EXECUTORCH_BUILD_EXTENSION_LLM extern void register_natives_for_llm(); #else // No op if we don't build LLM @@ -526,6 +532,7 @@ void register_natives_for_training() {} JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize(vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); + register_natives_for_asr(); register_natives_for_llm(); register_natives_for_runtime(); register_natives_for_training(); diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index cabf30c42e4..ead8dca9774 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -13,6 +13,8 @@ #include #include +#include +#include #include #include #include @@ -33,6 +35,7 @@ #if defined(EXECUTORCH_BUILD_QNN) #include +#include #endif #if defined(EXECUTORCH_BUILD_MEDIATEK) @@ -317,8 +320,92 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { } }; +class ExecuTorchASRJni : public facebook::jni::HybridClass { + private: + friend HybridBase; + std::unique_ptr<::executorch::extension::audio::ASRRunner> runner_; + + public: + constexpr static auto kJavaDescriptor = + "Lorg/pytorch/executorch/extension/audio/ASRModule;"; + + static facebook::jni::local_ref initHybrid( + facebook::jni::alias_ref, + facebook::jni::alias_ref model_path, + facebook::jni::alias_ref tokenizer_path) { + return makeCxxInstance(model_path, tokenizer_path); + } + + ExecuTorchASRJni( + facebook::jni::alias_ref model_path, + facebook::jni::alias_ref tokenizer_path) { +#if defined(ET_USE_THREADPOOL) + // Reserve 1 thread for the main thread. + int32_t num_performant_cores = + ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; + if (num_performant_cores > 0) { + ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); + ::executorch::extension::threadpool::get_threadpool() + ->_unsafe_reset_threadpool(num_performant_cores); + } +#endif +#if defined(EXECUTORCH_BUILD_QNN) + // create runner + runner_ = std::make_unique( + model_path->toStdString(), tokenizer_path->toStdString()); +#endif + } + + jint transcribe( + jint seq_len, + facebook::jni::alias_ref< + facebook::jni::JArrayClass::javaobject> inputs, + facebook::jni::alias_ref callback, + jint n_bins = 80, // whisper defaults + jint n_frames = 3000) { + // Convert Java byte[][] to C++ vector> + std::vector> cppData; + auto input_size = inputs->size(); + cppData.reserve(input_size); + // TODO: add support for larger batch sizes + for (jsize i = 0; i < input_size; i++) { + auto byte_array = inputs->getElement(i); + if (byte_array) { + auto array_length = byte_array->size(); + auto bytes = byte_array->getRegion(0, array_length); + std::vector uint8Vector; + uint8Vector.reserve(array_length); + for (jsize j = 0; j < array_length; j++) { + uint8Vector.push_back(static_cast(bytes[j])); + } + cppData.push_back(std::move(uint8Vector)); + } + } + executorch::extension::llm::Audio audio{cppData[0], 1, n_bins, n_frames}; + runner_->transcribe(seq_len, audio, [callback](std::string result) { + callback->onResult(result); + }); + return 0; + } + + jint load() { + return static_cast(runner_->load()); + } + + static void registerNatives() { + registerHybrid({ + makeNativeMethod("initHybrid", ExecuTorchASRJni::initHybrid), + makeNativeMethod("transcribe", ExecuTorchASRJni::transcribe), + makeNativeMethod("load", ExecuTorchASRJni::load), + }); + } +}; + } // namespace executorch_jni void register_natives_for_llm() { executorch_jni::ExecuTorchLlmJni::registerNatives(); } +void register_natives_for_asr() { + executorch_jni::ExecuTorchASRJni::registerNatives(); +} diff --git a/extension/audio/runner/asr_runner.h b/extension/audio/runner/asr_runner.h new file mode 100644 index 00000000000..1178b999f9a --- /dev/null +++ b/extension/audio/runner/asr_runner.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Interface for audio-to-text model runners. Currently only used for +// supporting QNN Whisper Runner + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace executorch { +namespace extension { +namespace audio { + +class ET_EXPERIMENTAL ASRRunner { + public: + virtual ~ASRRunner() = default; + + /** + * Check if the runner is loaded and ready for inference. + * + * @return true if the runner is loaded, false otherwise + */ + virtual bool is_loaded() const = 0; + + /** + * Load the model and prepare for inference. + * + * @return Error::Ok if successful, an error otherwise + */ + virtual runtime::Error load() = 0; + + /** + * Generate text from raw audio. + * + * @param seq_len Length of input sequence + * @param audio processed audio input, which contains a vector of bytes that + * encodes a float tensor in little-endian byte order + * @param token_callback Callback function called for each generated token + * @param stats_callback Callback function for generation statistics + * @return Error::Ok if successful, an error otherwise + */ + virtual runtime::Error transcribe( + int32_t seq_len, + ::executorch::extension::llm::Audio& audio, + std::function token_callback = {}, + std::function + stats_callback = {}) = 0; +}; + +} // namespace audio +} // namespace extension +} // namespace executorch diff --git a/scripts/build_android_library.sh b/scripts/build_android_library.sh index a50d15709bd..0083dd6ff65 100755 --- a/scripts/build_android_library.sh +++ b/scripts/build_android_library.sh @@ -42,7 +42,7 @@ build_android_native_library() { -DEXECUTORCH_BUILD_EXTENSION_LLM="${EXECUTORCH_BUILD_EXTENSION_LLM:-ON}" \ -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER="${EXECUTORCH_BUILD_EXTENSION_LLM:-ON}" \ -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON \ - -DEXECUTORCH_BUILD_LLAMA_JNI="${EXECUTORCH_BUILD_EXTENSION_LLM:-ON}" \ + -DEXECUTORCH_BUILD_EXTENSION_AUDIO="${EXECUTORCH_BUILD_EXTENSION_AUDIO:-ON}" \ -DEXECUTORCH_BUILD_NEURON="${EXECUTORCH_BUILD_NEURON}" \ -DNEURON_BUFFER_ALLOCATOR_LIB="${NEURON_BUFFER_ALLOCATOR_LIB}" \ -DEXECUTORCH_BUILD_QNN="${EXECUTORCH_BUILD_QNN}" \