diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index e9f32a927cc..ac14270ed51 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -19,6 +19,7 @@ import android.os.Bundle; import android.os.Handler; import android.os.Looper; +import android.os.Process; import android.provider.MediaStore; import android.system.ErrnoException; import android.system.Os; @@ -44,6 +45,8 @@ import java.lang.reflect.Type; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import org.pytorch.executorch.LlamaCallback; import org.pytorch.executorch.LlamaModule; @@ -71,15 +74,16 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa private Handler mMemoryUpdateHandler; private Runnable memoryUpdater; private int promptID = 0; - + private long startPos = 0; private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2; + private Executor executor; @Override public void onResult(String result) { if (result.equals(PromptFormat.getStopToken(mCurrentSettingsFields.getModelType()))) { return; } - if (result.equals("\n\n")) { + if (result.equals("\n\n") || result.equals("\n")) { if (!mResultMessage.getText().isEmpty()) { mResultMessage.appendText(result); run(); @@ -150,6 +154,12 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera + (float) loadDuration / 1000 + " sec." + " You can send text or image for inference"; + + if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { + ETLogging.getInstance().log("Llava start prefill prompt"); + startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt(), 0, 1, 0); + ETLogging.getInstance().log("Llava completes prefill prompt"); + } } Message modelLoadedMessage = new Message(modelInfo, false, MessageType.SYSTEM, 0); @@ -241,6 +251,7 @@ protected void onCreate(Bundle savedInstanceState) { setupCameraRoll(); startMemoryUpdate(); setupShowLogsButton(); + executor = Executors.newSingleThreadExecutor(); } @Override @@ -546,6 +557,32 @@ private void showMediaPreview(List uris) { imageViews.get(i).setVisibility(View.VISIBLE); imageViews.get(i).setImageURI(mSelectedImageUri.get(i)); } + + // For LLava, we want to call prefill_image as soon as an image is selected + // Llava only support 1 image for now + if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { + List processedImageList = getProcessedImagesForModel(mSelectedImageUri); + if (!processedImageList.isEmpty()) { + mMessageAdapter.add( + new Message("Llava - Starting image Prefill.", false, MessageType.SYSTEM, 0)); + mMessageAdapter.notifyDataSetChanged(); + Runnable runnable = + () -> { + Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); + ETLogging.getInstance().log("Starting runnable prefill image"); + ETImage img = processedImageList.get(0); + ETLogging.getInstance().log("Llava start prefill image"); + startPos = + mModule.prefillImages( + img.getInts(), + img.getWidth(), + img.getHeight(), + ModelUtils.VISION_MODEL_IMAGE_CHANNELS, + startPos); + }; + executor.execute(runnable); + } + } } private void addSelectedImagesToChatThread(List selectedImageUri) { @@ -618,24 +655,6 @@ private void onModelRunStopped() { mSendButton.setOnClickListener( view -> { addSelectedImagesToChatThread(mSelectedImageUri); - List processedImageList = getProcessedImagesForModel(mSelectedImageUri); - processedImageList.forEach( - image -> { - ETLogging.getInstance() - .log( - "Image preprocessed:" - + " uri = " - + image.getUri().getLastPathSegment() - + "," - + " width = " - + image.getWidth() - + "," - + " height = " - + image.getHeight() - + "," - + " bytes size = " - + image.getBytes().length); - }); String rawPrompt = mEditTextMessage.getText().toString(); // We store raw prompt into message adapter, because we don't want to show the extra // tokens from system prompt @@ -654,6 +673,8 @@ private void onModelRunStopped() { new Runnable() { @Override public void run() { + Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); + ETLogging.getInstance().log("starting runnable generate()"); runOnUiThread( new Runnable() { @Override @@ -664,31 +685,12 @@ public void run() { long generateStartTime = System.currentTimeMillis(); if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()) == ModelUtils.VISION_MODEL) { - ETLogging.getInstance().log("Running inference.. prompt=" + rawPrompt); - if (!processedImageList.isEmpty()) { - // For now, Llava only support 1 image. - ETImage img = processedImageList.get(0); - mModule.generate( - processedImageList.get(0).getInts(), - img.getWidth(), - img.getHeight(), - ModelUtils.VISION_MODEL_IMAGE_CHANNELS, - rawPrompt, - ModelUtils.VISION_MODEL_SEQ_LEN, - MainActivity.this, - false); - } else { - // no image selected, we pass in empty int array - mModule.generate( - new int[0], - 0, - 0, - ModelUtils.VISION_MODEL_IMAGE_CHANNELS, - rawPrompt, - ModelUtils.VISION_MODEL_SEQ_LEN, - MainActivity.this, - false); - } + mModule.generateFromPos( + mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt), + ModelUtils.VISION_MODEL_SEQ_LEN, + startPos, + MainActivity.this, + false); } else { String finalPrompt = getTotalFormattedPrompt(getConversationHistory(), rawPrompt); @@ -712,7 +714,7 @@ public void run() { ETLogging.getInstance().log("Inference completed"); } }; - new Thread(runnable).start(); + executor.execute(runnable); }); mMessageAdapter.notifyDataSetChanged(); } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java index 4b450553236..640d3782128 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java @@ -66,4 +66,9 @@ public static String getStopToken(ModelType modelType) { return ""; } } + + public static String getLlavaPresetPrompt() { + return "A chat between a curious human and an artificial intelligence assistant. The assistant" + + " gives helpful, detailed, and polite answers to the human's questions. USER: "; + } }