From 45636f0da13187eefdbe8051bed5be2eed6c96ff Mon Sep 17 00:00:00 2001 From: Riandy Riandy Date: Wed, 4 Sep 2024 14:49:33 -0700 Subject: [PATCH] Unified Android aar support for llava and llama models (#5086) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5086 - Previously, we need two separate aar for vision and text models. Since ET core runner side has a combined aar built, I am making changes on the Android app side to support this behavior. - Introducing ModelUtils class so we can get the correct model category to be passed on to generate() - Seq_len is now an exposed parameters, defaulting to 128. For llava models, 128 is not enough, hence we are changing it to 768 when calling generate() - Minor bug fix on ETImage logic. Reviewed By: cmodi-meta, kirklandsign Differential Revision: D61406255 --- .../example/executorchllamademo/ETImage.java | 16 ++++++-- .../executorchllamademo/MainActivity.java | 38 +++++++++++++++++-- .../executorchllamademo/ModelUtils.java | 28 ++++++++++++++ .../executorchllamademo/PromptFormat.java | 2 + 4 files changed, 77 insertions(+), 7 deletions(-) create mode 100644 examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java index cf3c3e5f0a5..e68c8472626 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java @@ -46,6 +46,16 @@ public byte[] getBytes() { return bytes; } + public int[] getInts() { + // We need to convert the byte array to an int array because + // the runner expects an int array as input. + int[] intArray = new int[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + intArray[i] = (bytes[i++] & 0xFF); + } + return intArray; + } + private byte[] getBytesFromImageURI(Uri uri) { try { int RESIZED_IMAGE_WIDTH = 336; @@ -72,9 +82,9 @@ private byte[] getBytesFromImageURI(Uri uri) { int blue = Color.blue(color); // Store the RGB values in the byte array - rgbValues[(y * width + x) * 3] = (byte) red; - rgbValues[(y * width + x) * 3 + 1] = (byte) green; - rgbValues[(y * width + x) * 3 + 2] = (byte) blue; + rgbValues[y * width + x] = (byte) red; + rgbValues[(y * width + x) + height * width] = (byte) green; + rgbValues[(y * width + x) + 2 * height * width] = (byte) blue; } } return rgbValues; diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 70936e17d84..f24254efb31 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -102,7 +102,12 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera mMessageAdapter.notifyDataSetChanged(); }); long runStartTime = System.currentTimeMillis(); - mModule = new LlamaModule(modelPath, tokenizerPath, temperature); + mModule = + new LlamaModule( + ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()), + modelPath, + tokenizerPath, + temperature); int loadResult = mModule.load(); long loadDuration = System.currentTimeMillis() - runStartTime; String modelLoadError = ""; @@ -552,8 +557,6 @@ private void onModelRunStopped() { mSendButton.setOnClickListener( view -> { addSelectedImagesToChatThread(mSelectedImageUri); - // TODO: When ET supports multimodal, this is where we will add the images as part of the - // prompt. List processedImageList = getProcessedImagesForModel(mSelectedImageUri); processedImageList.forEach( image -> { @@ -599,7 +602,34 @@ public void run() { }); ETLogging.getInstance().log("Running inference.. prompt=" + prompt); long generateStartTime = System.currentTimeMillis(); - mModule.generate(prompt, MainActivity.this); + if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()) + == ModelUtils.VISION_MODEL) { + if (!processedImageList.isEmpty()) { + // For now, Llava only support 1 image. + ETImage img = processedImageList.get(0); + mModule.generate( + processedImageList.get(0).getInts(), + img.getWidth(), + img.getHeight(), + ModelUtils.VISION_MODEL_IMAGE_CHANNELS, + prompt, + ModelUtils.VISION_MODEL_SEQ_LEN, + MainActivity.this); + } else { + // no image selected, we pass in empty int array + mModule.generate( + new int[0], + 0, + 0, + ModelUtils.VISION_MODEL_IMAGE_CHANNELS, + prompt, + ModelUtils.VISION_MODEL_SEQ_LEN, + MainActivity.this); + } + } else { + mModule.generate(prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, MainActivity.this); + } + long generateDuration = System.currentTimeMillis() - generateStartTime; mResultMessage.setTotalGenerationTime(generateDuration); runOnUiThread( diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java new file mode 100644 index 00000000000..ab1f1bc92fc --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +public class ModelUtils { + static final int TEXT_MODEL = 1; + static final int VISION_MODEL = 2; + static final int VISION_MODEL_IMAGE_CHANNELS = 3; + static final int VISION_MODEL_SEQ_LEN = 768; + static final int TEXT_MODEL_SEQ_LEN = 256; + + public static int getModelCategory(ModelType modelType) { + switch (modelType) { + case LLAVA_1_5: + return VISION_MODEL; + case LLAMA_3: + case LLAMA_3_1: + default: + return TEXT_MODEL; + } + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java index 72990f4ea8b..a077f4d677f 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java @@ -21,6 +21,7 @@ public static String getSystemPromptTemplate(ModelType modelType) { + SYSTEM_PLACEHOLDER + "<|eot_id|>"; case LLAVA_1_5: + return "USER: "; default: return SYSTEM_PLACEHOLDER; } @@ -35,6 +36,7 @@ public static String getUserPromptTemplate(ModelType modelType) { + "<|eot_id|>\n" + "<|start_header_id|>assistant<|end_header_id|>"; case LLAVA_1_5: + return USER_PLACEHOLDER + " ASSISTANT:"; default: return USER_PLACEHOLDER; }