diff --git a/extension/android/benchmark/app/build.gradle.kts b/extension/android/benchmark/app/build.gradle.kts index b716f2e8bd0..dcf99ca9cd0 100644 --- a/extension/android/benchmark/app/build.gradle.kts +++ b/extension/android/benchmark/app/build.gradle.kts @@ -38,6 +38,7 @@ dependencies { implementation(files("libs/executorch.aar")) implementation("com.facebook.soloader:soloader:0.10.5") implementation("com.facebook.fbjni:fbjni:0.5.1") + implementation("com.google.code.gson:gson:2.8.6") testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.2.1") androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1") diff --git a/extension/android/benchmark/app/src/main/AndroidManifest.xml b/extension/android/benchmark/app/src/main/AndroidManifest.xml index 49711b6830e..098905c052c 100644 --- a/extension/android/benchmark/app/src/main/AndroidManifest.xml +++ b/extension/android/benchmark/app/src/main/AndroidManifest.xml @@ -16,6 +16,14 @@ + + + + + + diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java new file mode 100644 index 00000000000..496cbde53d6 --- /dev/null +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + +import android.app.Activity; +import android.content.Intent; +import android.os.Bundle; +import android.util.Log; +import com.google.gson.Gson; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.Arrays; + +public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback { + ModelRunner mModelRunner; + + String mPrompt; + StatsInfo mStatsInfo; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + + Intent intent = getIntent(); + + File modelDir = new File(intent.getStringExtra("model_dir")); + File model = + Arrays.stream(modelDir.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .findFirst() + .get(); + String tokenizerPath = intent.getStringExtra("tokenizer_path"); + + float temperature = intent.getFloatExtra("temperature", 0.8f); + mPrompt = intent.getStringExtra("prompt"); + if (mPrompt == null) { + mPrompt = "The ultimate answer"; + } + + mStatsInfo = new StatsInfo(); + mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); + mStatsInfo.loadStart = System.currentTimeMillis(); + } + + @Override + public void onModelLoaded(int status) { + mStatsInfo.loadEnd = System.currentTimeMillis(); + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); + onGenerationStopped(); + return; + } + mStatsInfo.generateStart = System.currentTimeMillis(); + mModelRunner.generate(mPrompt); + } + + @Override + public void onTokenGenerated(String token) {} + + @Override + public void onStats(String stats) { + mStatsInfo.tokens = stats; + } + + @Override + public void onGenerationStopped() { + mStatsInfo.generateEnd = System.currentTimeMillis(); + + // TODO (huydhn): Remove txt files here once the JSON format is ready + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { + writer.write(mStatsInfo.toString()); + } catch (IOException e) { + e.printStackTrace(); + } + + // TODO (huydhn): Figure out on what the final JSON results looks like, we need something + // with the same number of fields as https://github.com/pytorch/pytorch/pull/135042 + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(mStatsInfo)); + } catch (IOException e) { + e.printStackTrace(); + } + } +} + +class StatsInfo { + long loadStart; + long loadEnd; + long generateStart; + long generateEnd; + String tokens; + + @Override + public String toString() { + return "loadStart: " + + loadStart + + "\nloadEnd: " + + loadEnd + + "\ngenerateStart: " + + generateStart + + "\ngenerateEnd: " + + generateEnd + + "\n" + + tokens; + } +} diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java new file mode 100644 index 00000000000..9e9b9e003d8 --- /dev/null +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Looper; +import android.os.Message; +import org.pytorch.executorch.LlamaCallback; +import org.pytorch.executorch.LlamaModule; + +/** A helper class to handle all model running logic within this class. */ +public class ModelRunner implements LlamaCallback { + LlamaModule mModule = null; + + String mModelFilePath = ""; + String mTokenizerFilePath = ""; + + ModelRunnerCallback mCallback = null; + + HandlerThread mHandlerThread = null; + Handler mHandler = null; + + /** + * ] Helper class to separate between UI logic and model runner logic. Automatically handle + * generate() request on worker thread. + * + * @param modelFilePath + * @param tokenizerFilePath + * @param callback + */ + ModelRunner( + String modelFilePath, + String tokenizerFilePath, + float temperature, + ModelRunnerCallback callback) { + mModelFilePath = modelFilePath; + mTokenizerFilePath = tokenizerFilePath; + mCallback = callback; + + mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f); + mHandlerThread = new HandlerThread("ModelRunner"); + mHandlerThread.start(); + mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); + } + + int generate(String prompt) { + Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); + msg.sendToTarget(); + return 0; + } + + void stop() { + mModule.stop(); + } + + @Override + public void onResult(String result) { + mCallback.onTokenGenerated(result); + } + + @Override + public void onStats(float tps) { + mCallback.onStats("tokens/second: " + tps); + } +} + +class ModelRunnerHandler extends Handler { + public static int MESSAGE_LOAD_MODEL = 1; + public static int MESSAGE_GENERATE = 2; + + private final ModelRunner mModelRunner; + + public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { + super(looper); + mModelRunner = modelRunner; + } + + @Override + public void handleMessage(android.os.Message msg) { + if (msg.what == MESSAGE_LOAD_MODEL) { + int status = mModelRunner.mModule.load(); + mModelRunner.mCallback.onModelLoaded(status); + } else if (msg.what == MESSAGE_GENERATE) { + mModelRunner.mModule.generate((String) msg.obj, mModelRunner); + mModelRunner.mCallback.onGenerationStopped(); + } + } +} diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java new file mode 100644 index 00000000000..63701a7bbc6 --- /dev/null +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + +/** + * A helper interface within the app for MainActivity and Benchmarking to handle callback from + * ModelRunner. + */ +public interface ModelRunnerCallback { + + void onModelLoaded(int status); + + void onTokenGenerated(String token); + + void onStats(String token); + + void onGenerationStopped(); +}