diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java deleted file mode 100644 index 5e1dd48926b..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.app.Activity; -import android.content.Intent; -import android.os.Bundle; -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; -import android.system.ErrnoException; -import android.system.Os; -import com.google.gson.Gson; -import java.io.File; -import java.io.FileWriter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -public class BenchmarkActivity extends Activity { - - File mModel; - int mNumIter; - int mNumWarmupIter; - String mTokenizerPath; - float mTemperature; - String mPrompt; - - HandlerThread mHandlerThread; - BenchmarkHandler mHandler; - - List mResult; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - - try { - Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); - } catch (ErrnoException e) { - finish(); - } - - Intent intent = getIntent(); - File modelDir = new File(intent.getStringExtra("model_dir")); - File model = - Arrays.stream(modelDir.listFiles()) - .filter(file -> file.getName().endsWith(".pte")) - .findFirst() - .get(); - - int numIter = intent.getIntExtra("num_iter", 50); - int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10); - String tokenizerPath = intent.getStringExtra("tokenizer_path"); - float temperature = intent.getFloatExtra("temperature", 0.8f); - String prompt = intent.getStringExtra("prompt"); - - mModel = model; - mNumIter = numIter; - mNumWarmupIter = numWarmupIter; - mTokenizerPath = tokenizerPath; - mTemperature = temperature; - mPrompt = prompt; - if (mPrompt == null) { - mPrompt = "The ultimate answer"; - } - mResult = new ArrayList<>(); - - mHandlerThread = new HandlerThread("ModelRunner"); - mHandlerThread.start(); - mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK); - } - - void writeResult() { - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(mResult)); - } catch (IOException e) { - e.printStackTrace(); - } finally { - finish(); - } - } -} - -class BenchmarkHandler extends Handler { - public static int MESSAGE_RUN_BENCHMARK = 1; - public static int MESSAGE_LLM_RUN_BENCHMARK = 2; - - ModelRunner mModelRunner; - BenchmarkActivity mBenchmarkActivity; - - LlmModelRunner mLlmModelRunner; - LlmBenchmark mLlmBenchmark; - - public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) { - super(looper); - mModelRunner = new ModelRunner(); - mBenchmarkActivity = benchmarkActivity; - } - - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_RUN_BENCHMARK) { - mModelRunner.runBenchmark( - mBenchmarkActivity.mModel, - mBenchmarkActivity.mNumWarmupIter, - mBenchmarkActivity.mNumIter, - mBenchmarkActivity.mResult); - - if (mBenchmarkActivity.mTokenizerPath == null) { - mBenchmarkActivity.writeResult(); - } else { - this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK); - } - } else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) { - mLlmBenchmark = - new LlmBenchmark( - mBenchmarkActivity, - mBenchmarkActivity.mModel.getPath(), - mBenchmarkActivity.mTokenizerPath, - mBenchmarkActivity.mPrompt, - mBenchmarkActivity.mTemperature, - mBenchmarkActivity.mResult); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.kt new file mode 100644 index 00000000000..b1d69c5f24f --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.kt @@ -0,0 +1,116 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.app.Activity +import android.os.Bundle +import android.os.Handler +import android.os.HandlerThread +import android.os.Looper +import android.os.Message +import android.system.Os +import com.google.gson.Gson +import java.io.File +import java.io.FileWriter +import java.io.IOException + +class BenchmarkActivity : Activity() { + + lateinit var model: File + var numIter: Int = 0 + var numWarmupIter: Int = 0 + var tokenizerPath: String? = null + var temperature: Float = 0.8f + var prompt: String = "The ultimate answer" + + private lateinit var handlerThread: HandlerThread + private lateinit var handler: BenchmarkHandler + + val results: MutableList = mutableListOf() + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + + try { + Os.setenv("ADSP_LIBRARY_PATH", applicationInfo.nativeLibraryDir, true) + } catch (e: android.system.ErrnoException) { + finish() + return + } + + val intent = intent + val modelDir = File(intent.getStringExtra("model_dir")!!) + model = modelDir.listFiles()!!.first { it.name.endsWith(".pte") } + + numIter = intent.getIntExtra("num_iter", 50) + numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10) + tokenizerPath = intent.getStringExtra("tokenizer_path") + temperature = intent.getFloatExtra("temperature", 0.8f) + prompt = intent.getStringExtra("prompt") ?: "The ultimate answer" + + handlerThread = HandlerThread("ModelRunner") + handlerThread.start() + handler = BenchmarkHandler(handlerThread.looper, this) + + handler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK) + } + + fun writeResult() { + try { + FileWriter("${filesDir}/benchmark_results.json").use { writer -> + writer.write(Gson().toJson(results)) + } + } catch (e: IOException) { + e.printStackTrace() + } finally { + finish() + } + } +} + +private class BenchmarkHandler( + looper: Looper, + private val activity: BenchmarkActivity, +) : Handler(looper) { + + private val modelRunner = ModelRunner() + + override fun handleMessage(msg: Message) { + when (msg.what) { + MESSAGE_RUN_BENCHMARK -> { + modelRunner.runBenchmark( + activity.model, + activity.numWarmupIter, + activity.numIter, + activity.results, + ) + if (activity.tokenizerPath == null) { + activity.writeResult() + } else { + sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK) + } + } + MESSAGE_LLM_RUN_BENCHMARK -> { + LlmBenchmark( + activity, + activity.model.path, + activity.tokenizerPath!!, + activity.prompt, + activity.temperature, + activity.results, + ) + } + } + } + + companion object { + const val MESSAGE_RUN_BENCHMARK = 1 + const val MESSAGE_LLM_RUN_BENCHMARK = 2 + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java deleted file mode 100644 index 66ab50550a4..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.app.ActivityManager; -import android.os.Build; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -class BenchmarkMetric { - public static class BenchmarkModel { - // The model name, i.e. stories110M - String name; - String backend; - String quantization; - - public BenchmarkModel(final String name, final String backend, final String quantization) { - this.name = name; - this.backend = backend; - this.quantization = quantization; - } - } - - BenchmarkModel benchmarkModel; - - // The metric name, i.e. TPS - String metric; - - // The actual value and the option target value - double actualValue; - double targetValue; - - public static class DeviceInfo { - // Let's see which information we want to include here - final String device = Build.BRAND; - // The phone model and Android release version - final String arch = Build.MODEL; - final String os = "Android " + Build.VERSION.RELEASE; - final long totalMem = new ActivityManager.MemoryInfo().totalMem; - final long availMem = new ActivityManager.MemoryInfo().availMem; - } - - DeviceInfo deviceInfo = new DeviceInfo(); - - public BenchmarkMetric( - final BenchmarkModel benchmarkModel, - final String metric, - final double actualValue, - final double targetValue) { - this.benchmarkModel = benchmarkModel; - this.metric = metric; - this.actualValue = actualValue; - this.targetValue = targetValue; - } - - // TODO (huydhn): Figure out a way to extract the backend and quantization information from - // the .pte model itself instead of parsing its name - public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { - final Matcher m = - Pattern.compile("(?\\w+)_(?[\\w\\+]+)_(?\\w+)").matcher(model); - if (m.matches()) { - return new BenchmarkMetric.BenchmarkModel( - m.group("name"), m.group("backend"), m.group("quantization")); - } else { - return new BenchmarkMetric.BenchmarkModel(model, "", ""); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.kt new file mode 100644 index 00000000000..7bed1ab05c0 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.kt @@ -0,0 +1,54 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.app.ActivityManager +import android.os.Build + +class BenchmarkMetric( + val benchmarkModel: BenchmarkModel, + val metric: String, + val actualValue: Double, + val targetValue: Double, +) { + data class BenchmarkModel( + val name: String, + val backend: String, + val quantization: String, + ) + + class DeviceInfo { + val device: String = Build.BRAND + val arch: String = Build.MODEL + val os: String = "Android ${Build.VERSION.RELEASE}" + val totalMem: Long = ActivityManager.MemoryInfo().totalMem + val availMem: Long = ActivityManager.MemoryInfo().availMem + } + + val deviceInfo: DeviceInfo = DeviceInfo() + + companion object { + // TODO (huydhn): Figure out a way to extract the backend and quantization information from + // the .pte model itself instead of parsing its name + @JvmStatic + fun extractBackendAndQuantization(model: String): BenchmarkModel { + val pattern = Regex("(?\\w+)_(?[\\w+]+)_(?\\w+)") + val match = pattern.matchEntire(model) + return if (match != null) { + BenchmarkModel( + match.groups["name"]!!.value, + match.groups["backend"]!!.value, + match.groups["quantization"]!!.value, + ) + } else { + BenchmarkModel(model, "", "") + } + } + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java deleted file mode 100644 index 0c0436d2676..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.util.Log; -import java.util.List; -import org.json.JSONException; -import org.json.JSONObject; - -public class LlmBenchmark implements LlmModelRunnerCallback { - LlmModelRunner mLlmModelRunner; - - String mPrompt; - StatsInfo mStatsInfo; - - List mResults; - BenchmarkActivity mActivity; - - LlmBenchmark( - BenchmarkActivity activity, - String modelFile, - String tokenizerPath, - String prompt, - float temperature, - List results) { - mResults = results; - mActivity = activity; - mStatsInfo = new StatsInfo(); - mStatsInfo.modelName = modelFile.substring(modelFile.lastIndexOf('/') + 1).replace(".pte", ""); - mPrompt = prompt; - mLlmModelRunner = new LlmModelRunner(modelFile, tokenizerPath, temperature, this); - mStatsInfo.loadStart = System.nanoTime(); - } - - @Override - public void onModelLoaded(int status) { - mStatsInfo.loadEnd = System.nanoTime(); - mStatsInfo.loadStatus = status; - if (status != 0) { - Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); - onGenerationStopped(); - return; - } - mStatsInfo.generateStart = System.nanoTime(); - mLlmModelRunner.generate(mPrompt); - } - - @Override - public void onTokenGenerated(String token) {} - - @Override - public void onStats(String stats) { - float tps = 0; - try { - JSONObject jsonObject = new JSONObject(stats); - int numGeneratedTokens = jsonObject.getInt("generated_tokens"); - int inferenceEndMs = jsonObject.getInt("inference_end_ms"); - int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); - tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; - mStatsInfo.tps = tps; - } catch (JSONException e) { - Log.e("LLM", "Error parsing JSON: " + e.getMessage()); - } - } - - @Override - public void onGenerationStopped() { - mStatsInfo.generateEnd = System.nanoTime(); - - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.modelName); - // The list of metrics we have atm includes: - // Load status - mResults.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0)); - // Model load time - mResults.add( - new BenchmarkMetric( - benchmarkModel, - "llm_model_load_time(ms)", - (mStatsInfo.loadEnd - mStatsInfo.loadStart) * 1e-6, - 0.0f)); - // LLM generate time - mResults.add( - new BenchmarkMetric( - benchmarkModel, - "generate_time(ms)", - (mStatsInfo.generateEnd - mStatsInfo.generateStart) * 1e-6, - 0.0f)); - // Token per second - mResults.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); - mActivity.writeResult(); - } -} - -class StatsInfo { - int loadStatus; - long loadStart; - long loadEnd; - long generateStart; - long generateEnd; - float tps; - String modelName; - - @Override - public String toString() { - return "loadStart: " - + loadStart - + "\nloadEnd: " - + loadEnd - + "\ngenerateStart: " - + generateStart - + "\ngenerateEnd: " - + generateEnd - + "\n" - + tps; - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.kt new file mode 100644 index 00000000000..5c75519f870 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.kt @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.util.Log +import org.json.JSONException +import org.json.JSONObject + +class LlmBenchmark( + private val activity: BenchmarkActivity, + modelFile: String, + tokenizerPath: String, + private val prompt: String, + temperature: Float, + private val results: MutableList, +) : LlmModelRunnerCallback { + + private val runner: LlmModelRunner + private val statsInfo = StatsInfo() + + init { + statsInfo.modelName = modelFile.substringAfterLast('/').removeSuffix(".pte") + runner = LlmModelRunner(modelFile, tokenizerPath, temperature, this) + statsInfo.loadStart = System.nanoTime() + } + + override fun onModelLoaded(status: Int) { + statsInfo.loadEnd = System.nanoTime() + statsInfo.loadStatus = status + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: $status") + onGenerationStopped() + return + } + statsInfo.generateStart = System.nanoTime() + runner.generate(prompt) + } + + override fun onTokenGenerated(token: String) {} + + override fun onStats(stats: String) { + try { + val json = JSONObject(stats) + val numGeneratedTokens = json.getInt("generated_tokens") + val inferenceEndMs = json.getInt("inference_end_ms") + val promptEvalEndMs = json.getInt("prompt_eval_end_ms") + statsInfo.tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 + } catch (e: JSONException) { + Log.e("LLM", "Error parsing JSON: ${e.message}") + } + } + + override fun onGenerationStopped() { + statsInfo.generateEnd = System.nanoTime() + + val benchmarkModel = BenchmarkMetric.extractBackendAndQuantization(statsInfo.modelName) + results.add(BenchmarkMetric(benchmarkModel, "load_status", statsInfo.loadStatus.toDouble(), 0.0)) + results.add( + BenchmarkMetric( + benchmarkModel, + "llm_model_load_time(ms)", + (statsInfo.loadEnd - statsInfo.loadStart) * 1e-6, + 0.0, + )) + results.add( + BenchmarkMetric( + benchmarkModel, + "generate_time(ms)", + (statsInfo.generateEnd - statsInfo.generateStart) * 1e-6, + 0.0, + )) + results.add(BenchmarkMetric(benchmarkModel, "token_per_sec", statsInfo.tps.toDouble(), 0.0)) + activity.writeResult() + } +} + +private class StatsInfo { + var loadStatus: Int = 0 + var loadStart: Long = 0 + var loadEnd: Long = 0 + var generateStart: Long = 0 + var generateEnd: Long = 0 + var tps: Float = 0f + var modelName: String = "" +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java deleted file mode 100644 index 3a345d3465b..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; -import android.os.Message; -import android.util.Log; -import org.pytorch.executorch.extension.llm.LlmCallback; -import org.pytorch.executorch.extension.llm.LlmModule; - -/** A helper class to handle all model running logic within this class. */ -public class LlmModelRunner implements LlmCallback { - LlmModule mModule = null; - - String mModelFilePath = ""; - String mTokenizerFilePath = ""; - - LlmModelRunnerCallback mCallback = null; - - HandlerThread mHandlerThread = null; - Handler mHandler = null; - - /** - * ] Helper class to separate between UI logic and model runner logic. Automatically handle - * generate() request on worker thread. - * - * @param modelFilePath - * @param tokenizerFilePath - * @param callback - */ - LlmModelRunner( - String modelFilePath, - String tokenizerFilePath, - float temperature, - LlmModelRunnerCallback callback) { - mModelFilePath = modelFilePath; - mTokenizerFilePath = tokenizerFilePath; - mCallback = callback; - - mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); - mHandlerThread = new HandlerThread("LlmModelRunner"); - mHandlerThread.start(); - mHandler = new LlmModelRunnerHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL); - } - - int generate(String prompt) { - Message msg = Message.obtain(mHandler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt); - msg.sendToTarget(); - return 0; - } - - void stop() { - mModule.stop(); - } - - @Override - public void onResult(String result) { - mCallback.onTokenGenerated(result); - } - - @Override - public void onStats(String result) { - mCallback.onStats(result); - } -} - -class LlmModelRunnerHandler extends Handler { - public static int MESSAGE_LOAD_MODEL = 1; - public static int MESSAGE_GENERATE = 2; - - private final LlmModelRunner mLlmModelRunner; - - public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { - super(looper); - mLlmModelRunner = llmModelRunner; - } - - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_LOAD_MODEL) { - int status = 0; - try { - mLlmModelRunner.mModule.load(); - } catch (Exception e) { - status = - (e instanceof org.pytorch.executorch.ExecutorchRuntimeException) - ? ((org.pytorch.executorch.ExecutorchRuntimeException) e).getErrorCode() - : -1; - } - mLlmModelRunner.mCallback.onModelLoaded(status); - } else if (msg.what == MESSAGE_GENERATE) { - try { - mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); - } catch (Exception e) { - Log.e("LlmModelRunner", "generate() failed", e); - } - mLlmModelRunner.mCallback.onGenerationStopped(); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.kt new file mode 100644 index 00000000000..29b9b177fb6 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.kt @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.os.Handler +import android.os.HandlerThread +import android.os.Looper +import android.os.Message +import android.util.Log +import org.pytorch.executorch.ExecutorchRuntimeException +import org.pytorch.executorch.extension.llm.LlmCallback +import org.pytorch.executorch.extension.llm.LlmModule + +/** A helper class to handle all model running logic within this class. */ +class LlmModelRunner( + modelFilePath: String, + tokenizerFilePath: String, + temperature: Float, + val callback: LlmModelRunnerCallback, +) : LlmCallback { + + val module: LlmModule = LlmModule(modelFilePath, tokenizerFilePath, temperature) + private val handlerThread: HandlerThread = HandlerThread("LlmModelRunner") + private val handler: Handler + + init { + handlerThread.start() + handler = LlmModelRunnerHandler(handlerThread.looper, this) + handler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL) + } + + fun generate(prompt: String): Int { + val msg = Message.obtain(handler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt) + msg.sendToTarget() + return 0 + } + + fun stop() { + module.stop() + } + + override fun onResult(result: String) { + callback.onTokenGenerated(result) + } + + override fun onStats(stats: String) { + callback.onStats(stats) + } +} + +private class LlmModelRunnerHandler( + looper: Looper, + private val runner: LlmModelRunner, +) : Handler(looper) { + + override fun handleMessage(msg: Message) { + when (msg.what) { + MESSAGE_LOAD_MODEL -> { + val status = + try { + runner.module.load() + 0 + } catch (e: ExecutorchRuntimeException) { + e.errorCode + } catch (e: Exception) { + -1 + } + runner.callback.onModelLoaded(status) + } + MESSAGE_GENERATE -> { + try { + runner.module.generate(msg.obj as String, runner) + } catch (e: Exception) { + Log.e("LlmModelRunner", "generate() failed", e) + } + runner.callback.onGenerationStopped() + } + } + } + + companion object { + const val MESSAGE_LOAD_MODEL = 1 + const val MESSAGE_GENERATE = 2 + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java deleted file mode 100644 index 915496a25af..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.os.Debug; -import java.io.File; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import org.pytorch.executorch.Module; - -public class ModelRunner { - /** - * @return list of #BenchmarkMetric - */ - public void runBenchmark( - File model, int numWarmupIter, int numIter, List results) { - long pssIdle = Debug.getPss(); - - List latency = new ArrayList<>(); - - long loadStart = System.nanoTime(); - Module module = Module.load(model.getPath()); - int errorCode = 0; - try { - module.loadMethod("forward"); - } catch (Exception e) { - errorCode = - (e instanceof org.pytorch.executorch.ExecutorchRuntimeException) - ? ((org.pytorch.executorch.ExecutorchRuntimeException) e).getErrorCode() - : -1; - } - long loadEnd = System.nanoTime(); - - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - - if (errorCode != 0) { - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); - results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); - module.destroy(); - return; - } - - try { - for (int i = 0; i < numWarmupIter; i++) { - module.forward(); - } - - for (int i = 0; i < numIter; i++) { - long start = System.nanoTime(); - module.forward(); - double forwardMs = (System.nanoTime() - start) * 1e-6; - latency.add(forwardMs); - } - - module.etdump(); - - // Currently the result has large variance from outliers, so only use - // 80% samples in the middle (trimmean 0.2) - Collections.sort(latency); - int resultSize = latency.size(); - List usedLatencyResults = latency.subList(resultSize / 10, resultSize * 9 / 10); - - results.add( - new BenchmarkMetric( - benchmarkModel, - "avg_inference_latency(ms)", - latency.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - results.add( - new BenchmarkMetric( - benchmarkModel, - "trimmean_inference_latency(ms)", - usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); - // RAM PSS usage - results.add( - new BenchmarkMetric( - benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); - } finally { - module.destroy(); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.kt new file mode 100644 index 00000000000..0f292b0d900 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.kt @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.os.Debug +import java.io.File +import org.pytorch.executorch.ExecutorchRuntimeException +import org.pytorch.executorch.Module + +class ModelRunner { + + fun runBenchmark( + model: File, + numWarmupIter: Int, + numIter: Int, + results: MutableList, + ) { + val pssIdle = Debug.getPss() + val latency = mutableListOf() + + val loadStart = System.nanoTime() + val module = Module.load(model.path) + var errorCode = 0 + try { + module.loadMethod("forward") + } catch (e: ExecutorchRuntimeException) { + errorCode = e.errorCode + } catch (e: Exception) { + errorCode = -1 + } + val loadEnd = System.nanoTime() + + val benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(model.name.removeSuffix(".pte")) + + if (errorCode != 0) { + results.add( + BenchmarkMetric(benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0)) + results.add(BenchmarkMetric(benchmarkModel, "load_status", errorCode.toDouble(), 0.0)) + module.destroy() + return + } + + try { + repeat(numWarmupIter) { module.forward() } + + repeat(numIter) { + val start = System.nanoTime() + module.forward() + latency.add((System.nanoTime() - start) * 1e-6) + } + + module.etdump() + + // Currently the result has large variance from outliers, so only use + // 80% samples in the middle (trimmean 0.2) + latency.sort() + val trimmed = latency.subList(latency.size / 10, latency.size * 9 / 10) + + results.add( + BenchmarkMetric( + benchmarkModel, + "avg_inference_latency(ms)", + latency.average(), + 0.0, + )) + results.add( + BenchmarkMetric( + benchmarkModel, + "trimmean_inference_latency(ms)", + trimmed.average(), + 0.0, + )) + results.add( + BenchmarkMetric(benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0)) + results.add(BenchmarkMetric(benchmarkModel, "load_status", errorCode.toDouble(), 0.0)) + results.add( + BenchmarkMetric( + benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024.0, 0.0)) + } finally { + module.destroy() + } + } +} diff --git a/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.java b/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.kt similarity index 55% rename from extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.java rename to extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.kt index c6a6a76a4d8..b98a49e4bf9 100644 --- a/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.java +++ b/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.kt @@ -6,20 +6,19 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.minibench; +package org.pytorch.minibench -import static org.junit.Assert.*; - -import org.junit.Test; +import org.junit.Assert.assertEquals +import org.junit.Test /** * Example local unit test, which will execute on the development machine (host). * - * @see Testing documentation + * @see [Testing documentation](http://d.android.com/tools/testing) */ -public class ExampleUnitTest { +class ExampleUnitTest { @Test - public void addition_isCorrect() { - assertEquals(4, 2 + 2); + fun addition_isCorrect() { + assertEquals(4, 2 + 2) } }