From 7271dff534528050f9db26c527378bca7f9d6a31 Mon Sep 17 00:00:00 2001 From: kirklandsign Date: Sat, 31 Aug 2024 00:25:05 -0700 Subject: [PATCH 01/20] try to merge jni --- build/build_android_llm_demo.sh | 5 +-- extension/android/CMakeLists.txt | 10 +----- extension/android/jni/jni_layer.cpp | 6 +++- extension/android/jni/jni_layer_llama.cpp | 32 ++----------------- .../org/pytorch/executorch/NativePeer.java | 2 +- 5 files changed, 13 insertions(+), 42 deletions(-) diff --git a/build/build_android_llm_demo.sh b/build/build_android_llm_demo.sh index 7b7150de210..83db918fbbf 100644 --- a/build/build_android_llm_demo.sh +++ b/build/build_android_llm_demo.sh @@ -87,13 +87,14 @@ build_aar() { find jni -type f -name "libexecutorch_jni.so" -exec bash -c 'mv "$1" "${1/_jni/}"' bash {} \; # Zip all necessary files into the AAR file zip -r executorch.aar libs jni/*/libexecutorch.so AndroidManifest.xml - zip -r executorch-llama.aar libs jni/*/libexecutorch_llama_jni.so AndroidManifest.xml + zip -r executorch-llama.aar libs jni/*/libexecutorch_llama_jni.so jni/*/libexecutorch.so AndroidManifest.xml popd } build_android_llm_demo_app() { mkdir -p examples/demo-apps/android/LlamaDemo/app/libs cp ${BUILD_AAR_DIR}/executorch-llama.aar examples/demo-apps/android/LlamaDemo/app/libs + cp ${BUILD_AAR_DIR}/executorch-llama.aar extension/android/benchmark/app/libs/executorch.aar pushd examples/demo-apps/android/LlamaDemo ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew build assembleAndroidTest popd @@ -120,7 +121,7 @@ collect_artifacts_to_be_uploaded() { BUILD_AAR_DIR="$(mktemp -d)" export BUILD_AAR_DIR -ANDROID_ABIS=("arm64-v8a" "x86_64") +ANDROID_ABIS=("x86_64") export ANDROID_ABIS ARTIFACTS_DIR_NAME="$1" diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 6827ae79040..914b0089294 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -70,14 +70,7 @@ if(TARGET vulkan_backend) list(APPEND link_libraries vulkan_backend) endif() -add_library(executorch_jni SHARED jni/jni_layer.cpp) -target_link_libraries(executorch_jni ${link_libraries}) -target_include_directories( - executorch_jni PRIVATE ${_common_include_directories} -) -target_compile_options(executorch_jni PUBLIC ${_common_compile_options}) -if(EXECUTORCH_BUILD_LLAMA_JNI) set(LLAMA_RUNNER_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama2/runner/libllama_runner.a ) @@ -100,7 +93,7 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) target_link_options_shared_lib(quantized_ops_lib) - set(LLAMA_JNI_SRCS jni/jni_layer_llama.cpp) + set(LLAMA_JNI_SRCS jni/jni_layer_llama.cpp jni/jni_layer.cpp) add_library(executorch_llama_jni SHARED ${LLAMA_JNI_SRCS}) if(TARGET pthreadpool) target_compile_definitions(executorch_llama_jni PRIVATE ET_USE_THREADPOOL=1) @@ -144,4 +137,3 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) ) set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) target_link_libraries(executorch_llama_jni re2::re2) -endif() diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 79c6ebc5161..d93c31b0c02 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -372,7 +372,11 @@ class ExecuTorchJni : public facebook::jni::HybridClass { }; } // namespace executorch::extension +extern void register_natives_jni(); + JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize( - vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); }); + vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); + register_natives_jni(); + }); } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 4f67d04396c..2a082181f75 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -30,32 +30,6 @@ #include #include -#ifdef __ANDROID__ -#include - -// For Android, write to logcat -void et_pal_emit_log_message( - et_timestamp_t timestamp, - et_pal_log_level_t level, - const char* filename, - const char* function, - size_t line, - const char* message, - size_t length) { - int android_log_level = ANDROID_LOG_UNKNOWN; - if (level == 'D') { - android_log_level = ANDROID_LOG_DEBUG; - } else if (level == 'I') { - android_log_level = ANDROID_LOG_INFO; - } else if (level == 'E') { - android_log_level = ANDROID_LOG_ERROR; - } else if (level == 'F') { - android_log_level = ANDROID_LOG_FATAL; - } - - __android_log_print(android_log_level, "LLAMA", "%s", message); -} -#endif using namespace torch::executor; @@ -209,7 +183,7 @@ class ExecuTorchLlamaJni } // namespace executorch_jni -JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { - return facebook::jni::initialize( - vm, [] { executorch_jni::ExecuTorchLlamaJni::registerNatives(); }); +void register_natives_jni() { + executorch_jni::ExecuTorchLlamaJni::registerNatives(); } + diff --git a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java index 6eadbf05097..caa005493a7 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java +++ b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java @@ -17,7 +17,7 @@ class NativePeer { static { // Loads libexecutorch.so from jniLibs - NativeLoader.loadLibrary("executorch"); + NativeLoader.loadLibrary("executorch_llama_jni"); } private final HybridData mHybridData; From c81a8f16f45d8ed1241617da65d23e665baa1089 Mon Sep 17 00:00:00 2001 From: kirklandsign Date: Sat, 31 Aug 2024 00:32:52 -0700 Subject: [PATCH 02/20] new activity! --- .../app/src/main/AndroidManifest.xml | 8 ++ .../minibench/LlmBenchmarkActivity.java | 99 +++++++++++++++++++ .../org/pytorch/minibench/ModelRunner.java | 97 ++++++++++++++++++ .../minibench/ModelRunnerCallback.java | 24 +++++ 4 files changed, 228 insertions(+) create mode 100644 extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java create mode 100644 extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java create mode 100644 extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java diff --git a/extension/android/benchmark/app/src/main/AndroidManifest.xml b/extension/android/benchmark/app/src/main/AndroidManifest.xml index 49711b6830e..098905c052c 100644 --- a/extension/android/benchmark/app/src/main/AndroidManifest.xml +++ b/extension/android/benchmark/app/src/main/AndroidManifest.xml @@ -16,6 +16,14 @@ + + + + + + diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java new file mode 100644 index 00000000000..14ee6334094 --- /dev/null +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + + +import android.app.Activity; +import android.content.Intent; +import android.os.Bundle; +import android.util.Log; +import android.widget.TextView; +import java.io.FileWriter; +import java.io.IOException; + +public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback { + ModelRunner mModelRunner; + + String mPrompt; + StatsDump mStatsDump; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + + Intent intent = getIntent(); + + String modelPath = intent.getStringExtra("model_path"); + String tokenizerPath = intent.getStringExtra("tokenizer_path"); + + float temperature = intent.getFloatExtra("temperature", 0.8f); + mPrompt = intent.getStringExtra("prompt"); + if (mPrompt == null) { + mPrompt = "The ultimate answer"; + } + + mStatsDump = new StatsDump(); + mModelRunner = new ModelRunner(modelPath, tokenizerPath, temperature, this); + mStatsDump.loadStart = System.currentTimeMillis(); + } + + @Override + public void onModelLoaded(int status) { + mStatsDump.loadEnd = System.currentTimeMillis(); + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); + onGenerationStopped(); + return; + } + mStatsDump.generateStart = System.currentTimeMillis(); + mModelRunner.generate(mPrompt); + } + + @Override + public void onTokenGenerated(String token) { + } + + @Override + public void onStats(String stats) { + mStatsDump.tokens = stats; + } + + @Override + public void onGenerationStopped() { + mStatsDump.generateEnd = System.currentTimeMillis(); + + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { + writer.write(mStatsDump.toString()); + } catch (IOException e) { + e.printStackTrace(); + } + } +} + +class StatsDump { + long loadStart; + long loadEnd; + long generateStart; + long generateEnd; + String tokens; + + @Override + public String toString() { + return "loadStart: " + + loadStart + + "\nloadEnd: " + + loadEnd + + "\ngenerateStart: " + + generateStart + + "\ngenerateEnd: " + + generateEnd + + "\n" + + tokens; + } +} \ No newline at end of file diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java new file mode 100644 index 00000000000..c435dafde65 --- /dev/null +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Looper; +import android.os.Message; +import org.pytorch.executorch.LlamaCallback; +import org.pytorch.executorch.LlamaModule; + +/** A helper class to handle all model running logic within this class. */ +public class ModelRunner implements LlamaCallback { + LlamaModule mModule = null; + + String mModelFilePath = ""; + String mTokenizerFilePath = ""; + + ModelRunnerCallback mCallback = null; + + HandlerThread mHandlerThread = null; + Handler mHandler = null; + + /** + * ] Helper class to separate between UI logic and model runner logic. Automatically handle + * generate() request on worker thread. + * + * @param modelFilePath + * @param tokenizerFilePath + * @param callback + */ + ModelRunner( + String modelFilePath, + String tokenizerFilePath, + float temperature, + ModelRunnerCallback callback) { + mModelFilePath = modelFilePath; + mTokenizerFilePath = tokenizerFilePath; + mCallback = callback; + + mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f); + mHandlerThread = new HandlerThread("ModelRunner"); + mHandlerThread.start(); + mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); + } + + int generate(String prompt) { + Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); + msg.sendToTarget(); + return 0; + } + + void stop() { + mModule.stop(); + } + + @Override + public void onResult(String result) { + mCallback.onTokenGenerated(result); + } + + @Override + public void onStats(float tps) { + mCallback.onStats("tokens/second: " + tps); + } +} + +class ModelRunnerHandler extends Handler { + public static int MESSAGE_LOAD_MODEL = 1; + public static int MESSAGE_GENERATE = 2; + + private final ModelRunner mModelRunner; + + public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { + super(looper); + mModelRunner = modelRunner; + } + + @Override + public void handleMessage(Message msg) { + if (msg.what == MESSAGE_LOAD_MODEL) { + int status = mModelRunner.mModule.load(); + mModelRunner.mCallback.onModelLoaded(status); + } else if (msg.what == MESSAGE_GENERATE) { + mModelRunner.mModule.generate((String) msg.obj, mModelRunner); + mModelRunner.mCallback.onGenerationStopped(); + } + } +} \ No newline at end of file diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java new file mode 100644 index 00000000000..0435be6875c --- /dev/null +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + +/** + * A helper interface within the app for MainActivity and Benchmarking to handle callback from + * ModelRunner. + */ +public interface ModelRunnerCallback { + + void onModelLoaded(int status); + + void onTokenGenerated(String token); + + void onStats(String token); + + void onGenerationStopped(); +} \ No newline at end of file From b406fbbc992a62876613a081cdeb8ce758b59743 Mon Sep 17 00:00:00 2001 From: kirklandsign Date: Sat, 31 Aug 2024 15:51:33 -0700 Subject: [PATCH 03/20] remove unused --- .../main/java/org/pytorch/minibench/LlmBenchmarkActivity.java | 1 - 1 file changed, 1 deletion(-) diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java index 14ee6334094..46f04f9eea5 100644 --- a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java @@ -13,7 +13,6 @@ import android.content.Intent; import android.os.Bundle; import android.util.Log; -import android.widget.TextView; import java.io.FileWriter; import java.io.IOException; From 11b3d6abd90229a7186462273f18cbd463c6514d Mon Sep 17 00:00:00 2001 From: kirklandsign Date: Sat, 31 Aug 2024 16:27:25 -0700 Subject: [PATCH 04/20] Remove API forwardOnes Instead, in Java layer, add a note that if no args is given to forward, we will infer that the user wants to try with sample value --- .../minibench/LlmBenchmarkActivity.java | 2 +- extension/android/jni/jni_layer.cpp | 26 +++++++++++++------ .../java/org/pytorch/executorch/Module.java | 8 ++---- .../org/pytorch/executorch/NativePeer.java | 9 ------- 4 files changed, 21 insertions(+), 24 deletions(-) diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java index 46f04f9eea5..5cfda5971e5 100644 --- a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java @@ -51,7 +51,7 @@ public void onModelLoaded(int status) { return; } mStatsDump.generateStart = System.currentTimeMillis(); - mModelRunner.generate(mPrompt); + int generateStatus = mModelRunner.generate(mPrompt); } @Override diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index d93c31b0c02..8b6dde7873f 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -293,6 +293,24 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::alias_ref< facebook::jni::JArrayClass::javaobject> jinputs) { + + // If no inputs is given, it will run with sample inputs (ones) + if (jinputs->size() == 0) { + auto&& underlying_method = module_->methods_[method].method; + auto&& buf = prepare_input_tensors(*underlying_method); + auto result = underlying_method->execute(); + if (result != Error::Ok) { + return {}; + } + facebook::jni::local_ref> jresult = + facebook::jni::JArrayClass::newArray(underlying_method->outputs_size()); + + for (int i = 0; i < underlying_method->outputs_size(); i++) { + auto jevalue = JEValue::newJEValueFromEValue(underlying_method->get_output(i)); + jresult->setElement(i, *jevalue); + } + return jresult; + } std::vector evalues = {}; std::vector managed_tensors = {}; @@ -353,20 +371,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass { return jresult; } - jint forward_ones() { - auto&& load_result = module_->load_method("forward"); - auto&& buf = prepare_input_tensors(*(module_->methods_["forward"].method)); - auto&& result = module_->methods_["forward"].method->execute(); - return (jint)result; - } - static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), makeNativeMethod("forward", ExecuTorchJni::forward), makeNativeMethod("execute", ExecuTorchJni::execute), makeNativeMethod("loadMethod", ExecuTorchJni::load_method), - makeNativeMethod("forwardOnes", ExecuTorchJni::forward_ones), }); } }; diff --git a/extension/android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/src/main/java/org/pytorch/executorch/Module.java index 981cfcd8c62..31ba049bfd2 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/src/main/java/org/pytorch/executorch/Module.java @@ -76,15 +76,11 @@ public static Module load(final String modelPath) { * Runs the 'forward' method of this module with the specified arguments. * * @param inputs arguments for the ExecuTorch module's 'forward' method. + * Note: if method 'forward' requires inputs but no inputs are given, the + * function will not error out, but run 'forward' with sample inputs. * @return return value from the 'forward' method. */ public EValue[] forward(EValue... inputs) { - if (inputs.length == 0) { - // forward default args (ones) - mNativePeer.forwardOnes(); - // discard the return value - return null; - } return mNativePeer.forward(inputs); } diff --git a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java index caa005493a7..a11ec39c27f 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java +++ b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java @@ -39,15 +39,6 @@ public void resetNative() { @DoNotStrip public native EValue[] forward(EValue... inputs); - /** - * Run a "forward" call with the sample inputs (ones) to test a module - * - * @return the outputs of the forward call - * @apiNote This is experimental and test-only API - */ - @DoNotStrip - public native int forwardOnes(); - /** Run an arbitrary method on the module */ @DoNotStrip public native EValue[] execute(String methodName, EValue... inputs); From 4259eb153849c0d23d25d39bb9219a166eae00d2 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Sep 2024 14:42:43 -0700 Subject: [PATCH 05/20] Remove managed_tensors --- extension/android/jni/jni_layer.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index abf3c283c83..743004f777f 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -313,7 +313,6 @@ class ExecuTorchJni : public facebook::jni::HybridClass { return jresult; } - std::vector managed_tensors = {}; std::vector evalues; std::vector tensors; From 3516aae33d2a0479258a9f3f8e92e5b02a75d3ff Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Sep 2024 19:22:03 -0700 Subject: [PATCH 06/20] [Android] Remove forwardOnes Instead, if we pass no args to forward(), we infer that the user wants to test against default inputs (ones) so we prepare the inputs for user --- extension/android/jni/jni_layer.cpp | 26 ++++++++++++++----- .../java/org/pytorch/executorch/Module.java | 8 ++---- .../org/pytorch/executorch/NativePeer.java | 9 ------- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index ef74d6480bb..48376cf2c03 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -294,6 +294,25 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::alias_ref< facebook::jni::JArrayClass::javaobject> jinputs) { + + // If no inputs is given, it will run with sample inputs (ones) + if (jinputs->size() == 0) { + auto&& underlying_method = module_->methods_[method].method; + auto&& buf = prepare_input_tensors(*underlying_method); + auto result = underlying_method->execute(); + if (result != Error::Ok) { + return {}; + } + facebook::jni::local_ref> jresult = + facebook::jni::JArrayClass::newArray(underlying_method->outputs_size()); + + for (int i = 0; i < underlying_method->outputs_size(); i++) { + auto jevalue = JEValue::newJEValueFromEValue(underlying_method->get_output(i)); + jresult->setElement(i, *jevalue); + } + return jresult; + } + std::vector evalues; std::vector tensors; @@ -352,13 +371,6 @@ class ExecuTorchJni : public facebook::jni::HybridClass { return jresult; } - jint forward_ones() { - auto&& load_result = module_->load_method("forward"); - auto&& buf = prepare_input_tensors(*(module_->methods_["forward"].method)); - auto&& result = module_->methods_["forward"].method->execute(); - return (jint)result; - } - static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), diff --git a/extension/android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/src/main/java/org/pytorch/executorch/Module.java index dc4bf710d9b..f41afd974b2 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/src/main/java/org/pytorch/executorch/Module.java @@ -80,15 +80,11 @@ public static Module load(final String modelPath) { * Runs the 'forward' method of this module with the specified arguments. * * @param inputs arguments for the ExecuTorch module's 'forward' method. + * Note: if method 'forward' requires inputs but no inputs are given, the + * function will not error out, but run 'forward' with sample inputs. * @return return value from the 'forward' method. */ public EValue[] forward(EValue... inputs) { - if (inputs.length == 0) { - // forward default args (ones) - mNativePeer.forwardOnes(); - // discard the return value - return null; - } return mNativePeer.forward(inputs); } diff --git a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java index 0e6c0a231cb..f63de985069 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java +++ b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java @@ -43,15 +43,6 @@ public void resetNative() { @DoNotStrip public native EValue[] forward(EValue... inputs); - /** - * Run a "forward" call with the sample inputs (ones) to test a module - * - * @return the outputs of the forward call - * @apiNote This is experimental and test-only API - */ - @DoNotStrip - public native int forwardOnes(); - /** Run an arbitrary method on the module */ @DoNotStrip public native EValue[] execute(String methodName, EValue... inputs); From 0cc883eee7b8d09fc5867256d45fd26f04491972 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Sep 2024 19:23:30 -0700 Subject: [PATCH 07/20] fix build --- extension/android/jni/jni_layer.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 48376cf2c03..0facfe1458a 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -377,7 +377,6 @@ class ExecuTorchJni : public facebook::jni::HybridClass { makeNativeMethod("forward", ExecuTorchJni::forward), makeNativeMethod("execute", ExecuTorchJni::execute), makeNativeMethod("loadMethod", ExecuTorchJni::load_method), - makeNativeMethod("forwardOnes", ExecuTorchJni::forward_ones), }); } }; From 7c369a3c6170cd8e662ec69fb7946028926f12a4 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Sep 2024 15:38:49 -0700 Subject: [PATCH 08/20] copy qnn part --- build/build_android_llm_demo.sh | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/build/build_android_llm_demo.sh b/build/build_android_llm_demo.sh index cd1d21a70b6..2d7df8cb068 100644 --- a/build/build_android_llm_demo.sh +++ b/build/build_android_llm_demo.sh @@ -84,6 +84,19 @@ build_android_native_library() { # Copy artifacts to ABI specific directory mkdir -p "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}" cp "${CMAKE_OUT}"/extension/android/*.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" + + # Copy QNN related so library + if [ -n "$QNN_SDK_ROOT" ] && [ "$ANDROID_ABI" == "arm64-v8a" ]; then + cp "${CMAKE_OUT}"/lib/libqnn_executorch_backend.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" + cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtp.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" + cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnSystem.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" + cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV69Stub.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" + cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV73Stub.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" + cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV75Stub.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" + cp "${QNN_SDK_ROOT}"/lib/hexagon-v69/unsigned/libQnnHtpV69Skel.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" + cp "${QNN_SDK_ROOT}"/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" + cp "${QNN_SDK_ROOT}"/lib/hexagon-v75/unsigned/libQnnHtpV75Skel.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" + fi } build_aar() { From bab4d66a21eb4ed931fc193247736594a71940d9 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Sep 2024 15:40:32 -0700 Subject: [PATCH 09/20] load method first --- extension/android/jni/jni_layer.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 743004f777f..b88b483b2b7 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -297,6 +297,9 @@ class ExecuTorchJni : public facebook::jni::HybridClass { // If no inputs is given, it will run with sample inputs (ones) if (jinputs->size() == 0) { + if (module_->load_method(method) != Error::Ok) { + return {}; + } auto&& underlying_method = module_->methods_[method].method; auto&& buf = prepare_input_tensors(*underlying_method); auto result = underlying_method->execute(); From 1d6a86e2ac00902fc377e5e90c6db49e1acaf1e4 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Sep 2024 15:47:03 -0700 Subject: [PATCH 10/20] Need to load method --- extension/android/jni/jni_layer.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 0facfe1458a..2592e50b90e 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -297,6 +297,9 @@ class ExecuTorchJni : public facebook::jni::HybridClass { // If no inputs is given, it will run with sample inputs (ones) if (jinputs->size() == 0) { + if (module_->load_method(method) != Error::Ok) { + return {}; + } auto&& underlying_method = module_->methods_[method].method; auto&& buf = prepare_input_tensors(*underlying_method); auto result = underlying_method->execute(); From 38ae11cee0c5f09df06e8ea539c61ed535a68da8 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Sep 2024 15:56:16 -0700 Subject: [PATCH 11/20] linter --- extension/android/jni/jni_layer.cpp | 7 ++++--- .../src/main/java/org/pytorch/executorch/Module.java | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 2592e50b90e..f2cfc4a5cff 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -294,7 +294,6 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::alias_ref< facebook::jni::JArrayClass::javaobject> jinputs) { - // If no inputs is given, it will run with sample inputs (ones) if (jinputs->size() == 0) { if (module_->load_method(method) != Error::Ok) { @@ -307,10 +306,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass { return {}; } facebook::jni::local_ref> jresult = - facebook::jni::JArrayClass::newArray(underlying_method->outputs_size()); + facebook::jni::JArrayClass::newArray( + underlying_method->outputs_size()); for (int i = 0; i < underlying_method->outputs_size(); i++) { - auto jevalue = JEValue::newJEValueFromEValue(underlying_method->get_output(i)); + auto jevalue = + JEValue::newJEValueFromEValue(underlying_method->get_output(i)); jresult->setElement(i, *jevalue); } return jresult; diff --git a/extension/android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/src/main/java/org/pytorch/executorch/Module.java index f41afd974b2..de2ed78b520 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/src/main/java/org/pytorch/executorch/Module.java @@ -79,9 +79,9 @@ public static Module load(final String modelPath) { /** * Runs the 'forward' method of this module with the specified arguments. * - * @param inputs arguments for the ExecuTorch module's 'forward' method. - * Note: if method 'forward' requires inputs but no inputs are given, the - * function will not error out, but run 'forward' with sample inputs. + * @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward' + * requires inputs but no inputs are given, the function will not error out, but run 'forward' + * with sample inputs. * @return return value from the 'forward' method. */ public EValue[] forward(EValue... inputs) { From e384d07ac618e0edb4b7b1242287d6f4d2358741 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Sep 2024 17:01:42 -0700 Subject: [PATCH 12/20] add qnn --- build/build_android_llm_demo.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build/build_android_llm_demo.sh b/build/build_android_llm_demo.sh index 2d7df8cb068..e2d899c0350 100644 --- a/build/build_android_llm_demo.sh +++ b/build/build_android_llm_demo.sh @@ -109,8 +109,8 @@ build_aar() { # between Java and JNI find jni -type f -name "libexecutorch_jni.so" -exec bash -c 'mv "$1" "${1/_jni/}"' bash {} \; # Zip all necessary files into the AAR file - zip -r executorch.aar libs jni/*/libexecutorch.so AndroidManifest.xml - zip -r executorch-llama.aar libs jni/*/libexecutorch_llama_jni.so jni/*/libexecutorch.so AndroidManifest.xml + zip -r executorch.aar libs jni/*/libexecutorch.so jni/*/libqnn*.so jni/*/libQnn*.so AndroidManifest.xml + zip -r executorch-llama.aar libs jni/*/libexecutorch_llama_jni.so jni/*/libqnn*.so jni/*/libQnn*.so AndroidManifest.xml popd } From 2b1d2e753402f752a227a92c1ea98d4d47837787 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 9 Sep 2024 10:51:30 -0700 Subject: [PATCH 13/20] Update Java Activity part --- .../android/benchmark/app/build.gradle.kts | 1 + .../minibench/LlmBenchmarkActivity.java | 148 ++++++++++-------- .../org/pytorch/minibench/ModelRunner.java | 148 +++++++++--------- .../minibench/ModelRunnerCallback.java | 10 +- 4 files changed, 162 insertions(+), 145 deletions(-) diff --git a/extension/android/benchmark/app/build.gradle.kts b/extension/android/benchmark/app/build.gradle.kts index b716f2e8bd0..dcf99ca9cd0 100644 --- a/extension/android/benchmark/app/build.gradle.kts +++ b/extension/android/benchmark/app/build.gradle.kts @@ -38,6 +38,7 @@ dependencies { implementation(files("libs/executorch.aar")) implementation("com.facebook.soloader:soloader:0.10.5") implementation("com.facebook.fbjni:fbjni:0.5.1") + implementation("com.google.code.gson:gson:2.8.6") testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.2.1") androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1") diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java index 5cfda5971e5..aba9f6c6799 100644 --- a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java @@ -8,91 +8,107 @@ package org.pytorch.minibench; - import android.app.Activity; import android.content.Intent; import android.os.Bundle; import android.util.Log; +import com.google.gson.Gson; +import java.io.File; import java.io.FileWriter; import java.io.IOException; +import java.util.Arrays; public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback { - ModelRunner mModelRunner; - - String mPrompt; - StatsDump mStatsDump; + ModelRunner mModelRunner; - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); + String mPrompt; + StatsDump mStatsDump; - Intent intent = getIntent(); + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); - String modelPath = intent.getStringExtra("model_path"); - String tokenizerPath = intent.getStringExtra("tokenizer_path"); + Intent intent = getIntent(); - float temperature = intent.getFloatExtra("temperature", 0.8f); - mPrompt = intent.getStringExtra("prompt"); - if (mPrompt == null) { - mPrompt = "The ultimate answer"; - } - - mStatsDump = new StatsDump(); - mModelRunner = new ModelRunner(modelPath, tokenizerPath, temperature, this); - mStatsDump.loadStart = System.currentTimeMillis(); - } + File modelDir = new File(intent.getStringExtra("model_dir")); + File model = + Arrays.stream(modelDir.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .findFirst() + .get(); + String tokenizerPath = intent.getStringExtra("tokenizer_path"); - @Override - public void onModelLoaded(int status) { - mStatsDump.loadEnd = System.currentTimeMillis(); - if (status != 0) { - Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); - onGenerationStopped(); - return; - } - mStatsDump.generateStart = System.currentTimeMillis(); - int generateStatus = mModelRunner.generate(mPrompt); + float temperature = intent.getFloatExtra("temperature", 0.8f); + mPrompt = intent.getStringExtra("prompt"); + if (mPrompt == null) { + mPrompt = "The ultimate answer"; } - @Override - public void onTokenGenerated(String token) { + mStatsDump = new StatsDump(); + mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); + mStatsDump.loadStart = System.currentTimeMillis(); + } + + @Override + public void onModelLoaded(int status) { + mStatsDump.loadEnd = System.currentTimeMillis(); + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); + onGenerationStopped(); + return; } - - @Override - public void onStats(String stats) { - mStatsDump.tokens = stats; + mStatsDump.generateStart = System.currentTimeMillis(); + mModelRunner.generate(mPrompt); + } + + @Override + public void onTokenGenerated(String token) {} + + @Override + public void onStats(String stats) { + mStatsDump.tokens = stats; + } + + @Override + public void onGenerationStopped() { + mStatsDump.generateEnd = System.currentTimeMillis(); + + // TODO (huydhn): Remove txt files here once the JSON format is ready + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { + writer.write(mStatsDump.toString()); + } catch (IOException e) { + e.printStackTrace(); } - @Override - public void onGenerationStopped() { - mStatsDump.generateEnd = System.currentTimeMillis(); - - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { - writer.write(mStatsDump.toString()); - } catch (IOException e) { - e.printStackTrace(); - } + // TODO (huydhn): Figure out on what the final JSON results looks like, we need something + // with the same number of fields as https://github.com/pytorch/pytorch/pull/135042 + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(mStatsDump)); + } catch (IOException e) { + e.printStackTrace(); } + } } class StatsDump { - long loadStart; - long loadEnd; - long generateStart; - long generateEnd; - String tokens; - - @Override - public String toString() { - return "loadStart: " - + loadStart - + "\nloadEnd: " - + loadEnd - + "\ngenerateStart: " - + generateStart - + "\ngenerateEnd: " - + generateEnd - + "\n" - + tokens; - } -} \ No newline at end of file + long loadStart; + long loadEnd; + long generateStart; + long generateEnd; + String tokens; + + @Override + public String toString() { + return "loadStart: " + + loadStart + + "\nloadEnd: " + + loadEnd + + "\ngenerateStart: " + + generateStart + + "\ngenerateEnd: " + + generateEnd + + "\n" + + tokens; + } +} diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java index c435dafde65..9e9b9e003d8 100644 --- a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -17,81 +17,81 @@ /** A helper class to handle all model running logic within this class. */ public class ModelRunner implements LlamaCallback { - LlamaModule mModule = null; - - String mModelFilePath = ""; - String mTokenizerFilePath = ""; - - ModelRunnerCallback mCallback = null; - - HandlerThread mHandlerThread = null; - Handler mHandler = null; - - /** - * ] Helper class to separate between UI logic and model runner logic. Automatically handle - * generate() request on worker thread. - * - * @param modelFilePath - * @param tokenizerFilePath - * @param callback - */ - ModelRunner( - String modelFilePath, - String tokenizerFilePath, - float temperature, - ModelRunnerCallback callback) { - mModelFilePath = modelFilePath; - mTokenizerFilePath = tokenizerFilePath; - mCallback = callback; - - mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f); - mHandlerThread = new HandlerThread("ModelRunner"); - mHandlerThread.start(); - mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); - } - - int generate(String prompt) { - Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); - msg.sendToTarget(); - return 0; - } - - void stop() { - mModule.stop(); - } - - @Override - public void onResult(String result) { - mCallback.onTokenGenerated(result); - } - - @Override - public void onStats(float tps) { - mCallback.onStats("tokens/second: " + tps); - } + LlamaModule mModule = null; + + String mModelFilePath = ""; + String mTokenizerFilePath = ""; + + ModelRunnerCallback mCallback = null; + + HandlerThread mHandlerThread = null; + Handler mHandler = null; + + /** + * ] Helper class to separate between UI logic and model runner logic. Automatically handle + * generate() request on worker thread. + * + * @param modelFilePath + * @param tokenizerFilePath + * @param callback + */ + ModelRunner( + String modelFilePath, + String tokenizerFilePath, + float temperature, + ModelRunnerCallback callback) { + mModelFilePath = modelFilePath; + mTokenizerFilePath = tokenizerFilePath; + mCallback = callback; + + mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f); + mHandlerThread = new HandlerThread("ModelRunner"); + mHandlerThread.start(); + mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); + } + + int generate(String prompt) { + Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); + msg.sendToTarget(); + return 0; + } + + void stop() { + mModule.stop(); + } + + @Override + public void onResult(String result) { + mCallback.onTokenGenerated(result); + } + + @Override + public void onStats(float tps) { + mCallback.onStats("tokens/second: " + tps); + } } class ModelRunnerHandler extends Handler { - public static int MESSAGE_LOAD_MODEL = 1; - public static int MESSAGE_GENERATE = 2; - - private final ModelRunner mModelRunner; - - public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { - super(looper); - mModelRunner = modelRunner; + public static int MESSAGE_LOAD_MODEL = 1; + public static int MESSAGE_GENERATE = 2; + + private final ModelRunner mModelRunner; + + public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { + super(looper); + mModelRunner = modelRunner; + } + + @Override + public void handleMessage(android.os.Message msg) { + if (msg.what == MESSAGE_LOAD_MODEL) { + int status = mModelRunner.mModule.load(); + mModelRunner.mCallback.onModelLoaded(status); + } else if (msg.what == MESSAGE_GENERATE) { + mModelRunner.mModule.generate((String) msg.obj, mModelRunner); + mModelRunner.mCallback.onGenerationStopped(); } - - @Override - public void handleMessage(Message msg) { - if (msg.what == MESSAGE_LOAD_MODEL) { - int status = mModelRunner.mModule.load(); - mModelRunner.mCallback.onModelLoaded(status); - } else if (msg.what == MESSAGE_GENERATE) { - mModelRunner.mModule.generate((String) msg.obj, mModelRunner); - mModelRunner.mCallback.onGenerationStopped(); - } - } -} \ No newline at end of file + } +} diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java index 0435be6875c..63701a7bbc6 100644 --- a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java @@ -14,11 +14,11 @@ */ public interface ModelRunnerCallback { - void onModelLoaded(int status); + void onModelLoaded(int status); - void onTokenGenerated(String token); + void onTokenGenerated(String token); - void onStats(String token); + void onStats(String token); - void onGenerationStopped(); -} \ No newline at end of file + void onGenerationStopped(); +} From 51529502cdc1a75659964e6cb8cd2196b3c64523 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 9 Sep 2024 14:51:21 -0700 Subject: [PATCH 14/20] update cmake --- build/build_android_llm_demo.sh | 15 +-- extension/android/CMakeLists.txt | 97 ++++++++----------- .../org/pytorch/executorch/LlamaModule.java | 2 +- .../org/pytorch/executorch/NativePeer.java | 2 +- 4 files changed, 43 insertions(+), 73 deletions(-) diff --git a/build/build_android_llm_demo.sh b/build/build_android_llm_demo.sh index e2d899c0350..07de28be833 100644 --- a/build/build_android_llm_demo.sh +++ b/build/build_android_llm_demo.sh @@ -54,20 +54,6 @@ build_android_native_library() { fi cmake --build "${CMAKE_OUT}" -j "${CMAKE_JOBS}" --target install --config Release - cmake examples/models/llama2 \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DANDROID_ABI="$ANDROID_ABI" \ - -DANDROID_PLATFORM=android-23 \ - -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ - -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ - -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ - -DEXECUTORCH_BUILD_XNNPACK=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -B"${CMAKE_OUT}"/examples/models/llama2 - - cmake --build "${CMAKE_OUT}"/examples/models/llama2 -j "${CMAKE_JOBS}" --config Release - - cmake extension/android \ -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ -DANDROID_ABI="${ANDROID_ABI}" \ @@ -75,6 +61,7 @@ build_android_native_library() { -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DEXECUTORCH_ENABLE_LOGGING=ON \ -DEXECUTORCH_LOG_LEVEL=Info \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ -DCMAKE_BUILD_TYPE=Release \ -B"${CMAKE_OUT}"/extension/android diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index f99bae1bc2b..5763306d34f 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -10,7 +10,6 @@ project(executorch_jni) if(NOT CMAKE_CXX_STANDARD) set(CMAKE_CXX_STANDARD 17) - # Can't set to 11 due to executor_runner.cpp make_unique endif() if(NOT ANDROID) @@ -71,70 +70,54 @@ if(TARGET vulkan_backend) list(APPEND link_libraries vulkan_backend) endif() - - set(LLAMA_RUNNER_PATH - ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama2/runner/libllama_runner.a - ) - add_library(llama_runner STATIC IMPORTED) - set_property( - TARGET llama_runner PROPERTY IMPORTED_LOCATION ${LLAMA_RUNNER_PATH} +if(EXECUTORCH_BUILD_KERNELS_CUSTOM) + list(APPEND link_libraries custom_ops) + add_subdirectory( + ${EXECUTORCH_ROOT}/extension/llm/custom_ops + ${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/custom_ops ) +endif() + +set(JNI_SRCS jni/jni_layer.cpp) +if(EXECUTORCH_BUILD_LLAMA_JNI) + list(APPEND JNI_SRCS jni/jni_layer_llama.cpp) + list(APPEND link_libraries llama_runner llava_runner) add_subdirectory( ${EXECUTORCH_ROOT}/examples/models/llava/runner ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llava/runner ) - set(CUSTOM_OPS_PATH - ${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/custom_ops/libcustom_ops.a + add_subdirectory( + ${EXECUTORCH_ROOT}/examples/models/llama2/runner + ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama2/runner ) - add_library(custom_ops STATIC IMPORTED) - set_property(TARGET custom_ops PROPERTY IMPORTED_LOCATION ${CUSTOM_OPS_PATH}) - target_link_options_shared_lib(custom_ops) - - target_link_options_shared_lib(quantized_ops_lib) - - set(LLAMA_JNI_SRCS jni/jni_layer_llama.cpp jni/jni_layer.cpp) - add_library(executorch_llama_jni SHARED ${LLAMA_JNI_SRCS}) - if(TARGET pthreadpool) - target_compile_definitions(executorch_llama_jni PRIVATE ET_USE_THREADPOOL=1) - target_include_directories( - executorch_llama_jni - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/cpuinfo/include - ) - target_include_directories( - executorch_llama_jni - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/pthreadpool/include - ) - endif() +endif() + +if(TARGET quantized_kernels) + list(APPEND link_libraries quantized_kernels quantized_ops_lib) +endif() + +add_library(executorch_jni SHARED ${JNI_SRCS}) + +target_include_directories( + executorch_jni PRIVATE ${_common_include_directories} +) + +target_compile_options(executorch_jni PUBLIC ${_common_compile_options}) + +target_link_libraries(executorch_jni ${link_libraries}) + +if(TARGET pthreadpool) + target_compile_definitions(executorch_jni PRIVATE ET_USE_THREADPOOL=1) target_include_directories( - executorch_llama_jni PRIVATE ${_common_include_directories} - ) - target_link_libraries( - executorch_llama_jni - ${link_libraries} - llama_runner - llava_runner - custom_ops - cpublas - eigen_blas - quantized_kernels - quantized_ops_lib + executorch_jni + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/cpuinfo/include ) - target_compile_options(executorch_llama_jni PUBLIC ${_common_compile_options}) - # link re2 - set(ABSL_ENABLE_INSTALL ON) - set(_pic_flag ${CMAKE_POSITION_INDEPENDENT_CODE}) - set(CMAKE_POSITION_INDEPENDENT_CODE ON) - add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/third-party/abseil-cpp - ${CMAKE_CURRENT_BINARY_DIR}/abseil-cpp - ) - add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/third-party/re2 - ${CMAKE_CURRENT_BINARY_DIR}/re2 + target_include_directories( + executorch_jni + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/pthreadpool/include ) - set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) - target_link_libraries(executorch_llama_jni re2::re2) +endif() diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index c4de23df0ee..e3438afec39 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -28,7 +28,7 @@ public class LlamaModule { if (!NativeLoader.isInitialized()) { NativeLoader.init(new SystemDelegate()); } - NativeLoader.loadLibrary("executorch_llama_jni"); + NativeLoader.loadLibrary("executorch"); } private final HybridData mHybridData; diff --git a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java index d188dc18520..f63de985069 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java +++ b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java @@ -21,7 +21,7 @@ class NativePeer { static { // Loads libexecutorch.so from jniLibs - NativeLoader.loadLibrary("executorch_llama_jni"); + NativeLoader.loadLibrary("executorch"); } private final HybridData mHybridData; From 19cec4e8d3d71a41df5c5b8430086601f5dd72da Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 9 Sep 2024 15:13:08 -0700 Subject: [PATCH 15/20] add a way to build non llm --- build/build_android_llm_demo.sh | 4 ++-- extension/android/CMakeLists.txt | 7 +++---- extension/android/jni/jni_layer.cpp | 10 +++++++--- extension/android/jni/jni_layer_llama.cpp | 3 +-- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/build/build_android_llm_demo.sh b/build/build_android_llm_demo.sh index 07de28be833..917512d71b6 100644 --- a/build/build_android_llm_demo.sh +++ b/build/build_android_llm_demo.sh @@ -97,14 +97,13 @@ build_aar() { find jni -type f -name "libexecutorch_jni.so" -exec bash -c 'mv "$1" "${1/_jni/}"' bash {} \; # Zip all necessary files into the AAR file zip -r executorch.aar libs jni/*/libexecutorch.so jni/*/libqnn*.so jni/*/libQnn*.so AndroidManifest.xml - zip -r executorch-llama.aar libs jni/*/libexecutorch_llama_jni.so jni/*/libqnn*.so jni/*/libQnn*.so AndroidManifest.xml + zip -r executorch-llama.aar libs jni/*/libexecutorch.so jni/*/libqnn*.so jni/*/libQnn*.so AndroidManifest.xml popd } build_android_demo_apps() { mkdir -p examples/demo-apps/android/LlamaDemo/app/libs cp ${BUILD_AAR_DIR}/executorch-llama.aar examples/demo-apps/android/LlamaDemo/app/libs - cp ${BUILD_AAR_DIR}/executorch-llama.aar extension/android/benchmark/app/libs/executorch.aar pushd examples/demo-apps/android/LlamaDemo ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew build assembleAndroidTest popd @@ -140,6 +139,7 @@ collect_artifacts_to_be_uploaded() { } BUILD_AAR_DIR="$(mktemp -d)" +export BUILD_AAR_DIR if [ -z "$ANDROID_ABIS" ]; then ANDROID_ABIS=("arm64-v8a" "x86_64") fi diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 5763306d34f..1c3ad434b17 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -78,11 +78,12 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM) ) endif() -set(JNI_SRCS jni/jni_layer.cpp) +add_library(executorch_jni SHARED jni/jni_layer.cpp) if(EXECUTORCH_BUILD_LLAMA_JNI) - list(APPEND JNI_SRCS jni/jni_layer_llama.cpp) + target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp) list(APPEND link_libraries llama_runner llava_runner) + target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1) add_subdirectory( ${EXECUTORCH_ROOT}/examples/models/llava/runner ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llava/runner @@ -98,8 +99,6 @@ if(TARGET quantized_kernels) list(APPEND link_libraries quantized_kernels quantized_ops_lib) endif() -add_library(executorch_jni SHARED ${JNI_SRCS}) - target_include_directories( executorch_jni PRIVATE ${_common_include_directories} ) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 7be2ecc8d30..a6204a1965c 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -386,11 +386,15 @@ class ExecuTorchJni : public facebook::jni::HybridClass { }; } // namespace executorch::extension -extern void register_natives_jni(); - +#ifdef EXECUTORCH_BUILD_LLAMA_JNI +extern void register_natives_for_llama(); +#else +// No op if we don't build llama +void register_natives_for_llama() {} +#endif JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize( vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); - register_natives_jni(); + register_natives_for_llama(); }); } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 68fe3c3025a..a5acf3d417f 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -265,7 +265,6 @@ class ExecuTorchLlamaJni } // namespace executorch_jni -void register_natives_jni() { +void register_natives_for_llama() { executorch_jni::ExecuTorchLlamaJni::registerNatives(); } - From 0265e6b038eca49e69db0a4d6db2645bcce0de88 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 9 Sep 2024 15:23:00 -0700 Subject: [PATCH 16/20] Remove reference libexecutorch_llama_jni.so --- examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh | 4 +++- examples/demo-apps/android/LlamaDemo/setup.sh | 2 +- extension/android/jni/BUCK | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh b/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh index 87d0f47c956..4deafb83487 100644 --- a/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh +++ b/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh @@ -37,6 +37,7 @@ cmake examples/models/llama2 \ -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DEXECUTORCH_USE_TIKTOKEN="${EXECUTORCH_USE_TIKTOKEN}" \ -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DCMAKE_BUILD_TYPE=Release \ -B"${CMAKE_OUT}"/examples/models/llama2 @@ -47,6 +48,7 @@ cmake extension/android \ -DANDROID_ABI="${ANDROID_ABI}" \ -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_USE_TIKTOKEN="${EXECUTORCH_USE_TIKTOKEN}" \ -DCMAKE_BUILD_TYPE=Release \ -B"${CMAKE_OUT}"/extension/android @@ -59,7 +61,7 @@ mkdir -p "${JNI_LIBS_PATH}/${ANDROID_ABI}" BUILD_AAR_DIR="$(mktemp -d)" mkdir -p "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}" "${BUILD_AAR_DIR}/libs" JNI_LIBS_PATH="${BUILD_AAR_DIR}/jni" -cp "${CMAKE_OUT}"/extension/android/libexecutorch_llama_jni.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" +cp "${CMAKE_OUT}"/extension/android/libexecutorch_jni.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/libexecutorch_jni.so" cp "${CMAKE_OUT}"/lib/libqnn_executorch_backend.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtp.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnSystem.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" diff --git a/examples/demo-apps/android/LlamaDemo/setup.sh b/examples/demo-apps/android/LlamaDemo/setup.sh index 91a68d4b88b..78816680bc7 100644 --- a/examples/demo-apps/android/LlamaDemo/setup.sh +++ b/examples/demo-apps/android/LlamaDemo/setup.sh @@ -56,7 +56,7 @@ cmake --build "${CMAKE_OUT}"/extension/android -j "${CMAKE_JOBS}" --config Relea BUILD_AAR_DIR="$(mktemp -d)" mkdir -p "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}" "${BUILD_AAR_DIR}/libs" -cp "${CMAKE_OUT}"/extension/android/libexecutorch_llama_jni.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}" +cp "${CMAKE_OUT}"/extension/android/libexecutorch_jni.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/libexecutorch.so" cp extension/android/build/libs/executorch.jar "${BUILD_AAR_DIR}/libs" echo \ \ diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index 7cdf8ef7ec4..f7e7932a21b 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -77,7 +77,7 @@ fb_android_cxx_library( "-fexceptions", "-Wno-format", ], - soname = "libexecutorch_llama_jni.$(ext)", + soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], deps = [ "//fbandroid/libraries/fbjni:fbjni", From 716380744184fd5e2aedde353de5670b91e786a0 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 9 Sep 2024 15:26:28 -0700 Subject: [PATCH 17/20] Remove app related stuff for now --- .../android/benchmark/app/build.gradle.kts | 1 - .../app/src/main/AndroidManifest.xml | 8 -- .../minibench/LlmBenchmarkActivity.java | 114 ------------------ .../org/pytorch/minibench/ModelRunner.java | 97 --------------- .../minibench/ModelRunnerCallback.java | 24 ---- 5 files changed, 244 deletions(-) delete mode 100644 extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java delete mode 100644 extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java delete mode 100644 extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java diff --git a/extension/android/benchmark/app/build.gradle.kts b/extension/android/benchmark/app/build.gradle.kts index dcf99ca9cd0..b716f2e8bd0 100644 --- a/extension/android/benchmark/app/build.gradle.kts +++ b/extension/android/benchmark/app/build.gradle.kts @@ -38,7 +38,6 @@ dependencies { implementation(files("libs/executorch.aar")) implementation("com.facebook.soloader:soloader:0.10.5") implementation("com.facebook.fbjni:fbjni:0.5.1") - implementation("com.google.code.gson:gson:2.8.6") testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.2.1") androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1") diff --git a/extension/android/benchmark/app/src/main/AndroidManifest.xml b/extension/android/benchmark/app/src/main/AndroidManifest.xml index 098905c052c..49711b6830e 100644 --- a/extension/android/benchmark/app/src/main/AndroidManifest.xml +++ b/extension/android/benchmark/app/src/main/AndroidManifest.xml @@ -16,14 +16,6 @@ - - - - - - diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java deleted file mode 100644 index aba9f6c6799..00000000000 --- a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.app.Activity; -import android.content.Intent; -import android.os.Bundle; -import android.util.Log; -import com.google.gson.Gson; -import java.io.File; -import java.io.FileWriter; -import java.io.IOException; -import java.util.Arrays; - -public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback { - ModelRunner mModelRunner; - - String mPrompt; - StatsDump mStatsDump; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - - Intent intent = getIntent(); - - File modelDir = new File(intent.getStringExtra("model_dir")); - File model = - Arrays.stream(modelDir.listFiles()) - .filter(file -> file.getName().endsWith(".pte")) - .findFirst() - .get(); - String tokenizerPath = intent.getStringExtra("tokenizer_path"); - - float temperature = intent.getFloatExtra("temperature", 0.8f); - mPrompt = intent.getStringExtra("prompt"); - if (mPrompt == null) { - mPrompt = "The ultimate answer"; - } - - mStatsDump = new StatsDump(); - mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); - mStatsDump.loadStart = System.currentTimeMillis(); - } - - @Override - public void onModelLoaded(int status) { - mStatsDump.loadEnd = System.currentTimeMillis(); - if (status != 0) { - Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); - onGenerationStopped(); - return; - } - mStatsDump.generateStart = System.currentTimeMillis(); - mModelRunner.generate(mPrompt); - } - - @Override - public void onTokenGenerated(String token) {} - - @Override - public void onStats(String stats) { - mStatsDump.tokens = stats; - } - - @Override - public void onGenerationStopped() { - mStatsDump.generateEnd = System.currentTimeMillis(); - - // TODO (huydhn): Remove txt files here once the JSON format is ready - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { - writer.write(mStatsDump.toString()); - } catch (IOException e) { - e.printStackTrace(); - } - - // TODO (huydhn): Figure out on what the final JSON results looks like, we need something - // with the same number of fields as https://github.com/pytorch/pytorch/pull/135042 - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(mStatsDump)); - } catch (IOException e) { - e.printStackTrace(); - } - } -} - -class StatsDump { - long loadStart; - long loadEnd; - long generateStart; - long generateEnd; - String tokens; - - @Override - public String toString() { - return "loadStart: " - + loadStart - + "\nloadEnd: " - + loadEnd - + "\ngenerateStart: " - + generateStart - + "\ngenerateEnd: " - + generateEnd - + "\n" - + tokens; - } -} diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java deleted file mode 100644 index 9e9b9e003d8..00000000000 --- a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; -import android.os.Message; -import org.pytorch.executorch.LlamaCallback; -import org.pytorch.executorch.LlamaModule; - -/** A helper class to handle all model running logic within this class. */ -public class ModelRunner implements LlamaCallback { - LlamaModule mModule = null; - - String mModelFilePath = ""; - String mTokenizerFilePath = ""; - - ModelRunnerCallback mCallback = null; - - HandlerThread mHandlerThread = null; - Handler mHandler = null; - - /** - * ] Helper class to separate between UI logic and model runner logic. Automatically handle - * generate() request on worker thread. - * - * @param modelFilePath - * @param tokenizerFilePath - * @param callback - */ - ModelRunner( - String modelFilePath, - String tokenizerFilePath, - float temperature, - ModelRunnerCallback callback) { - mModelFilePath = modelFilePath; - mTokenizerFilePath = tokenizerFilePath; - mCallback = callback; - - mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f); - mHandlerThread = new HandlerThread("ModelRunner"); - mHandlerThread.start(); - mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); - } - - int generate(String prompt) { - Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); - msg.sendToTarget(); - return 0; - } - - void stop() { - mModule.stop(); - } - - @Override - public void onResult(String result) { - mCallback.onTokenGenerated(result); - } - - @Override - public void onStats(float tps) { - mCallback.onStats("tokens/second: " + tps); - } -} - -class ModelRunnerHandler extends Handler { - public static int MESSAGE_LOAD_MODEL = 1; - public static int MESSAGE_GENERATE = 2; - - private final ModelRunner mModelRunner; - - public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { - super(looper); - mModelRunner = modelRunner; - } - - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_LOAD_MODEL) { - int status = mModelRunner.mModule.load(); - mModelRunner.mCallback.onModelLoaded(status); - } else if (msg.what == MESSAGE_GENERATE) { - mModelRunner.mModule.generate((String) msg.obj, mModelRunner); - mModelRunner.mCallback.onGenerationStopped(); - } - } -} diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java deleted file mode 100644 index 63701a7bbc6..00000000000 --- a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -/** - * A helper interface within the app for MainActivity and Benchmarking to handle callback from - * ModelRunner. - */ -public interface ModelRunnerCallback { - - void onModelLoaded(int status); - - void onTokenGenerated(String token); - - void onStats(String token); - - void onGenerationStopped(); -} From 232e746e57922692d8ac6f48785250f492ddd6a3 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 9 Sep 2024 15:50:56 -0700 Subject: [PATCH 18/20] linter --- extension/android/jni/jni_layer.cpp | 8 ++++---- extension/android/jni/jni_layer_llama.cpp | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index a6204a1965c..1ef81b20b08 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -393,8 +393,8 @@ extern void register_natives_for_llama(); void register_natives_for_llama() {} #endif JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { - return facebook::jni::initialize( - vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); - register_natives_for_llama(); - }); + return facebook::jni::initialize(vm, [] { + executorch::extension::ExecuTorchJni::registerNatives(); + register_natives_for_llama(); + }); } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index a5acf3d417f..007412e1e7d 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -30,7 +30,6 @@ #include #include - using namespace torch::executor; namespace executorch_jni { @@ -266,5 +265,5 @@ class ExecuTorchLlamaJni } // namespace executorch_jni void register_natives_for_llama() { - executorch_jni::ExecuTorchLlamaJni::registerNatives(); + executorch_jni::ExecuTorchLlamaJni::registerNatives(); } From 50920d5698f53746a1e3411e18486783197f8a61 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 9 Sep 2024 15:50:56 -0700 Subject: [PATCH 19/20] linter --- extension/android/jni/jni_layer.cpp | 8 ++++---- extension/android/jni/jni_layer_llama.cpp | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index a6204a1965c..1ef81b20b08 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -393,8 +393,8 @@ extern void register_natives_for_llama(); void register_natives_for_llama() {} #endif JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { - return facebook::jni::initialize( - vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); - register_natives_for_llama(); - }); + return facebook::jni::initialize(vm, [] { + executorch::extension::ExecuTorchJni::registerNatives(); + register_natives_for_llama(); + }); } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index a5acf3d417f..007412e1e7d 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -30,7 +30,6 @@ #include #include - using namespace torch::executor; namespace executorch_jni { @@ -266,5 +265,5 @@ class ExecuTorchLlamaJni } // namespace executorch_jni void register_natives_for_llama() { - executorch_jni::ExecuTorchLlamaJni::registerNatives(); + executorch_jni::ExecuTorchLlamaJni::registerNatives(); } From 957d1f1089503e98f85b8419a1858ac948cdd845 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 9 Sep 2024 23:24:47 -0700 Subject: [PATCH 20/20] Link custom_ops --- extension/android/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 1c3ad434b17..c9396a55879 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -71,11 +71,12 @@ if(TARGET vulkan_backend) endif() if(EXECUTORCH_BUILD_KERNELS_CUSTOM) - list(APPEND link_libraries custom_ops) add_subdirectory( ${EXECUTORCH_ROOT}/extension/llm/custom_ops ${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/custom_ops ) + list(APPEND link_libraries custom_ops) + target_link_options_shared_lib(custom_ops) endif() add_library(executorch_jni SHARED jni/jni_layer.cpp)