diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java index a241ca3d52d..b1074ee2cc6 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java @@ -11,6 +11,7 @@ public enum ModelType { LLAMA_3, LLAMA_3_1, + LLAMA_3_2, LLAVA_1_5, LLAMA_GUARD_3, } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java index ab1f1bc92fc..28e14cdac01 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java @@ -21,6 +21,7 @@ public static int getModelCategory(ModelType modelType) { return VISION_MODEL; case LLAMA_3: case LLAMA_3_1: + case LLAMA_3_2: default: return TEXT_MODEL; } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java index 14cf38e669d..1d794733d27 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java @@ -19,6 +19,7 @@ public static String getSystemPromptTemplate(ModelType modelType) { switch (modelType) { case LLAMA_3: case LLAMA_3_1: + case LLAMA_3_2: return "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" + SYSTEM_PLACEHOLDER + "<|eot_id|>"; @@ -33,6 +34,7 @@ public static String getUserPromptTemplate(ModelType modelType) { switch (modelType) { case LLAMA_3: case LLAMA_3_1: + case LLAMA_3_2: case LLAMA_GUARD_3: return "<|start_header_id|>user<|end_header_id|>\n" + USER_PLACEHOLDER @@ -49,6 +51,7 @@ public static String getConversationFormat(ModelType modelType) { switch (modelType) { case LLAMA_3: case LLAMA_3_1: + case LLAMA_3_2: return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>"; case LLAVA_1_5: return USER_PLACEHOLDER + " ASSISTANT:"; @@ -61,6 +64,7 @@ public static String getStopToken(ModelType modelType) { switch (modelType) { case LLAMA_3: case LLAMA_3_1: + case LLAMA_3_2: case LLAMA_GUARD_3: return "<|eot_id|>"; case LLAVA_1_5: