diff --git a/build/build_android_llm_demo.sh b/build/build_android_llm_demo.sh index 4d34eb95b23..7b7150de210 100644 --- a/build/build_android_llm_demo.sh +++ b/build/build_android_llm_demo.sh @@ -30,6 +30,7 @@ build_android_native_library() { -DEXECUTORCH_XNNPACK_SHARED_WORKSPACE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index daa9c7c2496..6827ae79040 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -32,8 +32,15 @@ find_package(executorch CONFIG REQUIRED) target_link_options_shared_lib(executorch) set(link_libraries) -list(APPEND link_libraries extension_data_loader extension_module extension_threadpool executorch - fbjni +list( + APPEND + link_libraries + executorch + extension_data_loader + extension_module + extension_runner_util + extension_threadpool + fbjni ) if(TARGET optimized_native_cpu_ops_lib) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index c70912a5451..79c6ebc5161 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -18,6 +18,7 @@ #include "jni_layer_constants.h" #include +#include #include #include #include @@ -56,7 +57,7 @@ void et_pal_emit_log_message( using namespace torch::executor; -namespace executorch_jni { +namespace executorch::extension { class TensorHybrid : public facebook::jni::HybridClass { public: constexpr static const char* kJavaDescriptor = @@ -352,19 +353,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass { return jresult; } + jint forward_ones() { + auto&& load_result = module_->load_method("forward"); + auto&& buf = prepare_input_tensors(*(module_->methods_["forward"].method)); + auto&& result = module_->methods_["forward"].method->execute(); + return (jint)result; + } + static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), makeNativeMethod("forward", ExecuTorchJni::forward), makeNativeMethod("execute", ExecuTorchJni::execute), makeNativeMethod("loadMethod", ExecuTorchJni::load_method), + makeNativeMethod("forwardOnes", ExecuTorchJni::forward_ones), }); } }; - -} // namespace executorch_jni +} // namespace executorch::extension JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize( - vm, [] { executorch_jni::ExecuTorchJni::registerNatives(); }); + vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); }); } diff --git a/extension/android/jni/jni_layer_constants.h b/extension/android/jni/jni_layer_constants.h index ac52b3a650d..43946ffab6e 100644 --- a/extension/android/jni/jni_layer_constants.h +++ b/extension/android/jni/jni_layer_constants.h @@ -10,7 +10,7 @@ #include -namespace executorch_jni { +namespace executorch::extension { constexpr static int kTensorDTypeUInt8 = 0; constexpr static int kTensorDTypeInt8 = 1; @@ -93,4 +93,4 @@ const std::unordered_map java_dtype_to_scalar_type = { {kTensorDTypeBits16, ScalarType::Bits16}, }; -} // namespace executorch_jni +} // namespace executorch::extension diff --git a/extension/android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/src/main/java/org/pytorch/executorch/Module.java index 5e57174114d..981cfcd8c62 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/src/main/java/org/pytorch/executorch/Module.java @@ -79,6 +79,12 @@ public static Module load(final String modelPath) { * @return return value from the 'forward' method. */ public EValue[] forward(EValue... inputs) { + if (inputs.length == 0) { + // forward default args (ones) + mNativePeer.forwardOnes(); + // discard the return value + return null; + } return mNativePeer.forward(inputs); } diff --git a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java index 865c503765d..6eadbf05097 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java +++ b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java @@ -13,6 +13,7 @@ import com.facebook.soloader.nativeloader.NativeLoader; import java.util.Map; +/** Interface for the native peer object for entry points to the Module */ class NativePeer { static { // Loads libexecutorch.so from jniLibs @@ -29,16 +30,33 @@ private static native HybridData initHybrid( mHybridData = initHybrid(moduleAbsolutePath, extraFiles, loadMode); } + /** Clean up the native resources associated with this instance */ public void resetNative() { mHybridData.resetNative(); } + /** Run a "forward" call with the given inputs */ @DoNotStrip public native EValue[] forward(EValue... inputs); + /** + * Run a "forward" call with the sample inputs (ones) to test a module + * + * @return the outputs of the forward call + * @apiNote This is experimental and test-only API + */ + @DoNotStrip + public native int forwardOnes(); + + /** Run an arbitrary method on the module */ @DoNotStrip public native EValue[] execute(String methodName, EValue... inputs); + /** + * Load a method on this module. + * + * @return the Error code if there was an error loading the method + */ @DoNotStrip public native int loadMethod(String methodName); } diff --git a/extension/module/module.h b/extension/module/module.h index 689fef5cd29..8ae7e436556 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -358,6 +358,8 @@ class Module final { std::unique_ptr<::executorch::runtime::MemoryAllocator> temp_allocator_; std::unique_ptr<::executorch::runtime::EventTracer> event_tracer_; std::unordered_map methods_; + + friend class ExecuTorchJni; }; } // namespace extension