From 77524172a1ac86cd61d2018eb6580754558eddc3 Mon Sep 17 00:00:00 2001 From: Riandy Riandy Date: Wed, 25 Sep 2024 16:49:43 -0700 Subject: [PATCH] Add llama 3.2 model type on Android (#5646) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5646 Adding 3.2 model type on android app Reviewed By: kirklandsign Differential Revision: D63404583 fbshipit-source-id: 57c79b120b11cd6a814cf56107c76165fdf04f91 (cherry picked from commit dacbba7e14e603d0eb909076fcad7614f0487b00) --- .../main/java/com/example/executorchllamademo/ModelType.java | 1 + .../main/java/com/example/executorchllamademo/ModelUtils.java | 1 + .../java/com/example/executorchllamademo/PromptFormat.java | 4 ++++ 3 files changed, 6 insertions(+) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java index a241ca3d52d..b1074ee2cc6 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java @@ -11,6 +11,7 @@ public enum ModelType { LLAMA_3, LLAMA_3_1, + LLAMA_3_2, LLAVA_1_5, LLAMA_GUARD_3, } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java index ab1f1bc92fc..28e14cdac01 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java @@ -21,6 +21,7 @@ public static int getModelCategory(ModelType modelType) { return VISION_MODEL; case LLAMA_3: case LLAMA_3_1: + case LLAMA_3_2: default: return TEXT_MODEL; } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java index 14cf38e669d..1d794733d27 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java @@ -19,6 +19,7 @@ public static String getSystemPromptTemplate(ModelType modelType) { switch (modelType) { case LLAMA_3: case LLAMA_3_1: + case LLAMA_3_2: return "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" + SYSTEM_PLACEHOLDER + "<|eot_id|>"; @@ -33,6 +34,7 @@ public static String getUserPromptTemplate(ModelType modelType) { switch (modelType) { case LLAMA_3: case LLAMA_3_1: + case LLAMA_3_2: case LLAMA_GUARD_3: return "<|start_header_id|>user<|end_header_id|>\n" + USER_PLACEHOLDER @@ -49,6 +51,7 @@ public static String getConversationFormat(ModelType modelType) { switch (modelType) { case LLAMA_3: case LLAMA_3_1: + case LLAMA_3_2: return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>"; case LLAVA_1_5: return USER_PLACEHOLDER + " ASSISTANT:"; @@ -61,6 +64,7 @@ public static String getStopToken(ModelType modelType) { switch (modelType) { case LLAMA_3: case LLAMA_3_1: + case LLAMA_3_2: case LLAMA_GUARD_3: return "<|eot_id|>"; case LLAVA_1_5: