diff --git a/torchchat/edge/android/torchchat/app/.gitignore b/torchchat/edge/android/torchchat/app/.gitignore index 42afabfd2..796b96d1c 100644 --- a/torchchat/edge/android/torchchat/app/.gitignore +++ b/torchchat/edge/android/torchchat/app/.gitignore @@ -1 +1 @@ -/build \ No newline at end of file +/build diff --git a/torchchat/edge/android/torchchat/app/build.gradle.kts b/torchchat/edge/android/torchchat/app/build.gradle.kts index 1001c9f81..e0c9c196b 100644 --- a/torchchat/edge/android/torchchat/app/build.gradle.kts +++ b/torchchat/edge/android/torchchat/app/build.gradle.kts @@ -1,45 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + plugins { - id("com.android.application") + id("com.android.application") + id("org.jetbrains.kotlin.android") } android { - namespace = "org.pytorch.torchchat" - compileSdk = 33 + namespace = "org.pytorch.torchchat" + compileSdk = 34 - defaultConfig { - applicationId = "org.pytorch.torchchat" - minSdk = 24 - targetSdk = 33 - versionCode = 1 - versionName = "1.0" + defaultConfig { + applicationId = "org.pytorch.torchchat" + minSdk = 28 + targetSdk = 33 + versionCode = 1 + versionName = "1.0" - testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" - } + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + vectorDrawables { useSupportLibrary = true } + externalNativeBuild { cmake { cppFlags += "" } } + } - buildTypes { - release { - isMinifyEnabled = false - proguardFiles( - getDefaultProguardFile("proguard-android-optimize.txt"), - "proguard-rules.pro" - ) - } + buildTypes { + release { + isMinifyEnabled = false + proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro") } - compileOptions { - sourceCompatibility = JavaVersion.VERSION_1_8 - targetCompatibility = JavaVersion.VERSION_1_8 + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 + } + kotlinOptions { jvmTarget = "1.8" } + buildFeatures { compose = true } + composeOptions { kotlinCompilerExtensionVersion = "1.4.3" } + packaging { resources { excludes += "/META-INF/{AL2.0,LGPL2.1}" } } +} + +dependencies { + implementation("androidx.core:core-ktx:1.9.0") + implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.1") + implementation("androidx.activity:activity-compose:1.7.0") + implementation(platform("androidx.compose:compose-bom:2023.03.00")) + implementation("androidx.compose.ui:ui") + implementation("androidx.compose.ui:ui-graphics") + implementation("androidx.compose.ui:ui-tooling-preview") + implementation("androidx.compose.material3:material3") + implementation("androidx.appcompat:appcompat:1.6.1") + implementation("androidx.camera:camera-core:1.3.0-rc02") + implementation("androidx.constraintlayout:constraintlayout:2.2.0-alpha12") + implementation("com.facebook.fbjni:fbjni:0.5.1") + implementation("com.google.code.gson:gson:2.8.6") + implementation(files("libs/executorch-llama.aar")) + implementation("com.google.android.material:material:1.12.0") + implementation("androidx.activity:activity:1.9.0") + testImplementation("junit:junit:4.13.2") + androidTestImplementation("androidx.test.ext:junit:1.1.5") + androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") + androidTestImplementation(platform("androidx.compose:compose-bom:2023.03.00")) + androidTestImplementation("androidx.compose.ui:ui-test-junit4") + debugImplementation("androidx.compose.ui:ui-tooling") + debugImplementation("androidx.compose.ui:ui-test-manifest") +} + +tasks.register("setup") { + doFirst { + exec { + commandLine("sh", "examples/demo-apps/android/LlamaDemo/setup.sh") + workingDir("../../../../../") } - buildFeatures { - viewBinding = true + } +} + +tasks.register("setupQnn") { + doFirst { + exec { + commandLine("sh", "examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh") + workingDir("../../../../../") } + } } -dependencies { - implementation("androidx.appcompat:appcompat:1.6.1") - implementation("androidx.constraintlayout:constraintlayout:2.1.4") - implementation("com.facebook.fbjni:fbjni:0.5.1") - implementation(files("libs/executorch.aar")) - testImplementation("junit:junit:4.13.2") - androidTestImplementation("androidx.test.ext:junit:1.1.5") - androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") +tasks.register("download_prebuilt_lib") { + doFirst { + exec { + commandLine("sh", "examples/demo-apps/android/LlamaDemo/download_prebuilt_lib.sh") + workingDir("../../../../../") + } + } } diff --git a/torchchat/edge/android/torchchat/app/proguard-rules.pro b/torchchat/edge/android/torchchat/app/proguard-rules.pro index f1b424510..481bb4348 100644 --- a/torchchat/edge/android/torchchat/app/proguard-rules.pro +++ b/torchchat/edge/android/torchchat/app/proguard-rules.pro @@ -18,4 +18,4 @@ # If you keep the line number information, uncomment this to # hide the original source file name. -#-renamesourcefileattribute SourceFile +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java b/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java deleted file mode 100644 index e3b4e5019..000000000 --- a/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -package org.pytorch.torchchat; - -import androidx.test.ext.junit.runners.AndroidJUnit4; - -import org.junit.Test; -import org.junit.runner.RunWith; -import org.pytorch.executorch.LlamaCallback; -import org.pytorch.executorch.LlamaModule; - -import static org.junit.Assert.*; - -/** - * Instrumented test, which will execute on an Android device. - * - * @see Testing documentation - */ -@RunWith(AndroidJUnit4.class) -public class LlamaModuleTest { - @Test - public void LlamaModule() { - LlamaModule module = new LlamaModule("/data/local/tmp/llama/model.pte", "/data/local/tmp/llama/tokenizer.bin", 0.8f); - assertEquals(module.load(), 0); - MyLlamaCallback callback = new MyLlamaCallback(); - // Note: module.generate() is synchronous. Callback happens within the same thread as - // generate() so when generate() returns, all callbacks are invoked. - assertEquals(module.generate("Hey", callback), 0); - assertNotEquals("", callback.result); - } -} - -/** - * LlamaCallback for testing. - * - * Note: onResult() and onStats() are invoked within the same thread as LlamaModule.generate() - * - * @see LlamaCallback interface guide - */ -class MyLlamaCallback implements LlamaCallback { - String result = ""; - @Override - public void onResult(String s) { - result += s; - } - - @Override - public void onStats(float v) { - - } -} diff --git a/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/PerfTest.java b/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/PerfTest.java new file mode 100644 index 000000000..e31eb7168 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/PerfTest.java @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import android.os.Bundle; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.platform.app.InstrumentationRegistry; +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.pytorch.executorch.LlamaCallback; +import org.pytorch.executorch.LlamaModule; + +@RunWith(AndroidJUnit4.class) +public class PerfTest implements LlamaCallback { + + private static final String RESOURCE_PATH = "/data/local/tmp/llama/"; + private static final String TOKENIZER_BIN = "tokenizer.bin"; + + private final List results = new ArrayList<>(); + private final List tokensPerSecond = new ArrayList<>(); + + @Test + public void testTokensPerSecond() { + String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN; + // Find out the model name + File directory = new File(RESOURCE_PATH); + Arrays.stream(directory.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .forEach( + model -> { + LlamaModule mModule = new LlamaModule(model.getPath(), tokenizerPath, 0.8f); + // Print the model name because there might be more than one of them + report("ModelName", model.getName()); + + int loadResult = mModule.load(); + // Check that the model can be load successfully + assertEquals(0, loadResult); + + // Run a testing prompt + mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); + assertFalse(tokensPerSecond.isEmpty()); + + final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); + report("TPS", tps); + }); + } + + @Override + public void onResult(String result) { + results.add(result); + } + + @Override + public void onStats(float tps) { + tokensPerSecond.add(tps); + } + + private void report(final String metric, final Float value) { + Bundle bundle = new Bundle(); + bundle.putFloat(metric, value); + InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); + } + + private void report(final String key, final String value) { + Bundle bundle = new Bundle(); + bundle.putString(key, value); + InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/AndroidManifest.xml b/torchchat/edge/android/torchchat/app/src/main/AndroidManifest.xml index 399bd1cbe..9fe80977a 100644 --- a/torchchat/edge/android/torchchat/app/src/main/AndroidManifest.xml +++ b/torchchat/edge/android/torchchat/app/src/main/AndroidManifest.xml @@ -1,35 +1,61 @@ - + xmlns:tools="http://schemas.android.com/tools" + package="org.pytorch.torchchat"> + + + + + + + + + android:theme="@style/Theme.AppCompat.Light.NoActionBar" + tools:targetApi="34"> + + + + + + android:theme="@style/Theme.AppCompat.Light.NoActionBar"> + + + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/AppLog.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/AppLog.java new file mode 100644 index 000000000..cd8462aab --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/AppLog.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Locale; + +public class AppLog { + private final Long timestamp; + private final String message; + + public AppLog(String message) { + this.timestamp = getCurrentTimeStamp(); + this.message = message; + } + + public Long getTimestamp() { + return timestamp; + } + + public String getMessage() { + return message; + } + + public String getFormattedLog() { + return "[" + getFormattedTimeStamp() + "] " + message; + } + + private Long getCurrentTimeStamp() { + return System.currentTimeMillis(); + } + + private String getFormattedTimeStamp() { + return formatDate(timestamp); + } + + private String formatDate(long milliseconds) { + SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault()); + Date date = new Date(milliseconds); + return formatter.format(date); + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/DemoSharedPreferences.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/DemoSharedPreferences.java new file mode 100644 index 000000000..6d672a5b7 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/DemoSharedPreferences.java @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +import android.content.Context; +import android.content.SharedPreferences; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import java.lang.reflect.Type; +import java.util.ArrayList; + +public class DemoSharedPreferences { + Context context; + SharedPreferences sharedPreferences; + + public DemoSharedPreferences(Context context) { + this.context = context; + this.sharedPreferences = getSharedPrefs(); + } + + private SharedPreferences getSharedPrefs() { + return context.getSharedPreferences( + context.getString(R.string.demo_pref_file_key), Context.MODE_PRIVATE); + } + + public String getSavedMessages() { + return sharedPreferences.getString(context.getString(R.string.saved_messages_json_key), ""); + } + + public void addMessages(MessageAdapter messageAdapter) { + SharedPreferences.Editor editor = sharedPreferences.edit(); + Gson gson = new Gson(); + String msgJSON = gson.toJson(messageAdapter.getSavedMessages()); + editor.putString(context.getString(R.string.saved_messages_json_key), msgJSON); + editor.apply(); + } + + public void removeExistingMessages() { + SharedPreferences.Editor editor = sharedPreferences.edit(); + editor.remove(context.getString(R.string.saved_messages_json_key)); + editor.apply(); + } + + public void addSettings(SettingsFields settingsFields) { + SharedPreferences.Editor editor = sharedPreferences.edit(); + Gson gson = new Gson(); + String settingsJSON = gson.toJson(settingsFields); + editor.putString(context.getString(R.string.settings_json_key), settingsJSON); + editor.apply(); + } + + public String getSettings() { + return sharedPreferences.getString(context.getString(R.string.settings_json_key), ""); + } + + public void saveLogs() { + SharedPreferences.Editor editor = sharedPreferences.edit(); + Gson gson = new Gson(); + String msgJSON = gson.toJson(ETLogging.getInstance().getLogs()); + editor.putString(context.getString(R.string.logs_json_key), msgJSON); + editor.apply(); + } + + public void removeExistingLogs() { + SharedPreferences.Editor editor = sharedPreferences.edit(); + editor.remove(context.getString(R.string.logs_json_key)); + editor.apply(); + } + + public ArrayList getSavedLogs() { + String logsJSONString = + sharedPreferences.getString(context.getString(R.string.logs_json_key), null); + if (logsJSONString == null || logsJSONString.isEmpty()) { + return new ArrayList<>(); + } + Gson gson = new Gson(); + Type type = new TypeToken>() {}.getType(); + ArrayList appLogs = gson.fromJson(logsJSONString, type); + if (appLogs == null) { + return new ArrayList<>(); + } + return appLogs; + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETImage.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETImage.java new file mode 100644 index 000000000..28f0752a4 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETImage.java @@ -0,0 +1,126 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +import android.content.ContentResolver; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.graphics.Color; +import android.net.Uri; +import androidx.annotation.Nullable; +import java.io.FileNotFoundException; +import java.io.InputStream; + +public class ETImage { + private int width; + private int height; + private final byte[] bytes; + private final Uri uri; + private final ContentResolver contentResolver; + + ETImage(ContentResolver contentResolver, Uri uri) { + this.contentResolver = contentResolver; + this.uri = uri; + bytes = getBytesFromImageURI(uri); + } + + public int getWidth() { + return width; + } + + public int getHeight() { + return height; + } + + public Uri getUri() { + return uri; + } + + public byte[] getBytes() { + return bytes; + } + + public int[] getInts() { + // We need to convert the byte array to an int array because + // the runner expects an int array as input. + int[] intArray = new int[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + intArray[i] = (bytes[i++] & 0xFF); + } + return intArray; + } + + private byte[] getBytesFromImageURI(Uri uri) { + try { + int RESIZED_IMAGE_WIDTH = 336; + Bitmap bitmap = resizeImage(uri, RESIZED_IMAGE_WIDTH); + + if (bitmap == null) { + ETLogging.getInstance().log("Unable to get bytes from Image URI. Bitmap is null"); + return new byte[0]; + } + + width = bitmap.getWidth(); + height = bitmap.getHeight(); + + byte[] rgbValues = new byte[width * height * 3]; + + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + // Get the color of the current pixel + int color = bitmap.getPixel(x, y); + + // Extract the RGB values from the color + int red = Color.red(color); + int green = Color.green(color); + int blue = Color.blue(color); + + // Store the RGB values in the byte array + rgbValues[y * width + x] = (byte) red; + rgbValues[(y * width + x) + height * width] = (byte) green; + rgbValues[(y * width + x) + 2 * height * width] = (byte) blue; + } + } + return rgbValues; + } catch (FileNotFoundException e) { + throw new RuntimeException(e); + } + } + + @Nullable + private Bitmap resizeImage(Uri uri, int maxLength) throws FileNotFoundException { + InputStream inputStream = contentResolver.openInputStream(uri); + if (inputStream == null) { + ETLogging.getInstance().log("Unable to resize image, input streams is null"); + return null; + } + Bitmap bitmap = BitmapFactory.decodeStream(inputStream); + if (bitmap == null) { + ETLogging.getInstance().log("Unable to resize image, bitmap during decode stream is null"); + return null; + } + + float aspectRatio; + int finalWidth, finalHeight; + + if (bitmap.getWidth() > bitmap.getHeight()) { + // width > height --> width = maxLength, height scale with aspect ratio + aspectRatio = bitmap.getWidth() / (float) bitmap.getHeight(); + finalWidth = maxLength; + finalHeight = Math.round(maxLength / aspectRatio); + } else { + // height >= width --> height = maxLength, width scale with aspect ratio + aspectRatio = bitmap.getHeight() / (float) bitmap.getWidth(); + finalHeight = maxLength; + finalWidth = Math.round(maxLength / aspectRatio); + } + + return Bitmap.createScaledBitmap(bitmap, finalWidth, finalHeight, false); + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETLogging.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETLogging.java new file mode 100644 index 000000000..c13abd02a --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETLogging.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +import android.app.Application; +import android.util.Log; +import java.util.ArrayList; + +public class ETLogging extends Application { + private static ETLogging singleton; + + private ArrayList logs; + private DemoSharedPreferences mDemoSharedPreferences; + + @Override + public void onCreate() { + super.onCreate(); + singleton = this; + mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext()); + logs = mDemoSharedPreferences.getSavedLogs(); + if (logs == null) { // We don't have existing sharedPreference stored + logs = new ArrayList<>(); + } + } + + public static ETLogging getInstance() { + return singleton; + } + + public void log(String message) { + AppLog appLog = new AppLog(message); + logs.add(appLog); + Log.d("ETLogging", appLog.getMessage()); + } + + public ArrayList getLogs() { + return logs; + } + + public void clearLogs() { + logs.clear(); + mDemoSharedPreferences.removeExistingLogs(); + } + + public void saveLogs() { + mDemoSharedPreferences.saveLogs(); + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LlmBenchmarkRunner.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LlmBenchmarkRunner.java new file mode 100644 index 000000000..d0a7cc677 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LlmBenchmarkRunner.java @@ -0,0 +1,223 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +import android.app.Activity; +import android.app.ActivityManager; +import android.content.Intent; +import android.os.Build; +import android.os.Bundle; +import android.util.Log; +import android.widget.TextView; +import androidx.annotation.NonNull; +import com.google.gson.Gson; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback { + ModelRunner mModelRunner; + + String mPrompt; + TextView mTextView; + StatsDump mStatsDump; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_benchmarking); + mTextView = findViewById(R.id.log_view); + + Intent intent = getIntent(); + + File modelDir = new File(intent.getStringExtra("model_dir")); + File model = + Arrays.stream(modelDir.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .findFirst() + .get(); + String tokenizerPath = intent.getStringExtra("tokenizer_path"); + + float temperature = intent.getFloatExtra("temperature", 0.8f); + mPrompt = intent.getStringExtra("prompt"); + if (mPrompt == null) { + mPrompt = "The ultimate answer"; + } + + mStatsDump = new StatsDump(); + mStatsDump.modelName = model.getName().replace(".pte", ""); + mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); + mStatsDump.loadStart = System.nanoTime(); + } + + @Override + public void onModelLoaded(int status) { + mStatsDump.loadEnd = System.nanoTime(); + mStatsDump.loadStatus = status; + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); + onGenerationStopped(); + return; + } + mStatsDump.generateStart = System.nanoTime(); + mModelRunner.generate(mPrompt); + } + + @Override + public void onTokenGenerated(String token) { + runOnUiThread( + () -> { + mTextView.append(token); + }); + } + + @Override + public void onStats(String stats) { + mStatsDump.tokens = stats; + } + + @Override + public void onGenerationStopped() { + mStatsDump.generateEnd = System.nanoTime(); + runOnUiThread( + () -> { + mTextView.append(mStatsDump.toString()); + }); + + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(mStatsDump.modelName); + final List results = new ArrayList<>(); + // The list of metrics we have atm includes: + // Load status + results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsDump.loadStatus, 0)); + // Model load time + results.add( + new BenchmarkMetric( + benchmarkModel, + "model_load_time(ms)", + (mStatsDump.loadEnd - mStatsDump.loadStart) * 1e-6, + 0.0f)); + // LLM generate time + results.add( + new BenchmarkMetric( + benchmarkModel, + "generate_time(ms)", + (mStatsDump.generateEnd - mStatsDump.generateStart) * 1e-6, + 0.0f)); + // Token per second + results.add( + new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f)); + + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(results)); + } catch (IOException e) { + e.printStackTrace(); + } + } + + private double extractTPS(final String tokens) { + final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens); + if (m.find()) { + return Double.parseDouble(m.group()); + } else { + return 0.0f; + } + } +} + +class BenchmarkMetric { + public static class BenchmarkModel { + // The model name, i.e. stories110M + String name; + String backend; + String quantization; + + public BenchmarkModel(final String name, final String backend, final String quantization) { + this.name = name; + this.backend = backend; + this.quantization = quantization; + } + } + + BenchmarkModel benchmarkModel; + + // The metric name, i.e. TPS + String metric; + + // The actual value and the option target value + double actualValue; + double targetValue; + + public static class DeviceInfo { + // Let's see which information we want to include here + final String device = Build.BRAND; + // The phone model and Android release version + final String arch = Build.MODEL; + final String os = "Android " + Build.VERSION.RELEASE; + final long totalMem = new ActivityManager.MemoryInfo().totalMem; + final long availMem = new ActivityManager.MemoryInfo().availMem; + } + + DeviceInfo deviceInfo = new DeviceInfo(); + + public BenchmarkMetric( + final BenchmarkModel benchmarkModel, + final String metric, + final double actualValue, + final double targetValue) { + this.benchmarkModel = benchmarkModel; + this.metric = metric; + this.actualValue = actualValue; + this.targetValue = targetValue; + } + + // TODO (huydhn): Figure out a way to extract the backend and quantization information from + // the .pte model itself instead of parsing its name + public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { + final Matcher m = + Pattern.compile("(?\\w+)_(?\\w+)_(?\\w+)").matcher(model); + if (m.matches()) { + return new BenchmarkMetric.BenchmarkModel( + m.group("name"), m.group("backend"), m.group("quantization")); + } else { + return new BenchmarkMetric.BenchmarkModel(model, "", ""); + } + } +} + +class StatsDump { + int loadStatus; + long loadStart; + long loadEnd; + long generateStart; + long generateEnd; + String tokens; + String modelName; + + @NonNull + @Override + public String toString() { + return "loadStart: " + + loadStart + + "\nloadEnd: " + + loadEnd + + "\ngenerateStart: " + + generateStart + + "\ngenerateEnd: " + + generateEnd + + "\n" + + tokens; + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsActivity.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsActivity.java new file mode 100644 index 000000000..f83d691c8 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsActivity.java @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +import android.app.AlertDialog; +import android.content.DialogInterface; +import android.os.Build; +import android.os.Bundle; +import android.widget.ImageButton; +import android.widget.ListView; +import androidx.appcompat.app.AppCompatActivity; +import androidx.core.content.ContextCompat; +import androidx.core.graphics.Insets; +import androidx.core.view.ViewCompat; +import androidx.core.view.WindowInsetsCompat; + +public class LogsActivity extends AppCompatActivity { + + private LogsAdapter mLogsAdapter; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_logs); + if (Build.VERSION.SDK_INT >= 21) { + getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar)); + getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar)); + } + ViewCompat.setOnApplyWindowInsetsListener( + requireViewById(R.id.main), + (v, insets) -> { + Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars()); + v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom); + return insets; + }); + + setupLogs(); + setupClearLogsButton(); + } + + @Override + public void onResume() { + super.onResume(); + mLogsAdapter.clear(); + mLogsAdapter.addAll(ETLogging.getInstance().getLogs()); + mLogsAdapter.notifyDataSetChanged(); + } + + private void setupLogs() { + ListView mLogsListView = requireViewById(R.id.logsListView); + mLogsAdapter = new LogsAdapter(this, R.layout.logs_message); + + mLogsListView.setAdapter(mLogsAdapter); + mLogsAdapter.addAll(ETLogging.getInstance().getLogs()); + mLogsAdapter.notifyDataSetChanged(); + } + + private void setupClearLogsButton() { + ImageButton clearLogsButton = requireViewById(R.id.clearLogsButton); + clearLogsButton.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Delete Logs History") + .setMessage("Do you really want to delete logs history?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + // Clear the messageAdapter and sharedPreference + ETLogging.getInstance().clearLogs(); + mLogsAdapter.clear(); + mLogsAdapter.notifyDataSetChanged(); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + @Override + protected void onDestroy() { + super.onDestroy(); + ETLogging.getInstance().saveLogs(); + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsAdapter.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsAdapter.java new file mode 100644 index 000000000..a793644ed --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsAdapter.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +import android.view.LayoutInflater; +import android.view.View; +import android.view.ViewGroup; +import android.widget.ArrayAdapter; +import android.widget.TextView; +import androidx.annotation.NonNull; +import java.util.Objects; + +public class LogsAdapter extends ArrayAdapter { + public LogsAdapter(android.content.Context context, int resource) { + super(context, resource); + } + + static class ViewHolder { + private TextView logTextView; + } + + @NonNull + @Override + public View getView(int position, View convertView, @NonNull ViewGroup parent) { + ViewHolder mViewHolder = null; + + String logMessage = Objects.requireNonNull(getItem(position)).getFormattedLog(); + + if (convertView == null || convertView.getTag() == null) { + mViewHolder = new ViewHolder(); + convertView = LayoutInflater.from(getContext()).inflate(R.layout.logs_message, parent, false); + mViewHolder.logTextView = convertView.requireViewById(R.id.logsTextView); + } else { + mViewHolder = (ViewHolder) convertView.getTag(); + } + mViewHolder.logTextView.setText(logMessage); + return convertView; + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MainActivity.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MainActivity.java index 4be7303c3..3c7dd55c7 100644 --- a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MainActivity.java +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MainActivity.java @@ -8,245 +8,768 @@ package org.pytorch.torchchat; -import android.app.Activity; +import android.Manifest; import android.app.ActivityManager; import android.app.AlertDialog; -import android.content.Context; +import android.content.ContentResolver; +import android.content.ContentValues; +import android.content.Intent; +import android.content.pm.PackageManager; +import android.net.Uri; +import android.os.Build; import android.os.Bundle; -import android.widget.Button; +import android.os.Handler; +import android.os.Looper; +import android.os.Process; +import android.provider.MediaStore; +import android.system.ErrnoException; +import android.system.Os; +import android.util.Log; +import android.view.View; import android.widget.EditText; import android.widget.ImageButton; +import android.widget.ImageView; +import android.widget.LinearLayout; import android.widget.ListView; +import android.widget.TextView; import android.widget.Toast; - -import java.io.File; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.PickVisualMediaRequest; +import androidx.activity.result.contract.ActivityResultContracts; +import androidx.annotation.NonNull; +import androidx.appcompat.app.AppCompatActivity; +import androidx.constraintlayout.widget.ConstraintLayout; +import androidx.core.app.ActivityCompat; +import androidx.core.content.ContextCompat; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import org.pytorch.executorch.LlamaCallback; import org.pytorch.executorch.LlamaModule; -public class MainActivity extends Activity implements Runnable, LlamaCallback { - private EditText mEditTextMessage; - private Button mSendButton; - private ImageButton mModelButton; - private ListView mMessagesView; - private MessageAdapter mMessageAdapter; - private LlamaModule mModule = null; - private Message mResultMessage = null; - - private String mModelFilePath = ""; - private String mTokenizerFilePath = ""; - - @Override - public void onResult(String result) { - if (result.startsWith("<") && result.endsWith(">")) { - return; - } +public class MainActivity extends AppCompatActivity implements Runnable, LlamaCallback { + private EditText mEditTextMessage; + private ImageButton mSendButton; + private ImageButton mGalleryButton; + private ImageButton mCameraButton; + private ListView mMessagesView; + private MessageAdapter mMessageAdapter; + private LlamaModule mModule = null; + private Message mResultMessage = null; + private ImageButton mSettingsButton; + private TextView mMemoryView; + private ActivityResultLauncher mPickGallery; + private ActivityResultLauncher mCameraRoll; + private List mSelectedImageUri; + private ConstraintLayout mMediaPreviewConstraintLayout; + private LinearLayout mAddMediaLayout; + private static final int MAX_NUM_OF_IMAGES = 5; + private static final int REQUEST_IMAGE_CAPTURE = 1; + private Uri cameraImageUri; + private DemoSharedPreferences mDemoSharedPreferences; + private SettingsFields mCurrentSettingsFields; + private Handler mMemoryUpdateHandler; + private Runnable memoryUpdater; + private int promptID = 0; + private long startPos = 0; + private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2; + private Executor executor; + + @Override + public void onResult(String result) { + if (result.equals(PromptFormat.getStopToken(mCurrentSettingsFields.getModelType()))) { + return; + } + if (result.equals("\n\n") || result.equals("\n")) { + if (!mResultMessage.getText().isEmpty()) { mResultMessage.appendText(result); run(); + } + } else { + mResultMessage.appendText(result); + run(); } + } - @Override - public void onStats(float tps) { - runOnUiThread( - () -> { - if (mResultMessage != null) { - mResultMessage.setTokensPerSecond(tps); - mMessageAdapter.notifyDataSetChanged(); - } - }); - } - - private static String[] listLocalFile(String path, String suffix) { - File directory = new File(path); - if (directory.exists() && directory.isDirectory()) { - File[] files = directory.listFiles((dir, name) -> name.toLowerCase().endsWith(suffix)); - String[] result = new String[files.length]; - for (int i = 0; i < files.length; i++) { - if (files[i].isFile() && files[i].getName().endsWith(suffix)) { - result[i] = files[i].getAbsolutePath(); - } - } - return result; - } - return null; - } - - private void setLocalModel(String modelPath, String tokenizerPath) { - Message modelLoadingMessage = new Message("Loading model...", false); - runOnUiThread( - () -> { - mSendButton.setEnabled(false); - mMessageAdapter.add(modelLoadingMessage); - mMessageAdapter.notifyDataSetChanged(); - }); - long runStartTime = System.currentTimeMillis(); - mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f); - int loadResult = mModule.load(); - if (loadResult != 0) { - AlertDialog.Builder builder = new AlertDialog.Builder(this); - builder.setTitle("Load failed: " + loadResult); - runOnUiThread( - () -> { - AlertDialog alert = builder.create(); - alert.show(); - }); - } + @Override + public void onStats(float tps) { + runOnUiThread( + () -> { + if (mResultMessage != null) { + mResultMessage.setTokensPerSecond(tps); + mMessageAdapter.notifyDataSetChanged(); + } + }); + } + + private void setLocalModel(String modelPath, String tokenizerPath, float temperature) { + Message modelLoadingMessage = new Message("Loading model...", false, MessageType.SYSTEM, 0); + ETLogging.getInstance().log("Loading model " + modelPath + " with tokenizer " + tokenizerPath); + runOnUiThread( + () -> { + mSendButton.setEnabled(false); + mMessageAdapter.add(modelLoadingMessage); + mMessageAdapter.notifyDataSetChanged(); + }); + if (mModule != null) { + ETLogging.getInstance().log("Start deallocating existing module instance"); + mModule.resetNative(); + mModule = null; + ETLogging.getInstance().log("Completed deallocating existing module instance"); + } + long runStartTime = System.currentTimeMillis(); + mModule = + new LlamaModule( + ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()), + modelPath, + tokenizerPath, + temperature); + int loadResult = mModule.load(); + long loadDuration = System.currentTimeMillis() - runStartTime; + String modelLoadError = ""; + String modelInfo = ""; + if (loadResult != 0) { + // TODO: Map the error code to a reason to let the user know why model loading failed + modelInfo = "*Model could not load (Error Code: " + loadResult + ")*" + "\n"; + loadDuration = 0; + AlertDialog.Builder builder = new AlertDialog.Builder(this); + builder.setTitle("Load failed: " + loadResult); + runOnUiThread( + () -> { + AlertDialog alert = builder.create(); + alert.show(); + }); + } else { + String[] segments = modelPath.split("/"); + String pteName = segments[segments.length - 1]; + segments = tokenizerPath.split("/"); + String tokenizerName = segments[segments.length - 1]; + modelInfo = + "Successfully loaded model. " + + pteName + + " and tokenizer " + + tokenizerName + + " in " + + (float) loadDuration / 1000 + + " sec." + + " You can send text or image for inference"; + + if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { + ETLogging.getInstance().log("Llava start prefill prompt"); + startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt(), 0, 1, 0); + ETLogging.getInstance().log("Llava completes prefill prompt"); + } + } + + Message modelLoadedMessage = new Message(modelInfo, false, MessageType.SYSTEM, 0); + + String modelLoggingInfo = + modelLoadError + + "Model path: " + + modelPath + + "\nTokenizer path: " + + tokenizerPath + + "\nTemperature: " + + temperature + + "\nModel loaded time: " + + loadDuration + + " ms"; + ETLogging.getInstance().log("Load complete. " + modelLoggingInfo); + + runOnUiThread( + () -> { + mSendButton.setEnabled(true); + mMessageAdapter.remove(modelLoadingMessage); + mMessageAdapter.add(modelLoadedMessage); + mMessageAdapter.notifyDataSetChanged(); + }); + } + + private void loadLocalModelAndParameters( + String modelFilePath, String tokenizerFilePath, float temperature) { + Runnable runnable = + new Runnable() { + @Override + public void run() { + setLocalModel(modelFilePath, tokenizerFilePath, temperature); + } + }; + new Thread(runnable).start(); + } + + private void populateExistingMessages(String existingMsgJSON) { + Gson gson = new Gson(); + Type type = new TypeToken>() {}.getType(); + ArrayList savedMessages = gson.fromJson(existingMsgJSON, type); + for (Message msg : savedMessages) { + mMessageAdapter.add(msg); + } + mMessageAdapter.notifyDataSetChanged(); + } + + private int setPromptID() { + + return mMessageAdapter.getMaxPromptID() + 1; + } + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + if (Build.VERSION.SDK_INT >= 21) { + getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar)); + getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar)); + } + + try { + Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); + } catch (ErrnoException e) { + finish(); + } + + mEditTextMessage = requireViewById(R.id.editTextMessage); + mSendButton = requireViewById(R.id.sendButton); + mSendButton.setEnabled(false); + mMessagesView = requireViewById(R.id.messages_view); + mMessageAdapter = new MessageAdapter(this, R.layout.sent_message, new ArrayList()); + mMessagesView.setAdapter(mMessageAdapter); + mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext()); + String existingMsgJSON = mDemoSharedPreferences.getSavedMessages(); + if (!existingMsgJSON.isEmpty()) { + populateExistingMessages(existingMsgJSON); + promptID = setPromptID(); + } + mSettingsButton = requireViewById(R.id.settings); + mSettingsButton.setOnClickListener( + view -> { + Intent myIntent = new Intent(MainActivity.this, SettingsActivity.class); + MainActivity.this.startActivity(myIntent); + }); + + mCurrentSettingsFields = new SettingsFields(); + mMemoryUpdateHandler = new Handler(Looper.getMainLooper()); + onModelRunStopped(); + setupMediaButton(); + setupGalleryPicker(); + setupCameraRoll(); + startMemoryUpdate(); + setupShowLogsButton(); + executor = Executors.newSingleThreadExecutor(); + } - long loadDuration = System.currentTimeMillis() - runStartTime; - String modelInfo = - "Model path: " - + modelPath - + "\nTokenizer path: " - + tokenizerPath - + "\nModel loaded time: " - + loadDuration - + " ms"; - Message modelLoadedMessage = new Message(modelInfo, false); - runOnUiThread( - () -> { - mSendButton.setEnabled(true); - mMessageAdapter.remove(modelLoadingMessage); - mMessageAdapter.add(modelLoadedMessage); - mMessageAdapter.notifyDataSetChanged(); - }); - } - - private String memoryInfo() { - final ActivityManager am = (ActivityManager) getSystemService(Context.ACTIVITY_SERVICE); - ActivityManager.MemoryInfo memInfo = new ActivityManager.MemoryInfo(); - am.getMemoryInfo(memInfo); - return "Total RAM: " - + Math.floorDiv(memInfo.totalMem, 1000000) - + " MB. Available RAM: " - + Math.floorDiv(memInfo.availMem, 1000000) - + " MB."; - } - - private void modelDialog() { - String[] pteFiles = listLocalFile("/data/local/tmp/llama/", ".pte"); - String[] binFiles = listLocalFile("/data/local/tmp/llama/", ".bin"); - String[] modelFiles = listLocalFile("/data/local/tmp/llama/", ".model"); - if (pteFiles == null || binFiles == null || modelFiles == null) { - Toast.makeText(this, - "Please create directory /data/local/tmp/llama/ first", - Toast.LENGTH_LONG).show(); - return; + @Override + protected void onPause() { + super.onPause(); + mDemoSharedPreferences.addMessages(mMessageAdapter); + } + + @Override + protected void onResume() { + super.onResume(); + // Check for if settings parameters have changed + Gson gson = new Gson(); + String settingsFieldsJSON = mDemoSharedPreferences.getSettings(); + if (!settingsFieldsJSON.isEmpty()) { + SettingsFields updatedSettingsFields = + gson.fromJson(settingsFieldsJSON, SettingsFields.class); + if (updatedSettingsFields == null) { + // Added this check, because gson.fromJson can return null + askUserToSelectModel(); + return; + } + boolean isUpdated = !mCurrentSettingsFields.equals(updatedSettingsFields); + boolean isLoadModel = updatedSettingsFields.getIsLoadModel(); + if (isUpdated) { + if (isLoadModel) { + // If users change the model file, but not pressing loadModelButton, we won't load the new + // model + checkForUpdateAndReloadModel(updatedSettingsFields); + } else { + askUserToSelectModel(); } - String[] tokenizerFiles = new String[binFiles.length + modelFiles.length]; - System.arraycopy(binFiles, 0, tokenizerFiles, 0, binFiles.length); - System.arraycopy(modelFiles, 0, tokenizerFiles, binFiles.length, modelFiles.length); - AlertDialog.Builder modelPathBuilder = new AlertDialog.Builder(this); - modelPathBuilder.setTitle("Select model path"); - AlertDialog.Builder tokenizerPathBuilder = new AlertDialog.Builder(this); - tokenizerPathBuilder.setTitle("Select tokenizer path"); - modelPathBuilder.setSingleChoiceItems( - pteFiles, - -1, - (dialog, item) -> { - mModelFilePath = pteFiles[item]; - mEditTextMessage.setText(""); - dialog.dismiss(); - tokenizerPathBuilder.create().show(); - }); - - tokenizerPathBuilder.setSingleChoiceItems( - tokenizerFiles, - -1, - (dialog, item) -> { - mTokenizerFilePath = tokenizerFiles[item]; - Runnable runnable = - new Runnable() { - @Override - public void run() { - setLocalModel(mModelFilePath, mTokenizerFilePath); - } - }; - new Thread(runnable).start(); - dialog.dismiss(); - }); - - modelPathBuilder.create().show(); - } - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - setContentView(R.layout.activity_main); - - mEditTextMessage = findViewById(R.id.editTextMessage); - mSendButton = findViewById(R.id.sendButton); - mSendButton.setEnabled(false); - mModelButton = findViewById(R.id.modelButton); - mMessagesView = findViewById(R.id.messages_view); - mMessageAdapter = new MessageAdapter(this, R.layout.sent_message); - mMessagesView.setAdapter(mMessageAdapter); - mModelButton.setOnClickListener( - view -> { - mModule.stop(); - mMessageAdapter.clear(); - mMessageAdapter.notifyDataSetChanged(); - modelDialog(); - }); - - onModelRunStopped(); - modelDialog(); - } - - private void onModelRunStarted() { - mSendButton.setText("Stop"); - mSendButton.setOnClickListener( - view -> { - mModule.stop(); - }); - } - - private void onModelRunStopped() { - setTitle(memoryInfo()); - mSendButton.setText("Generate"); - mSendButton.setOnClickListener( - view -> { - String prompt = mEditTextMessage.getText().toString(); - mMessageAdapter.add(new Message(prompt, true)); - mMessageAdapter.notifyDataSetChanged(); - mEditTextMessage.setText(""); - mResultMessage = new Message("", false); - mMessageAdapter.add(mResultMessage); - Runnable runnable = - new Runnable() { - @Override - public void run() { - runOnUiThread( - new Runnable() { - @Override - public void run() { - onModelRunStarted(); - } - }); - - mModule.generate(prompt, MainActivity.this); - - runOnUiThread( - new Runnable() { - @Override - public void run() { - onModelRunStopped(); - } - }); - } - }; - new Thread(runnable).start(); - }); + checkForClearChatHistory(updatedSettingsFields); + // Update current to point to the latest + mCurrentSettingsFields = new SettingsFields(updatedSettingsFields); + } + } else { + askUserToSelectModel(); + } + } + + private void checkForClearChatHistory(SettingsFields updatedSettingsFields) { + if (updatedSettingsFields.getIsClearChatHistory()) { + mMessageAdapter.clear(); + mMessageAdapter.notifyDataSetChanged(); + mDemoSharedPreferences.removeExistingMessages(); + // changing to false since chat history has been cleared. + updatedSettingsFields.saveIsClearChatHistory(false); + mDemoSharedPreferences.addSettings(updatedSettingsFields); + } + } + + private void checkForUpdateAndReloadModel(SettingsFields updatedSettingsFields) { + // TODO need to add 'load model' in settings and queue loading based on that + String modelPath = updatedSettingsFields.getModelFilePath(); + String tokenizerPath = updatedSettingsFields.getTokenizerFilePath(); + double temperature = updatedSettingsFields.getTemperature(); + if (!modelPath.isEmpty() && !tokenizerPath.isEmpty()) { + if (updatedSettingsFields.getIsLoadModel() + || !modelPath.equals(mCurrentSettingsFields.getModelFilePath()) + || !tokenizerPath.equals(mCurrentSettingsFields.getTokenizerFilePath()) + || temperature != mCurrentSettingsFields.getTemperature()) { + loadLocalModelAndParameters( + updatedSettingsFields.getModelFilePath(), + updatedSettingsFields.getTokenizerFilePath(), + (float) updatedSettingsFields.getTemperature()); + updatedSettingsFields.saveLoadModelAction(false); + mDemoSharedPreferences.addSettings(updatedSettingsFields); + } + } else { + askUserToSelectModel(); + } + } + + private void askUserToSelectModel() { + String askLoadModel = + "To get started, select your desired model and tokenizer " + "from the top right corner"; + Message askLoadModelMessage = new Message(askLoadModel, false, MessageType.SYSTEM, 0); + ETLogging.getInstance().log(askLoadModel); + runOnUiThread( + () -> { + mMessageAdapter.add(askLoadModelMessage); + mMessageAdapter.notifyDataSetChanged(); + }); + } + + private void setupShowLogsButton() { + ImageButton showLogsButton = requireViewById(R.id.showLogsButton); + showLogsButton.setOnClickListener( + view -> { + Intent myIntent = new Intent(MainActivity.this, LogsActivity.class); + MainActivity.this.startActivity(myIntent); + }); + } + + private void setupMediaButton() { + mAddMediaLayout = requireViewById(R.id.addMediaLayout); + mAddMediaLayout.setVisibility(View.GONE); // We hide this initially + + ImageButton addMediaButton = requireViewById(R.id.addMediaButton); + addMediaButton.setOnClickListener( + view -> { + mAddMediaLayout.setVisibility(View.VISIBLE); + }); + + mGalleryButton = requireViewById(R.id.galleryButton); + mGalleryButton.setOnClickListener( + view -> { + // Launch the photo picker and let the user choose only images. + mPickGallery.launch( + new PickVisualMediaRequest.Builder() + .setMediaType(ActivityResultContracts.PickVisualMedia.ImageOnly.INSTANCE) + .build()); + }); + mCameraButton = requireViewById(R.id.cameraButton); + mCameraButton.setOnClickListener( + view -> { + Log.d("CameraRoll", "Check permission"); + if (ContextCompat.checkSelfPermission(MainActivity.this, Manifest.permission.CAMERA) + != PackageManager.PERMISSION_GRANTED) { + ActivityCompat.requestPermissions( + MainActivity.this, + new String[] {Manifest.permission.CAMERA}, + REQUEST_IMAGE_CAPTURE); + } else { + launchCamera(); + } + }); + } + + private void setupCameraRoll() { + // Registers a camera roll activity launcher. + mCameraRoll = + registerForActivityResult( + new ActivityResultContracts.TakePicture(), + result -> { + if (result && cameraImageUri != null) { + Log.d("CameraRoll", "Photo saved to uri: " + cameraImageUri); + mAddMediaLayout.setVisibility(View.GONE); + List uris = new ArrayList<>(); + uris.add(cameraImageUri); + showMediaPreview(uris); + } else { + // Delete the temp image file based on the url since the photo is not successfully + // taken + if (cameraImageUri != null) { + ContentResolver contentResolver = MainActivity.this.getContentResolver(); + contentResolver.delete(cameraImageUri, null, null); + Log.d("CameraRoll", "No photo taken. Delete temp uri"); + } + } + }); + mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout); + ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton); + mediaPreviewCloseButton.setOnClickListener( + view -> { + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + mSelectedImageUri = null; + }); + + ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton); + addMoreImageButton.setOnClickListener( + view -> { + Log.d("addMore", "clicked"); + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + // Direct user to select type of input + mCameraButton.callOnClick(); + }); + } + + private String updateMemoryUsage() { + ActivityManager.MemoryInfo memoryInfo = new ActivityManager.MemoryInfo(); + ActivityManager activityManager = (ActivityManager) getSystemService(ACTIVITY_SERVICE); + if (activityManager == null) { + return "---"; + } + activityManager.getMemoryInfo(memoryInfo); + long totalMem = memoryInfo.totalMem / (1024 * 1024); + long availableMem = memoryInfo.availMem / (1024 * 1024); + long usedMem = totalMem - availableMem; + return usedMem + "MB"; + } + + private void startMemoryUpdate() { + mMemoryView = requireViewById(R.id.ram_usage_live); + memoryUpdater = + new Runnable() { + @Override + public void run() { + mMemoryView.setText(updateMemoryUsage()); + mMemoryUpdateHandler.postDelayed(this, 1000); + } + }; + mMemoryUpdateHandler.post(memoryUpdater); + } + + @Override + public void onRequestPermissionsResult( + int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + if (requestCode == REQUEST_IMAGE_CAPTURE && grantResults.length != 0) { + if (grantResults[0] == PackageManager.PERMISSION_GRANTED) { + launchCamera(); + } else if (grantResults[0] == PackageManager.PERMISSION_DENIED) { + Log.d("CameraRoll", "Permission denied"); + } + } + } + + private void launchCamera() { + ContentValues values = new ContentValues(); + values.put(MediaStore.Images.Media.TITLE, "New Picture"); + values.put(MediaStore.Images.Media.DESCRIPTION, "From Camera"); + values.put(MediaStore.Images.Media.RELATIVE_PATH, "DCIM/Camera/"); + cameraImageUri = + MainActivity.this + .getContentResolver() + .insert(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, values); + mCameraRoll.launch(cameraImageUri); + } + + private void setupGalleryPicker() { + // Registers a photo picker activity launcher in single-select mode. + mPickGallery = + registerForActivityResult( + new ActivityResultContracts.PickMultipleVisualMedia(MAX_NUM_OF_IMAGES), + uris -> { + if (!uris.isEmpty()) { + Log.d("PhotoPicker", "Selected URIs: " + uris); + mAddMediaLayout.setVisibility(View.GONE); + for (Uri uri : uris) { + MainActivity.this + .getContentResolver() + .takePersistableUriPermission(uri, Intent.FLAG_GRANT_READ_URI_PERMISSION); + } + showMediaPreview(uris); + } else { + Log.d("PhotoPicker", "No media selected"); + } + }); + + mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout); + ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton); + mediaPreviewCloseButton.setOnClickListener( + view -> { + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + mSelectedImageUri = null; + }); + + ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton); + addMoreImageButton.setOnClickListener( + view -> { + Log.d("addMore", "clicked"); + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + mGalleryButton.callOnClick(); + }); + } + + private List getProcessedImagesForModel(List uris) { + List imageList = new ArrayList<>(); + if (uris != null) { + uris.forEach( + (uri) -> { + imageList.add(new ETImage(this.getContentResolver(), uri)); + }); + } + return imageList; + } + + private void showMediaPreview(List uris) { + if (mSelectedImageUri == null) { + mSelectedImageUri = uris; + } else { + mSelectedImageUri.addAll(uris); + } + + if (mSelectedImageUri.size() > MAX_NUM_OF_IMAGES) { + mSelectedImageUri = mSelectedImageUri.subList(0, MAX_NUM_OF_IMAGES); + Toast.makeText( + this, "Only max " + MAX_NUM_OF_IMAGES + " images are allowed", Toast.LENGTH_SHORT) + .show(); + } + Log.d("mSelectedImageUri", mSelectedImageUri.size() + " " + mSelectedImageUri); + + mMediaPreviewConstraintLayout.setVisibility(View.VISIBLE); + + List imageViews = new ArrayList(); + + // Pre-populate all the image views that are available from the layout (currently max 5) + imageViews.add(requireViewById(R.id.mediaPreviewImageView1)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView2)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView3)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView4)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView5)); + + // Hide all the image views (reset state) + for (int i = 0; i < imageViews.size(); i++) { + imageViews.get(i).setVisibility(View.GONE); + } + + // Only show/render those that have proper Image URIs + for (int i = 0; i < mSelectedImageUri.size(); i++) { + imageViews.get(i).setVisibility(View.VISIBLE); + imageViews.get(i).setImageURI(mSelectedImageUri.get(i)); + } + + // For LLava, we want to call prefill_image as soon as an image is selected + // Llava only support 1 image for now + if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { + List processedImageList = getProcessedImagesForModel(mSelectedImageUri); + if (!processedImageList.isEmpty()) { + mMessageAdapter.add( + new Message("Llava - Starting image Prefill.", false, MessageType.SYSTEM, 0)); mMessageAdapter.notifyDataSetChanged(); + Runnable runnable = + () -> { + Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); + ETLogging.getInstance().log("Starting runnable prefill image"); + ETImage img = processedImageList.get(0); + ETLogging.getInstance().log("Llava start prefill image"); + startPos = + mModule.prefillImages( + img.getInts(), + img.getWidth(), + img.getHeight(), + ModelUtils.VISION_MODEL_IMAGE_CHANNELS, + startPos); + }; + executor.execute(runnable); + } + } + } + + private void addSelectedImagesToChatThread(List selectedImageUri) { + if (selectedImageUri == null) { + return; + } + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + for (int i = 0; i < selectedImageUri.size(); i++) { + Uri imageURI = selectedImageUri.get(i); + Log.d("image uri ", "test " + imageURI.getPath()); + mMessageAdapter.add(new Message(imageURI.toString(), true, MessageType.IMAGE, 0)); + } + mMessageAdapter.notifyDataSetChanged(); + } + + private String getConversationHistory() { + String conversationHistory = ""; + + ArrayList conversations = + mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK); + if (conversations.isEmpty()) { + return conversationHistory; + } + + int prevPromptID = conversations.get(0).getPromptID(); + String conversationFormat = + PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType()); + String format = conversationFormat; + for (int i = 0; i < conversations.size(); i++) { + Message conversation = conversations.get(i); + int currentPromptID = conversation.getPromptID(); + if (currentPromptID != prevPromptID) { + conversationHistory = conversationHistory + format; + format = conversationFormat; + prevPromptID = currentPromptID; + } + if (conversation.getIsSent()) { + format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText()); + } else { + format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText()); + } } + conversationHistory = conversationHistory + format; - @Override - public void run() { - runOnUiThread( - new Runnable() { - @Override - public void run() { - mMessageAdapter.notifyDataSetChanged(); - setTitle(memoryInfo()); - } - }); + return conversationHistory; + } + + private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) { + if (conversationHistory.isEmpty()) { + return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt); + } + + return mCurrentSettingsFields.getFormattedSystemPrompt() + + conversationHistory + + mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt); + } + + private void onModelRunStarted() { + mSendButton.setClickable(false); + mSendButton.setImageResource(R.drawable.baseline_stop_24); + mSendButton.setOnClickListener( + view -> { + mModule.stop(); + }); + } + + private void onModelRunStopped() { + mSendButton.setClickable(true); + mSendButton.setImageResource(R.drawable.baseline_send_24); + mSendButton.setOnClickListener( + view -> { + addSelectedImagesToChatThread(mSelectedImageUri); + String finalPrompt; + String rawPrompt = mEditTextMessage.getText().toString(); + if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()) + == ModelUtils.VISION_MODEL) { + finalPrompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt); + } else { + finalPrompt = getTotalFormattedPrompt(getConversationHistory(), rawPrompt); + } + // We store raw prompt into message adapter, because we don't want to show the extra + // tokens from system prompt + mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID)); + mMessageAdapter.notifyDataSetChanged(); + mEditTextMessage.setText(""); + mResultMessage = new Message("", false, MessageType.TEXT, promptID); + mMessageAdapter.add(mResultMessage); + // Scroll to bottom of the list + mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1); + // After images are added to prompt and chat thread, we clear the imageURI list + // Note: This has to be done after imageURIs are no longer needed by LlamaModule + mSelectedImageUri = null; + promptID++; + Runnable runnable = + new Runnable() { + @Override + public void run() { + Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); + ETLogging.getInstance().log("starting runnable generate()"); + runOnUiThread( + new Runnable() { + @Override + public void run() { + onModelRunStarted(); + } + }); + long generateStartTime = System.currentTimeMillis(); + if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()) + == ModelUtils.VISION_MODEL) { + mModule.generateFromPos( + finalPrompt, + ModelUtils.VISION_MODEL_SEQ_LEN, + startPos, + MainActivity.this, + false); + } else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) { + String llamaGuardPromptForClassification = + PromptFormat.getFormattedLlamaGuardPrompt(rawPrompt); + ETLogging.getInstance() + .log("Running inference.. prompt=" + llamaGuardPromptForClassification); + mModule.generate( + llamaGuardPromptForClassification, + llamaGuardPromptForClassification.length() + 64, + MainActivity.this, + false); + } else { + ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt); + mModule.generate( + finalPrompt, + (int) (finalPrompt.length() * 0.75) + 64, + MainActivity.this, + false); + } + + long generateDuration = System.currentTimeMillis() - generateStartTime; + mResultMessage.setTotalGenerationTime(generateDuration); + runOnUiThread( + new Runnable() { + @Override + public void run() { + onModelRunStopped(); + } + }); + ETLogging.getInstance().log("Inference completed"); + } + }; + executor.execute(runnable); + }); + mMessageAdapter.notifyDataSetChanged(); + } + + @Override + public void run() { + runOnUiThread( + new Runnable() { + @Override + public void run() { + mMessageAdapter.notifyDataSetChanged(); + } + }); + } + + @Override + public void onBackPressed() { + super.onBackPressed(); + if (mAddMediaLayout != null && mAddMediaLayout.getVisibility() == View.VISIBLE) { + mAddMediaLayout.setVisibility(View.GONE); + } else { + // Default behavior of back button + finish(); } + } + + @Override + protected void onDestroy() { + super.onDestroy(); + mMemoryUpdateHandler.removeCallbacks(memoryUpdater); + // This is to cover the case where the app is shutdown when user is on MainActivity but + // never clicked on the logsActivity + ETLogging.getInstance().saveLogs(); + } } diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/Message.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/Message.java index f42c6afb3..031d052fd 100644 --- a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/Message.java +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/Message.java @@ -8,33 +8,87 @@ package org.pytorch.torchchat; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Locale; + public class Message { - private String text; - private boolean isSent; - private float tokensPerSecond; + private String text; + private final boolean isSent; + private float tokensPerSecond; + private long totalGenerationTime; + private final long timestamp; + private final MessageType messageType; + private String imagePath; + private final int promptID; - public Message(String text, boolean isSent) { - this.text = text; - this.isSent = isSent; - } + private static final String TIMESTAMP_FORMAT = "hh:mm a"; // example: 2:23 PM - public String getText() { - return text; - } + public Message(String text, boolean isSent, MessageType messageType, int promptID) { + this.isSent = isSent; + this.messageType = messageType; + this.promptID = promptID; - public void appendText(String text) { - this.text += text; + if (messageType == MessageType.IMAGE) { + this.imagePath = text; + } else { + this.text = text; } - public boolean getIsSent() { - return isSent; + if (messageType != MessageType.SYSTEM) { + this.timestamp = System.currentTimeMillis(); + } else { + this.timestamp = (long) 0; } + } - public void setTokensPerSecond(float tokensPerSecond) { - this.tokensPerSecond = tokensPerSecond; - } + public int getPromptID() { + return promptID; + } - public float getTokensPerSecond() { - return tokensPerSecond; - } + public MessageType getMessageType() { + return messageType; + } + + public String getImagePath() { + return imagePath; + } + + public String getText() { + return text; + } + + public void appendText(String text) { + this.text += text; + } + + public boolean getIsSent() { + return isSent; + } + + public void setTokensPerSecond(float tokensPerSecond) { + this.tokensPerSecond = tokensPerSecond; + } + + public void setTotalGenerationTime(long totalGenerationTime) { + this.totalGenerationTime = totalGenerationTime; + } + + public float getTokensPerSecond() { + return tokensPerSecond; + } + + public long getTotalGenerationTime() { + return totalGenerationTime; + } + + public long getTimestamp() { + return timestamp; + } + + public String getFormattedTimestamp() { + SimpleDateFormat formatter = new SimpleDateFormat(TIMESTAMP_FORMAT, Locale.getDefault()); + Date date = new Date(timestamp); + return formatter.format(date); + } } diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageAdapter.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageAdapter.java index 619e15ab8..bb69c3cd7 100644 --- a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageAdapter.java +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageAdapter.java @@ -1,4 +1,3 @@ - /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. @@ -9,33 +8,128 @@ package org.pytorch.torchchat; +import android.net.Uri; import android.view.LayoutInflater; import android.view.View; import android.view.ViewGroup; import android.widget.ArrayAdapter; +import android.widget.ImageView; import android.widget.TextView; +import java.util.ArrayList; +import java.util.Collections; public class MessageAdapter extends ArrayAdapter { - public MessageAdapter(android.content.Context context, int resource) { - super(context, resource); + + private final ArrayList savedMessages; + + public MessageAdapter( + android.content.Context context, int resource, ArrayList savedMessages) { + super(context, resource); + this.savedMessages = savedMessages; + } + + @Override + public View getView(int position, View convertView, ViewGroup parent) { + Message currentMessage = getItem(position); + int layoutIdForListItem; + + if (currentMessage.getMessageType() == MessageType.SYSTEM) { + layoutIdForListItem = R.layout.system_message; + } else { + layoutIdForListItem = + currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message; + } + View listItemView = + LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false); + if (currentMessage.getMessageType() == MessageType.IMAGE) { + ImageView messageImageView = listItemView.requireViewById(R.id.message_image); + messageImageView.setImageURI(Uri.parse(currentMessage.getImagePath())); + TextView messageTextView = listItemView.requireViewById(R.id.message_text); + messageTextView.setVisibility(View.GONE); + } else { + TextView messageTextView = listItemView.requireViewById(R.id.message_text); + messageTextView.setText(currentMessage.getText()); + } + + String metrics = ""; + TextView tokensView; + if (currentMessage.getTokensPerSecond() > 0) { + metrics = String.format("%.2f", currentMessage.getTokensPerSecond()) + "t/s "; + } + + if (currentMessage.getTotalGenerationTime() > 0) { + metrics = metrics + (float) currentMessage.getTotalGenerationTime() / 1000 + "s "; } - @Override - public View getView(int position, View convertView, ViewGroup parent) { - Message currentMessage = getItem(position); + if (currentMessage.getTokensPerSecond() > 0 || currentMessage.getTotalGenerationTime() > 0) { + tokensView = listItemView.requireViewById(R.id.generation_metrics); + tokensView.setText(metrics); + TextView separatorView = listItemView.requireViewById(R.id.bar); + separatorView.setVisibility(View.VISIBLE); + } + + if (currentMessage.getTimestamp() > 0) { + TextView timestampView = listItemView.requireViewById(R.id.timestamp); + timestampView.setText(currentMessage.getFormattedTimestamp()); + } + + return listItemView; + } + + @Override + public void add(Message msg) { + super.add(msg); + savedMessages.add(msg); + } - int layoutIdForListItem = - currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message; - View listItemView = - LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false); - TextView messageTextView = listItemView.findViewById(R.id.message_text); - messageTextView.setText(currentMessage.getText()); + @Override + public void clear() { + super.clear(); + savedMessages.clear(); + } - if (currentMessage.getTokensPerSecond() > 0) { - TextView tokensView = listItemView.findViewById(R.id.tokens_per_second); - tokensView.setText("" + currentMessage.getTokensPerSecond() + " t/s"); + public ArrayList getSavedMessages() { + return savedMessages; + } + + public ArrayList getRecentSavedTextMessages(int numOfLatestPromptMessages) { + ArrayList recentMessages = new ArrayList(); + int lastIndex = savedMessages.size() - 1; + // In most cases lastIndex >=0 . + // A situation where the user clears chat history and enters prompt. Causes lastIndex=-1 . + if (lastIndex >= 0) { + Message messageToAdd = savedMessages.get(lastIndex); + int oldPromptID = messageToAdd.getPromptID(); + + for (int i = 0; i < savedMessages.size(); i++) { + messageToAdd = savedMessages.get(lastIndex - i); + if (messageToAdd.getMessageType() != MessageType.SYSTEM) { + if (messageToAdd.getPromptID() != oldPromptID) { + numOfLatestPromptMessages--; + oldPromptID = messageToAdd.getPromptID(); + } + if (numOfLatestPromptMessages > 0) { + if (messageToAdd.getMessageType() == MessageType.TEXT) { + recentMessages.add(messageToAdd); + } + } else { + break; + } } + } + // To place the order in [input1, output1, input2, output2...] + Collections.reverse(recentMessages); + } + + return recentMessages; + } + + public int getMaxPromptID() { + int maxPromptID = -1; + for (Message msg : savedMessages) { - return listItemView; + maxPromptID = Math.max(msg.getPromptID(), maxPromptID); } + return maxPromptID; + } } diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageType.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageType.java new file mode 100644 index 000000000..d10871310 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageType.java @@ -0,0 +1,15 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +public enum MessageType { + TEXT, + IMAGE, + SYSTEM +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunner.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunner.java new file mode 100644 index 000000000..ff5b22364 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunner.java @@ -0,0 +1,98 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Looper; +import android.os.Message; +import androidx.annotation.NonNull; +import org.pytorch.executorch.LlamaCallback; +import org.pytorch.executorch.LlamaModule; + +/** A helper class to handle all model running logic within this class. */ +public class ModelRunner implements LlamaCallback { + LlamaModule mModule = null; + + String mModelFilePath = ""; + String mTokenizerFilePath = ""; + + ModelRunnerCallback mCallback = null; + + HandlerThread mHandlerThread = null; + Handler mHandler = null; + + /** + * ] Helper class to separate between UI logic and model runner logic. Automatically handle + * generate() request on worker thread. + * + * @param modelFilePath + * @param tokenizerFilePath + * @param callback + */ + ModelRunner( + String modelFilePath, + String tokenizerFilePath, + float temperature, + ModelRunnerCallback callback) { + mModelFilePath = modelFilePath; + mTokenizerFilePath = tokenizerFilePath; + mCallback = callback; + + mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f); + mHandlerThread = new HandlerThread("ModelRunner"); + mHandlerThread.start(); + mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); + } + + int generate(String prompt) { + Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); + msg.sendToTarget(); + return 0; + } + + void stop() { + mModule.stop(); + } + + @Override + public void onResult(String result) { + mCallback.onTokenGenerated(result); + } + + @Override + public void onStats(float tps) { + mCallback.onStats("tokens/second: " + tps); + } +} + +class ModelRunnerHandler extends Handler { + public static int MESSAGE_LOAD_MODEL = 1; + public static int MESSAGE_GENERATE = 2; + + private final ModelRunner mModelRunner; + + public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { + super(looper); + mModelRunner = modelRunner; + } + + @Override + public void handleMessage(@NonNull android.os.Message msg) { + if (msg.what == MESSAGE_LOAD_MODEL) { + int status = mModelRunner.mModule.load(); + mModelRunner.mCallback.onModelLoaded(status); + } else if (msg.what == MESSAGE_GENERATE) { + mModelRunner.mModule.generate((String) msg.obj, mModelRunner); + mModelRunner.mCallback.onGenerationStopped(); + } + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunnerCallback.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunnerCallback.java new file mode 100644 index 000000000..a04c660f8 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunnerCallback.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +/** + * A helper interface within the app for MainActivity and Benchmarking to handle callback from + * ModelRunner. + */ +public interface ModelRunnerCallback { + + void onModelLoaded(int status); + + void onTokenGenerated(String token); + + void onStats(String token); + + void onGenerationStopped(); +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelType.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelType.java new file mode 100644 index 000000000..307609152 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelType.java @@ -0,0 +1,17 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +public enum ModelType { + LLAMA_3, + LLAMA_3_1, + LLAMA_3_2, + LLAVA_1_5, + LLAMA_GUARD_3, +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelUtils.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelUtils.java new file mode 100644 index 000000000..056180318 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelUtils.java @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +public class ModelUtils { + static final int TEXT_MODEL = 1; + static final int VISION_MODEL = 2; + static final int VISION_MODEL_IMAGE_CHANNELS = 3; + static final int VISION_MODEL_SEQ_LEN = 768; + static final int TEXT_MODEL_SEQ_LEN = 256; + + public static int getModelCategory(ModelType modelType) { + switch (modelType) { + case LLAVA_1_5: + return VISION_MODEL; + case LLAMA_3: + case LLAMA_3_1: + case LLAMA_3_2: + default: + return TEXT_MODEL; + } + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/PromptFormat.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/PromptFormat.java new file mode 100644 index 000000000..2f256783f --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/PromptFormat.java @@ -0,0 +1,121 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +public class PromptFormat { + + public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}"; + public static final String USER_PLACEHOLDER = "{{ user_prompt }}"; + public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}"; + public static final String DEFAULT_SYSTEM_PROMPT = "Answer the questions in a few sentences"; + + public static String getSystemPromptTemplate(ModelType modelType) { + switch (modelType) { + case LLAMA_3: + case LLAMA_3_1: + case LLAMA_3_2: + return "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" + + SYSTEM_PLACEHOLDER + + "<|eot_id|>"; + case LLAVA_1_5: + return "USER: "; + default: + return SYSTEM_PLACEHOLDER; + } + } + + public static String getUserPromptTemplate(ModelType modelType) { + switch (modelType) { + case LLAMA_3: + case LLAMA_3_1: + case LLAMA_3_2: + case LLAMA_GUARD_3: + return "<|start_header_id|>user<|end_header_id|>\n" + + USER_PLACEHOLDER + + "<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>"; + + case LLAVA_1_5: + default: + return USER_PLACEHOLDER; + } + } + + public static String getConversationFormat(ModelType modelType) { + switch (modelType) { + case LLAMA_3: + case LLAMA_3_1: + case LLAMA_3_2: + return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>"; + case LLAVA_1_5: + return USER_PLACEHOLDER + " ASSISTANT:"; + default: + return USER_PLACEHOLDER; + } + } + + public static String getStopToken(ModelType modelType) { + switch (modelType) { + case LLAMA_3: + case LLAMA_3_1: + case LLAMA_3_2: + case LLAMA_GUARD_3: + return "<|eot_id|>"; + case LLAVA_1_5: + return ""; + default: + return ""; + } + } + + public static String getLlavaPresetPrompt() { + return "A chat between a curious human and an artificial intelligence assistant. The assistant" + + " gives helpful, detailed, and polite answers to the human's questions. USER: "; + } + + public static String getFormattedLlamaGuardPrompt(String userPrompt) { + return getUserPromptTemplate(ModelType.LLAMA_GUARD_3) + .replace( + USER_PLACEHOLDER, getLlamaGuardPresetPrompt().replace(USER_PLACEHOLDER, userPrompt)); + } + + public static String getLlamaGuardPresetPrompt() { + return "Task: Check if there is unsafe content in 'User' messages in conversations" + + " according our safety policy with the below categories.\n" + + "\n" + + "\n" + + "S1: Violent Crimes. \n" + + "S2: Non-Violent Crimes. \n" + + "S3: Sex Crimes. \n" + + "S4: Child Exploitation. \n" + + "S5: Defamation. \n" + + "S6: Specialized Advice. \n" + + "S7: Privacy. \n" + + "S8: Intellectual Property. \n" + + "S9: Indiscriminate Weapons. \n" + + "S10: Hate. \n" + + "S11: Self-Harm. \n" + + "S12: Sexual Content. \n" + + "S13: Elections. \n" + + "S14: Code Interpreter Abuse.\n" + + "\n" + + "\n" + + "\n" + + "\n" + + "User: " + + USER_PLACEHOLDER + + "\n" + + "\n" + + "Provide your safety assessment for ONLY THE LAST User message in the above" + + " conversation:\n" + + " - First line must read 'safe' or 'unsafe'.\n" + + " - If unsafe, a second line must include a comma-separated list of violated" + + " categories."; + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsActivity.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsActivity.java new file mode 100644 index 000000000..c040bbd53 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsActivity.java @@ -0,0 +1,395 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +import android.app.AlertDialog; +import android.content.DialogInterface; +import android.os.Build; +import android.os.Bundle; +import android.text.Editable; +import android.text.TextWatcher; +import android.widget.Button; +import android.widget.EditText; +import android.widget.ImageButton; +import android.widget.TextView; +import androidx.appcompat.app.AppCompatActivity; +import androidx.core.content.ContextCompat; +import androidx.core.graphics.Insets; +import androidx.core.view.ViewCompat; +import androidx.core.view.WindowInsetsCompat; +import com.google.gson.Gson; +import java.io.File; +import java.util.ArrayList; +import java.util.List; + +public class SettingsActivity extends AppCompatActivity { + + private String mModelFilePath = ""; + private String mTokenizerFilePath = ""; + private TextView mModelTextView; + private TextView mTokenizerTextView; + private TextView mModelTypeTextView; + private EditText mSystemPromptEditText; + private EditText mUserPromptEditText; + private Button mLoadModelButton; + private double mSetTemperature; + private String mSystemPrompt; + private String mUserPrompt; + private ModelType mModelType; + public SettingsFields mSettingsFields; + + private DemoSharedPreferences mDemoSharedPreferences; + public static double TEMPERATURE_MIN_VALUE = 0.0; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_settings); + if (Build.VERSION.SDK_INT >= 21) { + getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar)); + getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar)); + } + ViewCompat.setOnApplyWindowInsetsListener( + requireViewById(R.id.main), + (v, insets) -> { + Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars()); + v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom); + return insets; + }); + mDemoSharedPreferences = new DemoSharedPreferences(getBaseContext()); + mSettingsFields = new SettingsFields(); + setupSettings(); + } + + private void setupSettings() { + mModelTextView = requireViewById(R.id.modelTextView); + mTokenizerTextView = requireViewById(R.id.tokenizerTextView); + mModelTypeTextView = requireViewById(R.id.modelTypeTextView); + ImageButton modelImageButton = requireViewById(R.id.modelImageButton); + ImageButton tokenizerImageButton = requireViewById(R.id.tokenizerImageButton); + ImageButton modelTypeImageButton = requireViewById(R.id.modelTypeImageButton); + mSystemPromptEditText = requireViewById(R.id.systemPromptText); + mUserPromptEditText = requireViewById(R.id.userPromptText); + loadSettings(); + + // TODO: The two setOnClickListeners will be removed after file path issue is resolved + modelImageButton.setOnClickListener( + view -> { + setupModelSelectorDialog(); + }); + tokenizerImageButton.setOnClickListener( + view -> { + setupTokenizerSelectorDialog(); + }); + modelTypeImageButton.setOnClickListener( + view -> { + setupModelTypeSelectorDialog(); + }); + mModelFilePath = mSettingsFields.getModelFilePath(); + if (!mModelFilePath.isEmpty()) { + mModelTextView.setText(getFilenameFromPath(mModelFilePath)); + } + mTokenizerFilePath = mSettingsFields.getTokenizerFilePath(); + if (!mTokenizerFilePath.isEmpty()) { + mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath)); + } + mModelType = mSettingsFields.getModelType(); + ETLogging.getInstance().log("mModelType from settings " + mModelType); + if (mModelType != null) { + mModelTypeTextView.setText(mModelType.toString()); + } + + setupParameterSettings(); + setupPromptSettings(); + setupClearChatHistoryButton(); + setupLoadModelButton(); + } + + private void setupLoadModelButton() { + mLoadModelButton = requireViewById(R.id.loadModelButton); + mLoadModelButton.setEnabled(true); + mLoadModelButton.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Load Model") + .setMessage("Do you really want to load the new model?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + mSettingsFields.saveLoadModelAction(true); + mLoadModelButton.setEnabled(false); + onBackPressed(); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private void setupClearChatHistoryButton() { + Button clearChatButton = requireViewById(R.id.clearChatButton); + clearChatButton.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Delete Chat History") + .setMessage("Do you really want to delete chat history?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + mSettingsFields.saveIsClearChatHistory(true); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private void setupParameterSettings() { + setupTemperatureSettings(); + } + + private void setupTemperatureSettings() { + mSetTemperature = mSettingsFields.getTemperature(); + EditText temperatureEditText = requireViewById(R.id.temperatureEditText); + temperatureEditText.setText(String.valueOf(mSetTemperature)); + temperatureEditText.addTextChangedListener( + new TextWatcher() { + @Override + public void beforeTextChanged(CharSequence s, int start, int count, int after) {} + + @Override + public void onTextChanged(CharSequence s, int start, int before, int count) {} + + @Override + public void afterTextChanged(Editable s) { + mSetTemperature = Double.parseDouble(s.toString()); + // This is needed because temperature is changed together with model loading + // Once temperature is no longer in LlamaModule constructor, we can remove this + mSettingsFields.saveLoadModelAction(true); + saveSettings(); + } + }); + } + + private void setupPromptSettings() { + setupSystemPromptSettings(); + setupUserPromptSettings(); + } + + private void setupSystemPromptSettings() { + mSystemPrompt = mSettingsFields.getSystemPrompt(); + mSystemPromptEditText.setText(mSystemPrompt); + mSystemPromptEditText.addTextChangedListener( + new TextWatcher() { + @Override + public void beforeTextChanged(CharSequence s, int start, int count, int after) {} + + @Override + public void onTextChanged(CharSequence s, int start, int before, int count) {} + + @Override + public void afterTextChanged(Editable s) { + mSystemPrompt = s.toString(); + } + }); + + ImageButton resetSystemPrompt = requireViewById(R.id.resetSystemPrompt); + resetSystemPrompt.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Reset System Prompt") + .setMessage("Do you really want to reset system prompt?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + // Clear the messageAdapter and sharedPreference + mSystemPromptEditText.setText(PromptFormat.DEFAULT_SYSTEM_PROMPT); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private void setupUserPromptSettings() { + mUserPrompt = mSettingsFields.getUserPrompt(); + mUserPromptEditText.setText(mUserPrompt); + mUserPromptEditText.addTextChangedListener( + new TextWatcher() { + @Override + public void beforeTextChanged(CharSequence s, int start, int count, int after) {} + + @Override + public void onTextChanged(CharSequence s, int start, int before, int count) {} + + @Override + public void afterTextChanged(Editable s) { + if (isValidUserPrompt(s.toString())) { + mUserPrompt = s.toString(); + } else { + showInvalidPromptDialog(); + } + } + }); + + ImageButton resetUserPrompt = requireViewById(R.id.resetUserPrompt); + resetUserPrompt.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Reset Prompt Template") + .setMessage("Do you really want to reset the prompt template?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + // Clear the messageAdapter and sharedPreference + mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType)); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private boolean isValidUserPrompt(String userPrompt) { + return userPrompt.contains(PromptFormat.USER_PLACEHOLDER); + } + + private void showInvalidPromptDialog() { + new AlertDialog.Builder(this) + .setTitle("Invalid Prompt Format") + .setMessage( + "Prompt format must contain " + + PromptFormat.USER_PLACEHOLDER + + ". Do you want to reset prompt format?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + (dialog, whichButton) -> { + mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType)); + }) + .setNegativeButton(android.R.string.no, null) + .show(); + } + + private void setupModelSelectorDialog() { + String[] pteFiles = listLocalFile("/data/local/tmp/llama/", ".pte"); + AlertDialog.Builder modelPathBuilder = new AlertDialog.Builder(this); + modelPathBuilder.setTitle("Select model path"); + + modelPathBuilder.setSingleChoiceItems( + pteFiles, + -1, + (dialog, item) -> { + mModelFilePath = pteFiles[item]; + mModelTextView.setText(getFilenameFromPath(mModelFilePath)); + mLoadModelButton.setEnabled(true); + dialog.dismiss(); + }); + + modelPathBuilder.create().show(); + } + + private static String[] listLocalFile(String path, String suffix) { + File directory = new File(path); + if (directory.exists() && directory.isDirectory()) { + File[] files = directory.listFiles((dir, name) -> name.toLowerCase().endsWith(suffix)); + String[] result = new String[files.length]; + for (int i = 0; i < files.length; i++) { + if (files[i].isFile() && files[i].getName().endsWith(suffix)) { + result[i] = files[i].getAbsolutePath(); + } + } + return result; + } + return new String[] {}; + } + + private void setupModelTypeSelectorDialog() { + // Convert enum to list + List modelTypesList = new ArrayList<>(); + for (ModelType modelType : ModelType.values()) { + modelTypesList.add(modelType.toString()); + } + // Alert dialog builder takes in arr of string instead of list + String[] modelTypes = modelTypesList.toArray(new String[0]); + AlertDialog.Builder modelTypeBuilder = new AlertDialog.Builder(this); + modelTypeBuilder.setTitle("Select model type"); + modelTypeBuilder.setSingleChoiceItems( + modelTypes, + -1, + (dialog, item) -> { + mModelTypeTextView.setText(modelTypes[item]); + mModelType = ModelType.valueOf(modelTypes[item]); + mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType)); + dialog.dismiss(); + }); + + modelTypeBuilder.create().show(); + } + + private void setupTokenizerSelectorDialog() { + String[] binFiles = listLocalFile("/data/local/tmp/llama/", ".bin"); + String[] modelFiles = listLocalFile("/data/local/tmp/llama/", ".model"); + String[] tokenizerFiles = new String[binFiles.length + modelFiles.length]; + System.arraycopy(binFiles, 0, tokenizerFiles, 0, binFiles.length); + System.arraycopy(modelFiles, 0, tokenizerFiles, binFiles.length, modelFiles.length); + AlertDialog.Builder tokenizerPathBuilder = new AlertDialog.Builder(this); + tokenizerPathBuilder.setTitle("Select tokenizer path"); + tokenizerPathBuilder.setSingleChoiceItems( + tokenizerFiles, + -1, + (dialog, item) -> { + mTokenizerFilePath = tokenizerFiles[item]; + mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath)); + mLoadModelButton.setEnabled(true); + dialog.dismiss(); + }); + + tokenizerPathBuilder.create().show(); + } + + private String getFilenameFromPath(String uriFilePath) { + String[] segments = uriFilePath.split("/"); + if (segments.length > 0) { + return segments[segments.length - 1]; // get last element (aka filename) + } + return ""; + } + + private void loadSettings() { + Gson gson = new Gson(); + String settingsFieldsJSON = mDemoSharedPreferences.getSettings(); + if (!settingsFieldsJSON.isEmpty()) { + mSettingsFields = gson.fromJson(settingsFieldsJSON, SettingsFields.class); + } + } + + private void saveSettings() { + mSettingsFields.saveModelPath(mModelFilePath); + mSettingsFields.saveTokenizerPath(mTokenizerFilePath); + mSettingsFields.saveParameters(mSetTemperature); + mSettingsFields.savePrompts(mSystemPrompt, mUserPrompt); + mSettingsFields.saveModelType(mModelType); + mDemoSharedPreferences.addSettings(mSettingsFields); + } + + @Override + public void onBackPressed() { + super.onBackPressed(); + saveSettings(); + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsFields.java b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsFields.java new file mode 100644 index 000000000..eb054cb9a --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsFields.java @@ -0,0 +1,131 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.torchchat; + +public class SettingsFields { + + public String getModelFilePath() { + return modelFilePath; + } + + public String getTokenizerFilePath() { + return tokenizerFilePath; + } + + public double getTemperature() { + return temperature; + } + + public String getSystemPrompt() { + return systemPrompt; + } + + public ModelType getModelType() { + return modelType; + } + + public String getUserPrompt() { + return userPrompt; + } + + public String getFormattedSystemAndUserPrompt(String prompt) { + return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt); + } + + public String getFormattedSystemPrompt() { + return PromptFormat.getSystemPromptTemplate(modelType) + .replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt); + } + + public String getFormattedUserPrompt(String prompt) { + return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt); + } + + public boolean getIsClearChatHistory() { + return isClearChatHistory; + } + + public boolean getIsLoadModel() { + return isLoadModel; + } + + private String modelFilePath; + private String tokenizerFilePath; + private double temperature; + private String systemPrompt; + private String userPrompt; + private boolean isClearChatHistory; + private boolean isLoadModel; + private ModelType modelType; + + public SettingsFields() { + ModelType DEFAULT_MODEL = ModelType.LLAMA_3; + + modelFilePath = ""; + tokenizerFilePath = ""; + temperature = SettingsActivity.TEMPERATURE_MIN_VALUE; + systemPrompt = ""; + userPrompt = PromptFormat.getUserPromptTemplate(DEFAULT_MODEL); + isClearChatHistory = false; + isLoadModel = false; + modelType = DEFAULT_MODEL; + } + + public SettingsFields(SettingsFields settingsFields) { + this.modelFilePath = settingsFields.modelFilePath; + this.tokenizerFilePath = settingsFields.tokenizerFilePath; + this.temperature = settingsFields.temperature; + this.systemPrompt = settingsFields.getSystemPrompt(); + this.userPrompt = settingsFields.getUserPrompt(); + this.isClearChatHistory = settingsFields.getIsClearChatHistory(); + this.isLoadModel = settingsFields.getIsLoadModel(); + this.modelType = settingsFields.modelType; + } + + public void saveModelPath(String modelFilePath) { + this.modelFilePath = modelFilePath; + } + + public void saveTokenizerPath(String tokenizerFilePath) { + this.tokenizerFilePath = tokenizerFilePath; + } + + public void saveModelType(ModelType modelType) { + this.modelType = modelType; + } + + public void saveParameters(Double temperature) { + this.temperature = temperature; + } + + public void savePrompts(String systemPrompt, String userPrompt) { + this.systemPrompt = systemPrompt; + this.userPrompt = userPrompt; + } + + public void saveIsClearChatHistory(boolean needToClear) { + this.isClearChatHistory = needToClear; + } + + public void saveLoadModelAction(boolean shouldLoadModel) { + this.isLoadModel = shouldLoadModel; + } + + public boolean equals(SettingsFields anotherSettingsFields) { + if (this == anotherSettingsFields) return true; + return modelFilePath.equals(anotherSettingsFields.modelFilePath) + && tokenizerFilePath.equals(anotherSettingsFields.tokenizerFilePath) + && temperature == anotherSettingsFields.temperature + && systemPrompt.equals(anotherSettingsFields.systemPrompt) + && userPrompt.equals(anotherSettingsFields.userPrompt) + && isClearChatHistory == anotherSettingsFields.isClearChatHistory + && isLoadModel == anotherSettingsFields.isLoadModel + && modelType == anotherSettingsFields.modelType; + } +} diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/banner_shape.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/banner_shape.xml new file mode 100644 index 000000000..0868ffffa --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/banner_shape.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_24.xml new file mode 100644 index 000000000..2ae27b840 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml new file mode 100644 index 000000000..7077fedd4 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_article_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_article_24.xml new file mode 100644 index 000000000..a6837b9c6 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_article_24.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_close_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_close_24.xml new file mode 100644 index 000000000..fb902d433 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_close_24.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_delete_forever_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_delete_forever_24.xml new file mode 100644 index 000000000..4680bc662 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_delete_forever_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_restart_alt_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_restart_alt_24.xml new file mode 100644 index 000000000..860470ab1 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_restart_alt_24.xml @@ -0,0 +1,6 @@ + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_send_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_send_24.xml new file mode 100644 index 000000000..2de1f6420 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_send_24.xml @@ -0,0 +1,6 @@ + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_settings_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_settings_24.xml new file mode 100644 index 000000000..c51d84b9f --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_settings_24.xml @@ -0,0 +1,11 @@ + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_stop_24.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_stop_24.xml new file mode 100644 index 000000000..832e25859 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_stop_24.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/btn.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/btn.xml new file mode 100644 index 000000000..ceb3ac56c --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/btn.xml @@ -0,0 +1,8 @@ + + + + + + + \ No newline at end of file diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/chat_background.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/chat_background.xml new file mode 100644 index 000000000..eb8b9d1f1 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/chat_background.xml @@ -0,0 +1,21 @@ + + + + + + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/custom_button_round.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/custom_button_round.xml new file mode 100644 index 000000000..87c82d2a3 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/custom_button_round.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/expand_circle_down.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/expand_circle_down.xml new file mode 100644 index 000000000..0a7a71f07 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/expand_circle_down.xml @@ -0,0 +1,9 @@ + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/input_text_shape.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/input_text_shape.xml new file mode 100644 index 000000000..35c778a43 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/input_text_shape.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/logo.png b/torchchat/edge/android/torchchat/app/src/main/res/drawable/logo.png new file mode 100644 index 000000000..60e3e5174 Binary files /dev/null and b/torchchat/edge/android/torchchat/app/src/main/res/drawable/logo.png differ diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_add_box_48.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_add_box_48.xml new file mode 100644 index 000000000..bb45d63d8 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_add_box_48.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_camera_alt_48.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_camera_alt_48.xml new file mode 100644 index 000000000..c7b4b2e4a --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_camera_alt_48.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_image_48.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_image_48.xml new file mode 100644 index 000000000..a8bb4b2f6 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_image_48.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/prompt_shape.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/prompt_shape.xml new file mode 100644 index 000000000..5f81396e3 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/prompt_shape.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/torchchat/edge/android/torchchat/app/src/main/res/drawable/received_message.xml b/torchchat/edge/android/torchchat/app/src/main/res/drawable/received_message.xml index ea2d1bbfa..c2288b5bf 100644 --- a/torchchat/edge/android/torchchat/app/src/main/res/drawable/received_message.xml +++ b/torchchat/edge/android/torchchat/app/src/main/res/drawable/received_message.xml @@ -1,6 +1,6 @@ - + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_benchmarking.xml b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_benchmarking.xml new file mode 100644 index 000000000..6e48b5de8 --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_benchmarking.xml @@ -0,0 +1,16 @@ + + + + + + diff --git a/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_logs.xml b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_logs.xml new file mode 100644 index 000000000..b327a544f --- /dev/null +++ b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_logs.xml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_main.xml b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_main.xml index 089acb572..7b8b8d176 100644 --- a/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_main.xml +++ b/torchchat/edge/android/torchchat/app/src/main/res/layout/activity_main.xml @@ -1,44 +1,233 @@ - - + + + + + + + - + + + + + + -