diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 7ed9c9ec979..308f5fac50a 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -70,6 +70,9 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa private SettingsFields mCurrentSettingsFields; private Handler mMemoryUpdateHandler; private Runnable memoryUpdater; + private int promptID = 0; + + private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2; @Override public void onResult(String result) { @@ -195,6 +198,11 @@ private void populateExistingMessages(String existingMsgJSON) { mMessageAdapter.notifyDataSetChanged(); } + private int setPromptID() { + + return mMessageAdapter.getMaxPromptID() + 1; + } + @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); @@ -216,6 +224,7 @@ protected void onCreate(Bundle savedInstanceState) { String existingMsgJSON = mDemoSharedPreferences.getSavedMessages(); if (!existingMsgJSON.isEmpty()) { populateExistingMessages(existingMsgJSON); + promptID = setPromptID(); } mSettingsButton = requireViewById(R.id.settings); mSettingsButton.setOnClickListener( @@ -552,6 +561,48 @@ private void addSelectedImagesToChatThread(List selectedImageUri) { mMessageAdapter.notifyDataSetChanged(); } + private String getConversationHistory() { + String conversationHistory = ""; + + ArrayList conversations = + mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK); + if (conversations.isEmpty()) { + return conversationHistory; + } + + int prevPromptID = conversations.get(0).getPromptID(); + String conversationFormat = + PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType()); + String format = conversationFormat; + for (int i = 0; i < conversations.size(); i++) { + Message conversation = conversations.get(i); + int currentPromptID = conversation.getPromptID(); + if (currentPromptID != prevPromptID) { + conversationHistory = conversationHistory + format; + format = conversationFormat; + prevPromptID = currentPromptID; + } + if (conversation.getIsSent()) { + format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText()); + } else { + format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText()); + } + } + conversationHistory = conversationHistory + format; + + return conversationHistory; + } + + private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) { + if (conversationHistory.isEmpty()) { + return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt); + } + + return mCurrentSettingsFields.getFormattedSystemPrompt() + + conversationHistory + + mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt); + } + private void onModelRunStarted() { mSendButton.setClickable(false); mSendButton.setImageResource(R.drawable.baseline_stop_24); @@ -586,19 +637,19 @@ private void onModelRunStopped() { + image.getBytes().length); }); String rawPrompt = mEditTextMessage.getText().toString(); - String prompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt); // We store raw prompt into message adapter, because we don't want to show the extra // tokens from system prompt - mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, 0)); + mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID)); mMessageAdapter.notifyDataSetChanged(); mEditTextMessage.setText(""); - mResultMessage = new Message("", false, MessageType.TEXT, 0); + mResultMessage = new Message("", false, MessageType.TEXT, promptID); mMessageAdapter.add(mResultMessage); // Scroll to bottom of the list mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1); // After images are added to prompt and chat thread, we clear the imageURI list // Note: This has to be done after imageURIs are no longer needed by LlamaModule mSelectedImageUri = null; + promptID++; Runnable runnable = new Runnable() { @Override @@ -610,10 +661,10 @@ public void run() { onModelRunStarted(); } }); - ETLogging.getInstance().log("Running inference.. prompt=" + prompt); long generateStartTime = System.currentTimeMillis(); if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()) == ModelUtils.VISION_MODEL) { + ETLogging.getInstance().log("Running inference.. prompt=" + rawPrompt); if (!processedImageList.isEmpty()) { // For now, Llava only support 1 image. ETImage img = processedImageList.get(0); @@ -622,7 +673,7 @@ public void run() { img.getWidth(), img.getHeight(), ModelUtils.VISION_MODEL_IMAGE_CHANNELS, - prompt, + rawPrompt, ModelUtils.VISION_MODEL_SEQ_LEN, false, MainActivity.this); @@ -633,14 +684,20 @@ public void run() { 0, 0, ModelUtils.VISION_MODEL_IMAGE_CHANNELS, - prompt, + rawPrompt, ModelUtils.VISION_MODEL_SEQ_LEN, false, MainActivity.this); } } else { + String finalPrompt = + getTotalFormattedPrompt(getConversationHistory(), rawPrompt); + ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt); mModule.generate( - prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, false, MainActivity.this); + finalPrompt, + (int) (finalPrompt.length() * 0.75) + 64, + false, + MainActivity.this); } long generateDuration = System.currentTimeMillis() - generateStartTime; diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java index d9cbd95a1a7..2538c852e48 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java @@ -16,6 +16,7 @@ import android.widget.ImageView; import android.widget.TextView; import java.util.ArrayList; +import java.util.Collections; public class MessageAdapter extends ArrayAdapter { @@ -90,4 +91,41 @@ public void clear() { public ArrayList getSavedMessages() { return savedMessages; } + + public ArrayList getRecentSavedTextMessages(int numOfLatestPromptMessages) { + ArrayList recentMessages = new ArrayList(); + int lastIndex = savedMessages.size() - 1; + Message messageToAdd = savedMessages.get(lastIndex); + int oldPromptID = messageToAdd.getPromptID(); + + for (int i = 0; i < savedMessages.size(); i++) { + messageToAdd = savedMessages.get(lastIndex - i); + if (messageToAdd.getMessageType() != MessageType.SYSTEM) { + if (messageToAdd.getPromptID() != oldPromptID) { + numOfLatestPromptMessages--; + oldPromptID = messageToAdd.getPromptID(); + } + if (numOfLatestPromptMessages > 0) { + if (messageToAdd.getMessageType() == MessageType.TEXT) { + recentMessages.add(messageToAdd); + } + } else { + break; + } + } + } + + // To place the order in [input1, output1, input2, output2...] + Collections.reverse(recentMessages); + return recentMessages; + } + + public int getMaxPromptID() { + int maxPromptID = -1; + for (Message msg : savedMessages) { + + maxPromptID = Math.max(msg.getPromptID(), maxPromptID); + } + return maxPromptID; + } } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java index 7342b4ab00c..4b450553236 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java @@ -12,6 +12,7 @@ public class PromptFormat { public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}"; public static final String USER_PLACEHOLDER = "{{ user_prompt }}"; + public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}"; public static String getSystemPromptTemplate(ModelType modelType) { switch (modelType) { @@ -33,8 +34,20 @@ public static String getUserPromptTemplate(ModelType modelType) { case LLAMA_3_1: return "<|start_header_id|>user<|end_header_id|>\n" + USER_PLACEHOLDER - + "<|eot_id|>\n" + + "<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>"; + + case LLAVA_1_5: + default: + return USER_PLACEHOLDER; + } + } + + public static String getConversationFormat(ModelType modelType) { + switch (modelType) { + case LLAMA_3: + case LLAMA_3_1: + return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>"; case LLAVA_1_5: return USER_PLACEHOLDER + " ASSISTANT:"; default: diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java index 466d3303e28..b71799981b2 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java @@ -38,12 +38,12 @@ public String getFormattedSystemAndUserPrompt(String prompt) { return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt); } - private String getFormattedSystemPrompt() { + public String getFormattedSystemPrompt() { return PromptFormat.getSystemPromptTemplate(modelType) .replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt); } - private String getFormattedUserPrompt(String prompt) { + public String getFormattedUserPrompt(String prompt) { return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt); }