Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New strategy for collecting messages and improvements to ReAct agent and message passing #305

Merged
merged 14 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ sealed class AIError @JvmOverloads constructor(message: String, cause: Throwable
"Prompt exceeds max token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}"
)

data class PromptExceedsMaxRemainingTokenLength(val promptTokens: Int, val maxTokens: Int) :
AIError(
"Prompt exceeds max remaining token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}"
)

data class JsonParsing(val result: String, val maxAttempts: Int, override val cause: Throwable) :
AIError("Failed to parse the JSON response after $maxAttempts attempts: $result", cause)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ class PromptConfiguration(
val docsInContext: Int = 5,
val memoryLimit: Int = 5,
val minResponseTokens: Int = 500,
val streamToStandardOut: Boolean = false
) {
companion object {

Expand All @@ -23,17 +22,12 @@ class PromptConfiguration(
private var numberOfPredictions: Int = 1
private var docsInContext: Int = 20
private var minResponseTokens: Int = 500
private var streamToStandardOut: Boolean = false
private var memoryLimit: Int = 5

fun maxDeserializationAttempts(maxDeserializationAttempts: Int) = apply {
this.maxDeserializationAttempts = maxDeserializationAttempts
}

fun streamToStandardOut(streamToStandardOut: Boolean) = apply {
this.streamToStandardOut = streamToStandardOut
}

fun user(user: String) = apply { this.user = user }

fun temperature(temperature: Double) = apply { this.temperature = temperature }
Expand All @@ -59,7 +53,6 @@ class PromptConfiguration(
docsInContext = docsInContext,
memoryLimit = memoryLimit,
minResponseTokens = minResponseTokens,
streamToStandardOut = streamToStandardOut,
)
}

Expand Down
220 changes: 114 additions & 106 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,23 @@ interface Chat : LLM {
): Flow<String> = flow {
val memories: List<Memory> = memories(conversationId, context, promptConfiguration)

val promptWithContext: List<Message> =
createPromptWithContextAwareOfTokens(
memories = memories,
ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext),
modelType = modelType,
prompt = prompt.message,
minResponseTokens = promptConfiguration.minResponseTokens
val messagesForRequest =
fitMessagesByTokens(
messagesFromMemory(memories),
prompt.toMessages(),
context,
modelType,
promptConfiguration
)

val messages: List<Message> = messagesFromMemory(memories) + promptWithContext

fun checkTotalLeftChatTokens(): Int {
val maxContextLength: Int = modelType.maxContextLength
val messagesTokens: Int = tokensFromMessages(messages)
val totalLeftTokens: Int = maxContextLength - messagesTokens
if (totalLeftTokens < 0) {
throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength)
}
return totalLeftTokens
}

val request: ChatCompletionRequest =
val request =
ChatCompletionRequest(
model = name,
user = promptConfiguration.user,
messages = messages,
messages = messagesForRequest,
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(),
maxTokens = promptConfiguration.minResponseTokens,
raulraja marked this conversation as resolved.
Show resolved Hide resolved
streamToStandardOut = true
)

Expand All @@ -90,31 +78,6 @@ interface Chat : LLM {
.collect { emit(it.choices.mapNotNull { it.delta?.content }.joinToString("")) }
}

private suspend fun addMemoriesAfterStream(
request: ChatCompletionRequest,
conversationId: ConversationId?,
buffer: StringBuilder,
context: VectorStore
) {
val lastRequestMessage = request.messages.lastOrNull()
if (conversationId != null && lastRequestMessage != null) {
val requestMemory =
Memory(
conversationId = conversationId,
content = lastRequestMessage,
timestamp = getTimeMillis()
)
val responseMemory =
Memory(
conversationId = conversationId,
content =
Message(role = Role.ASSISTANT, content = buffer.toString(), name = Role.ASSISTANT.name),
timestamp = getTimeMillis(),
)
context.addMemories(listOf(requestMemory, responseMemory))
}
}

@AiDsl
suspend fun promptMessage(
question: String,
Expand All @@ -136,6 +99,23 @@ interface Chat : LLM {
): List<String> =
promptMessages(Prompt(question), context, conversationId, functions, promptConfiguration)

@AiDsl
suspend fun promptMessages(
Montagon marked this conversation as resolved.
Show resolved Hide resolved
prompt: Prompt,
context: VectorStore,
conversationId: ConversationId? = null,
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): List<String> {
return promptMessages(
prompt.toMessages(),
context,
conversationId,
functions,
promptConfiguration
)
}

@AiDsl
suspend fun promptMessages(
messages: List<Message>,
Expand All @@ -146,37 +126,33 @@ interface Chat : LLM {
): List<String> {

val memories: List<Memory> = memories(conversationId, context, promptConfiguration)
val allMessages = messagesFromMemory(memories) + messages

fun checkTotalLeftChatTokens(): Int {
val maxContextLength: Int = modelType.maxContextLength
val messagesTokens: Int = tokensFromMessages(allMessages)
val totalLeftTokens: Int = maxContextLength - messagesTokens
if (totalLeftTokens < 0) {
throw AIError.MessagesExceedMaxTokenLength(allMessages, messagesTokens, maxContextLength)
}
return totalLeftTokens
}
val messagesForRequest =
fitMessagesByTokens(
messagesFromMemory(memories),
messages,
context,
modelType,
promptConfiguration
)

fun chatRequest(): ChatCompletionRequest =
ChatCompletionRequest(
model = name,
user = promptConfiguration.user,
messages = messages,
messages = messagesForRequest,
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(),
streamToStandardOut = promptConfiguration.streamToStandardOut
maxTokens = promptConfiguration.minResponseTokens,
)

fun withFunctionsRequest(): ChatCompletionRequestWithFunctions =
ChatCompletionRequestWithFunctions(
model = name,
user = promptConfiguration.user,
messages = messages,
messages = messagesForRequest,
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(),
maxTokens = promptConfiguration.minResponseTokens,
functions = functions,
functionCall = mapOf("name" to (functions.firstOrNull()?.name ?: ""))
)
Expand Down Expand Up @@ -207,33 +183,33 @@ interface Chat : LLM {
}
}

@AiDsl
suspend fun promptMessages(
prompt: Prompt,
context: VectorStore,
conversationId: ConversationId? = null,
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): List<String> {

val memories: List<Memory> = memories(conversationId, context, promptConfiguration)
suspend fun String.toMessages(): List<Message> = Prompt(this).toMessages()

val promptWithContext: List<Message> =
createPromptWithContextAwareOfTokens(
memories = memories,
ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext),
modelType = modelType,
prompt = prompt.message,
minResponseTokens = promptConfiguration.minResponseTokens
)
suspend fun Prompt.toMessages(): List<Message> = listOf(Message.userMessage { message })

return promptMessages(
promptWithContext,
context,
conversationId,
functions,
promptConfiguration
)
private suspend fun addMemoriesAfterStream(
request: ChatCompletionRequest,
conversationId: ConversationId?,
buffer: StringBuilder,
context: VectorStore
) {
val lastRequestMessage = request.messages.lastOrNull()
if (conversationId != null && lastRequestMessage != null) {
val requestMemory =
Memory(
conversationId = conversationId,
content = lastRequestMessage,
timestamp = getTimeMillis()
)
val responseMemory =
Memory(
conversationId = conversationId,
content =
Message(role = Role.ASSISTANT, content = buffer.toString(), name = Role.ASSISTANT.name),
timestamp = getTimeMillis(),
)
context.addMemories(listOf(requestMemory, responseMemory))
}
}

private suspend fun List<ChoiceWithFunctions>.addChoiceWithFunctionsToMemory(
Expand Down Expand Up @@ -307,30 +283,62 @@ interface Chat : LLM {
emptyList()
}

private suspend fun createPromptWithContextAwareOfTokens(
memories: List<Memory>,
ctxInfo: List<String>,
private suspend fun fitMessagesByTokens(
history: List<Message>,
messages: List<Message>,
context: VectorStore,
modelType: ModelType,
prompt: String,
minResponseTokens: Int,
promptConfiguration: PromptConfiguration,
): List<Message> {
val maxContextLength: Int = modelType.maxContextLength
val promptTokens: Int = modelType.encoding.countTokens(prompt)
val memoryTokens = tokensFromMessages(memories.map { it.content })
val remainingTokens: Int = maxContextLength - promptTokens - memoryTokens - minResponseTokens
val remainingTokens: Int = maxContextLength - promptConfiguration.minResponseTokens

val messagesTokens = tokensFromMessages(messages)

if (messagesTokens >= remainingTokens) {
throw AIError.PromptExceedsMaxRemainingTokenLength(messagesTokens, remainingTokens)
}

val remainingTokensForContexts = remainingTokens - messagesTokens

// TODO we should move this to PromptConfiguration
val historyPercent = 50
val contextPercent = 50
javipacheco marked this conversation as resolved.
Show resolved Hide resolved

val maxHistoryTokens = (remainingTokensForContexts * historyPercent) / 100

val historyMessagesWithTokens = history.map { Pair(it, tokensFromMessages(listOf(it))) }

val totalTokenWithMessages =
historyMessagesWithTokens.reversed().fold(Pair(0, emptyList<Message>())) { acc, pair ->
if (acc.first + pair.second > maxHistoryTokens) {
acc
} else {
Pair(acc.first + pair.second, acc.second + pair.first)
}
}
javipacheco marked this conversation as resolved.
Show resolved Hide resolved

val historyAllowed = totalTokenWithMessages.second.reversed()

val maxContextTokens = (remainingTokensForContexts * contextPercent) / 100

val ctxInfo =
context.similaritySearch(
messages.joinToString("\n") { it.content },
promptConfiguration.docsInContext,
)

val contextAllowed =
if (ctxInfo.isNotEmpty()) {
val ctx: String = ctxInfo.joinToString("\n")

return if (ctxInfo.isNotEmpty() && remainingTokens > minResponseTokens) {
val ctx: String = ctxInfo.joinToString("\n")
val ctxTruncated: String = modelType.encoding.truncateText(ctx, maxContextTokens)

if (promptTokens >= maxContextLength) {
throw AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, maxContextLength)
listOf(Message.assistantMessage { ctxTruncated })
} else {
emptyList()
}
// truncate the context if it's too long based on the max tokens calculated considering the
// existing prompt tokens
// alternatively we could summarize the context, but that's not implemented yet
val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens)

listOf(Message.assistantMessage { "Context: $ctxTruncated" }, Message.userMessage { prompt })
} else listOf(Message.userMessage { prompt })
return contextAllowed + historyAllowed + messages
}
}