Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7271dff
try to merge jni
kirklandsign Aug 31, 2024
c81a8f1
new activity!
kirklandsign Aug 31, 2024
b406fbb
remove unused
kirklandsign Aug 31, 2024
11b3d6a
Remove API forwardOnes
kirklandsign Aug 31, 2024
08c7664
Merge remote-tracking branch 'origin/main' into experiments-jni
kirklandsign Sep 5, 2024
4b3437d
Merge remote-tracking branch 'origin/main' into experiments-jni
kirklandsign Sep 5, 2024
4259eb1
Remove managed_tensors
kirklandsign Sep 5, 2024
3516aae
[Android] Remove forwardOnes
kirklandsign Sep 6, 2024
0cc883e
fix build
kirklandsign Sep 6, 2024
5947dd5
Merge remote-tracking branch 'origin/main' into experiments-jni
kirklandsign Sep 6, 2024
7c369a3
copy qnn part
kirklandsign Sep 6, 2024
bab4d66
load method first
kirklandsign Sep 6, 2024
1d6a86e
Need to load method
kirklandsign Sep 6, 2024
38ae11c
linter
kirklandsign Sep 6, 2024
ff1fe3c
Merge branch 'android-api-change' into experiments-jni
kirklandsign Sep 6, 2024
7864c62
Merge remote-tracking branch 'origin/main' into experiments-jni
kirklandsign Sep 6, 2024
e384d07
add qnn
kirklandsign Sep 7, 2024
608e020
Merge remote-tracking branch 'origin/main' into experiments-jni
kirklandsign Sep 8, 2024
e688c2e
Merge branch 'experiments-jni' of github.com:kirklandsign/executorch …
kirklandsign Sep 8, 2024
be3abec
Merge remote-tracking branch 'origin/experiments-jni' into experiment…
kirklandsign Sep 8, 2024
2b1d2e7
Update Java Activity part
kirklandsign Sep 9, 2024
5152950
update cmake
kirklandsign Sep 9, 2024
19cec4e
add a way to build non llm
kirklandsign Sep 9, 2024
0265e6b
Remove reference libexecutorch_llama_jni.so
kirklandsign Sep 9, 2024
7163807
Remove app related stuff for now
kirklandsign Sep 9, 2024
fd13fcb
Copy LLM benchmarking activity from LlamaDemo app
kirklandsign Sep 9, 2024
6d268f3
Rename StatsInfo
kirklandsign Sep 9, 2024
73bdcd9
linter
kirklandsign Sep 9, 2024
4f59446
Link custom_ops
kirklandsign Sep 10, 2024
d85a621
Merge remote-tracking branch 'origin/main' into app-side-change
kirklandsign Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension/android/benchmark/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies {
implementation(files("libs/executorch.aar"))
implementation("com.facebook.soloader:soloader:0.10.5")
implementation("com.facebook.fbjni:fbjni:0.5.1")
implementation("com.google.code.gson:gson:2.8.6")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.2.1")
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
</intent-filter>
</activity>

<activity
android:name=".LlmBenchmarkActivity"
android:exported="true">
<intent-filter>
<action android:name="org.pytorch.minibench.BENCHMARK" />
</intent-filter>
</activity>

</application>

</manifest>
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.minibench;

import android.app.Activity;
import android.content.Intent;
import android.os.Bundle;
import android.util.Log;
import com.google.gson.Gson;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Arrays;

public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback {
ModelRunner mModelRunner;

String mPrompt;
StatsInfo mStatsInfo;

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);

Intent intent = getIntent();

File modelDir = new File(intent.getStringExtra("model_dir"));
File model =
Arrays.stream(modelDir.listFiles())
.filter(file -> file.getName().endsWith(".pte"))
.findFirst()
.get();
String tokenizerPath = intent.getStringExtra("tokenizer_path");

float temperature = intent.getFloatExtra("temperature", 0.8f);
mPrompt = intent.getStringExtra("prompt");
if (mPrompt == null) {
mPrompt = "The ultimate answer";
}

mStatsInfo = new StatsInfo();
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
mStatsInfo.loadStart = System.currentTimeMillis();
}

@Override
public void onModelLoaded(int status) {
mStatsInfo.loadEnd = System.currentTimeMillis();
if (status != 0) {
Log.e("LlmBenchmarkRunner", "Loaded failed: " + status);
onGenerationStopped();
return;
}
mStatsInfo.generateStart = System.currentTimeMillis();
mModelRunner.generate(mPrompt);
}

@Override
public void onTokenGenerated(String token) {}

@Override
public void onStats(String stats) {
mStatsInfo.tokens = stats;
}

@Override
public void onGenerationStopped() {
mStatsInfo.generateEnd = System.currentTimeMillis();

// TODO (huydhn): Remove txt files here once the JSON format is ready
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
writer.write(mStatsInfo.toString());
} catch (IOException e) {
e.printStackTrace();
}

// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(mStatsInfo));
} catch (IOException e) {
e.printStackTrace();
}
}
}

class StatsInfo {
long loadStart;
long loadEnd;
long generateStart;
long generateEnd;
String tokens;

@Override
public String toString() {
return "loadStart: "
+ loadStart
+ "\nloadEnd: "
+ loadEnd
+ "\ngenerateStart: "
+ generateStart
+ "\ngenerateEnd: "
+ generateEnd
+ "\n"
+ tokens;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.minibench;

import android.os.Handler;
import android.os.HandlerThread;
import android.os.Looper;
import android.os.Message;
import org.pytorch.executorch.LlamaCallback;
import org.pytorch.executorch.LlamaModule;

/** A helper class to handle all model running logic within this class. */
public class ModelRunner implements LlamaCallback {
LlamaModule mModule = null;

String mModelFilePath = "";
String mTokenizerFilePath = "";

ModelRunnerCallback mCallback = null;

HandlerThread mHandlerThread = null;
Handler mHandler = null;

/**
* ] Helper class to separate between UI logic and model runner logic. Automatically handle
* generate() request on worker thread.
*
* @param modelFilePath
* @param tokenizerFilePath
* @param callback
*/
ModelRunner(
String modelFilePath,
String tokenizerFilePath,
float temperature,
ModelRunnerCallback callback) {
mModelFilePath = modelFilePath;
mTokenizerFilePath = tokenizerFilePath;
mCallback = callback;

mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f);
mHandlerThread = new HandlerThread("ModelRunner");
mHandlerThread.start();
mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this);

mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL);
}

int generate(String prompt) {
Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt);
msg.sendToTarget();
return 0;
}

void stop() {
mModule.stop();
}

@Override
public void onResult(String result) {
mCallback.onTokenGenerated(result);
}

@Override
public void onStats(float tps) {
mCallback.onStats("tokens/second: " + tps);
}
}

class ModelRunnerHandler extends Handler {
public static int MESSAGE_LOAD_MODEL = 1;
public static int MESSAGE_GENERATE = 2;

private final ModelRunner mModelRunner;

public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) {
super(looper);
mModelRunner = modelRunner;
}

@Override
public void handleMessage(android.os.Message msg) {
if (msg.what == MESSAGE_LOAD_MODEL) {
int status = mModelRunner.mModule.load();
mModelRunner.mCallback.onModelLoaded(status);
} else if (msg.what == MESSAGE_GENERATE) {
mModelRunner.mModule.generate((String) msg.obj, mModelRunner);
mModelRunner.mCallback.onGenerationStopped();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.minibench;

/**
* A helper interface within the app for MainActivity and Benchmarking to handle callback from
* ModelRunner.
*/
public interface ModelRunnerCallback {

void onModelLoaded(int status);

void onTokenGenerated(String token);

void onStats(String token);

void onGenerationStopped();
}
Loading