Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory limit by tokens #316

Merged
merged 8 commits into from
Aug 14, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class PromptConfiguration(
val temperature: Double = 0.4,
val numberOfPredictions: Int = 1,
val docsInContext: Int = 5,
val memoryLimit: Int = 5,
val minResponseTokens: Int = 500,
val messagePolicy: MessagePolicy = MessagePolicy(),
) {
Expand All @@ -23,7 +22,6 @@ class PromptConfiguration(
private var numberOfPredictions: Int = 1
private var docsInContext: Int = 20
private var minResponseTokens: Int = 500
private var memoryLimit: Int = 5
private var messagePolicy: MessagePolicy = MessagePolicy()

fun maxDeserializationAttempts(maxDeserializationAttempts: Int) = apply {
Expand All @@ -44,8 +42,6 @@ class PromptConfiguration(
this.minResponseTokens = minResponseTokens
}

fun memoryLimit(memoryLimit: Int) = apply { this.memoryLimit = memoryLimit }

fun messagePolicy(messagePolicy: MessagePolicy) = apply { this.messagePolicy = messagePolicy }

fun build() =
Expand All @@ -55,7 +51,6 @@ class PromptConfiguration(
temperature = temperature,
numberOfPredictions = numberOfPredictions,
docsInContext = docsInContext,
memoryLimit = memoryLimit,
minResponseTokens = minResponseTokens,
messagePolicy = messagePolicy,
)
Expand All @@ -77,5 +72,6 @@ class PromptConfiguration(
*/
class MessagePolicy(
val historyPercent: Int = 50,
val historyPaddingTokens: Int = 100,
val contextPercent: Int = 50,
)
119 changes: 62 additions & 57 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.llm.models.functions.CFunction
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.vectorstores.Memory
import com.xebia.functional.xef.vectorstores.VectorStore
import io.ktor.util.date.*
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
Expand Down Expand Up @@ -41,16 +40,8 @@ interface Chat : LLM {
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): Flow<String> = flow {
val memories: List<Memory> = memories(scope, promptConfiguration)

val messagesForRequest =
fitMessagesByTokens(
messagesFromMemory(memories),
prompt.toMessages(),
scope.store,
modelType,
promptConfiguration
)
fitMessagesByTokens(prompt.toMessages(), scope, modelType, promptConfiguration)

val request =
ChatCompletionRequest(
Expand Down Expand Up @@ -110,15 +101,7 @@ interface Chat : LLM {
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): List<String> {

val memories: List<Memory> = memories(scope, promptConfiguration)
val messagesForRequest =
fitMessagesByTokens(
messagesFromMemory(memories),
messages,
scope.store,
modelType,
promptConfiguration
)
val messagesForRequest = fitMessagesByTokens(messages, scope, modelType, promptConfiguration)

fun chatRequest(): ChatCompletionRequest =
ChatCompletionRequest(
Expand Down Expand Up @@ -181,14 +164,17 @@ interface Chat : LLM {
Memory(
conversationId = scope.conversationId,
content = lastRequestMessage,
timestamp = getTimeMillis()
timestamp = getTimeMillis(),
approxTokens = tokensFromMessages(listOf(lastRequestMessage))
)
val responseMessage =
Message(role = Role.ASSISTANT, content = buffer.toString(), name = Role.ASSISTANT.name)
val responseMemory =
Memory(
conversationId = scope.conversationId,
content =
Message(role = Role.ASSISTANT, content = buffer.toString(), name = Role.ASSISTANT.name),
content = responseMessage,
timestamp = getTimeMillis(),
approxTokens = tokensFromMessages(listOf(responseMessage))
)
scope.store.addMemories(listOf(requestMemory, responseMemory))
}
Expand All @@ -206,19 +192,22 @@ interface Chat : LLM {
Memory(
conversationId = scope.conversationId,
content = requestUserMessage,
timestamp = getTimeMillis()
timestamp = getTimeMillis(),
approxTokens = tokensFromMessages(listOf(requestUserMessage))
)
val firstChoiceMessage =
Message(
role = role,
content = firstChoice.message?.content
?: firstChoice.message?.functionCall?.arguments ?: "",
name = role.name
)
val firstChoiceMemory =
Memory(
conversationId = scope.conversationId,
content =
Message(
role = role,
content = firstChoice.message?.content
?: firstChoice.message?.functionCall?.arguments ?: "",
name = role.name
), //
timestamp = getTimeMillis()
content = firstChoiceMessage,
timestamp = getTimeMillis(),
approxTokens = tokensFromMessages(listOf(firstChoiceMessage))
)
scope.store.addMemories(listOf(requestMemory, firstChoiceMemory))
}
Expand All @@ -236,14 +225,17 @@ interface Chat : LLM {
Memory(
conversationId = scope.conversationId,
content = requestUserMessage,
timestamp = getTimeMillis()
timestamp = getTimeMillis(),
approxTokens = tokensFromMessages(listOf(requestUserMessage))
)
val firstChoiceMessage =
Message(role = role, content = firstChoice.message?.content ?: "", name = role.name)
val firstChoiceMemory =
Memory(
conversationId = scope.conversationId,
content =
Message(role = role, content = firstChoice.message?.content ?: "", name = role.name),
timestamp = getTimeMillis()
content = firstChoiceMessage,
timestamp = getTimeMillis(),
approxTokens = tokensFromMessages(listOf(firstChoiceMessage))
)
scope.store.addMemories(listOf(requestMemory, firstChoiceMemory))
}
Expand All @@ -252,23 +244,21 @@ interface Chat : LLM {
private fun messagesFromMemory(memories: List<Memory>): List<Message> =
memories.map { it.content }

private suspend fun memories(
scope: Conversation,
promptConfiguration: PromptConfiguration
): List<Memory> =
if (scope.conversationId != null) {
scope.store.memories(scope.conversationId, promptConfiguration.memoryLimit)
private suspend fun Conversation.memories(limitTokens: Int): List<Memory> =
if (conversationId != null) {
store.memories(conversationId, limitTokens)
} else {
emptyList()
}

private suspend fun fitMessagesByTokens(
history: List<Message>,
messages: List<Message>,
context: VectorStore,
scope: Conversation,
modelType: ModelType,
promptConfiguration: PromptConfiguration,
): List<Message> {

// calculate tokens for history and context
val maxContextLength: Int = modelType.maxContextLength
val remainingTokens: Int = maxContextLength - promptConfiguration.minResponseTokens

Expand All @@ -284,24 +274,39 @@ interface Chat : LLM {
val contextPercent = promptConfiguration.messagePolicy.contextPercent

val maxHistoryTokens = (remainingTokensForContexts * historyPercent) / 100
val maxContextTokens = (remainingTokensForContexts * contextPercent) / 100

val historyMessagesWithTokens = history.map { Pair(it, tokensFromMessages(listOf(it))) }

val totalTokenWithMessages =
historyMessagesWithTokens.foldRight(Pair(0, emptyList<Message>())) { pair, acc ->
if (acc.first + pair.second > maxHistoryTokens) {
acc
} else {
Pair(acc.first + pair.second, acc.second + pair.first)
// calculate messages for history based on tokens

val memories: List<Memory> =
scope.memories(maxHistoryTokens + promptConfiguration.messagePolicy.historyPaddingTokens)

val historyAllowed =
if (memories.isNotEmpty()) {
val history = messagesFromMemory(memories)

// since we have the approximate tokens in memory, we need to fit the messages back to the
// number of tokens if necessary
val historyTokens = tokensFromMessages(history)
if (historyTokens <= maxHistoryTokens) history
else {
val historyMessagesWithTokens = history.map { Pair(it, tokensFromMessages(listOf(it))) }

val totalTokenWithMessages =
historyMessagesWithTokens.foldRight(Pair(0, emptyList<Message>())) { pair, acc ->
if (acc.first + pair.second > maxHistoryTokens) {
acc
} else {
Pair(acc.first + pair.second, acc.second + pair.first)
}
}
totalTokenWithMessages.second.reversed()
}
}

val historyAllowed = totalTokenWithMessages.second.reversed()

val maxContextTokens = (remainingTokensForContexts * contextPercent) / 100
} else emptyList()

// calculate messages for context based on tokens
val ctxInfo =
context.similaritySearch(
scope.store.similaritySearch(
messages.joinToString("\n") { it.content },
promptConfiguration.docsInContext,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,22 @@ import com.xebia.functional.xef.embeddings.Embedding
class CombinedVectorStore(private val top: VectorStore, private val bottom: VectorStore) :
VectorStore by top {

override suspend fun memories(conversationId: ConversationId, limit: Int): List<Memory> {
val bottomResults = bottom.memories(conversationId, limit)
val topResults = top.memories(conversationId, limit)
return (topResults + bottomResults).sortedBy { it.timestamp }.takeLast(limit)
override suspend fun memories(conversationId: ConversationId, limitTokens: Int): List<Memory> {
val bottomResults = bottom.memories(conversationId, limitTokens)
val topResults = top.memories(conversationId, limitTokens)

return (topResults + bottomResults)
.sortedByDescending { it.timestamp }
.fold(Pair(0, emptyList<Memory>())) { (accTokens, list), memory ->
val totalTokens = accTokens + memory.approxTokens
if (totalTokens <= limitTokens) {
Pair(totalTokens, list + memory)
} else {
Pair(accTokens, list)
}
}
.second
.reversed()
}

override suspend fun similaritySearch(query: String, limit: Int): List<String> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,20 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi
}
}

override suspend fun memories(conversationId: ConversationId, limit: Int): List<Memory> {
override suspend fun memories(conversationId: ConversationId, limitTokens: Int): List<Memory> {
val memories = state.get().orderedMemories[conversationId]
return memories?.takeLast(limit).orEmpty().sortedBy { it.timestamp }
return memories
?.foldRight(Pair(0, emptyList<Memory>())) { memory, (accTokens, list) ->
javipacheco marked this conversation as resolved.
Show resolved Hide resolved
val totalTokens = accTokens + memory.approxTokens
if (totalTokens <= limitTokens) {
Pair(totalTokens, list + memory)
} else {
Pair(accTokens, list)
}
}
?.second
.orEmpty()
.sortedBy { it.timestamp }
}

override suspend fun addTexts(texts: List<String>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,9 @@ import com.xebia.functional.xef.llm.models.chat.Message
* @property conversationId uniquely identifies the conversation in which the message took place.
* @property timestamp in milliseconds.
*/
data class Memory(val conversationId: ConversationId, val content: Message, val timestamp: Long)
data class Memory(
val conversationId: ConversationId,
val content: Message,
val timestamp: Long,
val approxTokens: Int
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ interface VectorStore {

suspend fun addMemories(memories: List<Memory>)

suspend fun memories(conversationId: ConversationId, limit: Int): List<Memory>
suspend fun memories(conversationId: ConversationId, limitTokens: Int): List<Memory>

/**
* Add texts to the vector store after running them through the embeddings
Expand Down Expand Up @@ -44,8 +44,10 @@ interface VectorStore {

override suspend fun addMemories(memories: List<Memory>) {}

override suspend fun memories(conversationId: ConversationId, limit: Int): List<Memory> =
emptyList()
override suspend fun memories(
conversationId: ConversationId,
limitTokens: Int
): List<Memory> = emptyList()

override suspend fun addTexts(texts: List<String>) {}

Expand Down