Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New strategy for collecting messages and improvements to ReAct agent and message passing #305

Merged
merged 14 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ sealed class AIError @JvmOverloads constructor(message: String, cause: Throwable
"Prompt exceeds max token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}"
)

data class PromptExceedsMaxRemainingTokenLength(val promptTokens: Int, val maxTokens: Int) :
AIError(
"Prompt exceeds max remaining token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}"
)

data class JsonParsing(val result: String, val maxAttempts: Int, override val cause: Throwable) :
AIError("Failed to parse the JSON response after $maxAttempts attempts: $result", cause)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class PromptConfiguration(
val docsInContext: Int = 5,
val memoryLimit: Int = 5,
val minResponseTokens: Int = 500,
val streamToStandardOut: Boolean = false
val messagePolicy: MessagePolicy = MessagePolicy(),
) {
companion object {

Expand All @@ -23,17 +23,13 @@ class PromptConfiguration(
private var numberOfPredictions: Int = 1
private var docsInContext: Int = 20
private var minResponseTokens: Int = 500
private var streamToStandardOut: Boolean = false
private var memoryLimit: Int = 5
private var messagePolicy: MessagePolicy = MessagePolicy()

fun maxDeserializationAttempts(maxDeserializationAttempts: Int) = apply {
this.maxDeserializationAttempts = maxDeserializationAttempts
}

fun streamToStandardOut(streamToStandardOut: Boolean) = apply {
this.streamToStandardOut = streamToStandardOut
}

fun user(user: String) = apply { this.user = user }

fun temperature(temperature: Double) = apply { this.temperature = temperature }
Expand All @@ -50,6 +46,8 @@ class PromptConfiguration(

fun memoryLimit(memoryLimit: Int) = apply { this.memoryLimit = memoryLimit }

fun messagePolicy(messagePolicy: MessagePolicy) = apply { this.messagePolicy = messagePolicy }

fun build() =
PromptConfiguration(
maxDeserializationAttempts = maxDeserializationAttempts,
Expand All @@ -59,7 +57,7 @@ class PromptConfiguration(
docsInContext = docsInContext,
memoryLimit = memoryLimit,
minResponseTokens = minResponseTokens,
streamToStandardOut = streamToStandardOut,
messagePolicy = messagePolicy,
)
}

Expand All @@ -69,3 +67,15 @@ class PromptConfiguration(
@JvmField val DEFAULTS = PromptConfiguration()
}
}

/**
* The [MessagePolicy] encapsulates the message selection policy for sending to the server. Allows
* defining the percentages of historical and contextual messages to include in the final list.
*
* @property historyPercent Percentage of historical messages
* @property contextPercent Percentage of context messages
*/
class MessagePolicy(
val historyPercent: Int = 50,
val contextPercent: Int = 50,
)
219 changes: 113 additions & 106 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,23 @@ interface Chat : LLM {
): Flow<String> = flow {
val memories: List<Memory> = memories(conversationId, context, promptConfiguration)

val promptWithContext: List<Message> =
createPromptWithContextAwareOfTokens(
memories = memories,
ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext),
modelType = modelType,
prompt = prompt.message,
minResponseTokens = promptConfiguration.minResponseTokens
val messagesForRequest =
fitMessagesByTokens(
messagesFromMemory(memories),
prompt.toMessages(),
context,
modelType,
promptConfiguration
)

val messages: List<Message> = messagesFromMemory(memories) + promptWithContext

fun checkTotalLeftChatTokens(): Int {
val maxContextLength: Int = modelType.maxContextLength
val messagesTokens: Int = tokensFromMessages(messages)
val totalLeftTokens: Int = maxContextLength - messagesTokens
if (totalLeftTokens < 0) {
throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength)
}
return totalLeftTokens
}

val request: ChatCompletionRequest =
val request =
ChatCompletionRequest(
model = name,
user = promptConfiguration.user,
messages = messages,
messages = messagesForRequest,
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(),
maxTokens = promptConfiguration.minResponseTokens,
raulraja marked this conversation as resolved.
Show resolved Hide resolved
streamToStandardOut = true
)

Expand All @@ -90,31 +78,6 @@ interface Chat : LLM {
.collect { emit(it.choices.mapNotNull { it.delta?.content }.joinToString("")) }
}

private suspend fun addMemoriesAfterStream(
request: ChatCompletionRequest,
conversationId: ConversationId?,
buffer: StringBuilder,
context: VectorStore
) {
val lastRequestMessage = request.messages.lastOrNull()
if (conversationId != null && lastRequestMessage != null) {
val requestMemory =
Memory(
conversationId = conversationId,
content = lastRequestMessage,
timestamp = getTimeMillis()
)
val responseMemory =
Memory(
conversationId = conversationId,
content =
Message(role = Role.ASSISTANT, content = buffer.toString(), name = Role.ASSISTANT.name),
timestamp = getTimeMillis(),
)
context.addMemories(listOf(requestMemory, responseMemory))
}
}

@AiDsl
suspend fun promptMessage(
question: String,
Expand All @@ -136,6 +99,23 @@ interface Chat : LLM {
): List<String> =
promptMessages(Prompt(question), context, conversationId, functions, promptConfiguration)

@AiDsl
suspend fun promptMessages(
Montagon marked this conversation as resolved.
Show resolved Hide resolved
prompt: Prompt,
context: VectorStore,
conversationId: ConversationId? = null,
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): List<String> {
return promptMessages(
prompt.toMessages(),
context,
conversationId,
functions,
promptConfiguration
)
}

@AiDsl
suspend fun promptMessages(
messages: List<Message>,
Expand All @@ -146,37 +126,33 @@ interface Chat : LLM {
): List<String> {

val memories: List<Memory> = memories(conversationId, context, promptConfiguration)
val allMessages = messagesFromMemory(memories) + messages

fun checkTotalLeftChatTokens(): Int {
val maxContextLength: Int = modelType.maxContextLength
val messagesTokens: Int = tokensFromMessages(allMessages)
val totalLeftTokens: Int = maxContextLength - messagesTokens
if (totalLeftTokens < 0) {
throw AIError.MessagesExceedMaxTokenLength(allMessages, messagesTokens, maxContextLength)
}
return totalLeftTokens
}
val messagesForRequest =
fitMessagesByTokens(
messagesFromMemory(memories),
messages,
context,
modelType,
promptConfiguration
)

fun chatRequest(): ChatCompletionRequest =
ChatCompletionRequest(
model = name,
user = promptConfiguration.user,
messages = messages,
messages = messagesForRequest,
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(),
streamToStandardOut = promptConfiguration.streamToStandardOut
maxTokens = promptConfiguration.minResponseTokens,
)

fun withFunctionsRequest(): ChatCompletionRequestWithFunctions =
ChatCompletionRequestWithFunctions(
model = name,
user = promptConfiguration.user,
messages = messages,
messages = messagesForRequest,
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(),
maxTokens = promptConfiguration.minResponseTokens,
functions = functions,
functionCall = mapOf("name" to (functions.firstOrNull()?.name ?: ""))
)
Expand Down Expand Up @@ -207,33 +183,33 @@ interface Chat : LLM {
}
}

@AiDsl
suspend fun promptMessages(
prompt: Prompt,
context: VectorStore,
conversationId: ConversationId? = null,
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): List<String> {

val memories: List<Memory> = memories(conversationId, context, promptConfiguration)
suspend fun String.toMessages(): List<Message> = Prompt(this).toMessages()

val promptWithContext: List<Message> =
createPromptWithContextAwareOfTokens(
memories = memories,
ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext),
modelType = modelType,
prompt = prompt.message,
minResponseTokens = promptConfiguration.minResponseTokens
)
suspend fun Prompt.toMessages(): List<Message> = listOf(Message.userMessage { message })

return promptMessages(
promptWithContext,
context,
conversationId,
functions,
promptConfiguration
)
private suspend fun addMemoriesAfterStream(
request: ChatCompletionRequest,
conversationId: ConversationId?,
buffer: StringBuilder,
context: VectorStore
) {
val lastRequestMessage = request.messages.lastOrNull()
if (conversationId != null && lastRequestMessage != null) {
val requestMemory =
Memory(
conversationId = conversationId,
content = lastRequestMessage,
timestamp = getTimeMillis()
)
val responseMemory =
Memory(
conversationId = conversationId,
content =
Message(role = Role.ASSISTANT, content = buffer.toString(), name = Role.ASSISTANT.name),
timestamp = getTimeMillis(),
)
context.addMemories(listOf(requestMemory, responseMemory))
}
}

private suspend fun List<ChoiceWithFunctions>.addChoiceWithFunctionsToMemory(
Expand Down Expand Up @@ -307,30 +283,61 @@ interface Chat : LLM {
emptyList()
}

private suspend fun createPromptWithContextAwareOfTokens(
memories: List<Memory>,
ctxInfo: List<String>,
private suspend fun fitMessagesByTokens(
history: List<Message>,
messages: List<Message>,
context: VectorStore,
modelType: ModelType,
prompt: String,
minResponseTokens: Int,
promptConfiguration: PromptConfiguration,
): List<Message> {
val maxContextLength: Int = modelType.maxContextLength
val promptTokens: Int = modelType.encoding.countTokens(prompt)
val memoryTokens = tokensFromMessages(memories.map { it.content })
val remainingTokens: Int = maxContextLength - promptTokens - memoryTokens - minResponseTokens
val remainingTokens: Int = maxContextLength - promptConfiguration.minResponseTokens

val messagesTokens = tokensFromMessages(messages)

if (messagesTokens >= remainingTokens) {
throw AIError.PromptExceedsMaxRemainingTokenLength(messagesTokens, remainingTokens)
}

val remainingTokensForContexts = remainingTokens - messagesTokens

val historyPercent = promptConfiguration.messagePolicy.historyPercent
val contextPercent = promptConfiguration.messagePolicy.contextPercent

val maxHistoryTokens = (remainingTokensForContexts * historyPercent) / 100

val historyMessagesWithTokens = history.map { Pair(it, tokensFromMessages(listOf(it))) }

val totalTokenWithMessages =
historyMessagesWithTokens.foldRight(Pair(0, emptyList<Message>())) { pair, acc ->
if (acc.first + pair.second > maxHistoryTokens) {
acc
} else {
Pair(acc.first + pair.second, acc.second + pair.first)
}
}

val historyAllowed = totalTokenWithMessages.second.reversed()

val maxContextTokens = (remainingTokensForContexts * contextPercent) / 100

val ctxInfo =
context.similaritySearch(
messages.joinToString("\n") { it.content },
promptConfiguration.docsInContext,
)

val contextAllowed =
if (ctxInfo.isNotEmpty()) {
val ctx: String = ctxInfo.joinToString("\n")

return if (ctxInfo.isNotEmpty() && remainingTokens > minResponseTokens) {
val ctx: String = ctxInfo.joinToString("\n")
val ctxTruncated: String = modelType.encoding.truncateText(ctx, maxContextTokens)

if (promptTokens >= maxContextLength) {
throw AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, maxContextLength)
listOf(Message.assistantMessage { ctxTruncated })
} else {
emptyList()
}
// truncate the context if it's too long based on the max tokens calculated considering the
// existing prompt tokens
// alternatively we could summarize the context, but that's not implemented yet
val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens)

listOf(Message.assistantMessage { "Context: $ctxTruncated" }, Message.userMessage { prompt })
} else listOf(Message.userMessage { prompt })
return contextAllowed + historyAllowed + messages
}
}
Loading