Skip to content

Commit

Permalink
Assistants YAML config (#643)
Browse files Browse the repository at this point in the history
* Do not attempt to call functions without arguments

* Creates or updates an Assistant based on YAML config

* Use https://github.com/Him188/yamlkt to hand parse in kmp

* removed println

* Apply spotless formatting

* Update core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt

Co-authored-by: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com>

* Update core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt

Co-authored-by: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com>

* Apply spotless formatting

* fix fqn

* example

* YAML is able to load tools from the tool config

* use simple name for kmp support

* review comments

---------

Co-authored-by: raulraja <raulraja@users.noreply.github.com>
Co-authored-by: David Vega <7826728+realdavidvega@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 5, 2024
1 parent 4d8c0a4 commit 34ad9d2
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ kotlin {
dependencies {
api(libs.bundles.arrow)
api(libs.kotlinx.serialization.json)
api(libs.kotlinx.serialization.yaml)
api(libs.ktor.utils)
api(projects.xefTokenizer)
api(projects.xefOpenaiClient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@ import com.xebia.functional.openai.apis.AssistantsApi
import com.xebia.functional.openai.infrastructure.ApiClient
import com.xebia.functional.openai.models.AssistantObject
import com.xebia.functional.openai.models.CreateAssistantRequest
import com.xebia.functional.openai.models.FunctionObject
import com.xebia.functional.openai.models.ModifyAssistantRequest
import com.xebia.functional.openai.models.ext.assistant.AssistantTools
import com.xebia.functional.openai.models.ext.assistant.AssistantToolsCode
import com.xebia.functional.openai.models.ext.assistant.AssistantToolsFunction
import com.xebia.functional.openai.models.ext.assistant.AssistantToolsRetrieval
import com.xebia.functional.xef.llm.fromEnvironment
import io.ktor.util.logging.*
import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import net.mamoe.yamlkt.Yaml
import net.mamoe.yamlkt.literalContentOrNull
import net.mamoe.yamlkt.toYamlElement

class Assistant(
val assistantId: String,
Expand Down Expand Up @@ -97,5 +104,117 @@ class Assistant(
val response = assistantsApi.createAssistant(request)
return Assistant(response.body(), toolsConfig, assistantsApi, api)
}

suspend fun fromConfig(
request: String,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
assistantsApi: AssistantsApi = fromEnvironment(::AssistantsApi),
api: AssistantApi = fromEnvironment(::AssistantApi)
): Assistant {
val parsed = Yaml.Default.decodeYamlMapFromString(request)
val assistantRequest =
AssistantRequest(
assistantId = parsed["assistant_id"]?.literalContentOrNull,
model = parsed["model"]?.literalContentOrNull ?: error("model is required"),
name = parsed["name"]?.literalContentOrNull,
description = parsed["description"]?.literalContentOrNull,
instructions = parsed["instructions"]?.literalContentOrNull,
tools =
parsed["tools"]?.let { list ->
(list as List<*>).map { element ->
when (element) {
is Map<*, *> -> {
val tool =
element["type".toYamlElement()]?.toString() ?: error("type is required")
when (tool) {
"code_interpreter" -> AssistantTool.CodeInterpreter
"retrieval" -> AssistantTool.Retrieval
"function" -> {
val className =
element["name".toYamlElement()]?.toString()
?: error("simple `name` for `function` is required")
val foundConfig =
toolsConfig.firstOrNull { it.tool::class.simpleName == className }
if (foundConfig != null) {
val functionObject = foundConfig.functionObject
AssistantTool.Function(
functionObject.name,
functionObject.description ?: "",
functionObject.parameters?.let {
ApiClient.JSON_DEFAULT.encodeToString(JsonObject.serializer(), it)
} ?: ""
)
} else {
error("Tool $className not found in toolsConfig")
}
}
else -> error("unknown tool $tool")
}
}
else -> error("unknown tool $element")
}
}
},
fileIds =
parsed["file_ids"]?.let { (it as List<*>).map { it.toString() } } ?: emptyList(),
)
return if (assistantRequest.assistantId != null) {
val assistant =
Assistant(
assistantId = assistantRequest.assistantId,
toolsConfig = toolsConfig,
assistantsApi = assistantsApi,
api = api
)
// list all assistants and get their files
// list all the org files
assistantsApi.listAssistants()

assistant.modify(
ModifyAssistantRequest(
name = assistantRequest.name,
description = assistantRequest.description,
instructions = assistantRequest.instructions,
tools = assistantTools(assistantRequest),
fileIds = assistantRequest.fileIds,
metadata = null // assistantRequest.metadata
)
)
} else
Assistant(
request =
CreateAssistantRequest(
model = assistantRequest.model,
name = assistantRequest.name,
description = assistantRequest.description,
instructions = assistantRequest.instructions,
tools = assistantTools(assistantRequest),
fileIds = assistantRequest.fileIds,
metadata = null // assistantRequest.metadata
),
toolsConfig = toolsConfig,
assistantsApi = assistantsApi,
api = api
)
}

private fun assistantTools(assistantRequest: AssistantRequest) =
assistantRequest.tools?.map {
when (it) {
is AssistantTool.CodeInterpreter -> AssistantToolsCode()
is AssistantTool.Retrieval -> AssistantToolsRetrieval()
is AssistantTool.Function ->
AssistantToolsFunction(
function =
FunctionObject(
name = it.name,
parameters =
ApiClient.JSON_DEFAULT.parseToJsonElement(it.parameters) as? JsonObject
?: JsonObject(emptyMap()),
description = it.description
)
)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package com.xebia.functional.xef.llm.assistants

import kotlinx.serialization.Required
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class AssistantRequest(
@SerialName(value = "assistant_id") val assistantId: String? = null,
@SerialName(value = "model") @Required val model: String,

/* The name of the assistant. The maximum length is 256 characters. */
@SerialName(value = "name") val name: String? = null,

/* The description of the assistant. The maximum length is 512 characters. */
@SerialName(value = "description") val description: String? = null,

/* The system instructions that the assistant uses. The maximum length is 32768 characters. */
@SerialName(value = "instructions") val instructions: String? = null,

/* A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `retrieval`, or `function`. */
@SerialName(value = "tools") val tools: List<AssistantTool>? = arrayListOf(),

/* A list of [file](/docs/api-reference/files) IDs attached to this assistant. There can be a maximum of 20 files attached to the assistant. Files are ordered by their creation date in ascending order. */
@SerialName(value = "file_ids") val fileIds: List<String>? = arrayListOf(),

/* Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. */
@SerialName(value = "metadata") val metadata: Map<String, String>? = null
)

@Serializable
sealed class AssistantTool {
@Serializable @SerialName(value = "code_interpreter") object CodeInterpreter : AssistantTool()

@Serializable @SerialName(value = "retrieval") object Retrieval : AssistantTool()

@Serializable
@SerialName(value = "function")
data class Function(val name: String, val description: String, val parameters: String) :
AssistantTool()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.xebia.functional.xef.assistants

import com.xebia.functional.xef.llm.assistants.Assistant
import com.xebia.functional.xef.llm.assistants.AssistantThread
import com.xebia.functional.xef.llm.assistants.MessageWithFiles
import com.xebia.functional.xef.llm.assistants.Tool

suspend fun main() {
// val filesApi = fromEnvironment(::FilesApi)
// // should only be created once and then referenced by id
// val file =
// filesApi
// .createFile(
// UploadFile("test.txt") { append("Hello World!") },
// FilesApi.PurposeCreateFile.assistants
// )
// .body()
val fileId = "file-q77cZu6e6sC2TsYbUs8UX5Dj"
// remove assistant id to create a new one
// language=yaml
val yamlConfig =
"""
assistant_id: "asst_ImCKap37lLHBqV1awa0kzjZ3"
model: "gpt-4-1106-preview"
name: "My Custom Test Assistant"
description: "A versatile AI assistant capable of conversational and informational tasks."
instructions: "This assistant is designed to provide informative and engaging conversations, answer queries, and execute code when necessary."
tools:
- type: "code_interpreter"
- type: "retrieval"
- type: "function"
name: "SumTool"
file_ids:
- "$fileId"
metadata:
version: "1.0"
created_by: "OpenAI"
use_case: "Customer support"
language: "English"
additional_info: "This assistant is continuously updated with the latest information."
"""
.trimIndent()
val tools = listOf(Tool.toolOf(SumTool()))
val assistant = Assistant.fromConfig(request = yamlConfig, toolsConfig = tools)
val assistantInfo = assistant.get()
println("assistant: $assistantInfo")
val thread = AssistantThread()
thread.createMessage(MessageWithFiles("What does this file say?", listOf(fileId)))
val stream = thread.run(assistant)
stream.collect {
when (it) {
is AssistantThread.RunDelta.ReceivedMessage ->
println("received message: ${it.message.content.firstOrNull()?.text}")
is AssistantThread.RunDelta.Run -> println("run: ${it.message.status.value}")
is AssistantThread.RunDelta.Step -> println("step: ${it.runStep.type.value}")
}
}
}
3 changes: 3 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ opentelemetry-alpha="1.30.1-alpha"
progressbar = "0.10.0"
jmf = "2.1.1e"
mp3-wav-converter = "1.0.4"
yamlkt="0.13.0"


[libraries]
arrow-core = { module = "io.arrow-kt:arrow-core", version.ref = "arrow" }
Expand All @@ -55,6 +57,7 @@ suspendApp-core = { module = "io.arrow-kt:suspendapp", version.ref = "suspendApp
suspendApp-ktor = { module = "io.arrow-kt:suspendapp-ktor", version.ref = "suspendApp" }
kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-json" }
kotlinx-serialization-hocon = { module = "org.jetbrains.kotlinx:kotlinx-serialization-hocon", version.ref = "kotlinx-json" }
kotlinx-serialization-yaml = { module = "net.mamoe.yamlkt:yamlkt", version.ref = "yamlkt" }
kotlinx-coroutines = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref="kotlinx-coroutines" }
kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactive", version.ref="kotlinx-coroutines" }
kotlinx-datetime = { module = "org.jetbrains.kotlinx:kotlinx-datetime", version.ref = "kotlinx-datetime" }
Expand Down

0 comments on commit 34ad9d2

Please sign in to comment.