From 793aaeb2f101c2519efa025aba5cb7b82a8079d5 Mon Sep 17 00:00:00 2001 From: Chirag Modi Date: Mon, 23 Sep 2024 11:50:58 -0700 Subject: [PATCH] Fix duplicating latest prompt (#5546) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5546 The last prompt sent would be included in `getConversationHistory()` + adding it prior to sending it with the generate(). It looks like this got move during the rebasing. To fix this we now call `getConversationHistory()` prior to adding the rawPrompt to a Message. In regards to model response, I noticed that it did not really change the quality of the response. (tested with Llama 3.1) Reviewed By: Riandy Differential Revision: D62761977 --- .../com/example/executorchllamademo/MainActivity.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index f5e50845eca..4d81ec8ae52 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -661,7 +661,14 @@ private void onModelRunStopped() { mSendButton.setOnClickListener( view -> { addSelectedImagesToChatThread(mSelectedImageUri); + String finalPrompt; String rawPrompt = mEditTextMessage.getText().toString(); + if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()) + == ModelUtils.VISION_MODEL) { + finalPrompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt); + } else { + finalPrompt = getTotalFormattedPrompt(getConversationHistory(), rawPrompt); + } // We store raw prompt into message adapter, because we don't want to show the extra // tokens from system prompt mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID)); @@ -692,14 +699,12 @@ public void run() { if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()) == ModelUtils.VISION_MODEL) { mModule.generateFromPos( - mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt), + finalPrompt, ModelUtils.VISION_MODEL_SEQ_LEN, startPos, MainActivity.this, false); } else { - String finalPrompt = - getTotalFormattedPrompt(getConversationHistory(), rawPrompt); ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt); mModule.generate( finalPrompt,