diff --git a/torchchat/edge/android/torchchat/app/.gitignore b/torchchat/edge/android/torchchat/app/.gitignore
index 42afabfd2..796b96d1c 100644
--- a/torchchat/edge/android/torchchat/app/.gitignore
+++ b/torchchat/edge/android/torchchat/app/.gitignore
@@ -1 +1 @@
-/build
\ No newline at end of file
+/build
diff --git a/torchchat/edge/android/torchchat/app/build.gradle.kts b/torchchat/edge/android/torchchat/app/build.gradle.kts
index 1001c9f81..e0c9c196b 100644
--- a/torchchat/edge/android/torchchat/app/build.gradle.kts
+++ b/torchchat/edge/android/torchchat/app/build.gradle.kts
@@ -1,45 +1,97 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
plugins {
- id("com.android.application")
+ id("com.android.application")
+ id("org.jetbrains.kotlin.android")
}
android {
- namespace = "org.pytorch.torchchat"
- compileSdk = 33
+ namespace = "org.pytorch.torchchat"
+ compileSdk = 34
- defaultConfig {
- applicationId = "org.pytorch.torchchat"
- minSdk = 24
- targetSdk = 33
- versionCode = 1
- versionName = "1.0"
+ defaultConfig {
+ applicationId = "org.pytorch.torchchat"
+ minSdk = 28
+ targetSdk = 33
+ versionCode = 1
+ versionName = "1.0"
- testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
- }
+ testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
+ vectorDrawables { useSupportLibrary = true }
+ externalNativeBuild { cmake { cppFlags += "" } }
+ }
- buildTypes {
- release {
- isMinifyEnabled = false
- proguardFiles(
- getDefaultProguardFile("proguard-android-optimize.txt"),
- "proguard-rules.pro"
- )
- }
+ buildTypes {
+ release {
+ isMinifyEnabled = false
+ proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro")
}
- compileOptions {
- sourceCompatibility = JavaVersion.VERSION_1_8
- targetCompatibility = JavaVersion.VERSION_1_8
+ }
+ compileOptions {
+ sourceCompatibility = JavaVersion.VERSION_1_8
+ targetCompatibility = JavaVersion.VERSION_1_8
+ }
+ kotlinOptions { jvmTarget = "1.8" }
+ buildFeatures { compose = true }
+ composeOptions { kotlinCompilerExtensionVersion = "1.4.3" }
+ packaging { resources { excludes += "/META-INF/{AL2.0,LGPL2.1}" } }
+}
+
+dependencies {
+ implementation("androidx.core:core-ktx:1.9.0")
+ implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.1")
+ implementation("androidx.activity:activity-compose:1.7.0")
+ implementation(platform("androidx.compose:compose-bom:2023.03.00"))
+ implementation("androidx.compose.ui:ui")
+ implementation("androidx.compose.ui:ui-graphics")
+ implementation("androidx.compose.ui:ui-tooling-preview")
+ implementation("androidx.compose.material3:material3")
+ implementation("androidx.appcompat:appcompat:1.6.1")
+ implementation("androidx.camera:camera-core:1.3.0-rc02")
+ implementation("androidx.constraintlayout:constraintlayout:2.2.0-alpha12")
+ implementation("com.facebook.fbjni:fbjni:0.5.1")
+ implementation("com.google.code.gson:gson:2.8.6")
+ implementation(files("libs/executorch-llama.aar"))
+ implementation("com.google.android.material:material:1.12.0")
+ implementation("androidx.activity:activity:1.9.0")
+ testImplementation("junit:junit:4.13.2")
+ androidTestImplementation("androidx.test.ext:junit:1.1.5")
+ androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
+ androidTestImplementation(platform("androidx.compose:compose-bom:2023.03.00"))
+ androidTestImplementation("androidx.compose.ui:ui-test-junit4")
+ debugImplementation("androidx.compose.ui:ui-tooling")
+ debugImplementation("androidx.compose.ui:ui-test-manifest")
+}
+
+tasks.register("setup") {
+ doFirst {
+ exec {
+ commandLine("sh", "examples/demo-apps/android/LlamaDemo/setup.sh")
+ workingDir("../../../../../")
}
- buildFeatures {
- viewBinding = true
+ }
+}
+
+tasks.register("setupQnn") {
+ doFirst {
+ exec {
+ commandLine("sh", "examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh")
+ workingDir("../../../../../")
}
+ }
}
-dependencies {
- implementation("androidx.appcompat:appcompat:1.6.1")
- implementation("androidx.constraintlayout:constraintlayout:2.1.4")
- implementation("com.facebook.fbjni:fbjni:0.5.1")
- implementation(files("libs/executorch.aar"))
- testImplementation("junit:junit:4.13.2")
- androidTestImplementation("androidx.test.ext:junit:1.1.5")
- androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
+tasks.register("download_prebuilt_lib") {
+ doFirst {
+ exec {
+ commandLine("sh", "examples/demo-apps/android/LlamaDemo/download_prebuilt_lib.sh")
+ workingDir("../../../../../")
+ }
+ }
}
diff --git a/torchchat/edge/android/torchchat/app/proguard-rules.pro b/torchchat/edge/android/torchchat/app/proguard-rules.pro
index f1b424510..481bb4348 100644
--- a/torchchat/edge/android/torchchat/app/proguard-rules.pro
+++ b/torchchat/edge/android/torchchat/app/proguard-rules.pro
@@ -18,4 +18,4 @@
# If you keep the line number information, uncomment this to
# hide the original source file name.
-#-renamesourcefileattribute SourceFile
+#-renamesourcefileattribute SourceFile
\ No newline at end of file
diff --git a/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java b/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java
deleted file mode 100644
index e3b4e5019..000000000
--- a/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-package org.pytorch.torchchat;
-
-import androidx.test.ext.junit.runners.AndroidJUnit4;
-
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.pytorch.executorch.LlamaCallback;
-import org.pytorch.executorch.LlamaModule;
-
-import static org.junit.Assert.*;
-
-/**
- * Instrumented test, which will execute on an Android device.
- *
- * @see Testing documentation
- */
-@RunWith(AndroidJUnit4.class)
-public class LlamaModuleTest {
- @Test
- public void LlamaModule() {
- LlamaModule module = new LlamaModule("/data/local/tmp/llama/model.pte", "/data/local/tmp/llama/tokenizer.bin", 0.8f);
- assertEquals(module.load(), 0);
- MyLlamaCallback callback = new MyLlamaCallback();
- // Note: module.generate() is synchronous. Callback happens within the same thread as
- // generate() so when generate() returns, all callbacks are invoked.
- assertEquals(module.generate("Hey", callback), 0);
- assertNotEquals("", callback.result);
- }
-}
-
-/**
- * LlamaCallback for testing.
- *
- * Note: onResult() and onStats() are invoked within the same thread as LlamaModule.generate()
- *
- * @see LlamaCallback interface guide
- */
-class MyLlamaCallback implements LlamaCallback {
- String result = "";
- @Override
- public void onResult(String s) {
- result += s;
- }
-
- @Override
- public void onStats(float v) {
-
- }
-}
diff --git a/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/PerfTest.java b/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/PerfTest.java
new file mode 100644
index 000000000..e31eb7168
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/PerfTest.java
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+import android.os.Bundle;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.platform.app.InstrumentationRegistry;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.pytorch.executorch.LlamaCallback;
+import org.pytorch.executorch.LlamaModule;
+
+@RunWith(AndroidJUnit4.class)
+public class PerfTest implements LlamaCallback {
+
+ private static final String RESOURCE_PATH = "/data/local/tmp/llama/";
+ private static final String TOKENIZER_BIN = "tokenizer.bin";
+
+ private final List results = new ArrayList<>();
+ private final List tokensPerSecond = new ArrayList<>();
+
+ @Test
+ public void testTokensPerSecond() {
+ String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN;
+ // Find out the model name
+ File directory = new File(RESOURCE_PATH);
+ Arrays.stream(directory.listFiles())
+ .filter(file -> file.getName().endsWith(".pte"))
+ .forEach(
+ model -> {
+ LlamaModule mModule = new LlamaModule(model.getPath(), tokenizerPath, 0.8f);
+ // Print the model name because there might be more than one of them
+ report("ModelName", model.getName());
+
+ int loadResult = mModule.load();
+ // Check that the model can be load successfully
+ assertEquals(0, loadResult);
+
+ // Run a testing prompt
+ mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this);
+ assertFalse(tokensPerSecond.isEmpty());
+
+ final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1);
+ report("TPS", tps);
+ });
+ }
+
+ @Override
+ public void onResult(String result) {
+ results.add(result);
+ }
+
+ @Override
+ public void onStats(float tps) {
+ tokensPerSecond.add(tps);
+ }
+
+ private void report(final String metric, final Float value) {
+ Bundle bundle = new Bundle();
+ bundle.putFloat(metric, value);
+ InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle);
+ }
+
+ private void report(final String key, final String value) {
+ Bundle bundle = new Bundle();
+ bundle.putString(key, value);
+ InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle);
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/AndroidManifest.xml b/torchchat/edge/android/torchchat/app/src/main/AndroidManifest.xml
index 399bd1cbe..9fe80977a 100644
--- a/torchchat/edge/android/torchchat/app/src/main/AndroidManifest.xml
+++ b/torchchat/edge/android/torchchat/app/src/main/AndroidManifest.xml
@@ -1,35 +1,61 @@
-
+ xmlns:tools="http://schemas.android.com/tools"
+ package="org.pytorch.torchchat">
+
+
+
+
+
+
+
+
+ android:theme="@style/Theme.AppCompat.Light.NoActionBar"
+ tools:targetApi="34">
+
+
+
+
+
+ android:theme="@style/Theme.AppCompat.Light.NoActionBar">
+
+
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/AppLog.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/AppLog.java
new file mode 100644
index 000000000..cd8462aab
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/AppLog.java
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+import java.text.SimpleDateFormat;
+import java.util.Date;
+import java.util.Locale;
+
+public class AppLog {
+ private final Long timestamp;
+ private final String message;
+
+ public AppLog(String message) {
+ this.timestamp = getCurrentTimeStamp();
+ this.message = message;
+ }
+
+ public Long getTimestamp() {
+ return timestamp;
+ }
+
+ public String getMessage() {
+ return message;
+ }
+
+ public String getFormattedLog() {
+ return "[" + getFormattedTimeStamp() + "] " + message;
+ }
+
+ private Long getCurrentTimeStamp() {
+ return System.currentTimeMillis();
+ }
+
+ private String getFormattedTimeStamp() {
+ return formatDate(timestamp);
+ }
+
+ private String formatDate(long milliseconds) {
+ SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault());
+ Date date = new Date(milliseconds);
+ return formatter.format(date);
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/DemoSharedPreferences.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/DemoSharedPreferences.java
new file mode 100644
index 000000000..6d672a5b7
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/DemoSharedPreferences.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+import android.content.Context;
+import android.content.SharedPreferences;
+import com.google.gson.Gson;
+import com.google.gson.reflect.TypeToken;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+
+public class DemoSharedPreferences {
+ Context context;
+ SharedPreferences sharedPreferences;
+
+ public DemoSharedPreferences(Context context) {
+ this.context = context;
+ this.sharedPreferences = getSharedPrefs();
+ }
+
+ private SharedPreferences getSharedPrefs() {
+ return context.getSharedPreferences(
+ context.getString(R.string.demo_pref_file_key), Context.MODE_PRIVATE);
+ }
+
+ public String getSavedMessages() {
+ return sharedPreferences.getString(context.getString(R.string.saved_messages_json_key), "");
+ }
+
+ public void addMessages(MessageAdapter messageAdapter) {
+ SharedPreferences.Editor editor = sharedPreferences.edit();
+ Gson gson = new Gson();
+ String msgJSON = gson.toJson(messageAdapter.getSavedMessages());
+ editor.putString(context.getString(R.string.saved_messages_json_key), msgJSON);
+ editor.apply();
+ }
+
+ public void removeExistingMessages() {
+ SharedPreferences.Editor editor = sharedPreferences.edit();
+ editor.remove(context.getString(R.string.saved_messages_json_key));
+ editor.apply();
+ }
+
+ public void addSettings(SettingsFields settingsFields) {
+ SharedPreferences.Editor editor = sharedPreferences.edit();
+ Gson gson = new Gson();
+ String settingsJSON = gson.toJson(settingsFields);
+ editor.putString(context.getString(R.string.settings_json_key), settingsJSON);
+ editor.apply();
+ }
+
+ public String getSettings() {
+ return sharedPreferences.getString(context.getString(R.string.settings_json_key), "");
+ }
+
+ public void saveLogs() {
+ SharedPreferences.Editor editor = sharedPreferences.edit();
+ Gson gson = new Gson();
+ String msgJSON = gson.toJson(ETLogging.getInstance().getLogs());
+ editor.putString(context.getString(R.string.logs_json_key), msgJSON);
+ editor.apply();
+ }
+
+ public void removeExistingLogs() {
+ SharedPreferences.Editor editor = sharedPreferences.edit();
+ editor.remove(context.getString(R.string.logs_json_key));
+ editor.apply();
+ }
+
+ public ArrayList getSavedLogs() {
+ String logsJSONString =
+ sharedPreferences.getString(context.getString(R.string.logs_json_key), null);
+ if (logsJSONString == null || logsJSONString.isEmpty()) {
+ return new ArrayList<>();
+ }
+ Gson gson = new Gson();
+ Type type = new TypeToken>() {}.getType();
+ ArrayList appLogs = gson.fromJson(logsJSONString, type);
+ if (appLogs == null) {
+ return new ArrayList<>();
+ }
+ return appLogs;
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETImage.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETImage.java
new file mode 100644
index 000000000..28f0752a4
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETImage.java
@@ -0,0 +1,126 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+import android.content.ContentResolver;
+import android.graphics.Bitmap;
+import android.graphics.BitmapFactory;
+import android.graphics.Color;
+import android.net.Uri;
+import androidx.annotation.Nullable;
+import java.io.FileNotFoundException;
+import java.io.InputStream;
+
+public class ETImage {
+ private int width;
+ private int height;
+ private final byte[] bytes;
+ private final Uri uri;
+ private final ContentResolver contentResolver;
+
+ ETImage(ContentResolver contentResolver, Uri uri) {
+ this.contentResolver = contentResolver;
+ this.uri = uri;
+ bytes = getBytesFromImageURI(uri);
+ }
+
+ public int getWidth() {
+ return width;
+ }
+
+ public int getHeight() {
+ return height;
+ }
+
+ public Uri getUri() {
+ return uri;
+ }
+
+ public byte[] getBytes() {
+ return bytes;
+ }
+
+ public int[] getInts() {
+ // We need to convert the byte array to an int array because
+ // the runner expects an int array as input.
+ int[] intArray = new int[bytes.length];
+ for (int i = 0; i < bytes.length; i++) {
+ intArray[i] = (bytes[i++] & 0xFF);
+ }
+ return intArray;
+ }
+
+ private byte[] getBytesFromImageURI(Uri uri) {
+ try {
+ int RESIZED_IMAGE_WIDTH = 336;
+ Bitmap bitmap = resizeImage(uri, RESIZED_IMAGE_WIDTH);
+
+ if (bitmap == null) {
+ ETLogging.getInstance().log("Unable to get bytes from Image URI. Bitmap is null");
+ return new byte[0];
+ }
+
+ width = bitmap.getWidth();
+ height = bitmap.getHeight();
+
+ byte[] rgbValues = new byte[width * height * 3];
+
+ for (int y = 0; y < height; y++) {
+ for (int x = 0; x < width; x++) {
+ // Get the color of the current pixel
+ int color = bitmap.getPixel(x, y);
+
+ // Extract the RGB values from the color
+ int red = Color.red(color);
+ int green = Color.green(color);
+ int blue = Color.blue(color);
+
+ // Store the RGB values in the byte array
+ rgbValues[y * width + x] = (byte) red;
+ rgbValues[(y * width + x) + height * width] = (byte) green;
+ rgbValues[(y * width + x) + 2 * height * width] = (byte) blue;
+ }
+ }
+ return rgbValues;
+ } catch (FileNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Nullable
+ private Bitmap resizeImage(Uri uri, int maxLength) throws FileNotFoundException {
+ InputStream inputStream = contentResolver.openInputStream(uri);
+ if (inputStream == null) {
+ ETLogging.getInstance().log("Unable to resize image, input streams is null");
+ return null;
+ }
+ Bitmap bitmap = BitmapFactory.decodeStream(inputStream);
+ if (bitmap == null) {
+ ETLogging.getInstance().log("Unable to resize image, bitmap during decode stream is null");
+ return null;
+ }
+
+ float aspectRatio;
+ int finalWidth, finalHeight;
+
+ if (bitmap.getWidth() > bitmap.getHeight()) {
+ // width > height --> width = maxLength, height scale with aspect ratio
+ aspectRatio = bitmap.getWidth() / (float) bitmap.getHeight();
+ finalWidth = maxLength;
+ finalHeight = Math.round(maxLength / aspectRatio);
+ } else {
+ // height >= width --> height = maxLength, width scale with aspect ratio
+ aspectRatio = bitmap.getHeight() / (float) bitmap.getWidth();
+ finalHeight = maxLength;
+ finalWidth = Math.round(maxLength / aspectRatio);
+ }
+
+ return Bitmap.createScaledBitmap(bitmap, finalWidth, finalHeight, false);
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETLogging.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETLogging.java
new file mode 100644
index 000000000..c13abd02a
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETLogging.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+import android.app.Application;
+import android.util.Log;
+import java.util.ArrayList;
+
+public class ETLogging extends Application {
+ private static ETLogging singleton;
+
+ private ArrayList logs;
+ private DemoSharedPreferences mDemoSharedPreferences;
+
+ @Override
+ public void onCreate() {
+ super.onCreate();
+ singleton = this;
+ mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext());
+ logs = mDemoSharedPreferences.getSavedLogs();
+ if (logs == null) { // We don't have existing sharedPreference stored
+ logs = new ArrayList<>();
+ }
+ }
+
+ public static ETLogging getInstance() {
+ return singleton;
+ }
+
+ public void log(String message) {
+ AppLog appLog = new AppLog(message);
+ logs.add(appLog);
+ Log.d("ETLogging", appLog.getMessage());
+ }
+
+ public ArrayList getLogs() {
+ return logs;
+ }
+
+ public void clearLogs() {
+ logs.clear();
+ mDemoSharedPreferences.removeExistingLogs();
+ }
+
+ public void saveLogs() {
+ mDemoSharedPreferences.saveLogs();
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LlmBenchmarkRunner.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LlmBenchmarkRunner.java
new file mode 100644
index 000000000..d0a7cc677
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LlmBenchmarkRunner.java
@@ -0,0 +1,223 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+import android.app.Activity;
+import android.app.ActivityManager;
+import android.content.Intent;
+import android.os.Build;
+import android.os.Bundle;
+import android.util.Log;
+import android.widget.TextView;
+import androidx.annotation.NonNull;
+import com.google.gson.Gson;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback {
+ ModelRunner mModelRunner;
+
+ String mPrompt;
+ TextView mTextView;
+ StatsDump mStatsDump;
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_benchmarking);
+ mTextView = findViewById(R.id.log_view);
+
+ Intent intent = getIntent();
+
+ File modelDir = new File(intent.getStringExtra("model_dir"));
+ File model =
+ Arrays.stream(modelDir.listFiles())
+ .filter(file -> file.getName().endsWith(".pte"))
+ .findFirst()
+ .get();
+ String tokenizerPath = intent.getStringExtra("tokenizer_path");
+
+ float temperature = intent.getFloatExtra("temperature", 0.8f);
+ mPrompt = intent.getStringExtra("prompt");
+ if (mPrompt == null) {
+ mPrompt = "The ultimate answer";
+ }
+
+ mStatsDump = new StatsDump();
+ mStatsDump.modelName = model.getName().replace(".pte", "");
+ mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
+ mStatsDump.loadStart = System.nanoTime();
+ }
+
+ @Override
+ public void onModelLoaded(int status) {
+ mStatsDump.loadEnd = System.nanoTime();
+ mStatsDump.loadStatus = status;
+ if (status != 0) {
+ Log.e("LlmBenchmarkRunner", "Loaded failed: " + status);
+ onGenerationStopped();
+ return;
+ }
+ mStatsDump.generateStart = System.nanoTime();
+ mModelRunner.generate(mPrompt);
+ }
+
+ @Override
+ public void onTokenGenerated(String token) {
+ runOnUiThread(
+ () -> {
+ mTextView.append(token);
+ });
+ }
+
+ @Override
+ public void onStats(String stats) {
+ mStatsDump.tokens = stats;
+ }
+
+ @Override
+ public void onGenerationStopped() {
+ mStatsDump.generateEnd = System.nanoTime();
+ runOnUiThread(
+ () -> {
+ mTextView.append(mStatsDump.toString());
+ });
+
+ final BenchmarkMetric.BenchmarkModel benchmarkModel =
+ BenchmarkMetric.extractBackendAndQuantization(mStatsDump.modelName);
+ final List results = new ArrayList<>();
+ // The list of metrics we have atm includes:
+ // Load status
+ results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsDump.loadStatus, 0));
+ // Model load time
+ results.add(
+ new BenchmarkMetric(
+ benchmarkModel,
+ "model_load_time(ms)",
+ (mStatsDump.loadEnd - mStatsDump.loadStart) * 1e-6,
+ 0.0f));
+ // LLM generate time
+ results.add(
+ new BenchmarkMetric(
+ benchmarkModel,
+ "generate_time(ms)",
+ (mStatsDump.generateEnd - mStatsDump.generateStart) * 1e-6,
+ 0.0f));
+ // Token per second
+ results.add(
+ new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f));
+
+ try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
+ Gson gson = new Gson();
+ writer.write(gson.toJson(results));
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ private double extractTPS(final String tokens) {
+ final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens);
+ if (m.find()) {
+ return Double.parseDouble(m.group());
+ } else {
+ return 0.0f;
+ }
+ }
+}
+
+class BenchmarkMetric {
+ public static class BenchmarkModel {
+ // The model name, i.e. stories110M
+ String name;
+ String backend;
+ String quantization;
+
+ public BenchmarkModel(final String name, final String backend, final String quantization) {
+ this.name = name;
+ this.backend = backend;
+ this.quantization = quantization;
+ }
+ }
+
+ BenchmarkModel benchmarkModel;
+
+ // The metric name, i.e. TPS
+ String metric;
+
+ // The actual value and the option target value
+ double actualValue;
+ double targetValue;
+
+ public static class DeviceInfo {
+ // Let's see which information we want to include here
+ final String device = Build.BRAND;
+ // The phone model and Android release version
+ final String arch = Build.MODEL;
+ final String os = "Android " + Build.VERSION.RELEASE;
+ final long totalMem = new ActivityManager.MemoryInfo().totalMem;
+ final long availMem = new ActivityManager.MemoryInfo().availMem;
+ }
+
+ DeviceInfo deviceInfo = new DeviceInfo();
+
+ public BenchmarkMetric(
+ final BenchmarkModel benchmarkModel,
+ final String metric,
+ final double actualValue,
+ final double targetValue) {
+ this.benchmarkModel = benchmarkModel;
+ this.metric = metric;
+ this.actualValue = actualValue;
+ this.targetValue = targetValue;
+ }
+
+ // TODO (huydhn): Figure out a way to extract the backend and quantization information from
+ // the .pte model itself instead of parsing its name
+ public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
+ final Matcher m =
+ Pattern.compile("(?\\w+)_(?\\w+)_(?\\w+)").matcher(model);
+ if (m.matches()) {
+ return new BenchmarkMetric.BenchmarkModel(
+ m.group("name"), m.group("backend"), m.group("quantization"));
+ } else {
+ return new BenchmarkMetric.BenchmarkModel(model, "", "");
+ }
+ }
+}
+
+class StatsDump {
+ int loadStatus;
+ long loadStart;
+ long loadEnd;
+ long generateStart;
+ long generateEnd;
+ String tokens;
+ String modelName;
+
+ @NonNull
+ @Override
+ public String toString() {
+ return "loadStart: "
+ + loadStart
+ + "\nloadEnd: "
+ + loadEnd
+ + "\ngenerateStart: "
+ + generateStart
+ + "\ngenerateEnd: "
+ + generateEnd
+ + "\n"
+ + tokens;
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsActivity.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsActivity.java
new file mode 100644
index 000000000..f83d691c8
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsActivity.java
@@ -0,0 +1,92 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+import android.app.AlertDialog;
+import android.content.DialogInterface;
+import android.os.Build;
+import android.os.Bundle;
+import android.widget.ImageButton;
+import android.widget.ListView;
+import androidx.appcompat.app.AppCompatActivity;
+import androidx.core.content.ContextCompat;
+import androidx.core.graphics.Insets;
+import androidx.core.view.ViewCompat;
+import androidx.core.view.WindowInsetsCompat;
+
+public class LogsActivity extends AppCompatActivity {
+
+ private LogsAdapter mLogsAdapter;
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_logs);
+ if (Build.VERSION.SDK_INT >= 21) {
+ getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar));
+ getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar));
+ }
+ ViewCompat.setOnApplyWindowInsetsListener(
+ requireViewById(R.id.main),
+ (v, insets) -> {
+ Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars());
+ v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom);
+ return insets;
+ });
+
+ setupLogs();
+ setupClearLogsButton();
+ }
+
+ @Override
+ public void onResume() {
+ super.onResume();
+ mLogsAdapter.clear();
+ mLogsAdapter.addAll(ETLogging.getInstance().getLogs());
+ mLogsAdapter.notifyDataSetChanged();
+ }
+
+ private void setupLogs() {
+ ListView mLogsListView = requireViewById(R.id.logsListView);
+ mLogsAdapter = new LogsAdapter(this, R.layout.logs_message);
+
+ mLogsListView.setAdapter(mLogsAdapter);
+ mLogsAdapter.addAll(ETLogging.getInstance().getLogs());
+ mLogsAdapter.notifyDataSetChanged();
+ }
+
+ private void setupClearLogsButton() {
+ ImageButton clearLogsButton = requireViewById(R.id.clearLogsButton);
+ clearLogsButton.setOnClickListener(
+ view -> {
+ new AlertDialog.Builder(this)
+ .setTitle("Delete Logs History")
+ .setMessage("Do you really want to delete logs history?")
+ .setIcon(android.R.drawable.ic_dialog_alert)
+ .setPositiveButton(
+ android.R.string.yes,
+ new DialogInterface.OnClickListener() {
+ public void onClick(DialogInterface dialog, int whichButton) {
+ // Clear the messageAdapter and sharedPreference
+ ETLogging.getInstance().clearLogs();
+ mLogsAdapter.clear();
+ mLogsAdapter.notifyDataSetChanged();
+ }
+ })
+ .setNegativeButton(android.R.string.no, null)
+ .show();
+ });
+ }
+
+ @Override
+ protected void onDestroy() {
+ super.onDestroy();
+ ETLogging.getInstance().saveLogs();
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsAdapter.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsAdapter.java
new file mode 100644
index 000000000..a793644ed
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsAdapter.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+import android.view.LayoutInflater;
+import android.view.View;
+import android.view.ViewGroup;
+import android.widget.ArrayAdapter;
+import android.widget.TextView;
+import androidx.annotation.NonNull;
+import java.util.Objects;
+
+public class LogsAdapter extends ArrayAdapter {
+ public LogsAdapter(android.content.Context context, int resource) {
+ super(context, resource);
+ }
+
+ static class ViewHolder {
+ private TextView logTextView;
+ }
+
+ @NonNull
+ @Override
+ public View getView(int position, View convertView, @NonNull ViewGroup parent) {
+ ViewHolder mViewHolder = null;
+
+ String logMessage = Objects.requireNonNull(getItem(position)).getFormattedLog();
+
+ if (convertView == null || convertView.getTag() == null) {
+ mViewHolder = new ViewHolder();
+ convertView = LayoutInflater.from(getContext()).inflate(R.layout.logs_message, parent, false);
+ mViewHolder.logTextView = convertView.requireViewById(R.id.logsTextView);
+ } else {
+ mViewHolder = (ViewHolder) convertView.getTag();
+ }
+ mViewHolder.logTextView.setText(logMessage);
+ return convertView;
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MainActivity.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MainActivity.java
index 4be7303c3..3c7dd55c7 100644
--- a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MainActivity.java
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MainActivity.java
@@ -8,245 +8,768 @@
package org.pytorch.torchchat;
-import android.app.Activity;
+import android.Manifest;
import android.app.ActivityManager;
import android.app.AlertDialog;
-import android.content.Context;
+import android.content.ContentResolver;
+import android.content.ContentValues;
+import android.content.Intent;
+import android.content.pm.PackageManager;
+import android.net.Uri;
+import android.os.Build;
import android.os.Bundle;
-import android.widget.Button;
+import android.os.Handler;
+import android.os.Looper;
+import android.os.Process;
+import android.provider.MediaStore;
+import android.system.ErrnoException;
+import android.system.Os;
+import android.util.Log;
+import android.view.View;
import android.widget.EditText;
import android.widget.ImageButton;
+import android.widget.ImageView;
+import android.widget.LinearLayout;
import android.widget.ListView;
+import android.widget.TextView;
import android.widget.Toast;
-
-import java.io.File;
+import androidx.activity.result.ActivityResultLauncher;
+import androidx.activity.result.PickVisualMediaRequest;
+import androidx.activity.result.contract.ActivityResultContracts;
+import androidx.annotation.NonNull;
+import androidx.appcompat.app.AppCompatActivity;
+import androidx.constraintlayout.widget.ConstraintLayout;
+import androidx.core.app.ActivityCompat;
+import androidx.core.content.ContextCompat;
+import com.google.gson.Gson;
+import com.google.gson.reflect.TypeToken;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
import org.pytorch.executorch.LlamaCallback;
import org.pytorch.executorch.LlamaModule;
-public class MainActivity extends Activity implements Runnable, LlamaCallback {
- private EditText mEditTextMessage;
- private Button mSendButton;
- private ImageButton mModelButton;
- private ListView mMessagesView;
- private MessageAdapter mMessageAdapter;
- private LlamaModule mModule = null;
- private Message mResultMessage = null;
-
- private String mModelFilePath = "";
- private String mTokenizerFilePath = "";
-
- @Override
- public void onResult(String result) {
- if (result.startsWith("<") && result.endsWith(">")) {
- return;
- }
+public class MainActivity extends AppCompatActivity implements Runnable, LlamaCallback {
+ private EditText mEditTextMessage;
+ private ImageButton mSendButton;
+ private ImageButton mGalleryButton;
+ private ImageButton mCameraButton;
+ private ListView mMessagesView;
+ private MessageAdapter mMessageAdapter;
+ private LlamaModule mModule = null;
+ private Message mResultMessage = null;
+ private ImageButton mSettingsButton;
+ private TextView mMemoryView;
+ private ActivityResultLauncher mPickGallery;
+ private ActivityResultLauncher mCameraRoll;
+ private List mSelectedImageUri;
+ private ConstraintLayout mMediaPreviewConstraintLayout;
+ private LinearLayout mAddMediaLayout;
+ private static final int MAX_NUM_OF_IMAGES = 5;
+ private static final int REQUEST_IMAGE_CAPTURE = 1;
+ private Uri cameraImageUri;
+ private DemoSharedPreferences mDemoSharedPreferences;
+ private SettingsFields mCurrentSettingsFields;
+ private Handler mMemoryUpdateHandler;
+ private Runnable memoryUpdater;
+ private int promptID = 0;
+ private long startPos = 0;
+ private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2;
+ private Executor executor;
+
+ @Override
+ public void onResult(String result) {
+ if (result.equals(PromptFormat.getStopToken(mCurrentSettingsFields.getModelType()))) {
+ return;
+ }
+ if (result.equals("\n\n") || result.equals("\n")) {
+ if (!mResultMessage.getText().isEmpty()) {
mResultMessage.appendText(result);
run();
+ }
+ } else {
+ mResultMessage.appendText(result);
+ run();
}
+ }
- @Override
- public void onStats(float tps) {
- runOnUiThread(
- () -> {
- if (mResultMessage != null) {
- mResultMessage.setTokensPerSecond(tps);
- mMessageAdapter.notifyDataSetChanged();
- }
- });
- }
-
- private static String[] listLocalFile(String path, String suffix) {
- File directory = new File(path);
- if (directory.exists() && directory.isDirectory()) {
- File[] files = directory.listFiles((dir, name) -> name.toLowerCase().endsWith(suffix));
- String[] result = new String[files.length];
- for (int i = 0; i < files.length; i++) {
- if (files[i].isFile() && files[i].getName().endsWith(suffix)) {
- result[i] = files[i].getAbsolutePath();
- }
- }
- return result;
- }
- return null;
- }
-
- private void setLocalModel(String modelPath, String tokenizerPath) {
- Message modelLoadingMessage = new Message("Loading model...", false);
- runOnUiThread(
- () -> {
- mSendButton.setEnabled(false);
- mMessageAdapter.add(modelLoadingMessage);
- mMessageAdapter.notifyDataSetChanged();
- });
- long runStartTime = System.currentTimeMillis();
- mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f);
- int loadResult = mModule.load();
- if (loadResult != 0) {
- AlertDialog.Builder builder = new AlertDialog.Builder(this);
- builder.setTitle("Load failed: " + loadResult);
- runOnUiThread(
- () -> {
- AlertDialog alert = builder.create();
- alert.show();
- });
- }
+ @Override
+ public void onStats(float tps) {
+ runOnUiThread(
+ () -> {
+ if (mResultMessage != null) {
+ mResultMessage.setTokensPerSecond(tps);
+ mMessageAdapter.notifyDataSetChanged();
+ }
+ });
+ }
+
+ private void setLocalModel(String modelPath, String tokenizerPath, float temperature) {
+ Message modelLoadingMessage = new Message("Loading model...", false, MessageType.SYSTEM, 0);
+ ETLogging.getInstance().log("Loading model " + modelPath + " with tokenizer " + tokenizerPath);
+ runOnUiThread(
+ () -> {
+ mSendButton.setEnabled(false);
+ mMessageAdapter.add(modelLoadingMessage);
+ mMessageAdapter.notifyDataSetChanged();
+ });
+ if (mModule != null) {
+ ETLogging.getInstance().log("Start deallocating existing module instance");
+ mModule.resetNative();
+ mModule = null;
+ ETLogging.getInstance().log("Completed deallocating existing module instance");
+ }
+ long runStartTime = System.currentTimeMillis();
+ mModule =
+ new LlamaModule(
+ ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()),
+ modelPath,
+ tokenizerPath,
+ temperature);
+ int loadResult = mModule.load();
+ long loadDuration = System.currentTimeMillis() - runStartTime;
+ String modelLoadError = "";
+ String modelInfo = "";
+ if (loadResult != 0) {
+ // TODO: Map the error code to a reason to let the user know why model loading failed
+ modelInfo = "*Model could not load (Error Code: " + loadResult + ")*" + "\n";
+ loadDuration = 0;
+ AlertDialog.Builder builder = new AlertDialog.Builder(this);
+ builder.setTitle("Load failed: " + loadResult);
+ runOnUiThread(
+ () -> {
+ AlertDialog alert = builder.create();
+ alert.show();
+ });
+ } else {
+ String[] segments = modelPath.split("/");
+ String pteName = segments[segments.length - 1];
+ segments = tokenizerPath.split("/");
+ String tokenizerName = segments[segments.length - 1];
+ modelInfo =
+ "Successfully loaded model. "
+ + pteName
+ + " and tokenizer "
+ + tokenizerName
+ + " in "
+ + (float) loadDuration / 1000
+ + " sec."
+ + " You can send text or image for inference";
+
+ if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) {
+ ETLogging.getInstance().log("Llava start prefill prompt");
+ startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt(), 0, 1, 0);
+ ETLogging.getInstance().log("Llava completes prefill prompt");
+ }
+ }
+
+ Message modelLoadedMessage = new Message(modelInfo, false, MessageType.SYSTEM, 0);
+
+ String modelLoggingInfo =
+ modelLoadError
+ + "Model path: "
+ + modelPath
+ + "\nTokenizer path: "
+ + tokenizerPath
+ + "\nTemperature: "
+ + temperature
+ + "\nModel loaded time: "
+ + loadDuration
+ + " ms";
+ ETLogging.getInstance().log("Load complete. " + modelLoggingInfo);
+
+ runOnUiThread(
+ () -> {
+ mSendButton.setEnabled(true);
+ mMessageAdapter.remove(modelLoadingMessage);
+ mMessageAdapter.add(modelLoadedMessage);
+ mMessageAdapter.notifyDataSetChanged();
+ });
+ }
+
+ private void loadLocalModelAndParameters(
+ String modelFilePath, String tokenizerFilePath, float temperature) {
+ Runnable runnable =
+ new Runnable() {
+ @Override
+ public void run() {
+ setLocalModel(modelFilePath, tokenizerFilePath, temperature);
+ }
+ };
+ new Thread(runnable).start();
+ }
+
+ private void populateExistingMessages(String existingMsgJSON) {
+ Gson gson = new Gson();
+ Type type = new TypeToken>() {}.getType();
+ ArrayList savedMessages = gson.fromJson(existingMsgJSON, type);
+ for (Message msg : savedMessages) {
+ mMessageAdapter.add(msg);
+ }
+ mMessageAdapter.notifyDataSetChanged();
+ }
+
+ private int setPromptID() {
+
+ return mMessageAdapter.getMaxPromptID() + 1;
+ }
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_main);
+
+ if (Build.VERSION.SDK_INT >= 21) {
+ getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar));
+ getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar));
+ }
+
+ try {
+ Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true);
+ } catch (ErrnoException e) {
+ finish();
+ }
+
+ mEditTextMessage = requireViewById(R.id.editTextMessage);
+ mSendButton = requireViewById(R.id.sendButton);
+ mSendButton.setEnabled(false);
+ mMessagesView = requireViewById(R.id.messages_view);
+ mMessageAdapter = new MessageAdapter(this, R.layout.sent_message, new ArrayList());
+ mMessagesView.setAdapter(mMessageAdapter);
+ mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext());
+ String existingMsgJSON = mDemoSharedPreferences.getSavedMessages();
+ if (!existingMsgJSON.isEmpty()) {
+ populateExistingMessages(existingMsgJSON);
+ promptID = setPromptID();
+ }
+ mSettingsButton = requireViewById(R.id.settings);
+ mSettingsButton.setOnClickListener(
+ view -> {
+ Intent myIntent = new Intent(MainActivity.this, SettingsActivity.class);
+ MainActivity.this.startActivity(myIntent);
+ });
+
+ mCurrentSettingsFields = new SettingsFields();
+ mMemoryUpdateHandler = new Handler(Looper.getMainLooper());
+ onModelRunStopped();
+ setupMediaButton();
+ setupGalleryPicker();
+ setupCameraRoll();
+ startMemoryUpdate();
+ setupShowLogsButton();
+ executor = Executors.newSingleThreadExecutor();
+ }
- long loadDuration = System.currentTimeMillis() - runStartTime;
- String modelInfo =
- "Model path: "
- + modelPath
- + "\nTokenizer path: "
- + tokenizerPath
- + "\nModel loaded time: "
- + loadDuration
- + " ms";
- Message modelLoadedMessage = new Message(modelInfo, false);
- runOnUiThread(
- () -> {
- mSendButton.setEnabled(true);
- mMessageAdapter.remove(modelLoadingMessage);
- mMessageAdapter.add(modelLoadedMessage);
- mMessageAdapter.notifyDataSetChanged();
- });
- }
-
- private String memoryInfo() {
- final ActivityManager am = (ActivityManager) getSystemService(Context.ACTIVITY_SERVICE);
- ActivityManager.MemoryInfo memInfo = new ActivityManager.MemoryInfo();
- am.getMemoryInfo(memInfo);
- return "Total RAM: "
- + Math.floorDiv(memInfo.totalMem, 1000000)
- + " MB. Available RAM: "
- + Math.floorDiv(memInfo.availMem, 1000000)
- + " MB.";
- }
-
- private void modelDialog() {
- String[] pteFiles = listLocalFile("/data/local/tmp/llama/", ".pte");
- String[] binFiles = listLocalFile("/data/local/tmp/llama/", ".bin");
- String[] modelFiles = listLocalFile("/data/local/tmp/llama/", ".model");
- if (pteFiles == null || binFiles == null || modelFiles == null) {
- Toast.makeText(this,
- "Please create directory /data/local/tmp/llama/ first",
- Toast.LENGTH_LONG).show();
- return;
+ @Override
+ protected void onPause() {
+ super.onPause();
+ mDemoSharedPreferences.addMessages(mMessageAdapter);
+ }
+
+ @Override
+ protected void onResume() {
+ super.onResume();
+ // Check for if settings parameters have changed
+ Gson gson = new Gson();
+ String settingsFieldsJSON = mDemoSharedPreferences.getSettings();
+ if (!settingsFieldsJSON.isEmpty()) {
+ SettingsFields updatedSettingsFields =
+ gson.fromJson(settingsFieldsJSON, SettingsFields.class);
+ if (updatedSettingsFields == null) {
+ // Added this check, because gson.fromJson can return null
+ askUserToSelectModel();
+ return;
+ }
+ boolean isUpdated = !mCurrentSettingsFields.equals(updatedSettingsFields);
+ boolean isLoadModel = updatedSettingsFields.getIsLoadModel();
+ if (isUpdated) {
+ if (isLoadModel) {
+ // If users change the model file, but not pressing loadModelButton, we won't load the new
+ // model
+ checkForUpdateAndReloadModel(updatedSettingsFields);
+ } else {
+ askUserToSelectModel();
}
- String[] tokenizerFiles = new String[binFiles.length + modelFiles.length];
- System.arraycopy(binFiles, 0, tokenizerFiles, 0, binFiles.length);
- System.arraycopy(modelFiles, 0, tokenizerFiles, binFiles.length, modelFiles.length);
- AlertDialog.Builder modelPathBuilder = new AlertDialog.Builder(this);
- modelPathBuilder.setTitle("Select model path");
- AlertDialog.Builder tokenizerPathBuilder = new AlertDialog.Builder(this);
- tokenizerPathBuilder.setTitle("Select tokenizer path");
- modelPathBuilder.setSingleChoiceItems(
- pteFiles,
- -1,
- (dialog, item) -> {
- mModelFilePath = pteFiles[item];
- mEditTextMessage.setText("");
- dialog.dismiss();
- tokenizerPathBuilder.create().show();
- });
-
- tokenizerPathBuilder.setSingleChoiceItems(
- tokenizerFiles,
- -1,
- (dialog, item) -> {
- mTokenizerFilePath = tokenizerFiles[item];
- Runnable runnable =
- new Runnable() {
- @Override
- public void run() {
- setLocalModel(mModelFilePath, mTokenizerFilePath);
- }
- };
- new Thread(runnable).start();
- dialog.dismiss();
- });
-
- modelPathBuilder.create().show();
- }
-
- @Override
- protected void onCreate(Bundle savedInstanceState) {
- super.onCreate(savedInstanceState);
- setContentView(R.layout.activity_main);
-
- mEditTextMessage = findViewById(R.id.editTextMessage);
- mSendButton = findViewById(R.id.sendButton);
- mSendButton.setEnabled(false);
- mModelButton = findViewById(R.id.modelButton);
- mMessagesView = findViewById(R.id.messages_view);
- mMessageAdapter = new MessageAdapter(this, R.layout.sent_message);
- mMessagesView.setAdapter(mMessageAdapter);
- mModelButton.setOnClickListener(
- view -> {
- mModule.stop();
- mMessageAdapter.clear();
- mMessageAdapter.notifyDataSetChanged();
- modelDialog();
- });
-
- onModelRunStopped();
- modelDialog();
- }
-
- private void onModelRunStarted() {
- mSendButton.setText("Stop");
- mSendButton.setOnClickListener(
- view -> {
- mModule.stop();
- });
- }
-
- private void onModelRunStopped() {
- setTitle(memoryInfo());
- mSendButton.setText("Generate");
- mSendButton.setOnClickListener(
- view -> {
- String prompt = mEditTextMessage.getText().toString();
- mMessageAdapter.add(new Message(prompt, true));
- mMessageAdapter.notifyDataSetChanged();
- mEditTextMessage.setText("");
- mResultMessage = new Message("", false);
- mMessageAdapter.add(mResultMessage);
- Runnable runnable =
- new Runnable() {
- @Override
- public void run() {
- runOnUiThread(
- new Runnable() {
- @Override
- public void run() {
- onModelRunStarted();
- }
- });
-
- mModule.generate(prompt, MainActivity.this);
-
- runOnUiThread(
- new Runnable() {
- @Override
- public void run() {
- onModelRunStopped();
- }
- });
- }
- };
- new Thread(runnable).start();
- });
+ checkForClearChatHistory(updatedSettingsFields);
+ // Update current to point to the latest
+ mCurrentSettingsFields = new SettingsFields(updatedSettingsFields);
+ }
+ } else {
+ askUserToSelectModel();
+ }
+ }
+
+ private void checkForClearChatHistory(SettingsFields updatedSettingsFields) {
+ if (updatedSettingsFields.getIsClearChatHistory()) {
+ mMessageAdapter.clear();
+ mMessageAdapter.notifyDataSetChanged();
+ mDemoSharedPreferences.removeExistingMessages();
+ // changing to false since chat history has been cleared.
+ updatedSettingsFields.saveIsClearChatHistory(false);
+ mDemoSharedPreferences.addSettings(updatedSettingsFields);
+ }
+ }
+
+ private void checkForUpdateAndReloadModel(SettingsFields updatedSettingsFields) {
+ // TODO need to add 'load model' in settings and queue loading based on that
+ String modelPath = updatedSettingsFields.getModelFilePath();
+ String tokenizerPath = updatedSettingsFields.getTokenizerFilePath();
+ double temperature = updatedSettingsFields.getTemperature();
+ if (!modelPath.isEmpty() && !tokenizerPath.isEmpty()) {
+ if (updatedSettingsFields.getIsLoadModel()
+ || !modelPath.equals(mCurrentSettingsFields.getModelFilePath())
+ || !tokenizerPath.equals(mCurrentSettingsFields.getTokenizerFilePath())
+ || temperature != mCurrentSettingsFields.getTemperature()) {
+ loadLocalModelAndParameters(
+ updatedSettingsFields.getModelFilePath(),
+ updatedSettingsFields.getTokenizerFilePath(),
+ (float) updatedSettingsFields.getTemperature());
+ updatedSettingsFields.saveLoadModelAction(false);
+ mDemoSharedPreferences.addSettings(updatedSettingsFields);
+ }
+ } else {
+ askUserToSelectModel();
+ }
+ }
+
+ private void askUserToSelectModel() {
+ String askLoadModel =
+ "To get started, select your desired model and tokenizer " + "from the top right corner";
+ Message askLoadModelMessage = new Message(askLoadModel, false, MessageType.SYSTEM, 0);
+ ETLogging.getInstance().log(askLoadModel);
+ runOnUiThread(
+ () -> {
+ mMessageAdapter.add(askLoadModelMessage);
+ mMessageAdapter.notifyDataSetChanged();
+ });
+ }
+
+ private void setupShowLogsButton() {
+ ImageButton showLogsButton = requireViewById(R.id.showLogsButton);
+ showLogsButton.setOnClickListener(
+ view -> {
+ Intent myIntent = new Intent(MainActivity.this, LogsActivity.class);
+ MainActivity.this.startActivity(myIntent);
+ });
+ }
+
+ private void setupMediaButton() {
+ mAddMediaLayout = requireViewById(R.id.addMediaLayout);
+ mAddMediaLayout.setVisibility(View.GONE); // We hide this initially
+
+ ImageButton addMediaButton = requireViewById(R.id.addMediaButton);
+ addMediaButton.setOnClickListener(
+ view -> {
+ mAddMediaLayout.setVisibility(View.VISIBLE);
+ });
+
+ mGalleryButton = requireViewById(R.id.galleryButton);
+ mGalleryButton.setOnClickListener(
+ view -> {
+ // Launch the photo picker and let the user choose only images.
+ mPickGallery.launch(
+ new PickVisualMediaRequest.Builder()
+ .setMediaType(ActivityResultContracts.PickVisualMedia.ImageOnly.INSTANCE)
+ .build());
+ });
+ mCameraButton = requireViewById(R.id.cameraButton);
+ mCameraButton.setOnClickListener(
+ view -> {
+ Log.d("CameraRoll", "Check permission");
+ if (ContextCompat.checkSelfPermission(MainActivity.this, Manifest.permission.CAMERA)
+ != PackageManager.PERMISSION_GRANTED) {
+ ActivityCompat.requestPermissions(
+ MainActivity.this,
+ new String[] {Manifest.permission.CAMERA},
+ REQUEST_IMAGE_CAPTURE);
+ } else {
+ launchCamera();
+ }
+ });
+ }
+
+ private void setupCameraRoll() {
+ // Registers a camera roll activity launcher.
+ mCameraRoll =
+ registerForActivityResult(
+ new ActivityResultContracts.TakePicture(),
+ result -> {
+ if (result && cameraImageUri != null) {
+ Log.d("CameraRoll", "Photo saved to uri: " + cameraImageUri);
+ mAddMediaLayout.setVisibility(View.GONE);
+ List uris = new ArrayList<>();
+ uris.add(cameraImageUri);
+ showMediaPreview(uris);
+ } else {
+ // Delete the temp image file based on the url since the photo is not successfully
+ // taken
+ if (cameraImageUri != null) {
+ ContentResolver contentResolver = MainActivity.this.getContentResolver();
+ contentResolver.delete(cameraImageUri, null, null);
+ Log.d("CameraRoll", "No photo taken. Delete temp uri");
+ }
+ }
+ });
+ mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout);
+ ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton);
+ mediaPreviewCloseButton.setOnClickListener(
+ view -> {
+ mMediaPreviewConstraintLayout.setVisibility(View.GONE);
+ mSelectedImageUri = null;
+ });
+
+ ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton);
+ addMoreImageButton.setOnClickListener(
+ view -> {
+ Log.d("addMore", "clicked");
+ mMediaPreviewConstraintLayout.setVisibility(View.GONE);
+ // Direct user to select type of input
+ mCameraButton.callOnClick();
+ });
+ }
+
+ private String updateMemoryUsage() {
+ ActivityManager.MemoryInfo memoryInfo = new ActivityManager.MemoryInfo();
+ ActivityManager activityManager = (ActivityManager) getSystemService(ACTIVITY_SERVICE);
+ if (activityManager == null) {
+ return "---";
+ }
+ activityManager.getMemoryInfo(memoryInfo);
+ long totalMem = memoryInfo.totalMem / (1024 * 1024);
+ long availableMem = memoryInfo.availMem / (1024 * 1024);
+ long usedMem = totalMem - availableMem;
+ return usedMem + "MB";
+ }
+
+ private void startMemoryUpdate() {
+ mMemoryView = requireViewById(R.id.ram_usage_live);
+ memoryUpdater =
+ new Runnable() {
+ @Override
+ public void run() {
+ mMemoryView.setText(updateMemoryUsage());
+ mMemoryUpdateHandler.postDelayed(this, 1000);
+ }
+ };
+ mMemoryUpdateHandler.post(memoryUpdater);
+ }
+
+ @Override
+ public void onRequestPermissionsResult(
+ int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
+ super.onRequestPermissionsResult(requestCode, permissions, grantResults);
+ if (requestCode == REQUEST_IMAGE_CAPTURE && grantResults.length != 0) {
+ if (grantResults[0] == PackageManager.PERMISSION_GRANTED) {
+ launchCamera();
+ } else if (grantResults[0] == PackageManager.PERMISSION_DENIED) {
+ Log.d("CameraRoll", "Permission denied");
+ }
+ }
+ }
+
+ private void launchCamera() {
+ ContentValues values = new ContentValues();
+ values.put(MediaStore.Images.Media.TITLE, "New Picture");
+ values.put(MediaStore.Images.Media.DESCRIPTION, "From Camera");
+ values.put(MediaStore.Images.Media.RELATIVE_PATH, "DCIM/Camera/");
+ cameraImageUri =
+ MainActivity.this
+ .getContentResolver()
+ .insert(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, values);
+ mCameraRoll.launch(cameraImageUri);
+ }
+
+ private void setupGalleryPicker() {
+ // Registers a photo picker activity launcher in single-select mode.
+ mPickGallery =
+ registerForActivityResult(
+ new ActivityResultContracts.PickMultipleVisualMedia(MAX_NUM_OF_IMAGES),
+ uris -> {
+ if (!uris.isEmpty()) {
+ Log.d("PhotoPicker", "Selected URIs: " + uris);
+ mAddMediaLayout.setVisibility(View.GONE);
+ for (Uri uri : uris) {
+ MainActivity.this
+ .getContentResolver()
+ .takePersistableUriPermission(uri, Intent.FLAG_GRANT_READ_URI_PERMISSION);
+ }
+ showMediaPreview(uris);
+ } else {
+ Log.d("PhotoPicker", "No media selected");
+ }
+ });
+
+ mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout);
+ ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton);
+ mediaPreviewCloseButton.setOnClickListener(
+ view -> {
+ mMediaPreviewConstraintLayout.setVisibility(View.GONE);
+ mSelectedImageUri = null;
+ });
+
+ ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton);
+ addMoreImageButton.setOnClickListener(
+ view -> {
+ Log.d("addMore", "clicked");
+ mMediaPreviewConstraintLayout.setVisibility(View.GONE);
+ mGalleryButton.callOnClick();
+ });
+ }
+
+ private List getProcessedImagesForModel(List uris) {
+ List imageList = new ArrayList<>();
+ if (uris != null) {
+ uris.forEach(
+ (uri) -> {
+ imageList.add(new ETImage(this.getContentResolver(), uri));
+ });
+ }
+ return imageList;
+ }
+
+ private void showMediaPreview(List uris) {
+ if (mSelectedImageUri == null) {
+ mSelectedImageUri = uris;
+ } else {
+ mSelectedImageUri.addAll(uris);
+ }
+
+ if (mSelectedImageUri.size() > MAX_NUM_OF_IMAGES) {
+ mSelectedImageUri = mSelectedImageUri.subList(0, MAX_NUM_OF_IMAGES);
+ Toast.makeText(
+ this, "Only max " + MAX_NUM_OF_IMAGES + " images are allowed", Toast.LENGTH_SHORT)
+ .show();
+ }
+ Log.d("mSelectedImageUri", mSelectedImageUri.size() + " " + mSelectedImageUri);
+
+ mMediaPreviewConstraintLayout.setVisibility(View.VISIBLE);
+
+ List imageViews = new ArrayList();
+
+ // Pre-populate all the image views that are available from the layout (currently max 5)
+ imageViews.add(requireViewById(R.id.mediaPreviewImageView1));
+ imageViews.add(requireViewById(R.id.mediaPreviewImageView2));
+ imageViews.add(requireViewById(R.id.mediaPreviewImageView3));
+ imageViews.add(requireViewById(R.id.mediaPreviewImageView4));
+ imageViews.add(requireViewById(R.id.mediaPreviewImageView5));
+
+ // Hide all the image views (reset state)
+ for (int i = 0; i < imageViews.size(); i++) {
+ imageViews.get(i).setVisibility(View.GONE);
+ }
+
+ // Only show/render those that have proper Image URIs
+ for (int i = 0; i < mSelectedImageUri.size(); i++) {
+ imageViews.get(i).setVisibility(View.VISIBLE);
+ imageViews.get(i).setImageURI(mSelectedImageUri.get(i));
+ }
+
+ // For LLava, we want to call prefill_image as soon as an image is selected
+ // Llava only support 1 image for now
+ if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) {
+ List processedImageList = getProcessedImagesForModel(mSelectedImageUri);
+ if (!processedImageList.isEmpty()) {
+ mMessageAdapter.add(
+ new Message("Llava - Starting image Prefill.", false, MessageType.SYSTEM, 0));
mMessageAdapter.notifyDataSetChanged();
+ Runnable runnable =
+ () -> {
+ Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE);
+ ETLogging.getInstance().log("Starting runnable prefill image");
+ ETImage img = processedImageList.get(0);
+ ETLogging.getInstance().log("Llava start prefill image");
+ startPos =
+ mModule.prefillImages(
+ img.getInts(),
+ img.getWidth(),
+ img.getHeight(),
+ ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
+ startPos);
+ };
+ executor.execute(runnable);
+ }
+ }
+ }
+
+ private void addSelectedImagesToChatThread(List selectedImageUri) {
+ if (selectedImageUri == null) {
+ return;
+ }
+ mMediaPreviewConstraintLayout.setVisibility(View.GONE);
+ for (int i = 0; i < selectedImageUri.size(); i++) {
+ Uri imageURI = selectedImageUri.get(i);
+ Log.d("image uri ", "test " + imageURI.getPath());
+ mMessageAdapter.add(new Message(imageURI.toString(), true, MessageType.IMAGE, 0));
+ }
+ mMessageAdapter.notifyDataSetChanged();
+ }
+
+ private String getConversationHistory() {
+ String conversationHistory = "";
+
+ ArrayList conversations =
+ mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK);
+ if (conversations.isEmpty()) {
+ return conversationHistory;
+ }
+
+ int prevPromptID = conversations.get(0).getPromptID();
+ String conversationFormat =
+ PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType());
+ String format = conversationFormat;
+ for (int i = 0; i < conversations.size(); i++) {
+ Message conversation = conversations.get(i);
+ int currentPromptID = conversation.getPromptID();
+ if (currentPromptID != prevPromptID) {
+ conversationHistory = conversationHistory + format;
+ format = conversationFormat;
+ prevPromptID = currentPromptID;
+ }
+ if (conversation.getIsSent()) {
+ format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText());
+ } else {
+ format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText());
+ }
}
+ conversationHistory = conversationHistory + format;
- @Override
- public void run() {
- runOnUiThread(
- new Runnable() {
- @Override
- public void run() {
- mMessageAdapter.notifyDataSetChanged();
- setTitle(memoryInfo());
- }
- });
+ return conversationHistory;
+ }
+
+ private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) {
+ if (conversationHistory.isEmpty()) {
+ return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
+ }
+
+ return mCurrentSettingsFields.getFormattedSystemPrompt()
+ + conversationHistory
+ + mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt);
+ }
+
+ private void onModelRunStarted() {
+ mSendButton.setClickable(false);
+ mSendButton.setImageResource(R.drawable.baseline_stop_24);
+ mSendButton.setOnClickListener(
+ view -> {
+ mModule.stop();
+ });
+ }
+
+ private void onModelRunStopped() {
+ mSendButton.setClickable(true);
+ mSendButton.setImageResource(R.drawable.baseline_send_24);
+ mSendButton.setOnClickListener(
+ view -> {
+ addSelectedImagesToChatThread(mSelectedImageUri);
+ String finalPrompt;
+ String rawPrompt = mEditTextMessage.getText().toString();
+ if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType())
+ == ModelUtils.VISION_MODEL) {
+ finalPrompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
+ } else {
+ finalPrompt = getTotalFormattedPrompt(getConversationHistory(), rawPrompt);
+ }
+ // We store raw prompt into message adapter, because we don't want to show the extra
+ // tokens from system prompt
+ mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID));
+ mMessageAdapter.notifyDataSetChanged();
+ mEditTextMessage.setText("");
+ mResultMessage = new Message("", false, MessageType.TEXT, promptID);
+ mMessageAdapter.add(mResultMessage);
+ // Scroll to bottom of the list
+ mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1);
+ // After images are added to prompt and chat thread, we clear the imageURI list
+ // Note: This has to be done after imageURIs are no longer needed by LlamaModule
+ mSelectedImageUri = null;
+ promptID++;
+ Runnable runnable =
+ new Runnable() {
+ @Override
+ public void run() {
+ Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE);
+ ETLogging.getInstance().log("starting runnable generate()");
+ runOnUiThread(
+ new Runnable() {
+ @Override
+ public void run() {
+ onModelRunStarted();
+ }
+ });
+ long generateStartTime = System.currentTimeMillis();
+ if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType())
+ == ModelUtils.VISION_MODEL) {
+ mModule.generateFromPos(
+ finalPrompt,
+ ModelUtils.VISION_MODEL_SEQ_LEN,
+ startPos,
+ MainActivity.this,
+ false);
+ } else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) {
+ String llamaGuardPromptForClassification =
+ PromptFormat.getFormattedLlamaGuardPrompt(rawPrompt);
+ ETLogging.getInstance()
+ .log("Running inference.. prompt=" + llamaGuardPromptForClassification);
+ mModule.generate(
+ llamaGuardPromptForClassification,
+ llamaGuardPromptForClassification.length() + 64,
+ MainActivity.this,
+ false);
+ } else {
+ ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt);
+ mModule.generate(
+ finalPrompt,
+ (int) (finalPrompt.length() * 0.75) + 64,
+ MainActivity.this,
+ false);
+ }
+
+ long generateDuration = System.currentTimeMillis() - generateStartTime;
+ mResultMessage.setTotalGenerationTime(generateDuration);
+ runOnUiThread(
+ new Runnable() {
+ @Override
+ public void run() {
+ onModelRunStopped();
+ }
+ });
+ ETLogging.getInstance().log("Inference completed");
+ }
+ };
+ executor.execute(runnable);
+ });
+ mMessageAdapter.notifyDataSetChanged();
+ }
+
+ @Override
+ public void run() {
+ runOnUiThread(
+ new Runnable() {
+ @Override
+ public void run() {
+ mMessageAdapter.notifyDataSetChanged();
+ }
+ });
+ }
+
+ @Override
+ public void onBackPressed() {
+ super.onBackPressed();
+ if (mAddMediaLayout != null && mAddMediaLayout.getVisibility() == View.VISIBLE) {
+ mAddMediaLayout.setVisibility(View.GONE);
+ } else {
+ // Default behavior of back button
+ finish();
}
+ }
+
+ @Override
+ protected void onDestroy() {
+ super.onDestroy();
+ mMemoryUpdateHandler.removeCallbacks(memoryUpdater);
+ // This is to cover the case where the app is shutdown when user is on MainActivity but
+ // never clicked on the logsActivity
+ ETLogging.getInstance().saveLogs();
+ }
}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/Message.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/Message.java
index f42c6afb3..031d052fd 100644
--- a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/Message.java
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/Message.java
@@ -8,33 +8,87 @@
package org.pytorch.torchchat;
+import java.text.SimpleDateFormat;
+import java.util.Date;
+import java.util.Locale;
+
public class Message {
- private String text;
- private boolean isSent;
- private float tokensPerSecond;
+ private String text;
+ private final boolean isSent;
+ private float tokensPerSecond;
+ private long totalGenerationTime;
+ private final long timestamp;
+ private final MessageType messageType;
+ private String imagePath;
+ private final int promptID;
- public Message(String text, boolean isSent) {
- this.text = text;
- this.isSent = isSent;
- }
+ private static final String TIMESTAMP_FORMAT = "hh:mm a"; // example: 2:23 PM
- public String getText() {
- return text;
- }
+ public Message(String text, boolean isSent, MessageType messageType, int promptID) {
+ this.isSent = isSent;
+ this.messageType = messageType;
+ this.promptID = promptID;
- public void appendText(String text) {
- this.text += text;
+ if (messageType == MessageType.IMAGE) {
+ this.imagePath = text;
+ } else {
+ this.text = text;
}
- public boolean getIsSent() {
- return isSent;
+ if (messageType != MessageType.SYSTEM) {
+ this.timestamp = System.currentTimeMillis();
+ } else {
+ this.timestamp = (long) 0;
}
+ }
- public void setTokensPerSecond(float tokensPerSecond) {
- this.tokensPerSecond = tokensPerSecond;
- }
+ public int getPromptID() {
+ return promptID;
+ }
- public float getTokensPerSecond() {
- return tokensPerSecond;
- }
+ public MessageType getMessageType() {
+ return messageType;
+ }
+
+ public String getImagePath() {
+ return imagePath;
+ }
+
+ public String getText() {
+ return text;
+ }
+
+ public void appendText(String text) {
+ this.text += text;
+ }
+
+ public boolean getIsSent() {
+ return isSent;
+ }
+
+ public void setTokensPerSecond(float tokensPerSecond) {
+ this.tokensPerSecond = tokensPerSecond;
+ }
+
+ public void setTotalGenerationTime(long totalGenerationTime) {
+ this.totalGenerationTime = totalGenerationTime;
+ }
+
+ public float getTokensPerSecond() {
+ return tokensPerSecond;
+ }
+
+ public long getTotalGenerationTime() {
+ return totalGenerationTime;
+ }
+
+ public long getTimestamp() {
+ return timestamp;
+ }
+
+ public String getFormattedTimestamp() {
+ SimpleDateFormat formatter = new SimpleDateFormat(TIMESTAMP_FORMAT, Locale.getDefault());
+ Date date = new Date(timestamp);
+ return formatter.format(date);
+ }
}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageAdapter.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageAdapter.java
index 619e15ab8..bb69c3cd7 100644
--- a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageAdapter.java
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageAdapter.java
@@ -1,4 +1,3 @@
-
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
@@ -9,33 +8,128 @@
package org.pytorch.torchchat;
+import android.net.Uri;
import android.view.LayoutInflater;
import android.view.View;
import android.view.ViewGroup;
import android.widget.ArrayAdapter;
+import android.widget.ImageView;
import android.widget.TextView;
+import java.util.ArrayList;
+import java.util.Collections;
public class MessageAdapter extends ArrayAdapter {
- public MessageAdapter(android.content.Context context, int resource) {
- super(context, resource);
+
+ private final ArrayList savedMessages;
+
+ public MessageAdapter(
+ android.content.Context context, int resource, ArrayList savedMessages) {
+ super(context, resource);
+ this.savedMessages = savedMessages;
+ }
+
+ @Override
+ public View getView(int position, View convertView, ViewGroup parent) {
+ Message currentMessage = getItem(position);
+ int layoutIdForListItem;
+
+ if (currentMessage.getMessageType() == MessageType.SYSTEM) {
+ layoutIdForListItem = R.layout.system_message;
+ } else {
+ layoutIdForListItem =
+ currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message;
+ }
+ View listItemView =
+ LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false);
+ if (currentMessage.getMessageType() == MessageType.IMAGE) {
+ ImageView messageImageView = listItemView.requireViewById(R.id.message_image);
+ messageImageView.setImageURI(Uri.parse(currentMessage.getImagePath()));
+ TextView messageTextView = listItemView.requireViewById(R.id.message_text);
+ messageTextView.setVisibility(View.GONE);
+ } else {
+ TextView messageTextView = listItemView.requireViewById(R.id.message_text);
+ messageTextView.setText(currentMessage.getText());
+ }
+
+ String metrics = "";
+ TextView tokensView;
+ if (currentMessage.getTokensPerSecond() > 0) {
+ metrics = String.format("%.2f", currentMessage.getTokensPerSecond()) + "t/s ";
+ }
+
+ if (currentMessage.getTotalGenerationTime() > 0) {
+ metrics = metrics + (float) currentMessage.getTotalGenerationTime() / 1000 + "s ";
}
- @Override
- public View getView(int position, View convertView, ViewGroup parent) {
- Message currentMessage = getItem(position);
+ if (currentMessage.getTokensPerSecond() > 0 || currentMessage.getTotalGenerationTime() > 0) {
+ tokensView = listItemView.requireViewById(R.id.generation_metrics);
+ tokensView.setText(metrics);
+ TextView separatorView = listItemView.requireViewById(R.id.bar);
+ separatorView.setVisibility(View.VISIBLE);
+ }
+
+ if (currentMessage.getTimestamp() > 0) {
+ TextView timestampView = listItemView.requireViewById(R.id.timestamp);
+ timestampView.setText(currentMessage.getFormattedTimestamp());
+ }
+
+ return listItemView;
+ }
+
+ @Override
+ public void add(Message msg) {
+ super.add(msg);
+ savedMessages.add(msg);
+ }
- int layoutIdForListItem =
- currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message;
- View listItemView =
- LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false);
- TextView messageTextView = listItemView.findViewById(R.id.message_text);
- messageTextView.setText(currentMessage.getText());
+ @Override
+ public void clear() {
+ super.clear();
+ savedMessages.clear();
+ }
- if (currentMessage.getTokensPerSecond() > 0) {
- TextView tokensView = listItemView.findViewById(R.id.tokens_per_second);
- tokensView.setText("" + currentMessage.getTokensPerSecond() + " t/s");
+ public ArrayList getSavedMessages() {
+ return savedMessages;
+ }
+
+ public ArrayList getRecentSavedTextMessages(int numOfLatestPromptMessages) {
+ ArrayList recentMessages = new ArrayList();
+ int lastIndex = savedMessages.size() - 1;
+ // In most cases lastIndex >=0 .
+ // A situation where the user clears chat history and enters prompt. Causes lastIndex=-1 .
+ if (lastIndex >= 0) {
+ Message messageToAdd = savedMessages.get(lastIndex);
+ int oldPromptID = messageToAdd.getPromptID();
+
+ for (int i = 0; i < savedMessages.size(); i++) {
+ messageToAdd = savedMessages.get(lastIndex - i);
+ if (messageToAdd.getMessageType() != MessageType.SYSTEM) {
+ if (messageToAdd.getPromptID() != oldPromptID) {
+ numOfLatestPromptMessages--;
+ oldPromptID = messageToAdd.getPromptID();
+ }
+ if (numOfLatestPromptMessages > 0) {
+ if (messageToAdd.getMessageType() == MessageType.TEXT) {
+ recentMessages.add(messageToAdd);
+ }
+ } else {
+ break;
+ }
}
+ }
+ // To place the order in [input1, output1, input2, output2...]
+ Collections.reverse(recentMessages);
+ }
+
+ return recentMessages;
+ }
+
+ public int getMaxPromptID() {
+ int maxPromptID = -1;
+ for (Message msg : savedMessages) {
- return listItemView;
+ maxPromptID = Math.max(msg.getPromptID(), maxPromptID);
}
+ return maxPromptID;
+ }
}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageType.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageType.java
new file mode 100644
index 000000000..d10871310
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageType.java
@@ -0,0 +1,15 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+public enum MessageType {
+ TEXT,
+ IMAGE,
+ SYSTEM
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunner.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunner.java
new file mode 100644
index 000000000..ff5b22364
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunner.java
@@ -0,0 +1,98 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.os.Looper;
+import android.os.Message;
+import androidx.annotation.NonNull;
+import org.pytorch.executorch.LlamaCallback;
+import org.pytorch.executorch.LlamaModule;
+
+/** A helper class to handle all model running logic within this class. */
+public class ModelRunner implements LlamaCallback {
+ LlamaModule mModule = null;
+
+ String mModelFilePath = "";
+ String mTokenizerFilePath = "";
+
+ ModelRunnerCallback mCallback = null;
+
+ HandlerThread mHandlerThread = null;
+ Handler mHandler = null;
+
+ /**
+ * ] Helper class to separate between UI logic and model runner logic. Automatically handle
+ * generate() request on worker thread.
+ *
+ * @param modelFilePath
+ * @param tokenizerFilePath
+ * @param callback
+ */
+ ModelRunner(
+ String modelFilePath,
+ String tokenizerFilePath,
+ float temperature,
+ ModelRunnerCallback callback) {
+ mModelFilePath = modelFilePath;
+ mTokenizerFilePath = tokenizerFilePath;
+ mCallback = callback;
+
+ mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f);
+ mHandlerThread = new HandlerThread("ModelRunner");
+ mHandlerThread.start();
+ mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this);
+
+ mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL);
+ }
+
+ int generate(String prompt) {
+ Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt);
+ msg.sendToTarget();
+ return 0;
+ }
+
+ void stop() {
+ mModule.stop();
+ }
+
+ @Override
+ public void onResult(String result) {
+ mCallback.onTokenGenerated(result);
+ }
+
+ @Override
+ public void onStats(float tps) {
+ mCallback.onStats("tokens/second: " + tps);
+ }
+}
+
+class ModelRunnerHandler extends Handler {
+ public static int MESSAGE_LOAD_MODEL = 1;
+ public static int MESSAGE_GENERATE = 2;
+
+ private final ModelRunner mModelRunner;
+
+ public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) {
+ super(looper);
+ mModelRunner = modelRunner;
+ }
+
+ @Override
+ public void handleMessage(@NonNull android.os.Message msg) {
+ if (msg.what == MESSAGE_LOAD_MODEL) {
+ int status = mModelRunner.mModule.load();
+ mModelRunner.mCallback.onModelLoaded(status);
+ } else if (msg.what == MESSAGE_GENERATE) {
+ mModelRunner.mModule.generate((String) msg.obj, mModelRunner);
+ mModelRunner.mCallback.onGenerationStopped();
+ }
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunnerCallback.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunnerCallback.java
new file mode 100644
index 000000000..a04c660f8
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunnerCallback.java
@@ -0,0 +1,24 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+/**
+ * A helper interface within the app for MainActivity and Benchmarking to handle callback from
+ * ModelRunner.
+ */
+public interface ModelRunnerCallback {
+
+ void onModelLoaded(int status);
+
+ void onTokenGenerated(String token);
+
+ void onStats(String token);
+
+ void onGenerationStopped();
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelType.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelType.java
new file mode 100644
index 000000000..307609152
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelType.java
@@ -0,0 +1,17 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+public enum ModelType {
+ LLAMA_3,
+ LLAMA_3_1,
+ LLAMA_3_2,
+ LLAVA_1_5,
+ LLAMA_GUARD_3,
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelUtils.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelUtils.java
new file mode 100644
index 000000000..056180318
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelUtils.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+public class ModelUtils {
+ static final int TEXT_MODEL = 1;
+ static final int VISION_MODEL = 2;
+ static final int VISION_MODEL_IMAGE_CHANNELS = 3;
+ static final int VISION_MODEL_SEQ_LEN = 768;
+ static final int TEXT_MODEL_SEQ_LEN = 256;
+
+ public static int getModelCategory(ModelType modelType) {
+ switch (modelType) {
+ case LLAVA_1_5:
+ return VISION_MODEL;
+ case LLAMA_3:
+ case LLAMA_3_1:
+ case LLAMA_3_2:
+ default:
+ return TEXT_MODEL;
+ }
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/PromptFormat.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/PromptFormat.java
new file mode 100644
index 000000000..2f256783f
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/PromptFormat.java
@@ -0,0 +1,121 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+public class PromptFormat {
+
+ public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}";
+ public static final String USER_PLACEHOLDER = "{{ user_prompt }}";
+ public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}";
+ public static final String DEFAULT_SYSTEM_PROMPT = "Answer the questions in a few sentences";
+
+ public static String getSystemPromptTemplate(ModelType modelType) {
+ switch (modelType) {
+ case LLAMA_3:
+ case LLAMA_3_1:
+ case LLAMA_3_2:
+ return "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
+ + SYSTEM_PLACEHOLDER
+ + "<|eot_id|>";
+ case LLAVA_1_5:
+ return "USER: ";
+ default:
+ return SYSTEM_PLACEHOLDER;
+ }
+ }
+
+ public static String getUserPromptTemplate(ModelType modelType) {
+ switch (modelType) {
+ case LLAMA_3:
+ case LLAMA_3_1:
+ case LLAMA_3_2:
+ case LLAMA_GUARD_3:
+ return "<|start_header_id|>user<|end_header_id|>\n"
+ + USER_PLACEHOLDER
+ + "<|eot_id|>"
+ + "<|start_header_id|>assistant<|end_header_id|>";
+
+ case LLAVA_1_5:
+ default:
+ return USER_PLACEHOLDER;
+ }
+ }
+
+ public static String getConversationFormat(ModelType modelType) {
+ switch (modelType) {
+ case LLAMA_3:
+ case LLAMA_3_1:
+ case LLAMA_3_2:
+ return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>";
+ case LLAVA_1_5:
+ return USER_PLACEHOLDER + " ASSISTANT:";
+ default:
+ return USER_PLACEHOLDER;
+ }
+ }
+
+ public static String getStopToken(ModelType modelType) {
+ switch (modelType) {
+ case LLAMA_3:
+ case LLAMA_3_1:
+ case LLAMA_3_2:
+ case LLAMA_GUARD_3:
+ return "<|eot_id|>";
+ case LLAVA_1_5:
+ return "";
+ default:
+ return "";
+ }
+ }
+
+ public static String getLlavaPresetPrompt() {
+ return "A chat between a curious human and an artificial intelligence assistant. The assistant"
+ + " gives helpful, detailed, and polite answers to the human's questions. USER: ";
+ }
+
+ public static String getFormattedLlamaGuardPrompt(String userPrompt) {
+ return getUserPromptTemplate(ModelType.LLAMA_GUARD_3)
+ .replace(
+ USER_PLACEHOLDER, getLlamaGuardPresetPrompt().replace(USER_PLACEHOLDER, userPrompt));
+ }
+
+ public static String getLlamaGuardPresetPrompt() {
+ return "Task: Check if there is unsafe content in 'User' messages in conversations"
+ + " according our safety policy with the below categories.\n"
+ + "\n"
+ + "\n"
+ + "S1: Violent Crimes. \n"
+ + "S2: Non-Violent Crimes. \n"
+ + "S3: Sex Crimes. \n"
+ + "S4: Child Exploitation. \n"
+ + "S5: Defamation. \n"
+ + "S6: Specialized Advice. \n"
+ + "S7: Privacy. \n"
+ + "S8: Intellectual Property. \n"
+ + "S9: Indiscriminate Weapons. \n"
+ + "S10: Hate. \n"
+ + "S11: Self-Harm. \n"
+ + "S12: Sexual Content. \n"
+ + "S13: Elections. \n"
+ + "S14: Code Interpreter Abuse.\n"
+ + "\n"
+ + "\n"
+ + "\n"
+ + "\n"
+ + "User: "
+ + USER_PLACEHOLDER
+ + "\n"
+ + "\n"
+ + "Provide your safety assessment for ONLY THE LAST User message in the above"
+ + " conversation:\n"
+ + " - First line must read 'safe' or 'unsafe'.\n"
+ + " - If unsafe, a second line must include a comma-separated list of violated"
+ + " categories.";
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsActivity.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsActivity.java
new file mode 100644
index 000000000..c040bbd53
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsActivity.java
@@ -0,0 +1,395 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+import android.app.AlertDialog;
+import android.content.DialogInterface;
+import android.os.Build;
+import android.os.Bundle;
+import android.text.Editable;
+import android.text.TextWatcher;
+import android.widget.Button;
+import android.widget.EditText;
+import android.widget.ImageButton;
+import android.widget.TextView;
+import androidx.appcompat.app.AppCompatActivity;
+import androidx.core.content.ContextCompat;
+import androidx.core.graphics.Insets;
+import androidx.core.view.ViewCompat;
+import androidx.core.view.WindowInsetsCompat;
+import com.google.gson.Gson;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.List;
+
+public class SettingsActivity extends AppCompatActivity {
+
+ private String mModelFilePath = "";
+ private String mTokenizerFilePath = "";
+ private TextView mModelTextView;
+ private TextView mTokenizerTextView;
+ private TextView mModelTypeTextView;
+ private EditText mSystemPromptEditText;
+ private EditText mUserPromptEditText;
+ private Button mLoadModelButton;
+ private double mSetTemperature;
+ private String mSystemPrompt;
+ private String mUserPrompt;
+ private ModelType mModelType;
+ public SettingsFields mSettingsFields;
+
+ private DemoSharedPreferences mDemoSharedPreferences;
+ public static double TEMPERATURE_MIN_VALUE = 0.0;
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_settings);
+ if (Build.VERSION.SDK_INT >= 21) {
+ getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar));
+ getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar));
+ }
+ ViewCompat.setOnApplyWindowInsetsListener(
+ requireViewById(R.id.main),
+ (v, insets) -> {
+ Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars());
+ v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom);
+ return insets;
+ });
+ mDemoSharedPreferences = new DemoSharedPreferences(getBaseContext());
+ mSettingsFields = new SettingsFields();
+ setupSettings();
+ }
+
+ private void setupSettings() {
+ mModelTextView = requireViewById(R.id.modelTextView);
+ mTokenizerTextView = requireViewById(R.id.tokenizerTextView);
+ mModelTypeTextView = requireViewById(R.id.modelTypeTextView);
+ ImageButton modelImageButton = requireViewById(R.id.modelImageButton);
+ ImageButton tokenizerImageButton = requireViewById(R.id.tokenizerImageButton);
+ ImageButton modelTypeImageButton = requireViewById(R.id.modelTypeImageButton);
+ mSystemPromptEditText = requireViewById(R.id.systemPromptText);
+ mUserPromptEditText = requireViewById(R.id.userPromptText);
+ loadSettings();
+
+ // TODO: The two setOnClickListeners will be removed after file path issue is resolved
+ modelImageButton.setOnClickListener(
+ view -> {
+ setupModelSelectorDialog();
+ });
+ tokenizerImageButton.setOnClickListener(
+ view -> {
+ setupTokenizerSelectorDialog();
+ });
+ modelTypeImageButton.setOnClickListener(
+ view -> {
+ setupModelTypeSelectorDialog();
+ });
+ mModelFilePath = mSettingsFields.getModelFilePath();
+ if (!mModelFilePath.isEmpty()) {
+ mModelTextView.setText(getFilenameFromPath(mModelFilePath));
+ }
+ mTokenizerFilePath = mSettingsFields.getTokenizerFilePath();
+ if (!mTokenizerFilePath.isEmpty()) {
+ mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath));
+ }
+ mModelType = mSettingsFields.getModelType();
+ ETLogging.getInstance().log("mModelType from settings " + mModelType);
+ if (mModelType != null) {
+ mModelTypeTextView.setText(mModelType.toString());
+ }
+
+ setupParameterSettings();
+ setupPromptSettings();
+ setupClearChatHistoryButton();
+ setupLoadModelButton();
+ }
+
+ private void setupLoadModelButton() {
+ mLoadModelButton = requireViewById(R.id.loadModelButton);
+ mLoadModelButton.setEnabled(true);
+ mLoadModelButton.setOnClickListener(
+ view -> {
+ new AlertDialog.Builder(this)
+ .setTitle("Load Model")
+ .setMessage("Do you really want to load the new model?")
+ .setIcon(android.R.drawable.ic_dialog_alert)
+ .setPositiveButton(
+ android.R.string.yes,
+ new DialogInterface.OnClickListener() {
+ public void onClick(DialogInterface dialog, int whichButton) {
+ mSettingsFields.saveLoadModelAction(true);
+ mLoadModelButton.setEnabled(false);
+ onBackPressed();
+ }
+ })
+ .setNegativeButton(android.R.string.no, null)
+ .show();
+ });
+ }
+
+ private void setupClearChatHistoryButton() {
+ Button clearChatButton = requireViewById(R.id.clearChatButton);
+ clearChatButton.setOnClickListener(
+ view -> {
+ new AlertDialog.Builder(this)
+ .setTitle("Delete Chat History")
+ .setMessage("Do you really want to delete chat history?")
+ .setIcon(android.R.drawable.ic_dialog_alert)
+ .setPositiveButton(
+ android.R.string.yes,
+ new DialogInterface.OnClickListener() {
+ public void onClick(DialogInterface dialog, int whichButton) {
+ mSettingsFields.saveIsClearChatHistory(true);
+ }
+ })
+ .setNegativeButton(android.R.string.no, null)
+ .show();
+ });
+ }
+
+ private void setupParameterSettings() {
+ setupTemperatureSettings();
+ }
+
+ private void setupTemperatureSettings() {
+ mSetTemperature = mSettingsFields.getTemperature();
+ EditText temperatureEditText = requireViewById(R.id.temperatureEditText);
+ temperatureEditText.setText(String.valueOf(mSetTemperature));
+ temperatureEditText.addTextChangedListener(
+ new TextWatcher() {
+ @Override
+ public void beforeTextChanged(CharSequence s, int start, int count, int after) {}
+
+ @Override
+ public void onTextChanged(CharSequence s, int start, int before, int count) {}
+
+ @Override
+ public void afterTextChanged(Editable s) {
+ mSetTemperature = Double.parseDouble(s.toString());
+ // This is needed because temperature is changed together with model loading
+ // Once temperature is no longer in LlamaModule constructor, we can remove this
+ mSettingsFields.saveLoadModelAction(true);
+ saveSettings();
+ }
+ });
+ }
+
+ private void setupPromptSettings() {
+ setupSystemPromptSettings();
+ setupUserPromptSettings();
+ }
+
+ private void setupSystemPromptSettings() {
+ mSystemPrompt = mSettingsFields.getSystemPrompt();
+ mSystemPromptEditText.setText(mSystemPrompt);
+ mSystemPromptEditText.addTextChangedListener(
+ new TextWatcher() {
+ @Override
+ public void beforeTextChanged(CharSequence s, int start, int count, int after) {}
+
+ @Override
+ public void onTextChanged(CharSequence s, int start, int before, int count) {}
+
+ @Override
+ public void afterTextChanged(Editable s) {
+ mSystemPrompt = s.toString();
+ }
+ });
+
+ ImageButton resetSystemPrompt = requireViewById(R.id.resetSystemPrompt);
+ resetSystemPrompt.setOnClickListener(
+ view -> {
+ new AlertDialog.Builder(this)
+ .setTitle("Reset System Prompt")
+ .setMessage("Do you really want to reset system prompt?")
+ .setIcon(android.R.drawable.ic_dialog_alert)
+ .setPositiveButton(
+ android.R.string.yes,
+ new DialogInterface.OnClickListener() {
+ public void onClick(DialogInterface dialog, int whichButton) {
+ // Clear the messageAdapter and sharedPreference
+ mSystemPromptEditText.setText(PromptFormat.DEFAULT_SYSTEM_PROMPT);
+ }
+ })
+ .setNegativeButton(android.R.string.no, null)
+ .show();
+ });
+ }
+
+ private void setupUserPromptSettings() {
+ mUserPrompt = mSettingsFields.getUserPrompt();
+ mUserPromptEditText.setText(mUserPrompt);
+ mUserPromptEditText.addTextChangedListener(
+ new TextWatcher() {
+ @Override
+ public void beforeTextChanged(CharSequence s, int start, int count, int after) {}
+
+ @Override
+ public void onTextChanged(CharSequence s, int start, int before, int count) {}
+
+ @Override
+ public void afterTextChanged(Editable s) {
+ if (isValidUserPrompt(s.toString())) {
+ mUserPrompt = s.toString();
+ } else {
+ showInvalidPromptDialog();
+ }
+ }
+ });
+
+ ImageButton resetUserPrompt = requireViewById(R.id.resetUserPrompt);
+ resetUserPrompt.setOnClickListener(
+ view -> {
+ new AlertDialog.Builder(this)
+ .setTitle("Reset Prompt Template")
+ .setMessage("Do you really want to reset the prompt template?")
+ .setIcon(android.R.drawable.ic_dialog_alert)
+ .setPositiveButton(
+ android.R.string.yes,
+ new DialogInterface.OnClickListener() {
+ public void onClick(DialogInterface dialog, int whichButton) {
+ // Clear the messageAdapter and sharedPreference
+ mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType));
+ }
+ })
+ .setNegativeButton(android.R.string.no, null)
+ .show();
+ });
+ }
+
+ private boolean isValidUserPrompt(String userPrompt) {
+ return userPrompt.contains(PromptFormat.USER_PLACEHOLDER);
+ }
+
+ private void showInvalidPromptDialog() {
+ new AlertDialog.Builder(this)
+ .setTitle("Invalid Prompt Format")
+ .setMessage(
+ "Prompt format must contain "
+ + PromptFormat.USER_PLACEHOLDER
+ + ". Do you want to reset prompt format?")
+ .setIcon(android.R.drawable.ic_dialog_alert)
+ .setPositiveButton(
+ android.R.string.yes,
+ (dialog, whichButton) -> {
+ mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType));
+ })
+ .setNegativeButton(android.R.string.no, null)
+ .show();
+ }
+
+ private void setupModelSelectorDialog() {
+ String[] pteFiles = listLocalFile("/data/local/tmp/llama/", ".pte");
+ AlertDialog.Builder modelPathBuilder = new AlertDialog.Builder(this);
+ modelPathBuilder.setTitle("Select model path");
+
+ modelPathBuilder.setSingleChoiceItems(
+ pteFiles,
+ -1,
+ (dialog, item) -> {
+ mModelFilePath = pteFiles[item];
+ mModelTextView.setText(getFilenameFromPath(mModelFilePath));
+ mLoadModelButton.setEnabled(true);
+ dialog.dismiss();
+ });
+
+ modelPathBuilder.create().show();
+ }
+
+ private static String[] listLocalFile(String path, String suffix) {
+ File directory = new File(path);
+ if (directory.exists() && directory.isDirectory()) {
+ File[] files = directory.listFiles((dir, name) -> name.toLowerCase().endsWith(suffix));
+ String[] result = new String[files.length];
+ for (int i = 0; i < files.length; i++) {
+ if (files[i].isFile() && files[i].getName().endsWith(suffix)) {
+ result[i] = files[i].getAbsolutePath();
+ }
+ }
+ return result;
+ }
+ return new String[] {};
+ }
+
+ private void setupModelTypeSelectorDialog() {
+ // Convert enum to list
+ List modelTypesList = new ArrayList<>();
+ for (ModelType modelType : ModelType.values()) {
+ modelTypesList.add(modelType.toString());
+ }
+ // Alert dialog builder takes in arr of string instead of list
+ String[] modelTypes = modelTypesList.toArray(new String[0]);
+ AlertDialog.Builder modelTypeBuilder = new AlertDialog.Builder(this);
+ modelTypeBuilder.setTitle("Select model type");
+ modelTypeBuilder.setSingleChoiceItems(
+ modelTypes,
+ -1,
+ (dialog, item) -> {
+ mModelTypeTextView.setText(modelTypes[item]);
+ mModelType = ModelType.valueOf(modelTypes[item]);
+ mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType));
+ dialog.dismiss();
+ });
+
+ modelTypeBuilder.create().show();
+ }
+
+ private void setupTokenizerSelectorDialog() {
+ String[] binFiles = listLocalFile("/data/local/tmp/llama/", ".bin");
+ String[] modelFiles = listLocalFile("/data/local/tmp/llama/", ".model");
+ String[] tokenizerFiles = new String[binFiles.length + modelFiles.length];
+ System.arraycopy(binFiles, 0, tokenizerFiles, 0, binFiles.length);
+ System.arraycopy(modelFiles, 0, tokenizerFiles, binFiles.length, modelFiles.length);
+ AlertDialog.Builder tokenizerPathBuilder = new AlertDialog.Builder(this);
+ tokenizerPathBuilder.setTitle("Select tokenizer path");
+ tokenizerPathBuilder.setSingleChoiceItems(
+ tokenizerFiles,
+ -1,
+ (dialog, item) -> {
+ mTokenizerFilePath = tokenizerFiles[item];
+ mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath));
+ mLoadModelButton.setEnabled(true);
+ dialog.dismiss();
+ });
+
+ tokenizerPathBuilder.create().show();
+ }
+
+ private String getFilenameFromPath(String uriFilePath) {
+ String[] segments = uriFilePath.split("/");
+ if (segments.length > 0) {
+ return segments[segments.length - 1]; // get last element (aka filename)
+ }
+ return "";
+ }
+
+ private void loadSettings() {
+ Gson gson = new Gson();
+ String settingsFieldsJSON = mDemoSharedPreferences.getSettings();
+ if (!settingsFieldsJSON.isEmpty()) {
+ mSettingsFields = gson.fromJson(settingsFieldsJSON, SettingsFields.class);
+ }
+ }
+
+ private void saveSettings() {
+ mSettingsFields.saveModelPath(mModelFilePath);
+ mSettingsFields.saveTokenizerPath(mTokenizerFilePath);
+ mSettingsFields.saveParameters(mSetTemperature);
+ mSettingsFields.savePrompts(mSystemPrompt, mUserPrompt);
+ mSettingsFields.saveModelType(mModelType);
+ mDemoSharedPreferences.addSettings(mSettingsFields);
+ }
+
+ @Override
+ public void onBackPressed() {
+ super.onBackPressed();
+ saveSettings();
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsFields.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsFields.java
new file mode 100644
index 000000000..eb054cb9a
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsFields.java
@@ -0,0 +1,131 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.torchchat;
+
+public class SettingsFields {
+
+ public String getModelFilePath() {
+ return modelFilePath;
+ }
+
+ public String getTokenizerFilePath() {
+ return tokenizerFilePath;
+ }
+
+ public double getTemperature() {
+ return temperature;
+ }
+
+ public String getSystemPrompt() {
+ return systemPrompt;
+ }
+
+ public ModelType getModelType() {
+ return modelType;
+ }
+
+ public String getUserPrompt() {
+ return userPrompt;
+ }
+
+ public String getFormattedSystemAndUserPrompt(String prompt) {
+ return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt);
+ }
+
+ public String getFormattedSystemPrompt() {
+ return PromptFormat.getSystemPromptTemplate(modelType)
+ .replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt);
+ }
+
+ public String getFormattedUserPrompt(String prompt) {
+ return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt);
+ }
+
+ public boolean getIsClearChatHistory() {
+ return isClearChatHistory;
+ }
+
+ public boolean getIsLoadModel() {
+ return isLoadModel;
+ }
+
+ private String modelFilePath;
+ private String tokenizerFilePath;
+ private double temperature;
+ private String systemPrompt;
+ private String userPrompt;
+ private boolean isClearChatHistory;
+ private boolean isLoadModel;
+ private ModelType modelType;
+
+ public SettingsFields() {
+ ModelType DEFAULT_MODEL = ModelType.LLAMA_3;
+
+ modelFilePath = "";
+ tokenizerFilePath = "";
+ temperature = SettingsActivity.TEMPERATURE_MIN_VALUE;
+ systemPrompt = "";
+ userPrompt = PromptFormat.getUserPromptTemplate(DEFAULT_MODEL);
+ isClearChatHistory = false;
+ isLoadModel = false;
+ modelType = DEFAULT_MODEL;
+ }
+
+ public SettingsFields(SettingsFields settingsFields) {
+ this.modelFilePath = settingsFields.modelFilePath;
+ this.tokenizerFilePath = settingsFields.tokenizerFilePath;
+ this.temperature = settingsFields.temperature;
+ this.systemPrompt = settingsFields.getSystemPrompt();
+ this.userPrompt = settingsFields.getUserPrompt();
+ this.isClearChatHistory = settingsFields.getIsClearChatHistory();
+ this.isLoadModel = settingsFields.getIsLoadModel();
+ this.modelType = settingsFields.modelType;
+ }
+
+ public void saveModelPath(String modelFilePath) {
+ this.modelFilePath = modelFilePath;
+ }
+
+ public void saveTokenizerPath(String tokenizerFilePath) {
+ this.tokenizerFilePath = tokenizerFilePath;
+ }
+
+ public void saveModelType(ModelType modelType) {
+ this.modelType = modelType;
+ }
+
+ public void saveParameters(Double temperature) {
+ this.temperature = temperature;
+ }
+
+ public void savePrompts(String systemPrompt, String userPrompt) {
+ this.systemPrompt = systemPrompt;
+ this.userPrompt = userPrompt;
+ }
+
+ public void saveIsClearChatHistory(boolean needToClear) {
+ this.isClearChatHistory = needToClear;
+ }
+
+ public void saveLoadModelAction(boolean shouldLoadModel) {
+ this.isLoadModel = shouldLoadModel;
+ }
+
+ public boolean equals(SettingsFields anotherSettingsFields) {
+ if (this == anotherSettingsFields) return true;
+ return modelFilePath.equals(anotherSettingsFields.modelFilePath)
+ && tokenizerFilePath.equals(anotherSettingsFields.tokenizerFilePath)
+ && temperature == anotherSettingsFields.temperature
+ && systemPrompt.equals(anotherSettingsFields.systemPrompt)
+ && userPrompt.equals(anotherSettingsFields.userPrompt)
+ && isClearChatHistory == anotherSettingsFields.isClearChatHistory
+ && isLoadModel == anotherSettingsFields.isLoadModel
+ && modelType == anotherSettingsFields.modelType;
+ }
+}
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/banner_shape.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/banner_shape.xml
new file mode 100644
index 000000000..0868ffffa
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/banner_shape.xml
@@ -0,0 +1,5 @@
+
+
+
+
\ No newline at end of file
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_24.xml
new file mode 100644
index 000000000..2ae27b840
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_24.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml
new file mode 100644
index 000000000..7077fedd4
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_article_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_article_24.xml
new file mode 100644
index 000000000..a6837b9c6
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_article_24.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_close_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_close_24.xml
new file mode 100644
index 000000000..fb902d433
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_close_24.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_delete_forever_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_delete_forever_24.xml
new file mode 100644
index 000000000..4680bc662
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_delete_forever_24.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_restart_alt_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_restart_alt_24.xml
new file mode 100644
index 000000000..860470ab1
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_restart_alt_24.xml
@@ -0,0 +1,6 @@
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_send_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_send_24.xml
new file mode 100644
index 000000000..2de1f6420
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_send_24.xml
@@ -0,0 +1,6 @@
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_settings_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_settings_24.xml
new file mode 100644
index 000000000..c51d84b9f
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_settings_24.xml
@@ -0,0 +1,11 @@
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_stop_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_stop_24.xml
new file mode 100644
index 000000000..832e25859
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_stop_24.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/btn.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/btn.xml
new file mode 100644
index 000000000..ceb3ac56c
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/btn.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/chat_background.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/chat_background.xml
new file mode 100644
index 000000000..eb8b9d1f1
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/chat_background.xml
@@ -0,0 +1,21 @@
+
+
+
+
+
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/custom_button_round.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/custom_button_round.xml
new file mode 100644
index 000000000..87c82d2a3
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/custom_button_round.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/expand_circle_down.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/expand_circle_down.xml
new file mode 100644
index 000000000..0a7a71f07
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/expand_circle_down.xml
@@ -0,0 +1,9 @@
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/input_text_shape.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/input_text_shape.xml
new file mode 100644
index 000000000..35c778a43
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/input_text_shape.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/logo.png b/torchchat/edge/android/torchchat/app/src/main/res/drawable/logo.png
new file mode 100644
index 000000000..60e3e5174
Binary files /dev/null and b/torchchat/edge/android/torchchat/app/src/main/res/drawable/logo.png differ
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_add_box_48.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_add_box_48.xml
new file mode 100644
index 000000000..bb45d63d8
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_add_box_48.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_camera_alt_48.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_camera_alt_48.xml
new file mode 100644
index 000000000..c7b4b2e4a
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_camera_alt_48.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_image_48.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_image_48.xml
new file mode 100644
index 000000000..a8bb4b2f6
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_image_48.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/prompt_shape.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/prompt_shape.xml
new file mode 100644
index 000000000..5f81396e3
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/prompt_shape.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/received_message.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/received_message.xml
index ea2d1bbfa..c2288b5bf 100644
--- a/torchchat/edge/android/torchchat/app/src/main/res/drawable/received_message.xml
+++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/received_message.xml
@@ -1,6 +1,6 @@
-
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_benchmarking.xml b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_benchmarking.xml
new file mode 100644
index 000000000..6e48b5de8
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_benchmarking.xml
@@ -0,0 +1,16 @@
+
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_logs.xml b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_logs.xml
new file mode 100644
index 000000000..b327a544f
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_logs.xml
@@ -0,0 +1,55 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_main.xml b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_main.xml
index 089acb572..7b8b8d176 100644
--- a/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_main.xml
+++ b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_main.xml
@@ -1,44 +1,233 @@
-
-
+
+
+
+
+
+
+
-
+
+
+
+
+
+
-
+
+
+ android:background="#16293D"
+ android:visibility="gone">
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_settings.xml b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_settings.xml
new file mode 100644
index 000000000..7d5c3b1b6
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_settings.xml
@@ -0,0 +1,295 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/layout/logs_message.xml b/torchchat/edge/android/torchchat/app/src/main/res/layout/logs_message.xml
new file mode 100644
index 000000000..3f80f58db
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/layout/logs_message.xml
@@ -0,0 +1,16 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/layout/received_message.xml b/torchchat/edge/android/torchchat/app/src/main/res/layout/received_message.xml
index ebcdf5e01..bffedf30c 100644
--- a/torchchat/edge/android/torchchat/app/src/main/res/layout/received_message.xml
+++ b/torchchat/edge/android/torchchat/app/src/main/res/layout/received_message.xml
@@ -9,33 +9,62 @@
+ android:text="Llama"
+ android:textColor="#FFFFFF" />
+ android:textColor="#FFFFFF"
+ android:textSize="16sp" />
-
+ android:layout_below="@+id/message_text">
+
+
+
+
+
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/layout/sent_message.xml b/torchchat/edge/android/torchchat/app/src/main/res/layout/sent_message.xml
index b8121e973..a04254e38 100644
--- a/torchchat/edge/android/torchchat/app/src/main/res/layout/sent_message.xml
+++ b/torchchat/edge/android/torchchat/app/src/main/res/layout/sent_message.xml
@@ -1,5 +1,7 @@
-
+ android:orientation="vertical">
+
+
+
+
+
+
+
+
+ android:layout_marginRight="10dp"
+ android:paddingBottom="4dp"
+ android:text=""
+ android:textColor="#FFFFFF" />
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/layout/system_message.xml b/torchchat/edge/android/torchchat/app/src/main/res/layout/system_message.xml
new file mode 100644
index 000000000..bd3cfef22
--- /dev/null
+++ b/torchchat/edge/android/torchchat/app/src/main/res/layout/system_message.xml
@@ -0,0 +1,23 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/values-land/dimens.xml b/torchchat/edge/android/torchchat/app/src/main/res/values-land/dimens.xml
deleted file mode 100644
index ec4deb847..000000000
--- a/torchchat/edge/android/torchchat/app/src/main/res/values-land/dimens.xml
+++ /dev/null
@@ -1,3 +0,0 @@
-
- 48dp
-
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/values-v23/themes.xml b/torchchat/edge/android/torchchat/app/src/main/res/values-v23/themes.xml
deleted file mode 100644
index 013331686..000000000
--- a/torchchat/edge/android/torchchat/app/src/main/res/values-v23/themes.xml
+++ /dev/null
@@ -1,9 +0,0 @@
-
-
-
-
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/values-w1240dp/dimens.xml b/torchchat/edge/android/torchchat/app/src/main/res/values-w1240dp/dimens.xml
deleted file mode 100644
index 2ecead29e..000000000
--- a/torchchat/edge/android/torchchat/app/src/main/res/values-w1240dp/dimens.xml
+++ /dev/null
@@ -1,3 +0,0 @@
-
- 200dp
-
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/values-w600dp/dimens.xml b/torchchat/edge/android/torchchat/app/src/main/res/values-w600dp/dimens.xml
deleted file mode 100644
index ec4deb847..000000000
--- a/torchchat/edge/android/torchchat/app/src/main/res/values-w600dp/dimens.xml
+++ /dev/null
@@ -1,3 +0,0 @@
-
- 48dp
-
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/values/colors.xml b/torchchat/edge/android/torchchat/app/src/main/res/values/colors.xml
index 4faecfa80..069727f3e 100644
--- a/torchchat/edge/android/torchchat/app/src/main/res/values/colors.xml
+++ b/torchchat/edge/android/torchchat/app/src/main/res/values/colors.xml
@@ -1,6 +1,10 @@
- #6200EE
+ #4294F0
#3700B3
#03DAC5
-
\ No newline at end of file
+ #007CBA
+ #A2A4B6
+ #16293D
+ #16293D
+
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/values/dimens.xml b/torchchat/edge/android/torchchat/app/src/main/res/values/dimens.xml
deleted file mode 100644
index 59a0b0c4f..000000000
--- a/torchchat/edge/android/torchchat/app/src/main/res/values/dimens.xml
+++ /dev/null
@@ -1,3 +0,0 @@
-
- 16dp
-
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/values/strings.xml b/torchchat/edge/android/torchchat/app/src/main/res/values/strings.xml
index 7bdbef9f5..f603a59b2 100644
--- a/torchchat/edge/android/torchchat/app/src/main/res/values/strings.xml
+++ b/torchchat/edge/android/torchchat/app/src/main/res/values/strings.xml
@@ -1,3 +1,7 @@
torchchat
+ DemoPrefFileKey
+ SavedMessagesJsonKey
+ SettingsJsonKey
+ LogsJsonKey
diff --git a/torchchat/edge/android/torchchat/app/src/main/res/values/styles.xml b/torchchat/edge/android/torchchat/app/src/main/res/values/styles.xml
index 391ec9ae3..387804aa1 100644
--- a/torchchat/edge/android/torchchat/app/src/main/res/values/styles.xml
+++ b/torchchat/edge/android/torchchat/app/src/main/res/values/styles.xml
@@ -7,4 +7,8 @@
- @color/colorAccent
+
diff --git a/torchchat/edge/android/torchchat/app/src/test/java/org/pytorch/torchchat/ExampleUnitTest.java b/torchchat/edge/android/torchchat/app/src/test/java/org/pytorch/torchchat/ExampleUnitTest.java
deleted file mode 100644
index 8fbffcf35..000000000
--- a/torchchat/edge/android/torchchat/app/src/test/java/org/pytorch/torchchat/ExampleUnitTest.java
+++ /dev/null
@@ -1,25 +0,0 @@
-/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-package org.pytorch.torchchat;
-
-import org.junit.Test;
-
-import static org.junit.Assert.*;
-
-/**
- * Example local unit test, which will execute on the development machine (host).
- *
- * @see Testing documentation
- */
-public class ExampleUnitTest {
- @Test
- public void addition_isCorrect() {
- assertEquals(4, 2 + 2);
- }
-}
diff --git a/torchchat/edge/android/torchchat/build.gradle.kts b/torchchat/edge/android/torchchat/build.gradle.kts
index cc9db8a5c..568efa281 100644
--- a/torchchat/edge/android/torchchat/build.gradle.kts
+++ b/torchchat/edge/android/torchchat/build.gradle.kts
@@ -1,4 +1,13 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
// Top-level build file where you can add configuration options common to all sub-projects/modules.
plugins {
- id("com.android.application") version "8.1.0" apply false
+ id("com.android.application") version "8.1.0" apply false
+ id("org.jetbrains.kotlin.android") version "1.8.10" apply false
}
diff --git a/torchchat/edge/android/torchchat/gradle.properties b/torchchat/edge/android/torchchat/gradle.properties
index 9440e7d54..2cbd6d19d 100644
--- a/torchchat/edge/android/torchchat/gradle.properties
+++ b/torchchat/edge/android/torchchat/gradle.properties
@@ -1,9 +1,3 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
# Project-wide Gradle settings.
# IDE (e.g. Android Studio) users:
# Gradle settings configured through the IDE *will override*
@@ -21,6 +15,8 @@ org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
# Android operating system, and which are packaged with your app's APK
# https://developer.android.com/topic/libraries/support-library/androidx-rn
android.useAndroidX=true
+# Kotlin code style for this project: "official" or "obsolete":
+kotlin.code.style=official
# Enables namespacing of each library's R class so that its R class includes only the
# resources declared in the library itself and none from the library's dependencies,
# thereby reducing the size of the R class for that library
diff --git a/torchchat/edge/android/torchchat/gradle/wrapper/gradle-wrapper.properties b/torchchat/edge/android/torchchat/gradle/wrapper/gradle-wrapper.properties
index 5ac70c8c5..2a7f77d2f 100644
--- a/torchchat/edge/android/torchchat/gradle/wrapper/gradle-wrapper.properties
+++ b/torchchat/edge/android/torchchat/gradle/wrapper/gradle-wrapper.properties
@@ -1,10 +1,4 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-#Thu Apr 25 21:54:24 PDT 2024
+#Mon Sep 25 11:23:11 PDT 2023
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.0-bin.zip
diff --git a/torchchat/edge/android/torchchat/gradlew.bat b/torchchat/edge/android/torchchat/gradlew.bat
index ac1b06f93..b4fb785a6 100644
--- a/torchchat/edge/android/torchchat/gradlew.bat
+++ b/torchchat/edge/android/torchchat/gradlew.bat
@@ -1,3 +1,9 @@
+@REM Copyright (c) Meta Platforms, Inc. and affiliates.
+@REM All rights reserved.
+@REM
+@REM This source code is licensed under the BSD-style license found in the
+@REM LICENSE file in the root directory of this source tree.
+
@rem
@rem Copyright 2015 the original author or authors.
@rem
diff --git a/torchchat/edge/android/torchchat/settings.gradle.kts b/torchchat/edge/android/torchchat/settings.gradle.kts
index 8ed189255..ba0e809fd 100644
--- a/torchchat/edge/android/torchchat/settings.gradle.kts
+++ b/torchchat/edge/android/torchchat/settings.gradle.kts
@@ -1,17 +1,27 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
pluginManagement {
- repositories {
- google()
- mavenCentral()
- gradlePluginPortal()
- }
+ repositories {
+ google()
+ mavenCentral()
+ gradlePluginPortal()
+ }
}
+
dependencyResolutionManagement {
- repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
- repositories {
- google()
- mavenCentral()
- }
+ repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
+ repositories {
+ google()
+ mavenCentral()
+ }
}
-rootProject.name = "torchchat"
+rootProject.name = "ExecuTorch Demo"
+
include(":app")
diff --git a/torchchat/utils/scripts/android_example.sh b/torchchat/utils/scripts/android_example.sh
index 06875543f..6ef6824c0 100755
--- a/torchchat/utils/scripts/android_example.sh
+++ b/torchchat/utils/scripts/android_example.sh
@@ -95,7 +95,7 @@ download_aar_library() {
}
build_app() {
- pushd android/torchchat
+ pushd torchchat/edge/android/torchchat
./gradlew :app:build
popd
}
@@ -138,7 +138,7 @@ push_files_to_android() {
}
run_android_instrumented_test() {
- pushd android/torchchat
+ pushd torchchat/edge/android/torchchat
./gradlew connectedAndroidTest
popd
}
@@ -155,7 +155,7 @@ if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
run_android_instrumented_test
fi
-adb install -t android/torchchat/app/build/outputs/apk/debug/app-debug.apk
+adb install -t torchchat/edge/android/torchchat/app/build/outputs/apk/debug/app-debug.apk
if [ -z "${CI_ENV:-}" ]; then
read -p "Press enter to exit emulator and finish"