diff --git a/.github/workflows/upload-android-test-specs.yml b/.github/workflows/upload-android-test-specs.yml index 5a468da44f1..04f7cf40d73 100644 --- a/.github/workflows/upload-android-test-specs.yml +++ b/.github/workflows/upload-android-test-specs.yml @@ -41,7 +41,7 @@ jobs: with: # Just use a small model here with a minimal amount of configuration to test the spec models: stories110M - devices: samsung_galaxy_s2x + devices: samsung_galaxy_s22 delegates: xnnpack test_spec: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/android-llm-device-farm-test-spec.yml diff --git a/examples/demo-apps/android/LlamaDemo/android-llm-device-farm-test-spec.yml b/examples/demo-apps/android/LlamaDemo/android-llm-device-farm-test-spec.yml index cac83b8e6f5..896e7b73fbf 100644 --- a/examples/demo-apps/android/LlamaDemo/android-llm-device-farm-test-spec.yml +++ b/examples/demo-apps/android/LlamaDemo/android-llm-device-farm-test-spec.yml @@ -73,8 +73,30 @@ phases: fi fi; + # Run the new generic benchmark activity https://developer.android.com/tools/adb#am + - echo "Run LLM benchmark" + - | + adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n com.example.executorchllamademo/.LlmBenchmarkRunner \ + --es "model_dir" "/data/local/tmp/llama" \ + --es "tokenizer_path" "/data/local/tmp/llama/tokenizer.bin" + post_test: commands: + - echo "Gather LLM benchmark results" + - | + BENCHMARK_RESULTS="" + ATTEMPT=0 + MAX_ATTEMPT=10 + while [ -z "${BENCHMARK_RESULTS}" ] && [ $ATTEMPT -lt $MAX_ATTEMPT ]; do + echo "Waiting for benchmark results..." + BENCHMARK_RESULTS=$(adb -s $DEVICEFARM_DEVICE_UDID shell run-as com.example.executorchllamademo cat files/benchmark_results.json) + sleep 30 + ((ATTEMPT++)) + done + + adb -s $DEVICEFARM_DEVICE_UDID shell run-as com.example.executorchllamademo ls -la files/ + # Trying to pull the file using adb ends up with permission error, but this works too, so why not + echo "${BENCHMARK_RESULTS}" > $DEVICEFARM_LOG_DIR/benchmark_results.json artifacts: # By default, Device Farm will collect your artifacts from the $DEVICEFARM_LOG_DIR directory. diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java index 33b230b1dff..cee623507fd 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java @@ -14,8 +14,11 @@ import android.util.Log; import android.widget.TextView; import androidx.annotation.NonNull; +import com.google.gson.Gson; +import java.io.File; import java.io.FileWriter; import java.io.IOException; +import java.util.Arrays; public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback { ModelRunner mModelRunner; @@ -32,7 +35,12 @@ protected void onCreate(Bundle savedInstanceState) { Intent intent = getIntent(); - String modelPath = intent.getStringExtra("model_path"); + File modelDir = new File(intent.getStringExtra("model_dir")); + File model = + Arrays.stream(modelDir.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .findFirst() + .get(); String tokenizerPath = intent.getStringExtra("tokenizer_path"); float temperature = intent.getFloatExtra("temperature", 0.8f); @@ -42,7 +50,7 @@ protected void onCreate(Bundle savedInstanceState) { } mStatsDump = new StatsDump(); - mModelRunner = new ModelRunner(modelPath, tokenizerPath, temperature, this); + mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); mStatsDump.loadStart = System.currentTimeMillis(); } @@ -79,11 +87,21 @@ public void onGenerationStopped() { mTextView.append(mStatsDump.toString()); }); + // TODO (huydhn): Remove txt files here once the JSON format is ready try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { writer.write(mStatsDump.toString()); } catch (IOException e) { e.printStackTrace(); } + + // TODO (huydhn): Figure out on what the final JSON results looks like, we need something + // with the same number of fields as https://github.com/pytorch/pytorch/pull/135042 + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(mStatsDump)); + } catch (IOException e) { + e.printStackTrace(); + } } }