From 9e57b34e6ccac0ecb4906a2df6253eca7222ba41 Mon Sep 17 00:00:00 2001 From: raulraja Date: Sat, 5 Aug 2023 01:00:54 +0200 Subject: [PATCH 1/9] Expression Language for LLM driven template replacements --- core/build.gradle.kts | 33 +++++-- .../com/xebia/functional/xef/llm/Chat.kt | 41 +++++--- .../functional/xef/llm/ChatWithFunctions.kt | 24 +++++ .../xef/prompt/expressions/Expression.kt | 98 +++++++++++++++++++ .../prompt/expressions/ExpressionResult.kt | 9 ++ .../xef/prompt/expressions/ReplacedValues.kt | 10 ++ .../xef/prompt/expressions/Replacement.kt | 12 +++ .../jdk21/reasoning/ToolSelectionExample.java | 42 -------- .../xef/java/auto/jdk21/tot/Problems.java | 2 +- .../jdk8/reasoning/ToolSelectionExample.java | 42 -------- .../auto/expressions/WorkoutPlanProgram.kt | 55 +++++++++++ gradle/libs.versions.toml | 4 +- reasoning/build.gradle.kts | 7 -- .../xef/reasoning/tools/ToolSelection.kt | 4 - 14 files changed, 263 insertions(+), 120 deletions(-) create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/ExpressionResult.kt create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/ReplacedValues.kt create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Replacement.kt delete mode 100644 examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/reasoning/ToolSelectionExample.java delete mode 100644 examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk8/reasoning/ToolSelectionExample.java create mode 100644 examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/expressions/WorkoutPlanProgram.kt diff --git a/core/build.gradle.kts b/core/build.gradle.kts index a1bf05ac9..924d68bfa 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -67,7 +67,7 @@ kotlin { api(libs.kotlinx.serialization.json) api(libs.ktor.utils) api(projects.xefTokenizer) - + implementation(libs.bundles.ktor.client) implementation(libs.klogging) implementation(libs.uuid) } @@ -87,10 +87,15 @@ kotlin { implementation(libs.logback) implementation(libs.skrape) implementation(libs.rss.reader) + api(libs.ktor.client.cio) } } - val jsMain by getting + val jsMain by getting { + dependencies { + api(libs.ktor.client.js) + } + } val jvmTest by getting { dependencies { @@ -98,10 +103,26 @@ kotlin { } } - val linuxX64Main by getting - val macosX64Main by getting - val macosArm64Main by getting - val mingwX64Main by getting + val linuxX64Main by getting { + dependencies { + implementation(libs.ktor.client.cio) + } + } + val macosX64Main by getting { + dependencies { + implementation(libs.ktor.client.cio) + } + } + val macosArm64Main by getting { + dependencies { + implementation(libs.ktor.client.cio) + } + } + val mingwX64Main by getting { + dependencies { + implementation(libs.ktor.client.winhttp) + } + } val linuxX64Test by getting val macosX64Test by getting val macosArm64Test by getting diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt index aa84df05f..5347fcaac 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt @@ -138,26 +138,12 @@ interface Chat : LLM { @AiDsl suspend fun promptMessages( - prompt: Prompt, + messages: List, context: VectorStore, conversationId: ConversationId? = null, functions: List = emptyList(), promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS ): List { - - val memories: List = memories(conversationId, context, promptConfiguration) - - val promptWithContext: String = - createPromptWithContextAwareOfTokens( - memories = memories, - ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext), - modelType = modelType, - prompt = prompt.message, - minResponseTokens = promptConfiguration.minResponseTokens - ) - - val messages: List = messages(memories, promptWithContext) - fun checkTotalLeftChatTokens(): Int { val maxContextLength: Int = modelType.maxContextLength val messagesTokens: Int = tokensFromMessages(messages) @@ -217,6 +203,31 @@ interface Chat : LLM { } } + @AiDsl + suspend fun promptMessages( + prompt: Prompt, + context: VectorStore, + conversationId: ConversationId? = null, + functions: List = emptyList(), + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS + ): List { + + val memories: List = memories(conversationId, context, promptConfiguration) + + val promptWithContext: String = + createPromptWithContextAwareOfTokens( + memories = memories, + ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext), + modelType = modelType, + prompt = prompt.message, + minResponseTokens = promptConfiguration.minResponseTokens + ) + + val messages: List = messages(memories, promptWithContext) + + return promptMessages(messages, context, conversationId, functions, promptConfiguration) + } + private suspend fun List.addChoiceWithFunctionsToMemory( request: ChatCompletionRequestWithFunctions, context: VectorStore, diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt index 160a42b6e..7284d053f 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt @@ -7,6 +7,7 @@ import com.xebia.functional.xef.auto.AiDsl import com.xebia.functional.xef.auto.PromptConfiguration import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequestWithFunctions import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponseWithFunctions +import com.xebia.functional.xef.llm.models.chat.Message import com.xebia.functional.xef.llm.models.functions.CFunction import com.xebia.functional.xef.llm.models.functions.encodeJsonSchema import com.xebia.functional.xef.prompt.Prompt @@ -45,6 +46,29 @@ interface ChatWithFunctions : Chat { promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, ): A = prompt(prompt, context, conversationId, functions, serializer, promptConfiguration) + @AiDsl + suspend fun prompt( + messages: List, + context: VectorStore, + serializer: KSerializer, + conversationId: ConversationId? = null, + functions: List = generateCFunction(serializer.descriptor), + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, + ): A { + return tryDeserialize( + { json -> Json.decodeFromString(serializer, json) }, + promptConfiguration.maxDeserializationAttempts + ) { + promptMessages( + messages = messages, + context = context, + conversationId = conversationId, + functions = functions, + promptConfiguration + ) + } + } + @AiDsl suspend fun prompt( prompt: Prompt, diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt new file mode 100644 index 000000000..e6263a927 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt @@ -0,0 +1,98 @@ +package com.xebia.functional.xef.prompt.expressions + +import com.xebia.functional.xef.auto.CoreAIScope +import com.xebia.functional.xef.auto.PromptConfiguration +import com.xebia.functional.xef.llm.ChatWithFunctions +import com.xebia.functional.xef.llm.models.chat.Message +import com.xebia.functional.xef.llm.models.chat.Role +import com.xebia.functional.xef.prompt.experts.ExpertSystem +import io.github.oshai.kotlinlogging.KLogger +import io.github.oshai.kotlinlogging.KotlinLogging + +class Expression( + private val scope: CoreAIScope, + private val model: ChatWithFunctions, + val block: suspend Expression.() -> Unit +) { + + private val logger: KLogger = KotlinLogging.logger {} + + private val messages: MutableList = mutableListOf() + + private val generationKeys: MutableList = mutableListOf() + + suspend fun system(message: suspend () -> String) { + messages.add(Message(role = Role.SYSTEM, content = message(), name = Role.SYSTEM.name)) + } + + suspend fun user(message: suspend () -> String) { + messages.add(Message(role = Role.USER, content = message(), name = Role.USER.name)) + } + + suspend fun assistant(message: suspend () -> String) { + messages.add(Message(role = Role.ASSISTANT, content = message(), name = Role.ASSISTANT.name)) + } + + fun prompt(key: String): String { + generationKeys.add(key) + return "{{$key}}" + } + + suspend fun run( + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS + ): ExpressionResult { + block() + val instructionMessage = + Message( + role = Role.USER, + content = + ExpertSystem( + system = "You are an expert in replacing variables in templates", + query = + """ + |I want to replace the following variables in the following template: + | + |The variables are: + |${generationKeys.joinToString("\n") { it }} + """ + .trimMargin(), + instructions = + listOf( + "Create a `ReplaceKeys` object `replacements` property of type `Map` where the keys are the variable names and the values are the values to replace them with.", + ) + ) + .message, + name = Role.USER.name + ) + val values: ReplacedValues = + model.prompt( + messages = messages + instructionMessage, + context = scope.context, + serializer = ReplacedValues.serializer(), + conversationId = scope.conversationId, + promptConfiguration = promptConfiguration + ) + val replacedTemplate = + messages.fold("") { acc, message -> + val replacedMessage = + generationKeys.fold(message.content) { acc, key -> + acc.replace( + "{{$key}}", + values.replacements.firstOrNull { it.key == key }?.value ?: "{{$key}}" + ) + } + acc + replacedMessage + "\n" + } + return ExpressionResult(messages = messages, result = replacedTemplate, values = values) + } + + companion object { + suspend fun run( + scope: CoreAIScope, + model: ChatWithFunctions, + block: suspend Expression.() -> Unit + ): ExpressionResult = Expression(scope, model, block).run() + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/ExpressionResult.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/ExpressionResult.kt new file mode 100644 index 000000000..cc8dbad2f --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/ExpressionResult.kt @@ -0,0 +1,9 @@ +package com.xebia.functional.xef.prompt.expressions + +import com.xebia.functional.xef.llm.models.chat.Message + +data class ExpressionResult( + val messages: List, + val result: String, + val values: ReplacedValues, +) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/ReplacedValues.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/ReplacedValues.kt new file mode 100644 index 000000000..9a73ac3b5 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/ReplacedValues.kt @@ -0,0 +1,10 @@ +package com.xebia.functional.xef.prompt.expressions + +import com.xebia.functional.xef.auto.Description +import kotlinx.serialization.Serializable + +@Serializable +data class ReplacedValues( + @Description(["The values that are generated for the template"]) + val replacements: List +) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Replacement.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Replacement.kt new file mode 100644 index 000000000..c341a1e95 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Replacement.kt @@ -0,0 +1,12 @@ +package com.xebia.functional.xef.prompt.expressions + +import com.xebia.functional.xef.auto.Description +import kotlinx.serialization.Serializable + +@Serializable +data class Replacement( + @Description(["The key originally in {{key}} format that was going to get replaced"]) + val key: String, + @Description(["The Assistant generated value that the `key` should be replaced with"]) + val value: String +) diff --git a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/reasoning/ToolSelectionExample.java b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/reasoning/ToolSelectionExample.java deleted file mode 100644 index 2899b31ca..000000000 --- a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/reasoning/ToolSelectionExample.java +++ /dev/null @@ -1,42 +0,0 @@ -package com.xebia.functional.xef.java.auto.jdk21.reasoning; - -import com.xebia.functional.xef.auto.llm.openai.OpenAI; -import com.xebia.functional.xef.java.auto.AIScope; -import com.xebia.functional.xef.java.auto.ExecutionContext; -import com.xebia.functional.xef.reasoning.filesystem.Files; -import com.xebia.functional.xef.reasoning.pdf.PDF; -import com.xebia.functional.xef.reasoning.text.Text; -import com.xebia.functional.xef.reasoning.tools.ToolSelection; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.Executors; - -public class ToolSelectionExample { - - public static void main(String[] args) { - try (var scope = new AIScope(new ExecutionContext(Executors.newVirtualThreadPerTaskExecutor()))) { - var model = OpenAI.DEFAULT_CHAT; - var serialization = OpenAI.DEFAULT_SERIALIZATION; - var text = Text.create(model, scope.getScope()); - var files = Files.create(serialization, scope.getScope(), Collections.emptyList()); - var pdf = PDF.create(model, serialization, scope.getScope()); - - var toolSelection = new ToolSelection( - serialization, - scope.getScope(), - List.of( - text.summarize, - pdf.readPDFFromUrl, - files.readFile, - files.writeToTextFile - ), - Collections.emptyList() - ); - - var inputText = "Extract information from https://arxiv.org/pdf/2305.10601.pdf"; - var result = toolSelection.applyInferredToolsBlocking(inputText); - System.out.println(result); - } - } -} - diff --git a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/tot/Problems.java b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/tot/Problems.java index 370a3b489..44ca246e9 100644 --- a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/tot/Problems.java +++ b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/tot/Problems.java @@ -98,7 +98,7 @@ public Memory addResult(Solutions.Solution result) { private static void checkAIScope() { if(aiScope == null){ - aiScope = new AIScope(new ExecutionContext(Executors.newVirtualThreadPerTaskExecutor())); + aiScope = new AIScope(new ExecutionContext(Executors.newSingleThreadExecutor())); } } diff --git a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk8/reasoning/ToolSelectionExample.java b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk8/reasoning/ToolSelectionExample.java deleted file mode 100644 index 4db9edce6..000000000 --- a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk8/reasoning/ToolSelectionExample.java +++ /dev/null @@ -1,42 +0,0 @@ -package com.xebia.functional.xef.java.auto.jdk8.reasoning; - -import com.xebia.functional.xef.auto.CoreAIScope; -import com.xebia.functional.xef.auto.llm.openai.OpenAI; -import com.xebia.functional.xef.auto.llm.openai.OpenAIEmbeddings; -import com.xebia.functional.xef.auto.llm.openai.OpenAIModel; -import com.xebia.functional.xef.reasoning.filesystem.Files; -import com.xebia.functional.xef.reasoning.pdf.PDF; -import com.xebia.functional.xef.reasoning.text.Text; -import com.xebia.functional.xef.reasoning.tools.ToolSelection; -import java.util.Collections; -import java.util.List; - -public class ToolSelectionExample { - - public static void main(String[] args) { - try (CoreAIScope scope = new CoreAIScope(new OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING))) { - OpenAIModel model = OpenAI.DEFAULT_CHAT; - OpenAIModel serialization = OpenAI.DEFAULT_SERIALIZATION; - Text text = Text.create(model, scope); - Files files = Files.create(serialization, scope, Collections.emptyList()); - PDF pdf = PDF.create(model, serialization, scope); - - ToolSelection toolSelection = new ToolSelection( - serialization, - scope, - List.of( - text.summarize, - pdf.readPDFFromUrl, - files.readFile, - files.writeToTextFile - ), - Collections.emptyList() - ); - - String inputText = "Extract information from https://arxiv.org/pdf/2305.10601.pdf"; - var result = toolSelection.applyInferredToolsBlocking(inputText); - System.out.println(result); - } - } -} - diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/expressions/WorkoutPlanProgram.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/expressions/WorkoutPlanProgram.kt new file mode 100644 index 000000000..311a280d4 --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/expressions/WorkoutPlanProgram.kt @@ -0,0 +1,55 @@ +package com.xebia.functional.xef.auto.expressions + +import com.xebia.functional.xef.auto.CoreAIScope +import com.xebia.functional.xef.auto.ai +import com.xebia.functional.xef.auto.llm.openai.OpenAI +import com.xebia.functional.xef.auto.llm.openai.getOrThrow +import com.xebia.functional.xef.llm.ChatWithFunctions +import com.xebia.functional.xef.prompt.expressions.Expression +import com.xebia.functional.xef.prompt.expressions.ExpressionResult + +suspend fun workoutPlan( + scope: CoreAIScope, + model: ChatWithFunctions, + goal: String, + experienceLevel: String, + equipment: String, + timeAvailable: Int +): ExpressionResult = Expression.run(scope = scope, model = model, block = { + system { "You are a personal fitness trainer" } + user { + """ + |I want to achieve $goal. + |My experience level is $experienceLevel, and I have access to the following equipment: $equipment. + |I can dedicate $timeAvailable minutes per day. + |Can you create a workout plan for me? + """.trimMargin() + } + assistant { + """ + |Sure! Based on your goal, experience level, equipment available, and time commitment, here's a customized workout plan: + |${prompt("workout_plan")} + """.trimMargin() + } +}) + +suspend fun main() { + val model = OpenAI.DEFAULT_SERIALIZATION + ai { + val plan = workoutPlan( + scope = this, + model = model, + goal = "building muscle", + experienceLevel = "intermediate", + equipment = "dumbbells, bench, resistance bands", + timeAvailable = 45 + ) + println("--------------------") + println("Workout Plan") + println("--------------------") + println("šŸ¤– replaced: ${plan.values.replacements.joinToString { it.key }}") + println("--------------------") + println(plan.result) + println("--------------------") + }.getOrThrow() +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 2ebfb0e1c..c5cff0a41 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -2,7 +2,6 @@ arrow = "1.2.0" arrowGradle = "0.12.0-rc.5" kotlin = "1.8.22" -openai = "0.14.0" kotlinx-json = "1.5.1" ktor = "2.3.2" spotless = "6.20.0" @@ -33,7 +32,7 @@ pdfbox = "2.0.29" mysql = "8.0.33" semverGradle = "0.5.0-rc.1" scala = "3.3.0" -openai-client-version = "3.3.2" +openai-client-version = "3.3.1" gpt4all-java = "1.1.5" ai-djl = "0.23.0" jackson = "2.15.2" @@ -44,7 +43,6 @@ suspend-transform = "0.3.1" [libraries] arrow-core = { module = "io.arrow-kt:arrow-core", version.ref = "arrow" } arrow-fx-coroutines = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" } -open-ai = { module = "com.theokanning.openai-gpt3-java:service", version.ref = "openai" } kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-json" } kotlinx-coroutines = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref="kotlinx-coroutines" } kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactive", version.ref="kotlinx-coroutines-reactive" } diff --git a/reasoning/build.gradle.kts b/reasoning/build.gradle.kts index 452dc534f..64db97aed 100644 --- a/reasoning/build.gradle.kts +++ b/reasoning/build.gradle.kts @@ -15,7 +15,6 @@ plugins { alias(libs.plugins.dokka) alias(libs.plugins.arrow.gradle.publish) alias(libs.plugins.semver.gradle) - alias(libs.plugins.suspend.transform.plugin) //id("com.xebia.asfuture").version("0.0.1") } @@ -174,12 +173,6 @@ tasks { } } -suspendTransform { - enabled = true // default: true - includeRuntime = true // default: true - useJvmDefault() -} - tasks.withType { dependsOn(tasks.withType()) } diff --git a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ToolSelection.kt b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ToolSelection.kt index 3d006c190..31f6b7674 100644 --- a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ToolSelection.kt +++ b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ToolSelection.kt @@ -4,8 +4,6 @@ import com.xebia.functional.xef.auto.CoreAIScope import com.xebia.functional.xef.llm.ChatWithFunctions import com.xebia.functional.xef.prompt.experts.ExpertSystem import io.github.oshai.kotlinlogging.KotlinLogging -import love.forte.plugin.suspendtrans.annotation.JvmAsync -import love.forte.plugin.suspendtrans.annotation.JvmBlocking class ToolSelection( private val model: ChatWithFunctions, @@ -27,8 +25,6 @@ class ToolSelection( } } - @JvmBlocking - @JvmAsync suspend fun applyInferredTools(task: String): ToolsExecutionTrace { logger.info { "šŸ” Applying inferred tools for task: $task" } val plan = createExecutionPlan(task) From 0e904dab83139ad24f1fee09adf68457cf8acd0c Mon Sep 17 00:00:00 2001 From: raulraja Date: Sun, 6 Aug 2023 23:16:30 +0200 Subject: [PATCH 2/9] Prompt adjustments --- .../com/xebia/functional/xef/prompt/expressions/Expression.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt index e6263a927..05e63c782 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt @@ -60,7 +60,7 @@ class Expression( .trimMargin(), instructions = listOf( - "Create a `ReplaceKeys` object `replacements` property of type `Map` where the keys are the variable names and the values are the values to replace them with.", + "Create a `ReplacedValues` object with the `replacements` where the keys are the variable names and the values are the values to replace them with.", ) ) .message, From 3bdf66165bbb09a7b85b8a5552442f0a7a2f21fd Mon Sep 17 00:00:00 2001 From: raulraja Date: Mon, 7 Aug 2023 14:06:32 +0200 Subject: [PATCH 3/9] Prompt adjustments and better structure for building final prompt based on messages --- .../com/xebia/functional/xef/llm/Chat.kt | 43 ++++++++----------- .../functional/xef/llm/models/chat/Message.kt | 13 +++++- .../xef/prompt/expressions/Expression.kt | 8 ++-- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt index 5347fcaac..c6611f85c 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt @@ -46,7 +46,7 @@ interface Chat : LLM { ): Flow = flow { val memories: List = memories(conversationId, context, promptConfiguration) - val promptWithContext: String = + val promptWithContext: List = createPromptWithContextAwareOfTokens( memories = memories, ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext), @@ -55,7 +55,7 @@ interface Chat : LLM { minResponseTokens = promptConfiguration.minResponseTokens ) - val messages: List = messages(memories, promptWithContext) + val messages: List = messagesFromMemory(memories) + promptWithContext fun checkTotalLeftChatTokens(): Int { val maxContextLength: Int = modelType.maxContextLength @@ -144,12 +144,16 @@ interface Chat : LLM { functions: List = emptyList(), promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS ): List { + + val memories: List = memories(conversationId, context, promptConfiguration) + val allMessages = messagesFromMemory(memories) + messages + fun checkTotalLeftChatTokens(): Int { val maxContextLength: Int = modelType.maxContextLength - val messagesTokens: Int = tokensFromMessages(messages) + val messagesTokens: Int = tokensFromMessages(allMessages) val totalLeftTokens: Int = maxContextLength - messagesTokens if (totalLeftTokens < 0) { - throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength) + throw AIError.MessagesExceedMaxTokenLength(allMessages, messagesTokens, maxContextLength) } return totalLeftTokens } @@ -214,7 +218,7 @@ interface Chat : LLM { val memories: List = memories(conversationId, context, promptConfiguration) - val promptWithContext: String = + val promptWithContext: List = createPromptWithContextAwareOfTokens( memories = memories, ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext), @@ -223,9 +227,7 @@ interface Chat : LLM { minResponseTokens = promptConfiguration.minResponseTokens ) - val messages: List = messages(memories, promptWithContext) - - return promptMessages(messages, context, conversationId, functions, promptConfiguration) + return promptMessages(promptWithContext, context, conversationId, functions, promptConfiguration) } private suspend fun List.addChoiceWithFunctionsToMemory( @@ -285,8 +287,8 @@ interface Chat : LLM { } } - private fun messages(memories: List, promptWithContext: String): List = - memories.map { it.content } + listOf(Message(Role.USER, promptWithContext, Role.USER.name)) + private fun messagesFromMemory(memories: List): List = + memories.map { it.content } private suspend fun memories( conversationId: ConversationId?, @@ -299,13 +301,13 @@ interface Chat : LLM { emptyList() } - private fun createPromptWithContextAwareOfTokens( + private suspend fun createPromptWithContextAwareOfTokens( memories: List, ctxInfo: List, modelType: ModelType, prompt: String, minResponseTokens: Int, - ): String { + ): List { val maxContextLength: Int = modelType.maxContextLength val promptTokens: Int = modelType.encoding.countTokens(prompt) val memoryTokens = tokensFromMessages(memories.map { it.content }) @@ -322,17 +324,10 @@ interface Chat : LLM { // alternatively we could summarize the context, but that's not implemented yet val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens) - """|```Context - |${ctxTruncated} - |``` - |The context is related to the question try to answer the `goal` as best as you can - |or provide information about the found content - |```goal - |${prompt} - |``` - |ANSWER: - |""" - .trimMargin() - } else prompt + listOf( + Message.assistantMessage { "Context: $ctxTruncated" }, + Message.userMessage { prompt } + ) + } else listOf(Message.userMessage { prompt }) } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt index 9f6ef6ed6..2a78143bf 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt @@ -1,3 +1,14 @@ package com.xebia.functional.xef.llm.models.chat -data class Message(val role: Role, val content: String, val name: String) +data class Message(val role: Role, val content: String, val name: String) { + companion object { + suspend fun systemMessage(message: suspend () -> String) = + Message(role = Role.SYSTEM, content = message(), name = Role.SYSTEM.name) + + suspend fun userMessage(message: suspend () -> String) = + Message(role = Role.USER, content = message(), name = Role.USER.name) + + suspend fun assistantMessage(message: suspend () -> String) = + Message(role = Role.ASSISTANT, content = message(), name = Role.ASSISTANT.name) + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt index 05e63c782..fce7eec71 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt @@ -22,15 +22,15 @@ class Expression( private val generationKeys: MutableList = mutableListOf() suspend fun system(message: suspend () -> String) { - messages.add(Message(role = Role.SYSTEM, content = message(), name = Role.SYSTEM.name)) + messages.add(Message.systemMessage(message)) } suspend fun user(message: suspend () -> String) { - messages.add(Message(role = Role.USER, content = message(), name = Role.USER.name)) + messages.add(Message.userMessage(message)) } suspend fun assistant(message: suspend () -> String) { - messages.add(Message(role = Role.ASSISTANT, content = message(), name = Role.ASSISTANT.name)) + messages.add(Message.assistantMessage(message)) } fun prompt(key: String): String { @@ -74,6 +74,7 @@ class Expression( conversationId = scope.conversationId, promptConfiguration = promptConfiguration ) + logger.info { "replaced: ${values.replacements.joinToString { it.key }}" } val replacedTemplate = messages.fold("") { acc, message -> val replacedMessage = @@ -94,5 +95,6 @@ class Expression( model: ChatWithFunctions, block: suspend Expression.() -> Unit ): ExpressionResult = Expression(scope, model, block).run() + } } From f674f871793a3d3dbb54f2fce4ca3ad9ce01ac22 Mon Sep 17 00:00:00 2001 From: raulraja Date: Tue, 8 Aug 2023 11:40:34 +0200 Subject: [PATCH 4/9] Prompt adjustments and better structure for building final prompt based on messages --- .../com/xebia/functional/xef/llm/Chat.kt | 13 +- .../xef/prompt/expressions/Expression.kt | 1 - .../xef/prompt/lang/PromptLanguage.kt | 138 +++++++++ .../auto/expressions/WorkoutPlanProgram.kt | 59 ++-- .../xef/auto/reasoning/ReActExample.kt | 10 +- .../xef/reasoning/internals/CallModel.kt | 17 +- .../functional/xef/reasoning/tools/LLMTool.kt | 27 +- .../xef/reasoning/tools/ReActAgent.kt | 266 ++++++++---------- .../filesystem/CreatePythonScript.kt | 83 ------ .../functional/xef/reasoning/search/Search.kt | 40 ++- 10 files changed, 366 insertions(+), 288 deletions(-) create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/lang/PromptLanguage.kt delete mode 100644 reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/CreatePythonScript.kt diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt index c6611f85c..8bd613f13 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt @@ -227,7 +227,13 @@ interface Chat : LLM { minResponseTokens = promptConfiguration.minResponseTokens ) - return promptMessages(promptWithContext, context, conversationId, functions, promptConfiguration) + return promptMessages( + promptWithContext, + context, + conversationId, + functions, + promptConfiguration + ) } private suspend fun List.addChoiceWithFunctionsToMemory( @@ -324,10 +330,7 @@ interface Chat : LLM { // alternatively we could summarize the context, but that's not implemented yet val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens) - listOf( - Message.assistantMessage { "Context: $ctxTruncated" }, - Message.userMessage { prompt } - ) + listOf(Message.assistantMessage { "Context: $ctxTruncated" }, Message.userMessage { prompt }) } else listOf(Message.userMessage { prompt }) } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt index fce7eec71..c0c2eb25b 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt @@ -95,6 +95,5 @@ class Expression( model: ChatWithFunctions, block: suspend Expression.() -> Unit ): ExpressionResult = Expression(scope, model, block).run() - } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/lang/PromptLanguage.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/lang/PromptLanguage.kt new file mode 100644 index 000000000..faf1ec903 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/lang/PromptLanguage.kt @@ -0,0 +1,138 @@ +package com.xebia.functional.xef.prompt.lang + +import kotlinx.serialization.Serializable + +/* +# Codebot + +Roleplay as a world-class senior software engineer pair programmer. + +DevProcess { + State { + Target Language: JavaScript + } + WriteTestsFIRST { + Use Riteway ({ given, should, actual, expected }) { + Define given, should, actual, and expected inline in the `assert` call. + "Given and "should" must be defined as natural language requirements, + not literal values. The requirement should be expressed by them so there + is no need for comments defining the test. + } + Tests must be { + Readable + Isolated from each other in separate scopes. Test units of code in + isolation from the rest of the program. + Thorough: Test all likely edge cases. + Explicit: Tests should have strong locality. Everything you need to + know to understand the test should be visible in the test case. + } + Each test must answer { + What is the unit under test? + What is the natural language requirement being tested? + What is the actual output? + What is the expected output? + On failure, identify and fix the bug. + } + } + Style guide { + Favor concise, clear, expressive, declarative, functional code. + Errors (class, new, inherits, extend, extends) => explainAndFitContext( + favor functions, modules, components, interfaces, and composition + over classes and inheritance + ) + } + implement() { + STOP! Write tests FIRST. + Implement the code such that unit tests pass. Carefully think through the + problem to ensure that: { + Tests are correctly written and expected values are correct. + Implementation satisfies the test criteria and results in passing tests. + } + } + /implement - Implement code in the target language from a SudoLang function + or natural language description + /l | lang - Set the target language + /h | help +} + + +When asked to implement a function, please carefully follow the +instructions above. šŸ™ + +welcome() + */ + +@Serializable data class Node(val name: String, val children: List) {} + +class PromptLanguage(private val children: MutableList = mutableListOf()) { + + operator fun String.invoke(): Node { + val node = Node(this, emptyList()) + children.add(node) + return node + } + + operator fun String.invoke(f: PromptLanguage.() -> Unit): Node { + val node = Node(this, PromptLanguage().apply(f).children) + children.add(node) + return node + } + + operator fun String.div(other: String): Node { + val node = Node(this, listOf(Node(other, emptyList()))) + children.add(node) + return node + } + + companion object { + operator fun invoke(f: PromptLanguage.() -> Node): Node = f(PromptLanguage()) + } +} + +fun infer(): Nothing = TODO() + +fun summarize(text: String): String = infer() + +val devProcess = PromptLanguage { + "DevProcess" { + "State" { "Target Language" { "JavaScript"() } } + "WriteTestsFIRST" { + "Use Riteway({ given, should, actual, expected })" { + "Define given, should, actual, and expected inline in the `assert` call."() + "Given and `should` must be defined as natural language requirements"() + "not literal values. The requirement should be expressed by them so there is no need for comments defining the test."() + } + "Tests must be" { + "Readable"() + "Isolated from each other in separate scopes. Test units of code in isolation from the rest of the program."() + "Thorough: Test all likely edge cases."() + "Explicit: Tests should have strong locality. Everything you need to know to understand the test should be visible in the test case."() + } + "Each test must answer" { + "What is the unit under test?"() + "What is the natural language requirement being tested?"() + "What is the actual output?"() + "What is the expected output?"() + "On failure, identify and fix the bug."() + } + } + "Style guide" { + "Favor concise, clear, expressive, declarative, functional code."() + "Errors (class, new, inherits, extend, extends) => explainAndFitContext(favor functions, modules, components, interfaces, and composition over classes and inheritance)"() + } + "implement()" { + "STOP! Write tests FIRST."() + "Implement the code such that unit tests pass. Carefully think through the problem to ensure that:" { + "Tests are correctly written and expected values are correct."() + "Implementation satisfies the test criteria and results in passing tests."() + } + } + "/implement - Implement code in the target language from a SudoLang function or natural language description"() + "/l | lang - Set the target language"() + "/h | help"() + + "When asked to implement a function, please carefully follow the instructions above. šŸ™"() + + "welcome()"() + } +} diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/expressions/WorkoutPlanProgram.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/expressions/WorkoutPlanProgram.kt index 311a280d4..a70ddd228 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/expressions/WorkoutPlanProgram.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/expressions/WorkoutPlanProgram.kt @@ -7,45 +7,66 @@ import com.xebia.functional.xef.auto.llm.openai.getOrThrow import com.xebia.functional.xef.llm.ChatWithFunctions import com.xebia.functional.xef.prompt.expressions.Expression import com.xebia.functional.xef.prompt.expressions.ExpressionResult +import com.xebia.functional.xef.reasoning.search.Search +import com.xebia.functional.xef.reasoning.tools.LLMTool +import com.xebia.functional.xef.reasoning.tools.Tool -suspend fun workoutPlan( +suspend fun taskSplitter( scope: CoreAIScope, model: ChatWithFunctions, - goal: String, - experienceLevel: String, - equipment: String, - timeAvailable: Int + prompt: String, + tools: List ): ExpressionResult = Expression.run(scope = scope, model = model, block = { - system { "You are a personal fitness trainer" } + system { "You are a professional task planner" } user { """ - |I want to achieve $goal. - |My experience level is $experienceLevel, and I have access to the following equipment: $equipment. - |I can dedicate $timeAvailable minutes per day. - |Can you create a workout plan for me? + |I want to achieve: """.trimMargin() } + user { + prompt + } + assistant { + "I have access to all these tool" + } + tools.forEach { + assistant { + "${it.name}: ${it.description}" + } + } assistant { """ - |Sure! Based on your goal, experience level, equipment available, and time commitment, here's a customized workout plan: - |${prompt("workout_plan")} + |I will break down your task into 3 tasks to make progress and help you accomplish this goal + |using the tools that I have available. + |1: ${prompt("task1")} + |2: ${prompt("task2")} + |3: ${prompt("task3")} """.trimMargin() } }) + suspend fun main() { - val model = OpenAI.DEFAULT_SERIALIZATION + ai { - val plan = workoutPlan( + val model = OpenAI.DEFAULT_SERIALIZATION + val math = LLMTool.create( + name = "Calculator", + description = "Perform math operations and calculations processing them with an LLM model. The tool input is a simple string containing the operation to solve expressed in numbers and math symbols.", + model = model, + scope = this + ) + val search = Search(model = model, scope = this) + val plan = taskSplitter( scope = this, model = model, - goal = "building muscle", - experienceLevel = "intermediate", - equipment = "dumbbells, bench, resistance bands", - timeAvailable = 45 + prompt = "Find and multiply the number of Leonardo di Caprio's girlfriends by the number of Metallica albums", + tools = listOf( + search, math + ) ) println("--------------------") - println("Workout Plan") + println("Plan") println("--------------------") println("šŸ¤– replaced: ${plan.values.replacements.joinToString { it.key }}") println("--------------------") diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/reasoning/ReActExample.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/reasoning/ReActExample.kt index f25bfce6c..ee4ed6a19 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/reasoning/ReActExample.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/reasoning/ReActExample.kt @@ -3,6 +3,7 @@ package com.xebia.functional.xef.auto.reasoning import com.xebia.functional.xef.auto.ai import com.xebia.functional.xef.auto.llm.openai.OpenAI import com.xebia.functional.xef.auto.llm.openai.getOrThrow +import com.xebia.functional.xef.llm.models.chat.Message import com.xebia.functional.xef.reasoning.search.Search import com.xebia.functional.xef.reasoning.tools.LLMTool import com.xebia.functional.xef.reasoning.tools.ReActAgent @@ -23,12 +24,17 @@ suspend fun main() { model = serialization, scope = this, tools = listOf( - math, search, + math, ), ) val result = - reActAgent.run("Multiply the number of Leonardo di Caprio's girlfriends by the number of Metallica albums") + reActAgent.run( + listOf( + Message.userMessage { + "Find and multiply the number of Leonardo di Caprio's girlfriends by the number of Metallica albums" + }) + ) println(result) }.getOrThrow() } diff --git a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/internals/CallModel.kt b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/internals/CallModel.kt index dfad04a57..d515e433b 100644 --- a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/internals/CallModel.kt +++ b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/internals/CallModel.kt @@ -2,16 +2,19 @@ package com.xebia.functional.xef.reasoning.internals import com.xebia.functional.xef.auto.CoreAIScope import com.xebia.functional.xef.llm.Chat -import com.xebia.functional.xef.prompt.Prompt +import com.xebia.functional.xef.llm.models.chat.Message internal suspend fun callModel( model: Chat, scope: CoreAIScope, - prompt: Prompt, + prompt: List, ): String { - return model.promptMessage( - question = prompt.message, - context = scope.context, - conversationId = scope.conversationId, - ) + return model + .promptMessages( + messages = prompt, + context = scope.context, + conversationId = scope.conversationId, + ) + .firstOrNull() + ?: error("No results found") } diff --git a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/LLMTool.kt b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/LLMTool.kt index ce8d5ac3d..044bd7daf 100644 --- a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/LLMTool.kt +++ b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/LLMTool.kt @@ -2,7 +2,7 @@ package com.xebia.functional.xef.reasoning.tools import com.xebia.functional.xef.auto.CoreAIScope import com.xebia.functional.xef.llm.Chat -import com.xebia.functional.xef.prompt.experts.ExpertSystem +import com.xebia.functional.xef.llm.models.chat.Message import com.xebia.functional.xef.reasoning.internals.callModel import io.github.oshai.kotlinlogging.KotlinLogging import kotlin.jvm.JvmOverloads @@ -18,25 +18,22 @@ abstract class LLMTool( private val logger = KotlinLogging.logger {} override suspend operator fun invoke(input: String): String { - logger.info { "šŸ”§ Running $name - $description" } + logger.info { "šŸ”§ $name[$input]" } return callModel( model, scope, prompt = - ExpertSystem( - system = "You are an expert in `$name` ($description)", - query = - """| - |Given the following input: - |```input - |${input} - |``` - |Produce an output that satisfies the tool `$name` ($description) operation. - """ - .trimMargin(), - instructions = instructions - ) + listOf( + Message.systemMessage { "You are an expert in executing tool:" }, + Message.systemMessage { "Tool: $name" }, + Message.systemMessage { "Description: $description" }, + ) + + instructions.map { Message.systemMessage { it } } + + listOf( + Message.userMessage { "input: $input" }, + Message.assistantMessage { "output:" }, + ) ) } diff --git a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ReActAgent.kt b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ReActAgent.kt index cfd41d4da..ec15f9c05 100644 --- a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ReActAgent.kt +++ b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ReActAgent.kt @@ -4,7 +4,10 @@ import com.xebia.functional.xef.auto.CoreAIScope import com.xebia.functional.xef.auto.Description import com.xebia.functional.xef.auto.PromptConfiguration import com.xebia.functional.xef.llm.ChatWithFunctions -import com.xebia.functional.xef.prompt.experts.ExpertSystem +import com.xebia.functional.xef.llm.models.chat.Message +import com.xebia.functional.xef.llm.models.chat.Message.Companion.assistantMessage +import com.xebia.functional.xef.llm.models.chat.Message.Companion.systemMessage +import com.xebia.functional.xef.llm.models.chat.Message.Companion.userMessage import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.serialization.Serializable @@ -15,10 +18,12 @@ class ReActAgent( private val maxIterations: Int = 10, ) { + val conversationId = scope.conversationId + private val logger = KotlinLogging.logger {} private suspend fun createExecutionPlan( - input: String, + input: List, chain: List, promptConfiguration: PromptConfiguration ): AgentPlan { @@ -32,172 +37,130 @@ class ReActAgent( } } - private suspend fun agentFinish(input: String, chain: List): AgentFinish = + private suspend fun agentFinish( + input: List, + chain: List + ): AgentFinish = model.prompt( context = scope.context, - conversationId = scope.conversationId, + conversationId = conversationId, serializer = AgentFinish.serializer(), - prompt = - ExpertSystem( - system = "You are an expert in providing answers", - query = - """| - |Given the following input: - |```input - |${input} - |``` - |And the following chain of thoughts and observations: - |```chain - |${ - chain.map { (k, v) -> - """ - |Thought: $k - |Observation: $v - """.trimMargin() - }.joinToString("\n") - } - |``` - """ - .trimMargin(), - instructions = - listOf( - "Provide the final answer to the `input` in a sentence or paragraph", - ) - ) + messages = + listOf( + systemMessage { "You are an expert in providing answers" }, + ) + + chain.chainToMessages() + + listOf( + userMessage { "Provide the final answer to the `input` in a sentence or paragraph" }, + userMessage { "input: $input" }, + assistantMessage { + "I should create a AgentFinish object with the final answer based on the thoughts and observations" + } + ) ) - private suspend fun agentAction(input: String, chain: List): AgentAction = + private suspend fun agentAction( + input: List, + chain: List + ): AgentAction = model.prompt( context = scope.context, - conversationId = scope.conversationId, + conversationId = conversationId, serializer = AgentAction.serializer(), - prompt = - ExpertSystem( - system = - "You are an expert in tool selection. You are given a `input` and a `chain` of thoughts and observations.", - query = - """| - |Given the following input: - |```input - |${input} - |``` - |And the following tools: - |```tools - |${ - (tools.map { - ToolMetadata( - it.name, - it.description + messages = + listOf( + systemMessage { + "You are an expert in tool selection. You are given a `input` and a `chain` of thoughts and observations." + }, + userMessage { "input:" }, + ) + + input + + listOf( + assistantMessage { "chain:" }, + ) + + chain.chainToMessages() + + listOf( + assistantMessage { "I can only use this tools:" }, + ) + + tools.toolsToMessages() + + listOf( + assistantMessage { + "I will not repeat the `toolInput` if the same one produced no satisfactory results in the observations" + }, + userMessage { "Provide the next tool to use and the `toolInput` for the tool" }, ) - }).joinToString("\n") { "${it.name}: ${it.description}" } - } - |``` - |And the following chain of thoughts and observations: - |```chain - |${ - chain.map { (k, v) -> - """ - |Thought: $k - |Observation: $v - """.trimMargin() - }.joinToString("\n") - } - |``` - """ - .trimMargin(), - instructions = - listOf( - "The `tool` and `toolInput` MUST be provided for the next step", - ) - ) ) + private suspend fun List.toolsToMessages(): List = flatMap { + listOf( + assistantMessage { "${it.name}: ${it.description}" }, + ) + } + + private suspend fun List.chainToMessages(): List = flatMap { + listOf( + assistantMessage { "Thought: ${it.thought}" }, + assistantMessage { "Observation: ${it.observation}" }, + ) + } + private suspend fun agentChoice( promptConfiguration: PromptConfiguration, - input: String, + input: List, chain: List ): AgentChoice = model.prompt( context = scope.context, - conversationId = scope.conversationId, + conversationId = conversationId, serializer = AgentChoice.serializer(), promptConfiguration = promptConfiguration, - prompt = - ExpertSystem( - system = - "You will reflect on the `input` and `chain` and decide whether you are done or not", - query = - """| - |Given the following input: - |```input - |${input} - |``` - |And the following chain of thoughts and observations: - |```chain - |${ - chain.joinToString("\n") { c -> - """ - |Thought: ${c.thought} - |Observation: ${c.observation} - """.trimMargin() - } - } - |``` - """ - .trimMargin(), - instructions = - listOf( - "Choose `CONTINUE` if you are not 100% certain that all elements in the original `input` question are answered completely by the info found in the `chain`", - "Choose `CONTINUE` if the `chain` needs more information to be able to completely answer all elements in the `input` question", - "Choose `FINISH` if you are 100% certain that all elements in the `input` question are answered by the info found in the `chain` and are not a list of steps to achieve the goal.", - ) - ) + messages = + input + + listOf( + assistantMessage { "chain:" }, + ) + + chain.chainToMessages() + + listOf( + assistantMessage { + "`CONTINUE` if the `input` has not been answered by the observations in the `chain`" + }, + assistantMessage { + "`FINISH` if the `input` has been answered by the observations in the `chain`" + }, + ) ) private suspend fun createInitialThought( - input: String, + input: List, promptConfiguration: PromptConfiguration ): Thought { - logger.info { "šŸ¤” $input" } return model.prompt( context = scope.context, - conversationId = scope.conversationId, + conversationId = conversationId, serializer = Thought.serializer(), promptConfiguration = promptConfiguration, - prompt = - ExpertSystem( - system = - "You are an expert in providing more descriptive inputs for tasks that a user wants to execute", - query = - """| - |Given the following input: - |```input - |${input} - |``` - |And the following tools: - |```tools - |${ - (tools.map { - ToolMetadata( - it.name, - it.description - ) - }).joinToString("\n") { "${it.name}: ${it.description}" } - } - |``` - """ - .trimMargin(), - instructions = - listOf( - "Create a prompt that serves as 'thought' of what to do next in order to accurately describe what the user wants to do", - "Your `RESPONSE` MUST be a `Thought` object, where the `thought` determines what the user should do next" - ) - ) + messages = + listOf( + systemMessage { "You are an expert in providing next steps to solve a problem" }, + systemMessage { "You are given a `input` provided by the user" }, + userMessage { "input:" }, + ) + + input + + listOf( + assistantMessage { "I have access to tools:" }, + ) + + tools.toolsToMessages() + + listOf( + assistantMessage { + "I should create a Thought object with the next thought based on the `input`" + }, + userMessage { "Provide the next thought based on the `input`" }, + ) ) } private tailrec suspend fun runRec( - input: String, + input: List, chain: List, currentIteration: Int, promptConfiguration: PromptConfiguration @@ -214,15 +177,30 @@ class ReActAgent( is AgentAction -> { logger.info { "šŸ¤” ${plan.thought}" } logger.info { "šŸ›  ${plan.tool}[${plan.toolInput}]" } - val observation: String? = tools.find { it.name == plan.tool }?.invoke(plan.toolInput) + val observation: String? = + tools.find { it.name.equals(plan.tool, ignoreCase = true) }?.invoke(plan.toolInput) if (observation == null) { logger.info { "šŸ¤·ā€ Could not find ${plan.tool}" } - runRec(input, chain, currentIteration + 1, promptConfiguration) + runRec( + input, + chain + + ThoughtObservation( + plan.thought, + "Result of running ${plan.tool}[${plan.toolInput}]: " + + "šŸ¤·ā€ Could not find ${plan.tool}, will not try this tool again" + ), + currentIteration + 1, + promptConfiguration + ) } else { logger.info { "šŸ‘€ $observation" } runRec( input, - chain + ThoughtObservation(plan.thought, observation), + chain + + ThoughtObservation( + plan.thought, + "Result of running ${plan.tool}[${plan.toolInput}]: " + observation + ), currentIteration + 1, promptConfiguration ) @@ -236,11 +214,17 @@ class ReActAgent( } suspend fun run( - input: String, + input: List, promptConfiguration: PromptConfiguration = PromptConfiguration { temperature(0.0) } ): String { val thought = createInitialThought(input, promptConfiguration) - return runRec(input, listOf(ThoughtObservation(input, thought.thought)), 0, promptConfiguration) + logger.info { "šŸ¤” ${thought.thought}" } + return runRec( + input, + listOf(ThoughtObservation("I should get started", thought.thought)), + 0, + promptConfiguration + ) } } diff --git a/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/CreatePythonScript.kt b/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/CreatePythonScript.kt deleted file mode 100644 index 6d778abd5..000000000 --- a/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/CreatePythonScript.kt +++ /dev/null @@ -1,83 +0,0 @@ -package com.xebia.functional.xef.reasoning.filesystem - -import com.xebia.functional.xef.auto.CoreAIScope -import com.xebia.functional.xef.io.CommandExecutor -import com.xebia.functional.xef.io.DEFAULT -import com.xebia.functional.xef.io.ExecuteCommandOptions -import com.xebia.functional.xef.llm.Chat -import com.xebia.functional.xef.prompt.experts.ExpertSystem -import com.xebia.functional.xef.reasoning.internals.callModel -import com.xebia.functional.xef.reasoning.tools.Tool -import okio.FileSystem - -class CreatePythonScript -@JvmOverloads -constructor( - private val model: Chat, - private val scope: CoreAIScope, - private val instructions: List = emptyList() -) : Tool { - override val name: String = "Create Python Script" - - override val description: String = - "Creates a Python script from the input, run it and capture its output" - - override suspend fun invoke(input: String): String { - val script: String = - callModel( - model = model, - scope = scope, - ExpertSystem( - system = "Create a Python script from the input", - query = input, - instructions = instructions, - ) - ) - val requirements: String = - callModel( - model = model, - scope = scope, - ExpertSystem( - system = "Create a requirements.txt for the input", - query = script, - instructions = instructions, - ) - ) - val tempFile = FileSystem.SYSTEM_TEMPORARY_DIRECTORY.resolve("data").resolve("script.py") - FileSystem.DEFAULT.write(tempFile, mustCreate = true) { writeUtf8(script) } - val tempRequirements = - FileSystem.SYSTEM_TEMPORARY_DIRECTORY.resolve("data").resolve("requirements.txt") - FileSystem.DEFAULT.write(tempRequirements, mustCreate = true) { writeUtf8(requirements) } - val output = - CommandExecutor.DEFAULT.executeCommandAndCaptureOutput( - command = - listOf( - "pip", - "install", - "-r", - tempRequirements.toString(), - "-t", - tempRequirements.parent.toString() - ), - options = - ExecuteCommandOptions( - directory = FileSystem.SYSTEM_TEMPORARY_DIRECTORY.toString(), - abortOnError = true, - redirectStderr = true, - trim = true - ) - ) - val runOutput = - CommandExecutor.DEFAULT.executeCommandAndCaptureOutput( - command = listOf("python", tempFile.toString()), - options = - ExecuteCommandOptions( - directory = FileSystem.SYSTEM_TEMPORARY_DIRECTORY.toString(), - abortOnError = true, - redirectStderr = true, - trim = true - ) - ) - return output + "\n" + runOutput - } -} diff --git a/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/search/Search.kt b/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/search/Search.kt index d048426d2..eb7e0a4f4 100644 --- a/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/search/Search.kt +++ b/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/search/Search.kt @@ -4,6 +4,7 @@ import com.xebia.functional.xef.auto.AutoClose import com.xebia.functional.xef.auto.CoreAIScope import com.xebia.functional.xef.auto.autoClose import com.xebia.functional.xef.llm.Chat +import com.xebia.functional.xef.llm.models.chat.Message import com.xebia.functional.xef.reasoning.serpapi.SerpApiClient import com.xebia.functional.xef.reasoning.tools.Tool @@ -12,29 +13,38 @@ class Search constructor( private val model: Chat, private val scope: CoreAIScope, + private val maxResultsInContext: Int = 3, private val client: SerpApiClient = SerpApiClient() ) : Tool, AutoCloseable, AutoClose by autoClose() { override val name: String = "Search" override val description: String = - "Search the web for the best answer. The tool input is a simple string" + "Search the web for information. The tool input is a simple one line string" override suspend fun invoke(input: String): String { val docs = client.search(SerpApiClient.SearchData(input)) - val innerDocs = docs.searchResults.mapNotNull { it.document } - scope.extendContext(*innerDocs.toTypedArray()) - return model.promptMessage( - question = - """| - |Given the following input: - |```input - |${input} - |``` - |Provide information that helps with the `input`. - """ - .trimMargin(), - context = scope.context, - ) + return model + .promptMessages( + messages = + listOf(Message.systemMessage { "Search results:" }) + + docs.searchResults.take(maxResultsInContext).flatMap { + listOf( + Message.systemMessage { "Title: ${it.title}" }, + Message.systemMessage { "Source: ${it.source}" }, + Message.systemMessage { "Content: ${it.document}" }, + ) + } + + listOf( + Message.userMessage { "input: $input" }, + Message.assistantMessage { + "I will select the best search results and reply with information relevant to the `input`" + } + ), + context = scope.context, + conversationId = scope.conversationId, + ) + .firstOrNull() + ?: "No results found" } override fun close() { From e53e6f79976d8fed5e1e0693fa808339417f4676 Mon Sep 17 00:00:00 2001 From: raulraja Date: Tue, 8 Aug 2023 11:41:41 +0200 Subject: [PATCH 5/9] spotless and clean up --- .../xef/prompt/lang/PromptLanguage.kt | 48 ------------------- 1 file changed, 48 deletions(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/lang/PromptLanguage.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/lang/PromptLanguage.kt index faf1ec903..f66fb8485 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/lang/PromptLanguage.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/lang/PromptLanguage.kt @@ -88,51 +88,3 @@ class PromptLanguage(private val children: MutableList = mutableListOf()) operator fun invoke(f: PromptLanguage.() -> Node): Node = f(PromptLanguage()) } } - -fun infer(): Nothing = TODO() - -fun summarize(text: String): String = infer() - -val devProcess = PromptLanguage { - "DevProcess" { - "State" { "Target Language" { "JavaScript"() } } - "WriteTestsFIRST" { - "Use Riteway({ given, should, actual, expected })" { - "Define given, should, actual, and expected inline in the `assert` call."() - "Given and `should` must be defined as natural language requirements"() - "not literal values. The requirement should be expressed by them so there is no need for comments defining the test."() - } - "Tests must be" { - "Readable"() - "Isolated from each other in separate scopes. Test units of code in isolation from the rest of the program."() - "Thorough: Test all likely edge cases."() - "Explicit: Tests should have strong locality. Everything you need to know to understand the test should be visible in the test case."() - } - "Each test must answer" { - "What is the unit under test?"() - "What is the natural language requirement being tested?"() - "What is the actual output?"() - "What is the expected output?"() - "On failure, identify and fix the bug."() - } - } - "Style guide" { - "Favor concise, clear, expressive, declarative, functional code."() - "Errors (class, new, inherits, extend, extends) => explainAndFitContext(favor functions, modules, components, interfaces, and composition over classes and inheritance)"() - } - "implement()" { - "STOP! Write tests FIRST."() - "Implement the code such that unit tests pass. Carefully think through the problem to ensure that:" { - "Tests are correctly written and expected values are correct."() - "Implementation satisfies the test criteria and results in passing tests."() - } - } - "/implement - Implement code in the target language from a SudoLang function or natural language description"() - "/l | lang - Set the target language"() - "/h | help"() - - "When asked to implement a function, please carefully follow the instructions above. šŸ™"() - - "welcome()"() - } -} From ae6ab02c2edfab28ae88ea8169ac3a76c90a4f52 Mon Sep 17 00:00:00 2001 From: raulraja Date: Tue, 8 Aug 2023 11:48:34 +0200 Subject: [PATCH 6/9] clean up removing one use of the ExpertSystem --- .../xef/prompt/expressions/Expression.kt | 34 +++++-------------- 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt index c0c2eb25b..39a280960 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/expressions/Expression.kt @@ -4,8 +4,6 @@ import com.xebia.functional.xef.auto.CoreAIScope import com.xebia.functional.xef.auto.PromptConfiguration import com.xebia.functional.xef.llm.ChatWithFunctions import com.xebia.functional.xef.llm.models.chat.Message -import com.xebia.functional.xef.llm.models.chat.Role -import com.xebia.functional.xef.prompt.experts.ExpertSystem import io.github.oshai.kotlinlogging.KLogger import io.github.oshai.kotlinlogging.KotlinLogging @@ -42,33 +40,17 @@ class Expression( promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS ): ExpressionResult { block() - val instructionMessage = - Message( - role = Role.USER, - content = - ExpertSystem( - system = "You are an expert in replacing variables in templates", - query = - """ - |I want to replace the following variables in the following template: - | - |The variables are: - |${generationKeys.joinToString("\n") { it }} - """ - .trimMargin(), - instructions = - listOf( - "Create a `ReplacedValues` object with the `replacements` where the keys are the variable names and the values are the values to replace them with.", - ) - ) - .message, - name = Role.USER.name + val prelude = + listOf( + Message.systemMessage { "You are an expert in replacing variables in templates" }, + ) + val instructionMessages = + listOf( + Message.assistantMessage { "I will replace all placeholders in the message" }, ) val values: ReplacedValues = model.prompt( - messages = messages + instructionMessage, + messages = prelude + messages + instructionMessages, context = scope.context, serializer = ReplacedValues.serializer(), conversationId = scope.conversationId, From b68cc8e1552f8115339f5d22f0fd7098304a8d5d Mon Sep 17 00:00:00 2001 From: Javi Pacheco Date: Wed, 9 Aug 2023 15:02:33 +0200 Subject: [PATCH 7/9] New strategy for collecting messages by tokens --- .../com/xebia/functional/xef/AIError.kt | 5 + .../xef/auto/PromptConfiguration.kt | 7 - .../com/xebia/functional/xef/llm/Chat.kt | 220 +++++++++--------- .../functional/xef/llm/ChatWithFunctions.kt | 80 +++---- .../xebia/functional/xef/auto/gpt4all/Chat.kt | 1 - .../auto/llm/openai/DeserializerLLMAgent.kt | 2 +- .../xef/reasoning/tools/ToolSelection.kt | 49 ++-- 7 files changed, 183 insertions(+), 181 deletions(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt index 771e6d53a..f7eaa549f 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt @@ -27,6 +27,11 @@ sealed class AIError @JvmOverloads constructor(message: String, cause: Throwable "Prompt exceeds max token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}" ) + data class PromptExceedsMaxRemainingTokenLength(val promptTokens: Int, val maxTokens: Int) : + AIError( + "Prompt exceeds max remaining token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}" + ) + data class JsonParsing(val result: String, val maxAttempts: Int, override val cause: Throwable) : AIError("Failed to parse the JSON response after $maxAttempts attempts: $result", cause) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt index 879bd5aa9..e9e4d0fa8 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt @@ -12,7 +12,6 @@ class PromptConfiguration( val docsInContext: Int = 5, val memoryLimit: Int = 5, val minResponseTokens: Int = 500, - val streamToStandardOut: Boolean = false ) { companion object { @@ -23,17 +22,12 @@ class PromptConfiguration( private var numberOfPredictions: Int = 1 private var docsInContext: Int = 20 private var minResponseTokens: Int = 500 - private var streamToStandardOut: Boolean = false private var memoryLimit: Int = 5 fun maxDeserializationAttempts(maxDeserializationAttempts: Int) = apply { this.maxDeserializationAttempts = maxDeserializationAttempts } - fun streamToStandardOut(streamToStandardOut: Boolean) = apply { - this.streamToStandardOut = streamToStandardOut - } - fun user(user: String) = apply { this.user = user } fun temperature(temperature: Double) = apply { this.temperature = temperature } @@ -59,7 +53,6 @@ class PromptConfiguration( docsInContext = docsInContext, memoryLimit = memoryLimit, minResponseTokens = minResponseTokens, - streamToStandardOut = streamToStandardOut, ) } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt index 8bd613f13..fddeac2a4 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt @@ -46,35 +46,23 @@ interface Chat : LLM { ): Flow = flow { val memories: List = memories(conversationId, context, promptConfiguration) - val promptWithContext: List = - createPromptWithContextAwareOfTokens( - memories = memories, - ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext), - modelType = modelType, - prompt = prompt.message, - minResponseTokens = promptConfiguration.minResponseTokens + val messagesForRequest = + fitMessagesByTokens( + messagesFromMemory(memories), + prompt.toMessages(), + context, + modelType, + promptConfiguration ) - val messages: List = messagesFromMemory(memories) + promptWithContext - - fun checkTotalLeftChatTokens(): Int { - val maxContextLength: Int = modelType.maxContextLength - val messagesTokens: Int = tokensFromMessages(messages) - val totalLeftTokens: Int = maxContextLength - messagesTokens - if (totalLeftTokens < 0) { - throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength) - } - return totalLeftTokens - } - - val request: ChatCompletionRequest = + val request = ChatCompletionRequest( model = name, user = promptConfiguration.user, - messages = messages, + messages = messagesForRequest, n = promptConfiguration.numberOfPredictions, temperature = promptConfiguration.temperature, - maxTokens = checkTotalLeftChatTokens(), + maxTokens = promptConfiguration.minResponseTokens, streamToStandardOut = true ) @@ -90,31 +78,6 @@ interface Chat : LLM { .collect { emit(it.choices.mapNotNull { it.delta?.content }.joinToString("")) } } - private suspend fun addMemoriesAfterStream( - request: ChatCompletionRequest, - conversationId: ConversationId?, - buffer: StringBuilder, - context: VectorStore - ) { - val lastRequestMessage = request.messages.lastOrNull() - if (conversationId != null && lastRequestMessage != null) { - val requestMemory = - Memory( - conversationId = conversationId, - content = lastRequestMessage, - timestamp = getTimeMillis() - ) - val responseMemory = - Memory( - conversationId = conversationId, - content = - Message(role = Role.ASSISTANT, content = buffer.toString(), name = Role.ASSISTANT.name), - timestamp = getTimeMillis(), - ) - context.addMemories(listOf(requestMemory, responseMemory)) - } - } - @AiDsl suspend fun promptMessage( question: String, @@ -136,6 +99,23 @@ interface Chat : LLM { ): List = promptMessages(Prompt(question), context, conversationId, functions, promptConfiguration) + @AiDsl + suspend fun promptMessages( + prompt: Prompt, + context: VectorStore, + conversationId: ConversationId? = null, + functions: List = emptyList(), + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS + ): List { + return promptMessages( + prompt.toMessages(), + context, + conversationId, + functions, + promptConfiguration + ) + } + @AiDsl suspend fun promptMessages( messages: List, @@ -146,37 +126,33 @@ interface Chat : LLM { ): List { val memories: List = memories(conversationId, context, promptConfiguration) - val allMessages = messagesFromMemory(memories) + messages - - fun checkTotalLeftChatTokens(): Int { - val maxContextLength: Int = modelType.maxContextLength - val messagesTokens: Int = tokensFromMessages(allMessages) - val totalLeftTokens: Int = maxContextLength - messagesTokens - if (totalLeftTokens < 0) { - throw AIError.MessagesExceedMaxTokenLength(allMessages, messagesTokens, maxContextLength) - } - return totalLeftTokens - } + val messagesForRequest = + fitMessagesByTokens( + messagesFromMemory(memories), + messages, + context, + modelType, + promptConfiguration + ) fun chatRequest(): ChatCompletionRequest = ChatCompletionRequest( model = name, user = promptConfiguration.user, - messages = messages, + messages = messagesForRequest, n = promptConfiguration.numberOfPredictions, temperature = promptConfiguration.temperature, - maxTokens = checkTotalLeftChatTokens(), - streamToStandardOut = promptConfiguration.streamToStandardOut + maxTokens = promptConfiguration.minResponseTokens, ) fun withFunctionsRequest(): ChatCompletionRequestWithFunctions = ChatCompletionRequestWithFunctions( model = name, user = promptConfiguration.user, - messages = messages, + messages = messagesForRequest, n = promptConfiguration.numberOfPredictions, temperature = promptConfiguration.temperature, - maxTokens = checkTotalLeftChatTokens(), + maxTokens = promptConfiguration.minResponseTokens, functions = functions, functionCall = mapOf("name" to (functions.firstOrNull()?.name ?: "")) ) @@ -207,33 +183,33 @@ interface Chat : LLM { } } - @AiDsl - suspend fun promptMessages( - prompt: Prompt, - context: VectorStore, - conversationId: ConversationId? = null, - functions: List = emptyList(), - promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS - ): List { - - val memories: List = memories(conversationId, context, promptConfiguration) + suspend fun String.toMessages(): List = Prompt(this).toMessages() - val promptWithContext: List = - createPromptWithContextAwareOfTokens( - memories = memories, - ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext), - modelType = modelType, - prompt = prompt.message, - minResponseTokens = promptConfiguration.minResponseTokens - ) + suspend fun Prompt.toMessages(): List = listOf(Message.userMessage { message }) - return promptMessages( - promptWithContext, - context, - conversationId, - functions, - promptConfiguration - ) + private suspend fun addMemoriesAfterStream( + request: ChatCompletionRequest, + conversationId: ConversationId?, + buffer: StringBuilder, + context: VectorStore + ) { + val lastRequestMessage = request.messages.lastOrNull() + if (conversationId != null && lastRequestMessage != null) { + val requestMemory = + Memory( + conversationId = conversationId, + content = lastRequestMessage, + timestamp = getTimeMillis() + ) + val responseMemory = + Memory( + conversationId = conversationId, + content = + Message(role = Role.ASSISTANT, content = buffer.toString(), name = Role.ASSISTANT.name), + timestamp = getTimeMillis(), + ) + context.addMemories(listOf(requestMemory, responseMemory)) + } } private suspend fun List.addChoiceWithFunctionsToMemory( @@ -307,30 +283,62 @@ interface Chat : LLM { emptyList() } - private suspend fun createPromptWithContextAwareOfTokens( - memories: List, - ctxInfo: List, + private suspend fun fitMessagesByTokens( + history: List, + messages: List, + context: VectorStore, modelType: ModelType, - prompt: String, - minResponseTokens: Int, + promptConfiguration: PromptConfiguration, ): List { val maxContextLength: Int = modelType.maxContextLength - val promptTokens: Int = modelType.encoding.countTokens(prompt) - val memoryTokens = tokensFromMessages(memories.map { it.content }) - val remainingTokens: Int = maxContextLength - promptTokens - memoryTokens - minResponseTokens + val remainingTokens: Int = maxContextLength - promptConfiguration.minResponseTokens + + val messagesTokens = tokensFromMessages(messages) + + if (messagesTokens >= remainingTokens) { + throw AIError.PromptExceedsMaxRemainingTokenLength(messagesTokens, remainingTokens) + } + + val remainingTokensForContexts = remainingTokens - messagesTokens + + // TODO we should move this to PromptConfiguration + val historyPercent = 50 + val contextPercent = 50 + + val maxHistoryTokens = (remainingTokensForContexts * historyPercent) / 100 + + val historyMessagesWithTokens = history.map { Pair(it, tokensFromMessages(listOf(it))) } + + val totalTokenWithMessages = + historyMessagesWithTokens.reversed().fold(Pair(0, emptyList())) { acc, pair -> + if (acc.first + pair.second > maxHistoryTokens) { + acc + } else { + Pair(acc.first + pair.second, acc.second + pair.first) + } + } + + val historyAllowed = totalTokenWithMessages.second.reversed() + + val maxContextTokens = (remainingTokensForContexts * contextPercent) / 100 + + val ctxInfo = + context.similaritySearch( + messages.joinToString("\n") { it.content }, + promptConfiguration.docsInContext, + ) + + val contextAllowed = + if (ctxInfo.isNotEmpty()) { + val ctx: String = ctxInfo.joinToString("\n") - return if (ctxInfo.isNotEmpty() && remainingTokens > minResponseTokens) { - val ctx: String = ctxInfo.joinToString("\n") + val ctxTruncated: String = modelType.encoding.truncateText(ctx, maxContextTokens) - if (promptTokens >= maxContextLength) { - throw AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, maxContextLength) + listOf(Message.assistantMessage { ctxTruncated }) + } else { + emptyList() } - // truncate the context if it's too long based on the max tokens calculated considering the - // existing prompt tokens - // alternatively we could summarize the context, but that's not implemented yet - val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens) - listOf(Message.assistantMessage { "Context: $ctxTruncated" }, Message.userMessage { prompt }) - } else listOf(Message.userMessage { prompt }) + return contextAllowed + historyAllowed + messages } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt index 7284d053f..f654a54cf 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt @@ -34,6 +34,17 @@ interface ChatWithFunctions : Chat { fun generateCFunction(fnName: String, schema: String): List = listOf(CFunction(fnName, "Generated function for $fnName", schema)) + @AiDsl + suspend fun prompt( + prompt: String, + context: VectorStore, + conversationId: ConversationId? = null, + functions: List = emptyList(), + serializer: (json: String) -> A, + promptConfiguration: PromptConfiguration, + ): A = + prompt(prompt.toMessages(), context, conversationId, functions, serializer, promptConfiguration) + @AiDsl suspend fun prompt( prompt: Prompt, @@ -44,30 +55,19 @@ interface ChatWithFunctions : Chat { serializer: (json: String) -> A, functions: List = generateCFunction(serializerName, jsonSchema), promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, - ): A = prompt(prompt, context, conversationId, functions, serializer, promptConfiguration) + ): A = + prompt(prompt.toMessages(), context, conversationId, functions, serializer, promptConfiguration) @AiDsl suspend fun prompt( - messages: List, + prompt: Prompt, context: VectorStore, - serializer: KSerializer, + serializer: (json: String) -> A, conversationId: ConversationId? = null, - functions: List = generateCFunction(serializer.descriptor), + functions: List = emptyList(), promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, - ): A { - return tryDeserialize( - { json -> Json.decodeFromString(serializer, json) }, - promptConfiguration.maxDeserializationAttempts - ) { - promptMessages( - messages = messages, - context = context, - conversationId = conversationId, - functions = functions, - promptConfiguration - ) - } - } + ): A = + prompt(prompt.toMessages(), context, conversationId, functions, serializer, promptConfiguration) @AiDsl suspend fun prompt( @@ -77,56 +77,52 @@ interface ChatWithFunctions : Chat { conversationId: ConversationId? = null, functions: List = generateCFunction(serializer.descriptor), promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, - ): A { - return prompt( - prompt, + ): A = + prompt( + prompt.toMessages(), context, conversationId, functions, { json -> Json.decodeFromString(serializer, json) }, promptConfiguration ) - } @AiDsl suspend fun prompt( - prompt: String, + messages: List, context: VectorStore, + serializer: KSerializer, conversationId: ConversationId? = null, - functions: List = emptyList(), - serializer: (json: String) -> A, - promptConfiguration: PromptConfiguration, - ): A { - return tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) { - promptMessages( - prompt = Prompt(prompt), - context = context, - conversationId = conversationId, - functions = functions, - promptConfiguration - ) - } - } + functions: List = generateCFunction(serializer.descriptor), + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, + ): A = + prompt( + messages, + context, + conversationId, + functions, + { json -> Json.decodeFromString(serializer, json) }, + promptConfiguration + ) @AiDsl suspend fun prompt( - prompt: Prompt, + messages: List, context: VectorStore, conversationId: ConversationId? = null, functions: List = emptyList(), serializer: (json: String) -> A, promptConfiguration: PromptConfiguration, - ): A { - return tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) { + ): A = + tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) { promptMessages( - prompt = prompt, + messages = messages, context = context, conversationId = conversationId, functions = functions, promptConfiguration ) } - } private suspend fun tryDeserialize( serializer: (json: String) -> A, diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt index 255decff9..abe25ed0b 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt @@ -39,7 +39,6 @@ suspend fun main() { context, promptConfiguration = PromptConfiguration { docsInContext(2) - streamToStandardOut(true) }).onCompletion { println("\nšŸ¤– Done") }.collect { diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/DeserializerLLMAgent.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/DeserializerLLMAgent.kt index 5616b8c61..5ce2042d7 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/DeserializerLLMAgent.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/DeserializerLLMAgent.kt @@ -63,9 +63,9 @@ suspend fun CoreAIScope.prompt( return model.prompt( prompt, context, + { json.decodeFromString(serializer, it) }, conversationId, functions, - { json.decodeFromString(serializer, it) }, promptConfiguration ) } diff --git a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ToolSelection.kt b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ToolSelection.kt index 31f6b7674..86aa102cb 100644 --- a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ToolSelection.kt +++ b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ToolSelection.kt @@ -2,7 +2,7 @@ package com.xebia.functional.xef.reasoning.tools import com.xebia.functional.xef.auto.CoreAIScope import com.xebia.functional.xef.llm.ChatWithFunctions -import com.xebia.functional.xef.prompt.experts.ExpertSystem +import com.xebia.functional.xef.llm.models.chat.Message import io.github.oshai.kotlinlogging.KotlinLogging class ToolSelection( @@ -79,32 +79,33 @@ class ToolSelection( suspend fun createExecutionPlan(task: String): ToolsExecutionPlan { logger.info { "šŸ” Creating execution plan for task: $task" } + + val messages: List = + listOf( + Message.systemMessage { + "You are an expert in tool selection that can choose the best tools for a specific task based on the tools descriptions" + }, + Message.assistantMessage { "Given the following task:" }, + Message.assistantMessage { task }, + Message.assistantMessage { "Given the following tools:" }, + ) + + tools.map { Message.assistantMessage { "${it.name}: ${it.description}" } } + + listOf( + Message.userMessage { "Follow the next instructions" }, + Message.userMessage { + "Select the best execution plan with tools for the `task` based on the `tools`" + }, + Message.userMessage { + "Your `RESPONSE` MUST be a `ToolsExecutionPlan` object, where the `steps` determine how the execution plan will run the tools" + }, + ) + + instructions.map { Message.userMessage { it } } + return model.prompt( context = scope.context, - conversationId = scope.conversationId, + conversationId = null, serializer = ToolsExecutionPlan.serializer(), - prompt = - ExpertSystem( - system = - "You are an expert in tool selection that can choose the best tools for a specific task based on the tools descriptions", - query = - """| - |Given the following task: - |```task - |${task} - |``` - |And the following tools: - |```tools - |${(tools.map { ToolMetadata(it.name, it.description) }).joinToString("\n") { "${it.name}: ${it.description}" }} - |``` - """ - .trimMargin(), - instructions = - listOf( - "Select the best execution plan with tools for the `task` based on the `tools`", - "Your `RESPONSE` MUST be a `ToolsExecutionPlan` object, where the `steps` determine how the execution plan will run the tools" - ) + instructions - ) + messages = messages ) } } From 158acee79aa56851180dc1972858a1f13260c300 Mon Sep 17 00:00:00 2001 From: Javi Pacheco Date: Wed, 9 Aug 2023 15:18:02 +0200 Subject: [PATCH 8/9] Java files fixed --- .../com/xebia/functional/xef/java/auto/jdk21/gpt4all/Chat.java | 2 +- .../com/xebia/functional/xef/java/auto/jdk8/gpt4all/Chat.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/gpt4all/Chat.java b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/gpt4all/Chat.java index 769aff757..a63d2641c 100644 --- a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/gpt4all/Chat.java +++ b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk21/gpt4all/Chat.java @@ -48,7 +48,7 @@ public static void main(String[] args) throws ExecutionException, InterruptedExc String line = br.readLine(); if (line.equals("exit")) break; - var promptConfiguration = new PromptConfiguration.Companion.Builder().docsInContext(2).streamToStandardOut(true).build(); + var promptConfiguration = new PromptConfiguration.Companion.Builder().docsInContext(2).build(); var answer = scope.promptStreaming(gpt4all, line, promptConfiguration); answer.subscribe(new Subscriber() { diff --git a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk8/gpt4all/Chat.java b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk8/gpt4all/Chat.java index ec6b11565..42c84b03a 100644 --- a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk8/gpt4all/Chat.java +++ b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/jdk8/gpt4all/Chat.java @@ -48,7 +48,7 @@ public static void main(String[] args) throws ExecutionException, InterruptedExc String line = br.readLine(); if (line.equals("exit")) break; - PromptConfiguration promptConfiguration = new PromptConfiguration.Companion.Builder().docsInContext(2).streamToStandardOut(true).build(); + PromptConfiguration promptConfiguration = new PromptConfiguration.Companion.Builder().docsInContext(2).build(); Publisher answer = scope.promptStreaming(gpt4all, line, promptConfiguration); answer.subscribe(new Subscriber() { From 97d0252021ba1d16c2b91844e6e10393594f3a13 Mon Sep 17 00:00:00 2001 From: Javi Pacheco Date: Thu, 10 Aug 2023 08:48:11 +0200 Subject: [PATCH 9/9] Comments addressed --- .../functional/xef/auto/PromptConfiguration.kt | 17 +++++++++++++++++ .../kotlin/com/xebia/functional/xef/llm/Chat.kt | 7 +++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt index e9e4d0fa8..42f000a68 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt @@ -12,6 +12,7 @@ class PromptConfiguration( val docsInContext: Int = 5, val memoryLimit: Int = 5, val minResponseTokens: Int = 500, + val messagePolicy: MessagePolicy = MessagePolicy(), ) { companion object { @@ -23,6 +24,7 @@ class PromptConfiguration( private var docsInContext: Int = 20 private var minResponseTokens: Int = 500 private var memoryLimit: Int = 5 + private var messagePolicy: MessagePolicy = MessagePolicy() fun maxDeserializationAttempts(maxDeserializationAttempts: Int) = apply { this.maxDeserializationAttempts = maxDeserializationAttempts @@ -44,6 +46,8 @@ class PromptConfiguration( fun memoryLimit(memoryLimit: Int) = apply { this.memoryLimit = memoryLimit } + fun messagePolicy(messagePolicy: MessagePolicy) = apply { this.messagePolicy = messagePolicy } + fun build() = PromptConfiguration( maxDeserializationAttempts = maxDeserializationAttempts, @@ -53,6 +57,7 @@ class PromptConfiguration( docsInContext = docsInContext, memoryLimit = memoryLimit, minResponseTokens = minResponseTokens, + messagePolicy = messagePolicy, ) } @@ -62,3 +67,15 @@ class PromptConfiguration( @JvmField val DEFAULTS = PromptConfiguration() } } + +/** + * The [MessagePolicy] encapsulates the message selection policy for sending to the server. Allows + * defining the percentages of historical and contextual messages to include in the final list. + * + * @property historyPercent Percentage of historical messages + * @property contextPercent Percentage of context messages + */ +class MessagePolicy( + val historyPercent: Int = 50, + val contextPercent: Int = 50, +) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt index fddeac2a4..5e18d60bb 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt @@ -301,16 +301,15 @@ interface Chat : LLM { val remainingTokensForContexts = remainingTokens - messagesTokens - // TODO we should move this to PromptConfiguration - val historyPercent = 50 - val contextPercent = 50 + val historyPercent = promptConfiguration.messagePolicy.historyPercent + val contextPercent = promptConfiguration.messagePolicy.contextPercent val maxHistoryTokens = (remainingTokensForContexts * historyPercent) / 100 val historyMessagesWithTokens = history.map { Pair(it, tokensFromMessages(listOf(it))) } val totalTokenWithMessages = - historyMessagesWithTokens.reversed().fold(Pair(0, emptyList())) { acc, pair -> + historyMessagesWithTokens.foldRight(Pair(0, emptyList())) { pair, acc -> if (acc.first + pair.second > maxHistoryTokens) { acc } else {