diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java index cf3c3e5f0a5..e68c8472626 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java @@ -46,6 +46,16 @@ public byte[] getBytes() { return bytes; } + public int[] getInts() { + // We need to convert the byte array to an int array because + // the runner expects an int array as input. + int[] intArray = new int[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + intArray[i] = (bytes[i++] & 0xFF); + } + return intArray; + } + private byte[] getBytesFromImageURI(Uri uri) { try { int RESIZED_IMAGE_WIDTH = 336; @@ -72,9 +82,9 @@ private byte[] getBytesFromImageURI(Uri uri) { int blue = Color.blue(color); // Store the RGB values in the byte array - rgbValues[(y * width + x) * 3] = (byte) red; - rgbValues[(y * width + x) * 3 + 1] = (byte) green; - rgbValues[(y * width + x) * 3 + 2] = (byte) blue; + rgbValues[y * width + x] = (byte) red; + rgbValues[(y * width + x) + height * width] = (byte) green; + rgbValues[(y * width + x) + 2 * height * width] = (byte) blue; } } return rgbValues; diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 70936e17d84..f24254efb31 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -102,7 +102,12 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera mMessageAdapter.notifyDataSetChanged(); }); long runStartTime = System.currentTimeMillis(); - mModule = new LlamaModule(modelPath, tokenizerPath, temperature); + mModule = + new LlamaModule( + ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()), + modelPath, + tokenizerPath, + temperature); int loadResult = mModule.load(); long loadDuration = System.currentTimeMillis() - runStartTime; String modelLoadError = ""; @@ -552,8 +557,6 @@ private void onModelRunStopped() { mSendButton.setOnClickListener( view -> { addSelectedImagesToChatThread(mSelectedImageUri); - // TODO: When ET supports multimodal, this is where we will add the images as part of the - // prompt. List processedImageList = getProcessedImagesForModel(mSelectedImageUri); processedImageList.forEach( image -> { @@ -599,7 +602,34 @@ public void run() { }); ETLogging.getInstance().log("Running inference.. prompt=" + prompt); long generateStartTime = System.currentTimeMillis(); - mModule.generate(prompt, MainActivity.this); + if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()) + == ModelUtils.VISION_MODEL) { + if (!processedImageList.isEmpty()) { + // For now, Llava only support 1 image. + ETImage img = processedImageList.get(0); + mModule.generate( + processedImageList.get(0).getInts(), + img.getWidth(), + img.getHeight(), + ModelUtils.VISION_MODEL_IMAGE_CHANNELS, + prompt, + ModelUtils.VISION_MODEL_SEQ_LEN, + MainActivity.this); + } else { + // no image selected, we pass in empty int array + mModule.generate( + new int[0], + 0, + 0, + ModelUtils.VISION_MODEL_IMAGE_CHANNELS, + prompt, + ModelUtils.VISION_MODEL_SEQ_LEN, + MainActivity.this); + } + } else { + mModule.generate(prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, MainActivity.this); + } + long generateDuration = System.currentTimeMillis() - generateStartTime; mResultMessage.setTotalGenerationTime(generateDuration); runOnUiThread( diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java new file mode 100644 index 00000000000..ab1f1bc92fc --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +public class ModelUtils { + static final int TEXT_MODEL = 1; + static final int VISION_MODEL = 2; + static final int VISION_MODEL_IMAGE_CHANNELS = 3; + static final int VISION_MODEL_SEQ_LEN = 768; + static final int TEXT_MODEL_SEQ_LEN = 256; + + public static int getModelCategory(ModelType modelType) { + switch (modelType) { + case LLAVA_1_5: + return VISION_MODEL; + case LLAMA_3: + case LLAMA_3_1: + default: + return TEXT_MODEL; + } + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java index 72990f4ea8b..a077f4d677f 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java @@ -21,6 +21,7 @@ public static String getSystemPromptTemplate(ModelType modelType) { + SYSTEM_PLACEHOLDER + "<|eot_id|>"; case LLAVA_1_5: + return "USER: "; default: return SYSTEM_PLACEHOLDER; } @@ -35,6 +36,7 @@ public static String getUserPromptTemplate(ModelType modelType) { + "<|eot_id|>\n" + "<|start_header_id|>assistant<|end_header_id|>"; case LLAVA_1_5: + return USER_PLACEHOLDER + " ASSISTANT:"; default: return USER_PLACEHOLDER; }