From 019f8e951e5a250e68e6826ffb0efe3cdcc2d3c9 Mon Sep 17 00:00:00 2001 From: rohansjoshi Date: Tue, 19 Aug 2025 13:28:34 -0700 Subject: [PATCH 1/4] Whisper JNI first commit Added EXECUTORCH_BUILD_WHISPER_JNI flag --- .../whisper/qnn_whisper_runner.cpp | 2 +- .../oss_scripts/whisper/runner/runner.cpp | 12 +- .../oss_scripts/whisper/runner/runner.h | 4 +- extension/android/CMakeLists.txt | 23 +++ .../extension/audio/WhisperCallback.java | 31 +++ .../extension/audio/WhisperModule.java | 60 ++++++ .../extension/audio/package-info.java | 2 + extension/android/jni/BUCK | 29 +++ extension/android/jni/jni_layer.cpp | 7 + extension/android/jni/jni_layer_whisper.cpp | 177 ++++++++++++++++++ scripts/build_android_library.sh | 1 + 11 files changed, 339 insertions(+), 9 deletions(-) create mode 100644 extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperCallback.java create mode 100644 extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperModule.java create mode 100644 extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java create mode 100644 extension/android/jni/jni_layer_whisper.cpp diff --git a/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp b/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp index e61b2f444c0..4590bc5aca8 100644 --- a/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp +++ b/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp @@ -97,7 +97,7 @@ std::vector>> parse_input_list_file( int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); // create llama runner - example::Runner runner(FLAGS_model_path, FLAGS_tokenizer_json_path); + example::WhisperRunner runner(FLAGS_model_path, FLAGS_tokenizer_json_path); std::vector>> multi_turns_input_buffers = parse_input_list_file(FLAGS_input_list_path); diff --git a/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp b/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp index c98326778bf..75a81584f4f 100644 --- a/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp @@ -27,7 +27,7 @@ static constexpr auto kDecoderStartTokenId = "decoder_start_token_id"; static constexpr auto kEosId = "get_eos_id"; static constexpr auto kMaxContextLen = "get_max_context_len"; } // namespace -Runner::Runner( +WhisperRunner::WhisperRunner( const std::string& model_path, const std::string& tokenizer_json_path) : tokenizer_json_path_(tokenizer_json_path) { @@ -35,12 +35,12 @@ Runner::Runner( decoder_ = std::make_unique(model_path); tokenizer_ = std::make_unique(); } -bool Runner::is_loaded() const { +bool WhisperRunner::is_loaded() const { return encoder_->is_method_loaded() && decoder_->is_method_loaded() && tokenizer_->is_loaded() && sampler_; } -Error Runner::load() { +Error WhisperRunner::load() { if (is_loaded()) { return Error::Ok; } @@ -108,7 +108,7 @@ Error Runner::load() { return Error::Ok; } -uint64_t Runner::logits_to_token( +uint64_t WhisperRunner::logits_to_token( const executorch::aten::Tensor& logits_tensor) { return sampler_->sample(logits_tensor.data_ptr()); } @@ -117,7 +117,7 @@ uint64_t Runner::logits_to_token( * encodes a float tensor in little-endian byte order. * */ -Error Runner::transcribe( +Error WhisperRunner::transcribe( int32_t seq_len, std::vector>& inputs, std::function token_callback) { @@ -188,7 +188,7 @@ Error Runner::transcribe( return Error::Ok; } -Error Runner::print_performance() { +Error WhisperRunner::print_performance() { ET_LOG(Info, "\tTotal Generated token:\t\t\t\t%ld", num_generated_token_); ET_LOG( diff --git a/examples/qualcomm/oss_scripts/whisper/runner/runner.h b/examples/qualcomm/oss_scripts/whisper/runner/runner.h index de7c38d0e32..22999f06fc2 100644 --- a/examples/qualcomm/oss_scripts/whisper/runner/runner.h +++ b/examples/qualcomm/oss_scripts/whisper/runner/runner.h @@ -24,9 +24,9 @@ namespace example { -class Runner { +class WhisperRunner { public: - explicit Runner( + explicit WhisperRunner( const std::string& model_path, const std::string& tokenizer_json_path); diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 2599d202e61..80fc5fa6745 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -69,10 +69,14 @@ set_target_properties( executorch_target_link_options_shared_lib(executorch) +<<<<<<< HEAD add_library( executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp jni/jni_helper.cpp ) +======= +add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp) +>>>>>>> 37d0a6944a (Whisper JNI first commit) set(link_libraries) list( @@ -221,6 +225,25 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) endif() endif() +if(EXECUTORCH_BUILD_WHISPER_JNI) + target_sources(executorch_jni PRIVATE jni/jni_layer_whisper.cpp jni/log.cpp) + target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_WHISPER_JNI=1) + if(QNN_SDK_ROOT) + target_sources( + executorch_jni + PRIVATE + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/encoder.cpp + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/decoder.cpp + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp + ) + + target_include_directories( + executorch_jni + PRIVATE ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner + ) + endif() +endif() + target_include_directories( executorch_jni PRIVATE diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperCallback.java new file mode 100644 index 00000000000..a0c0a41fc3a --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperCallback.java @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.audio; + +import com.facebook.jni.annotations.DoNotStrip; +import org.pytorch.executorch.annotations.Experimental; + +/** + * Callback interface for Whisper model. Users can implement this interface to receive the generated + * tokens and statistics. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +public interface WhisperCallback { + /** + * Called when a new result is available from JNI. Users will keep getting onResult() invocations + * until generate() finishes. + * + * @param result Last generated token + */ + @DoNotStrip + public void onResult(String result); + +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperModule.java new file mode 100644 index 00000000000..ad4b21cd924 --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperModule.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.audio; +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import java.io.File; +import org.pytorch.executorch.ExecuTorchRuntime; +import org.pytorch.executorch.annotations.Experimental; + +/** + * WhisperModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text + * from the model. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +public class WhisperModule { + + @DoNotStrip private final HybridData mHybridData; + + @DoNotStrip + private static native HybridData initHybrid( + String modulePath, String tokenizerPath); + + public WhisperModule( + String modulePath, String tokenizerPath) { + ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime(); + + File modelFile = new File(modulePath); + if (!modelFile.canRead() || !modelFile.isFile()) { + throw new RuntimeException("Cannot load model path " + modulePath); + } + File tokenizerFile = new File(tokenizerPath); + if (!tokenizerFile.canRead() || !tokenizerFile.isFile()) { + throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath); + } + mHybridData = initHybrid(modulePath, tokenizerPath); + } + + public void resetNative() { + mHybridData.resetNative(); + } + + @DoNotStrip + public native int transcribe( + int seqLen, + byte[][] inputs, + WhisperCallback callback); + + + /** Force loading the module. Otherwise the model is loaded during first generate(). */ + @DoNotStrip + public native int load(); +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java new file mode 100644 index 00000000000..ddd1b48d80e --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java @@ -0,0 +1,2 @@ +/** Extension for LLM related use cases for ExecuTorch Android Java/JNI package. */ +package org.pytorch.executorch.extension.audio; diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index a6f4fe186cf..cbb27043ebb 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -121,6 +121,35 @@ non_fbcode_target(_kind = fb_android_cxx_library, ], ) +non_fbcode_target(_kind = fb_android_cxx_library, + name = "executorch_whisper_jni", + srcs = [ + "jni_layer.cpp", + "jni_layer_whisper.cpp", + "jni_layer_runtime.cpp", + ], + allow_jni_merging = False, + compiler_flags = ET_JNI_COMPILER_FLAGS + [ + "-DEXECUTORCH_BUILD_WHISPER_JNI", + ], + soname = "libexecutorch.$(ext)", + visibility = ["PUBLIC"], + deps = [ + ":jni_headers", + ":log_provider_static", + "//fbandroid/libraries/fbjni:fbjni", + "//fbandroid/native/fb:fb", + "//third-party/glog:glog", + "//xplat/executorch/backends/xnnpack:xnnpack_backend_static", + "//xplat/executorch/examples/oss_scripts/qualcomm/whisper/runner:runner_static", + "//xplat/executorch/extension/module:module_static", + "//xplat/executorch/extension/runner_util:inputs_static", + "//xplat/executorch/extension/tensor:tensor_static", + "//xplat/executorch/extension/threadpool:cpuinfo_utils_static", + "//xplat/executorch/extension/threadpool:threadpool_static", + ], +) + non_fbcode_target(_kind = runtime.cxx_library, name = "log_provider", srcs = ["log.cpp"], diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index e7ef6e62c74..0ce88abee5e 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -508,6 +508,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass { }; } // namespace executorch::extension +#ifdef EXECUTORCH_BUILD_WHISPER_JNI +extern void register_natives_for_whisper(); +#else +void register_natives_for_whisper() {} +#endif + #ifdef EXECUTORCH_BUILD_LLAMA_JNI extern void register_natives_for_llm(); #else @@ -526,6 +532,7 @@ void register_natives_for_training() {} JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize(vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); + register_natives_for_whisper(); register_natives_for_llm(); register_natives_for_runtime(); register_natives_for_training(); diff --git a/extension/android/jni/jni_layer_whisper.cpp b/extension/android/jni/jni_layer_whisper.cpp new file mode 100644 index 00000000000..f007977f13e --- /dev/null +++ b/extension/android/jni/jni_layer_whisper.cpp @@ -0,0 +1,177 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#if defined(ET_USE_THREADPOOL) +#include +#include +#endif + +#include +#include + +using ::executorch::runtime::Error; + +namespace { +bool utf8_check_validity(const char* str, size_t length) { + for (size_t i = 0; i < length; ++i) { + uint8_t byte = static_cast(str[i]); + if (byte >= 0x80) { // Non-ASCII byte + if (i + 1 >= length) { // Incomplete sequence + return false; + } + uint8_t next_byte = static_cast(str[i + 1]); + if ((byte & 0xE0) == 0xC0 && + (next_byte & 0xC0) == 0x80) { // 2-byte sequence + i += 1; + } else if ( + (byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 && + (i + 2 < length) && + (static_cast(str[i + 2]) & 0xC0) == + 0x80) { // 3-byte sequence + i += 2; + } else if ( + (byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 && + (i + 2 < length) && + (static_cast(str[i + 2]) & 0xC0) == 0x80 && + (i + 3 < length) && + (static_cast(str[i + 3]) & 0xC0) == + 0x80) { // 4-byte sequence + i += 3; + } else { + return false; // Invalid sequence + } + } + } + return true; // All bytes were valid +} + +std::string token_buffer; +} // namespace + +namespace executorch_jni { + +class ExecuTorchWhisperCallbackJni + : public facebook::jni::JavaClass { + public: + constexpr static const char* kJavaDescriptor = + "Lorg/pytorch/executorch/extension/audio/WhisperCallback;"; + + void onResult(std::string result) const { + static auto cls = ExecuTorchWhisperCallbackJni::javaClassStatic(); + static const auto method = + cls->getMethod)>("onResult"); + + token_buffer += result; + if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { + ET_LOG( + Info, "Current token buffer is not valid UTF-8. Waiting for more."); + return; + } + result = token_buffer; + token_buffer = ""; + facebook::jni::local_ref s = facebook::jni::make_jstring(result); + method(self(), s); + } +}; + +class ExecuTorchWhisperJni + : public facebook::jni::HybridClass { + private: + friend HybridBase; + std::unique_ptr runner_; + + public: + constexpr static auto kJavaDescriptor = + "Lorg/pytorch/executorch/extension/audio/WhisperModule;"; + + static facebook::jni::local_ref initHybrid( + facebook::jni::alias_ref, + facebook::jni::alias_ref model_path, + facebook::jni::alias_ref tokenizer_path) { + return makeCxxInstance(model_path, tokenizer_path); + } + + ExecuTorchWhisperJni( + facebook::jni::alias_ref model_path, + facebook::jni::alias_ref tokenizer_path) { +#if defined(ET_USE_THREADPOOL) + // Reserve 1 thread for the main thread. + int32_t num_performant_cores = + ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; + if (num_performant_cores > 0) { + ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); + ::executorch::extension::threadpool::get_threadpool() + ->_unsafe_reset_threadpool(num_performant_cores); + } +#endif + + // create runner + runner_ = std::make_unique( + model_path->toStdString(), tokenizer_path->toStdString()); + } + + jint transcribe( + jint seq_len, + facebook::jni::alias_ref< + facebook::jni::JArrayClass::javaobject> inputs, + facebook::jni::alias_ref callback) { + // Convert Java byte[][] to C++ vector> + std::vector> cppData; + auto input_size = inputs->size(); + + for (jsize i = 0; i < input_size; i++) { + auto byte_array = inputs->getElement(i); + if (byte_array) { + auto array_length = byte_array->size(); + auto bytes = byte_array->getRegion(0, array_length); + std::vector charVector; + charVector.reserve(array_length); + for (jsize j = 0; j < array_length; j++) { + charVector.push_back(static_cast(bytes[j])); + } + cppData.push_back(std::move(charVector)); + } + } + + runner_->transcribe(seq_len, cppData, [callback](std::string result) { + callback->onResult(result); + }); + return 0; + } + + jint load() { + return static_cast(runner_->load()); + } + + static void registerNatives() { + registerHybrid({ + makeNativeMethod("initHybrid", ExecuTorchWhisperJni::initHybrid), + makeNativeMethod("transcribe", ExecuTorchWhisperJni::transcribe), + makeNativeMethod("load", ExecuTorchWhisperJni::load), + }); + } +}; + +} // namespace executorch_jni + +void register_natives_for_whisper() { + executorch_jni::ExecuTorchWhisperJni::registerNatives(); +} + diff --git a/scripts/build_android_library.sh b/scripts/build_android_library.sh index a50d15709bd..e9149180de6 100755 --- a/scripts/build_android_library.sh +++ b/scripts/build_android_library.sh @@ -43,6 +43,7 @@ build_android_native_library() { -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER="${EXECUTORCH_BUILD_EXTENSION_LLM:-ON}" \ -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON \ -DEXECUTORCH_BUILD_LLAMA_JNI="${EXECUTORCH_BUILD_EXTENSION_LLM:-ON}" \ + -DEXECUTORCH_BUILD_WHISPER_JNI="${EXECUTORCH_BUILD_EXTENSION_AUDIO:-ON}" \ -DEXECUTORCH_BUILD_NEURON="${EXECUTORCH_BUILD_NEURON}" \ -DNEURON_BUFFER_ALLOCATOR_LIB="${NEURON_BUFFER_ALLOCATOR_LIB}" \ -DEXECUTORCH_BUILD_QNN="${EXECUTORCH_BUILD_QNN}" \ From 66c9463ba9b147bb857dc8670156166b29409fc3 Mon Sep 17 00:00:00 2001 From: rohansjoshi Date: Thu, 28 Aug 2025 10:22:35 -0700 Subject: [PATCH 2/4] Misc changes --- extension/android/CMakeLists.txt | 8 +++----- extension/android/jni/jni_layer_whisper.cpp | 1 - 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 80fc5fa6745..512599189da 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -69,14 +69,10 @@ set_target_properties( executorch_target_link_options_shared_lib(executorch) -<<<<<<< HEAD add_library( executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp jni/jni_helper.cpp ) -======= -add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp) ->>>>>>> 37d0a6944a (Whisper JNI first commit) set(link_libraries) list( @@ -227,7 +223,9 @@ endif() if(EXECUTORCH_BUILD_WHISPER_JNI) target_sources(executorch_jni PRIVATE jni/jni_layer_whisper.cpp jni/log.cpp) - target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_WHISPER_JNI=1) + target_compile_definitions( + executorch_jni PUBLIC EXECUTORCH_BUILD_WHISPER_JNI=1 + ) if(QNN_SDK_ROOT) target_sources( executorch_jni diff --git a/extension/android/jni/jni_layer_whisper.cpp b/extension/android/jni/jni_layer_whisper.cpp index f007977f13e..209b76076f5 100644 --- a/extension/android/jni/jni_layer_whisper.cpp +++ b/extension/android/jni/jni_layer_whisper.cpp @@ -174,4 +174,3 @@ class ExecuTorchWhisperJni void register_natives_for_whisper() { executorch_jni::ExecuTorchWhisperJni::registerNatives(); } - From 899206bb7718fa7d682dcb5075e171ab2e7395ec Mon Sep 17 00:00:00 2001 From: rohansjoshi Date: Wed, 10 Sep 2025 11:27:41 -0700 Subject: [PATCH 3/4] Changed API Whisper -> ASR --- CMakeLists.txt | 5 ++ .../oss_scripts/whisper/CMakeLists.txt | 1 + .../oss_scripts/whisper/runner/runner.h | 3 +- extension/android/CMakeLists.txt | 14 ++--- ...{WhisperCallback.java => ASRCallback.java} | 2 +- .../{WhisperModule.java => ASRModule.java} | 9 ++- .../extension/audio/package-info.java | 2 +- extension/android/jni/BUCK | 8 +-- extension/android/jni/jni_layer.cpp | 10 +-- ...ni_layer_whisper.cpp => jni_layer_asr.cpp} | 38 ++++++------ extension/llm/runner/asr_runner.h | 61 +++++++++++++++++++ scripts/build_android_library.sh | 3 +- 12 files changed, 113 insertions(+), 43 deletions(-) rename extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/{WhisperCallback.java => ASRCallback.java} (96%) rename extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/{WhisperModule.java => ASRModule.java} (88%) rename extension/android/jni/{jni_layer_whisper.cpp => jni_layer_asr.cpp} (83%) create mode 100644 extension/llm/runner/asr_runner.h diff --git a/CMakeLists.txt b/CMakeLists.txt index fc427d517a9..da71e839990 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -630,7 +630,12 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE) list(APPEND _executorch_extensions extension_module_static) endif() +if(EXECUTORCH_BUILD_EXTENSION_AUDIO) + message(STATUS "Audio/ASR extension enabled") +endif() + if(EXECUTORCH_BUILD_EXTENSION_LLM) + message(STATUS "LLM extension enabled") if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) set(SUPPORT_REGEX_LOOKAHEAD ON) # llama/runner/CMakeLists.txt builds a shared library libllama_runner.so diff --git a/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt b/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt index 8f7d0f9a9be..347d84621db 100644 --- a/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt @@ -14,6 +14,7 @@ set(_qnn_whisper_runner__srcs ${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp ${CMAKE_CURRENT_LIST_DIR}/runner/runner.h ${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp + ${EXECUTORCH_ROOT}/extension/llm/runner/asr_runner.h ) # build qnn whisper runner diff --git a/examples/qualcomm/oss_scripts/whisper/runner/runner.h b/examples/qualcomm/oss_scripts/whisper/runner/runner.h index 22999f06fc2..ad80d7de2f5 100644 --- a/examples/qualcomm/oss_scripts/whisper/runner/runner.h +++ b/examples/qualcomm/oss_scripts/whisper/runner/runner.h @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -24,7 +25,7 @@ namespace example { -class WhisperRunner { +class WhisperRunner : public executorch::extension::llm::ASRRunner { public: explicit WhisperRunner( const std::string& model_path, diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 512599189da..e287385abeb 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -166,10 +166,10 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING) ) endif() -if(EXECUTORCH_BUILD_LLAMA_JNI) +if(EXECUTORCH_BUILD_EXTENSION_LLM) target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp) list(APPEND link_libraries extension_llm_runner) - target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1) + target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_LLM=1) if(QNN_SDK_ROOT) target_sources( @@ -221,11 +221,10 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) endif() endif() -if(EXECUTORCH_BUILD_WHISPER_JNI) - target_sources(executorch_jni PRIVATE jni/jni_layer_whisper.cpp jni/log.cpp) - target_compile_definitions( - executorch_jni PUBLIC EXECUTORCH_BUILD_WHISPER_JNI=1 - ) +if(EXECUTORCH_BUILD_EXTENSION_AUDIO) + target_sources(executorch_jni PRIVATE jni/jni_layer_asr.cpp jni/log.cpp) + target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_AUDIO=1) + if(QNN_SDK_ROOT) target_sources( executorch_jni @@ -239,6 +238,7 @@ if(EXECUTORCH_BUILD_WHISPER_JNI) executorch_jni PRIVATE ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner ) + target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_QNN=1) endif() endif() diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRCallback.java similarity index 96% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperCallback.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRCallback.java index a0c0a41fc3a..5ec3462eb05 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperCallback.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRCallback.java @@ -18,7 +18,7 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public interface WhisperCallback { +public interface ASRCallback { /** * Called when a new result is available from JNI. Users will keep getting onResult() invocations * until generate() finishes. diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRModule.java similarity index 88% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperModule.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRModule.java index ad4b21cd924..36d12e0c0b0 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/WhisperModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRModule.java @@ -14,13 +14,12 @@ import org.pytorch.executorch.annotations.Experimental; /** - * WhisperModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text - * from the model. + * ASRModule is a wrapper around the Executorch ASR runners like Whisper runner. * *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class WhisperModule { +public class ASRModule { @DoNotStrip private final HybridData mHybridData; @@ -28,7 +27,7 @@ public class WhisperModule { private static native HybridData initHybrid( String modulePath, String tokenizerPath); - public WhisperModule( + public ASRModule( String modulePath, String tokenizerPath) { ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime(); @@ -51,7 +50,7 @@ public void resetNative() { public native int transcribe( int seqLen, byte[][] inputs, - WhisperCallback callback); + ASRCallback callback); /** Force loading the module. Otherwise the model is loaded during first generate(). */ diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java index ddd1b48d80e..49f3d5ad2b3 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java @@ -1,2 +1,2 @@ -/** Extension for LLM related use cases for ExecuTorch Android Java/JNI package. */ +/** Extension for ASR related use cases for ExecuTorch Android Java/JNI package. */ package org.pytorch.executorch.extension.audio; diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index cbb27043ebb..bb75b68932b 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -101,7 +101,7 @@ non_fbcode_target(_kind = fb_android_cxx_library, srcs = ["jni_layer.cpp", "jni_layer_llama.cpp", "jni_layer_runtime.cpp", "jni_helper.cpp"], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS + [ - "-DEXECUTORCH_BUILD_LLAMA_JNI", + "-DEXECUTORCH_BUILD_EXTENSION_LLM", ], soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], @@ -122,15 +122,15 @@ non_fbcode_target(_kind = fb_android_cxx_library, ) non_fbcode_target(_kind = fb_android_cxx_library, - name = "executorch_whisper_jni", + name = "executorch_asr_jni", srcs = [ "jni_layer.cpp", - "jni_layer_whisper.cpp", + "jni_layer_asr.cpp", "jni_layer_runtime.cpp", ], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS + [ - "-DEXECUTORCH_BUILD_WHISPER_JNI", + "-DEXECUTORCH_BUILD_EXTENSION_AUDIO", ], soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 0ce88abee5e..1d3a6991d5e 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -508,13 +508,13 @@ class ExecuTorchJni : public facebook::jni::HybridClass { }; } // namespace executorch::extension -#ifdef EXECUTORCH_BUILD_WHISPER_JNI -extern void register_natives_for_whisper(); +#ifdef EXECUTORCH_BUILD_EXTENSION_AUDIO +extern void register_natives_for_asr(); #else -void register_natives_for_whisper() {} +void register_natives_for_asr() {} #endif -#ifdef EXECUTORCH_BUILD_LLAMA_JNI +#ifdef EXECUTORCH_BUILD_EXTENSION_LLM extern void register_natives_for_llm(); #else // No op if we don't build LLM @@ -532,7 +532,7 @@ void register_natives_for_training() {} JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize(vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); - register_natives_for_whisper(); + register_natives_for_asr(); register_natives_for_llm(); register_natives_for_runtime(); register_natives_for_training(); diff --git a/extension/android/jni/jni_layer_whisper.cpp b/extension/android/jni/jni_layer_asr.cpp similarity index 83% rename from extension/android/jni/jni_layer_whisper.cpp rename to extension/android/jni/jni_layer_asr.cpp index 209b76076f5..a2cabdf9c92 100644 --- a/extension/android/jni/jni_layer_whisper.cpp +++ b/extension/android/jni/jni_layer_asr.cpp @@ -13,7 +13,7 @@ #include #include -#include +#include #include #include #include @@ -23,6 +23,10 @@ #include #endif +#if defined(EXECUTORCH_BUILD_QNN) +#include +#endif + #include #include @@ -67,14 +71,14 @@ std::string token_buffer; namespace executorch_jni { -class ExecuTorchWhisperCallbackJni - : public facebook::jni::JavaClass { +class ExecuTorchASRCallbackJni + : public facebook::jni::JavaClass { public: constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/extension/audio/WhisperCallback;"; + "Lorg/pytorch/executorch/extension/audio/ASRCallback;"; void onResult(std::string result) const { - static auto cls = ExecuTorchWhisperCallbackJni::javaClassStatic(); + static auto cls = ExecuTorchASRCallbackJni::javaClassStatic(); static const auto method = cls->getMethod)>("onResult"); @@ -91,15 +95,14 @@ class ExecuTorchWhisperCallbackJni } }; -class ExecuTorchWhisperJni - : public facebook::jni::HybridClass { +class ExecuTorchASRJni : public facebook::jni::HybridClass { private: friend HybridBase; - std::unique_ptr runner_; + std::unique_ptr<::executorch::extension::llm::ASRRunner> runner_; public: constexpr static auto kJavaDescriptor = - "Lorg/pytorch/executorch/extension/audio/WhisperModule;"; + "Lorg/pytorch/executorch/extension/audio/ASRModule;"; static facebook::jni::local_ref initHybrid( facebook::jni::alias_ref, @@ -108,7 +111,7 @@ class ExecuTorchWhisperJni return makeCxxInstance(model_path, tokenizer_path); } - ExecuTorchWhisperJni( + ExecuTorchASRJni( facebook::jni::alias_ref model_path, facebook::jni::alias_ref tokenizer_path) { #if defined(ET_USE_THREADPOOL) @@ -121,17 +124,18 @@ class ExecuTorchWhisperJni ->_unsafe_reset_threadpool(num_performant_cores); } #endif - +#if defined(EXECUTORCH_BUILD_QNN) // create runner runner_ = std::make_unique( model_path->toStdString(), tokenizer_path->toStdString()); +#endif } jint transcribe( jint seq_len, facebook::jni::alias_ref< facebook::jni::JArrayClass::javaobject> inputs, - facebook::jni::alias_ref callback) { + facebook::jni::alias_ref callback) { // Convert Java byte[][] to C++ vector> std::vector> cppData; auto input_size = inputs->size(); @@ -162,15 +166,15 @@ class ExecuTorchWhisperJni static void registerNatives() { registerHybrid({ - makeNativeMethod("initHybrid", ExecuTorchWhisperJni::initHybrid), - makeNativeMethod("transcribe", ExecuTorchWhisperJni::transcribe), - makeNativeMethod("load", ExecuTorchWhisperJni::load), + makeNativeMethod("initHybrid", ExecuTorchASRJni::initHybrid), + makeNativeMethod("transcribe", ExecuTorchASRJni::transcribe), + makeNativeMethod("load", ExecuTorchASRJni::load), }); } }; } // namespace executorch_jni -void register_natives_for_whisper() { - executorch_jni::ExecuTorchWhisperJni::registerNatives(); +void register_natives_for_asr() { + executorch_jni::ExecuTorchASRJni::registerNatives(); } diff --git a/extension/llm/runner/asr_runner.h b/extension/llm/runner/asr_runner.h new file mode 100644 index 00000000000..9a455bf513c --- /dev/null +++ b/extension/llm/runner/asr_runner.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Interface for audio-to-text model runners. Currently only used for +// supporting QNN Whisper Runner + +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +class ET_EXPERIMENTAL ASRRunner { + public: + virtual ~ASRRunner() = default; + + /** + * Check if the runner is loaded and ready for inference. + * + * @return true if the runner is loaded, false otherwise + */ + virtual bool is_loaded() const = 0; + + /** + * Load the model and prepare for inference. + * + * @return Error::Ok if successful, an error otherwise + */ + virtual runtime::Error load() = 0; + + /** + * Generate text from raw audio. + * + * @param seq_len Length of input sequence + * @param inputs A vector containing one element: a vector of bytes that + * encodes a float tensor in little-endian byte order + * @param token_callback Callback function called for each generated token + * @return Error::Ok if successful, an error otherwise + */ + virtual runtime::Error transcribe( + int32_t seq_len, + std::vector>& inputs, + std::function token_callback = {}) = 0; +}; + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/scripts/build_android_library.sh b/scripts/build_android_library.sh index e9149180de6..0083dd6ff65 100755 --- a/scripts/build_android_library.sh +++ b/scripts/build_android_library.sh @@ -42,8 +42,7 @@ build_android_native_library() { -DEXECUTORCH_BUILD_EXTENSION_LLM="${EXECUTORCH_BUILD_EXTENSION_LLM:-ON}" \ -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER="${EXECUTORCH_BUILD_EXTENSION_LLM:-ON}" \ -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON \ - -DEXECUTORCH_BUILD_LLAMA_JNI="${EXECUTORCH_BUILD_EXTENSION_LLM:-ON}" \ - -DEXECUTORCH_BUILD_WHISPER_JNI="${EXECUTORCH_BUILD_EXTENSION_AUDIO:-ON}" \ + -DEXECUTORCH_BUILD_EXTENSION_AUDIO="${EXECUTORCH_BUILD_EXTENSION_AUDIO:-ON}" \ -DEXECUTORCH_BUILD_NEURON="${EXECUTORCH_BUILD_NEURON}" \ -DNEURON_BUFFER_ALLOCATOR_LIB="${NEURON_BUFFER_ALLOCATOR_LIB}" \ -DEXECUTORCH_BUILD_QNN="${EXECUTORCH_BUILD_QNN}" \ From 6afb2213f42e1dcc2ace0be301804e95c7ffe97a Mon Sep 17 00:00:00 2001 From: rohansjoshi Date: Thu, 18 Sep 2025 15:50:45 -0700 Subject: [PATCH 4/4] Addressed feedback --- backends/qualcomm/scripts/build.sh | 2 + .../oss_scripts/whisper/CMakeLists.txt | 2 +- .../whisper/qnn_whisper_runner.cpp | 10 +- .../oss_scripts/whisper/runner/runner.cpp | 17 +- .../oss_scripts/whisper/runner/runner.h | 12 +- extension/android/CMakeLists.txt | 10 +- .../extension/audio/ASRCallback.java | 31 --- .../executorch/extension/audio/ASRModule.java | 6 +- .../extension/audio/package-info.java | 2 +- extension/android/jni/BUCK | 29 --- extension/android/jni/jni_layer_asr.cpp | 180 ------------------ extension/android/jni/jni_layer_llama.cpp | 87 +++++++++ extension/{llm => audio}/runner/asr_runner.h | 14 +- 13 files changed, 134 insertions(+), 268 deletions(-) delete mode 100644 extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRCallback.java delete mode 100644 extension/android/jni/jni_layer_asr.cpp rename extension/{llm => audio}/runner/asr_runner.h (78%) diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index c84911cf851..b05ad2c38ea 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -81,6 +81,7 @@ if [ "$BUILD_AARCH64" = true ]; then -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ -DEXECUTORCH_BUILD_QNN=ON \ -DEXECUTORCH_BUILD_DEVTOOLS=ON \ + -DEXECUTORCH_BUILD_EXTENSION_AUDIO=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ @@ -150,6 +151,7 @@ if [ "$BUILD_X86_64" = true ]; then -DQNN_SDK_ROOT=${QNN_SDK_ROOT} \ -DEXECUTORCH_BUILD_QNN=ON \ -DEXECUTORCH_BUILD_DEVTOOLS=ON \ + -DEXECUTORCH_BUILD_EXTENSION_AUDIO=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ diff --git a/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt b/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt index 347d84621db..137318fcce8 100644 --- a/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt @@ -14,7 +14,7 @@ set(_qnn_whisper_runner__srcs ${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp ${CMAKE_CURRENT_LIST_DIR}/runner/runner.h ${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp - ${EXECUTORCH_ROOT}/extension/llm/runner/asr_runner.h + ${EXECUTORCH_ROOT}/extension/audio/runner/asr_runner.h ) # build qnn whisper runner diff --git a/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp b/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp index 4590bc5aca8..f47923e4884 100644 --- a/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp +++ b/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp @@ -14,6 +14,7 @@ */ #include +#include #include #include #include @@ -110,7 +111,14 @@ int main(int argc, char** argv) { } }; // generate tokens - runner.transcribe(FLAGS_seq_len, multi_turns_input_buffers[iter], callback); + executorch::extension::llm::Audio audio{ + std::vector( + multi_turns_input_buffers[iter][0].begin(), + multi_turns_input_buffers[iter][0].end()), + 1, + 80, + 3000}; + runner.transcribe(FLAGS_seq_len, audio, callback); auto output_file_name = FLAGS_output_folder_path + "/output_" + std::to_string(iter) + ".txt"; std::ofstream fout(output_file_name); diff --git a/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp b/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp index 75a81584f4f..b55c961912c 100644 --- a/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp @@ -112,29 +112,24 @@ uint64_t WhisperRunner::logits_to_token( const executorch::aten::Tensor& logits_tensor) { return sampler_->sample(logits_tensor.data_ptr()); } -/** - * @param inputs: A vector containing one element: a vector of bytes that - * encodes a float tensor in little-endian byte order. - * - */ Error WhisperRunner::transcribe( int32_t seq_len, - std::vector>& inputs, - std::function token_callback) { + executorch::extension::llm::Audio& audio, + std::function token_callback, + std::function + stats_callback) { if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); stats_.model_load_end_ms = time_in_ms(); } - ET_CHECK_MSG(inputs.size() == 1, "The input size of whisper should be one."); - ET_LOG(Info, "Start Encoding"); stats_.encoder_inference_start_ms = time_in_ms(); auto input_features_tensor_ptr = from_blob( - inputs[0].data(), + audio.data.data(), // (1, processor.feature_extractor.feature_size, // processor.feature_extractor.nb_max_frames) - {1, 80, 3000}, + {audio.batch_size, audio.n_bins, audio.n_frames}, // {1, 80, 3000} ScalarType::Float); Result encoder_out = encoder_->encode(input_features_tensor_ptr); auto encoder_out_tensor_ptr = make_tensor_ptr(encoder_out.get()); diff --git a/examples/qualcomm/oss_scripts/whisper/runner/runner.h b/examples/qualcomm/oss_scripts/whisper/runner/runner.h index ad80d7de2f5..9f3b2046f92 100644 --- a/examples/qualcomm/oss_scripts/whisper/runner/runner.h +++ b/examples/qualcomm/oss_scripts/whisper/runner/runner.h @@ -13,7 +13,9 @@ #include #include -#include +#include +#include +#include #include #include #include @@ -25,7 +27,7 @@ namespace example { -class WhisperRunner : public executorch::extension::llm::ASRRunner { +class WhisperRunner : public executorch::extension::audio::ASRRunner { public: explicit WhisperRunner( const std::string& model_path, @@ -52,8 +54,10 @@ class WhisperRunner : public executorch::extension::llm::ASRRunner { executorch::runtime::Error load(); executorch::runtime::Error transcribe( int32_t seq_len, - std::vector>& inputs, - std::function token_callback = {}); + executorch::extension::llm::Audio& audio, + std::function token_callback = {}, + std::function + stats_callback = {}); private: executorch::runtime::Error print_performance(); diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index e287385abeb..9214576e288 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -169,7 +169,9 @@ endif() if(EXECUTORCH_BUILD_EXTENSION_LLM) target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp) list(APPEND link_libraries extension_llm_runner) - target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_LLM=1) + target_compile_definitions( + executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_LLM=1 + ) if(QNN_SDK_ROOT) target_sources( @@ -222,8 +224,10 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM) endif() if(EXECUTORCH_BUILD_EXTENSION_AUDIO) - target_sources(executorch_jni PRIVATE jni/jni_layer_asr.cpp jni/log.cpp) - target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_AUDIO=1) + target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp) + target_compile_definitions( + executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_AUDIO=1 + ) if(QNN_SDK_ROOT) target_sources( diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRCallback.java deleted file mode 100644 index 5ec3462eb05..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRCallback.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.extension.audio; - -import com.facebook.jni.annotations.DoNotStrip; -import org.pytorch.executorch.annotations.Experimental; - -/** - * Callback interface for Whisper model. Users can implement this interface to receive the generated - * tokens and statistics. - * - *

Warning: These APIs are experimental and subject to change without notice - */ -@Experimental -public interface ASRCallback { - /** - * Called when a new result is available from JNI. Users will keep getting onResult() invocations - * until generate() finishes. - * - * @param result Last generated token - */ - @DoNotStrip - public void onResult(String result); - -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRModule.java index 36d12e0c0b0..862f39891bf 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRModule.java @@ -11,6 +11,7 @@ import com.facebook.jni.annotations.DoNotStrip; import java.io.File; import org.pytorch.executorch.ExecuTorchRuntime; +import org.pytorch.executorch.extension.llm.LlmCallback; import org.pytorch.executorch.annotations.Experimental; /** @@ -50,8 +51,9 @@ public void resetNative() { public native int transcribe( int seqLen, byte[][] inputs, - ASRCallback callback); - + LlmCallback callback, + int n_bins, + int n_frames); /** Force loading the module. Otherwise the model is loaded during first generate(). */ @DoNotStrip diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java index 49f3d5ad2b3..8c7e2d27dd5 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/package-info.java @@ -1,2 +1,2 @@ -/** Extension for ASR related use cases for ExecuTorch Android Java/JNI package. */ +/** Extension for audio and ASR related use cases for ExecuTorch Android Java/JNI package. */ package org.pytorch.executorch.extension.audio; diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index bb75b68932b..f79e0ccb98a 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -121,35 +121,6 @@ non_fbcode_target(_kind = fb_android_cxx_library, ], ) -non_fbcode_target(_kind = fb_android_cxx_library, - name = "executorch_asr_jni", - srcs = [ - "jni_layer.cpp", - "jni_layer_asr.cpp", - "jni_layer_runtime.cpp", - ], - allow_jni_merging = False, - compiler_flags = ET_JNI_COMPILER_FLAGS + [ - "-DEXECUTORCH_BUILD_EXTENSION_AUDIO", - ], - soname = "libexecutorch.$(ext)", - visibility = ["PUBLIC"], - deps = [ - ":jni_headers", - ":log_provider_static", - "//fbandroid/libraries/fbjni:fbjni", - "//fbandroid/native/fb:fb", - "//third-party/glog:glog", - "//xplat/executorch/backends/xnnpack:xnnpack_backend_static", - "//xplat/executorch/examples/oss_scripts/qualcomm/whisper/runner:runner_static", - "//xplat/executorch/extension/module:module_static", - "//xplat/executorch/extension/runner_util:inputs_static", - "//xplat/executorch/extension/tensor:tensor_static", - "//xplat/executorch/extension/threadpool:cpuinfo_utils_static", - "//xplat/executorch/extension/threadpool:threadpool_static", - ], -) - non_fbcode_target(_kind = runtime.cxx_library, name = "log_provider", srcs = ["log.cpp"], diff --git a/extension/android/jni/jni_layer_asr.cpp b/extension/android/jni/jni_layer_asr.cpp deleted file mode 100644 index a2cabdf9c92..00000000000 --- a/extension/android/jni/jni_layer_asr.cpp +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#if defined(ET_USE_THREADPOOL) -#include -#include -#endif - -#if defined(EXECUTORCH_BUILD_QNN) -#include -#endif - -#include -#include - -using ::executorch::runtime::Error; - -namespace { -bool utf8_check_validity(const char* str, size_t length) { - for (size_t i = 0; i < length; ++i) { - uint8_t byte = static_cast(str[i]); - if (byte >= 0x80) { // Non-ASCII byte - if (i + 1 >= length) { // Incomplete sequence - return false; - } - uint8_t next_byte = static_cast(str[i + 1]); - if ((byte & 0xE0) == 0xC0 && - (next_byte & 0xC0) == 0x80) { // 2-byte sequence - i += 1; - } else if ( - (byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 && - (i + 2 < length) && - (static_cast(str[i + 2]) & 0xC0) == - 0x80) { // 3-byte sequence - i += 2; - } else if ( - (byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 && - (i + 2 < length) && - (static_cast(str[i + 2]) & 0xC0) == 0x80 && - (i + 3 < length) && - (static_cast(str[i + 3]) & 0xC0) == - 0x80) { // 4-byte sequence - i += 3; - } else { - return false; // Invalid sequence - } - } - } - return true; // All bytes were valid -} - -std::string token_buffer; -} // namespace - -namespace executorch_jni { - -class ExecuTorchASRCallbackJni - : public facebook::jni::JavaClass { - public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/extension/audio/ASRCallback;"; - - void onResult(std::string result) const { - static auto cls = ExecuTorchASRCallbackJni::javaClassStatic(); - static const auto method = - cls->getMethod)>("onResult"); - - token_buffer += result; - if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { - ET_LOG( - Info, "Current token buffer is not valid UTF-8. Waiting for more."); - return; - } - result = token_buffer; - token_buffer = ""; - facebook::jni::local_ref s = facebook::jni::make_jstring(result); - method(self(), s); - } -}; - -class ExecuTorchASRJni : public facebook::jni::HybridClass { - private: - friend HybridBase; - std::unique_ptr<::executorch::extension::llm::ASRRunner> runner_; - - public: - constexpr static auto kJavaDescriptor = - "Lorg/pytorch/executorch/extension/audio/ASRModule;"; - - static facebook::jni::local_ref initHybrid( - facebook::jni::alias_ref, - facebook::jni::alias_ref model_path, - facebook::jni::alias_ref tokenizer_path) { - return makeCxxInstance(model_path, tokenizer_path); - } - - ExecuTorchASRJni( - facebook::jni::alias_ref model_path, - facebook::jni::alias_ref tokenizer_path) { -#if defined(ET_USE_THREADPOOL) - // Reserve 1 thread for the main thread. - int32_t num_performant_cores = - ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; - if (num_performant_cores > 0) { - ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); - ::executorch::extension::threadpool::get_threadpool() - ->_unsafe_reset_threadpool(num_performant_cores); - } -#endif -#if defined(EXECUTORCH_BUILD_QNN) - // create runner - runner_ = std::make_unique( - model_path->toStdString(), tokenizer_path->toStdString()); -#endif - } - - jint transcribe( - jint seq_len, - facebook::jni::alias_ref< - facebook::jni::JArrayClass::javaobject> inputs, - facebook::jni::alias_ref callback) { - // Convert Java byte[][] to C++ vector> - std::vector> cppData; - auto input_size = inputs->size(); - - for (jsize i = 0; i < input_size; i++) { - auto byte_array = inputs->getElement(i); - if (byte_array) { - auto array_length = byte_array->size(); - auto bytes = byte_array->getRegion(0, array_length); - std::vector charVector; - charVector.reserve(array_length); - for (jsize j = 0; j < array_length; j++) { - charVector.push_back(static_cast(bytes[j])); - } - cppData.push_back(std::move(charVector)); - } - } - - runner_->transcribe(seq_len, cppData, [callback](std::string result) { - callback->onResult(result); - }); - return 0; - } - - jint load() { - return static_cast(runner_->load()); - } - - static void registerNatives() { - registerHybrid({ - makeNativeMethod("initHybrid", ExecuTorchASRJni::initHybrid), - makeNativeMethod("transcribe", ExecuTorchASRJni::transcribe), - makeNativeMethod("load", ExecuTorchASRJni::load), - }); - } -}; - -} // namespace executorch_jni - -void register_natives_for_asr() { - executorch_jni::ExecuTorchASRJni::registerNatives(); -} diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index cabf30c42e4..ead8dca9774 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -13,6 +13,8 @@ #include #include +#include +#include #include #include #include @@ -33,6 +35,7 @@ #if defined(EXECUTORCH_BUILD_QNN) #include +#include #endif #if defined(EXECUTORCH_BUILD_MEDIATEK) @@ -317,8 +320,92 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { } }; +class ExecuTorchASRJni : public facebook::jni::HybridClass { + private: + friend HybridBase; + std::unique_ptr<::executorch::extension::audio::ASRRunner> runner_; + + public: + constexpr static auto kJavaDescriptor = + "Lorg/pytorch/executorch/extension/audio/ASRModule;"; + + static facebook::jni::local_ref initHybrid( + facebook::jni::alias_ref, + facebook::jni::alias_ref model_path, + facebook::jni::alias_ref tokenizer_path) { + return makeCxxInstance(model_path, tokenizer_path); + } + + ExecuTorchASRJni( + facebook::jni::alias_ref model_path, + facebook::jni::alias_ref tokenizer_path) { +#if defined(ET_USE_THREADPOOL) + // Reserve 1 thread for the main thread. + int32_t num_performant_cores = + ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; + if (num_performant_cores > 0) { + ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); + ::executorch::extension::threadpool::get_threadpool() + ->_unsafe_reset_threadpool(num_performant_cores); + } +#endif +#if defined(EXECUTORCH_BUILD_QNN) + // create runner + runner_ = std::make_unique( + model_path->toStdString(), tokenizer_path->toStdString()); +#endif + } + + jint transcribe( + jint seq_len, + facebook::jni::alias_ref< + facebook::jni::JArrayClass::javaobject> inputs, + facebook::jni::alias_ref callback, + jint n_bins = 80, // whisper defaults + jint n_frames = 3000) { + // Convert Java byte[][] to C++ vector> + std::vector> cppData; + auto input_size = inputs->size(); + cppData.reserve(input_size); + // TODO: add support for larger batch sizes + for (jsize i = 0; i < input_size; i++) { + auto byte_array = inputs->getElement(i); + if (byte_array) { + auto array_length = byte_array->size(); + auto bytes = byte_array->getRegion(0, array_length); + std::vector uint8Vector; + uint8Vector.reserve(array_length); + for (jsize j = 0; j < array_length; j++) { + uint8Vector.push_back(static_cast(bytes[j])); + } + cppData.push_back(std::move(uint8Vector)); + } + } + executorch::extension::llm::Audio audio{cppData[0], 1, n_bins, n_frames}; + runner_->transcribe(seq_len, audio, [callback](std::string result) { + callback->onResult(result); + }); + return 0; + } + + jint load() { + return static_cast(runner_->load()); + } + + static void registerNatives() { + registerHybrid({ + makeNativeMethod("initHybrid", ExecuTorchASRJni::initHybrid), + makeNativeMethod("transcribe", ExecuTorchASRJni::transcribe), + makeNativeMethod("load", ExecuTorchASRJni::load), + }); + } +}; + } // namespace executorch_jni void register_natives_for_llm() { executorch_jni::ExecuTorchLlmJni::registerNatives(); } +void register_natives_for_asr() { + executorch_jni::ExecuTorchASRJni::registerNatives(); +} diff --git a/extension/llm/runner/asr_runner.h b/extension/audio/runner/asr_runner.h similarity index 78% rename from extension/llm/runner/asr_runner.h rename to extension/audio/runner/asr_runner.h index 9a455bf513c..1178b999f9a 100644 --- a/extension/llm/runner/asr_runner.h +++ b/extension/audio/runner/asr_runner.h @@ -16,12 +16,13 @@ #include #include +#include #include #include namespace executorch { namespace extension { -namespace llm { +namespace audio { class ET_EXPERIMENTAL ASRRunner { public: @@ -45,17 +46,20 @@ class ET_EXPERIMENTAL ASRRunner { * Generate text from raw audio. * * @param seq_len Length of input sequence - * @param inputs A vector containing one element: a vector of bytes that + * @param audio processed audio input, which contains a vector of bytes that * encodes a float tensor in little-endian byte order * @param token_callback Callback function called for each generated token + * @param stats_callback Callback function for generation statistics * @return Error::Ok if successful, an error otherwise */ virtual runtime::Error transcribe( int32_t seq_len, - std::vector>& inputs, - std::function token_callback = {}) = 0; + ::executorch::extension::llm::Audio& audio, + std::function token_callback = {}, + std::function + stats_callback = {}) = 0; }; -} // namespace llm +} // namespace audio } // namespace extension } // namespace executorch