Skip to content

Commit

Permalink
Update Open AI client to assistants v2 (#742)
Browse files Browse the repository at this point in the history
* Progress updating to OpenAI v2. Fails to generate `CreateAssistantRequestToolResourcesFileSearch`

* Support assistants v2 (only function calls tested which is our most immediate need) (file search and code interpreter untested)

* fix path in streaming endpoints
  • Loading branch information
raulraja committed May 20, 2024
1 parent 1aef223 commit 9dcf81e
Show file tree
Hide file tree
Showing 18 changed files with 8,637 additions and 4,098 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ class Assistant(
@Serializable data class ToolOutput(val schema: JsonObject, val result: JsonElement)

suspend operator fun invoke(
model: String,
model: CreateAssistantRequestModel,
name: String? = null,
description: String? = null,
instructions: String? = null,
tools: List<AssistantObjectToolsInner> = arrayListOf(),
fileIds: List<String> = arrayListOf(),
toolResources: CreateAssistantRequestToolResources? = null,
metadata: JsonObject? = null,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
config: Config = Config(),
Expand All @@ -92,7 +92,7 @@ class Assistant(
description = description,
instructions = instructions,
tools = tools,
fileIds = fileIds,
toolResources = toolResources,
metadata = metadata
),
toolsConfig,
Expand All @@ -117,10 +117,23 @@ class Assistant(
assistantsApi: Assistants = OpenAI(config, logRequests = true).assistants,
): Assistant {
val parsed = Yaml.Default.decodeYamlMapFromString(request)
val fileIds = parsed["file_ids"]?.let { (it as List<*>).map { it.toString() } }
val vectorStoreIds = parsed["vector_store_ids"]?.let { (it as List<*>).map { it.toString() } }
val toolResourcesRequest =
CreateAssistantRequestToolResources(
codeInterpreter =
fileIds?.let { CreateAssistantRequestToolResourcesCodeInterpreter(fileIds = it) },
fileSearch =
vectorStoreIds?.let {
CreateAssistantRequestToolResourcesFileSearch(vectorStoreIds = it)
}
)
val assistantRequest =
AssistantRequest(
assistantId = parsed["assistant_id"]?.literalContentOrNull,
model = parsed["model"]?.literalContentOrNull ?: error("model is required"),
model =
parsed["model"]?.literalContentOrNull?.let { CreateAssistantRequestModel.Custom(it) }
?: error("model is required"),
name = parsed["name"]?.literalContentOrNull,
description = parsed["description"]?.literalContentOrNull,
instructions =
Expand Down Expand Up @@ -164,8 +177,7 @@ class Assistant(
}
}
},
fileIds =
parsed["file_ids"]?.let { (it as List<*>).map { it.toString() } } ?: emptyList(),
toolResources = toolResourcesRequest,
)
return if (assistantRequest.assistantId != null) {
val assistant =
Expand All @@ -178,12 +190,24 @@ class Assistant(

assistant.modify(
ModifyAssistantRequest(
model = assistantRequest.model,
model = assistantRequest.model.value,
name = assistantRequest.name,
description = assistantRequest.description,
instructions = assistantRequest.instructions,
tools = assistantTools(assistantRequest),
fileIds = assistantRequest.fileIds,
toolResources =
assistantRequest.toolResources?.let {
ModifyAssistantRequestToolResources(
codeInterpreter =
it.codeInterpreter?.let {
ModifyAssistantRequestToolResourcesCodeInterpreter(fileIds = it.fileIds)
},
fileSearch =
ModifyAssistantRequestToolResourcesFileSearch(
vectorStoreIds = it.fileSearch?.vectorStoreIds
)
)
},
metadata = null // assistantRequest.metadata
)
)
Expand All @@ -196,8 +220,11 @@ class Assistant(
description = assistantRequest.description,
instructions = assistantRequest.instructions,
tools = assistantTools(assistantRequest),
fileIds = assistantRequest.fileIds,
metadata = null // assistantRequest.metadata
toolResources = assistantRequest.toolResources,
metadata =
assistantRequest.metadata
?.map { (k, v) -> k to JsonPrimitive(v) }
?.let { JsonObject(it.toMap()) }
),
toolsConfig = toolsConfig,
config = config,
Expand All @@ -215,8 +242,8 @@ class Assistant(
AssistantToolsCode(type = AssistantToolsCode.Type.code_interpreter)
)
is AssistantTool.Retrieval ->
AssistantObjectToolsInner.CaseAssistantToolsRetrieval(
AssistantToolsRetrieval(type = AssistantToolsRetrieval.Type.retrieval)
AssistantObjectToolsInner.CaseAssistantToolsFileSearch(
AssistantToolsFileSearch(type = AssistantToolsFileSearch.Type.file_search)
)
is AssistantTool.Function ->
AssistantObjectToolsInner.CaseAssistantToolsFunction(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package com.xebia.functional.xef.llm.assistants

import com.xebia.functional.openai.generated.model.CreateAssistantRequestModel
import com.xebia.functional.openai.generated.model.CreateAssistantRequestToolResources
import kotlinx.serialization.Required
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class AssistantRequest(
@SerialName(value = "assistant_id") val assistantId: String? = null,
@SerialName(value = "model") @Required val model: String,
@SerialName(value = "model") @Required val model: CreateAssistantRequestModel,

/* The name of the assistant. The maximum length is 256 characters. */
@SerialName(value = "name") val name: String? = null,
Expand All @@ -20,9 +22,8 @@ data class AssistantRequest(

/* A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `retrieval`, or `function`. */
@SerialName(value = "tools") val tools: List<AssistantTool>? = arrayListOf(),

/* A list of [file](/docs/api-reference/files) IDs attached to this assistant. There can be a maximum of 20 files attached to the assistant. Files are ordered by their creation date in ascending order. */
@SerialName(value = "file_ids") val fileIds: List<String>? = arrayListOf(),
@SerialName(value = "tool_resources")
val toolResources: CreateAssistantRequestToolResources? = null,

/* Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. */
@SerialName(value = "metadata") val metadata: Map<String, String>? = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,18 @@ class AssistantThread(
request =
CreateMessageRequest(
role = CreateMessageRequest.Role.user,
content = message.content,
fileIds = message.fileIds
content = CreateMessageRequestContent.CaseString(message.content),
attachments = message.fileIds.map { MessageObjectAttachmentsInner(fileId = it) }
),
)

suspend fun createMessage(content: String): MessageObject =
createMessage(CreateMessageRequest(role = CreateMessageRequest.Role.user, content = content))
createMessage(
CreateMessageRequest(
role = CreateMessageRequest.Role.user,
content = CreateMessageRequestContent.CaseString(content)
)
)

suspend fun createMessage(request: CreateMessageRequest): MessageObject =
api.createMessage(threadId, request, configure = ::defaultConfig)
Expand Down Expand Up @@ -144,6 +149,7 @@ class AssistantThread(
RunObject.Status.failed -> RunDelta.RunFailed(run)
RunObject.Status.completed -> RunDelta.RunCompleted(run)
RunObject.Status.expired -> RunDelta.RunExpired(run)
RunObject.Status.incomplete -> RunDelta.RunIncomplete(run)
}
flowCollector.emit(finalEvent)
run
Expand Down Expand Up @@ -206,9 +212,9 @@ class AssistantThread(

companion object {

/** Support for OpenAI-Beta: assistants=v1 */
/** Support for OpenAI-Beta: assistants=v2 */
fun defaultConfig(httpRequestBuilder: HttpRequestBuilder): Unit {
httpRequestBuilder.header("OpenAI-Beta", "assistants=v1")
httpRequestBuilder.header("OpenAI-Beta", "assistants=v2")
}

@JvmName("createWithMessagesAndFiles")
Expand All @@ -225,14 +231,15 @@ class AssistantThread(
.createThread(
createThreadRequest =
CreateThreadRequest(
messages.map {
CreateMessageRequest(
role = CreateMessageRequest.Role.user,
content = it.content,
fileIds = it.fileIds
)
},
metadata
messages =
messages.map {
CreateMessageRequest(
role = CreateMessageRequest.Role.user,
content = CreateMessageRequestContent.CaseString(it.content),
attachments = it.fileIds.map { MessageObjectAttachmentsInner(fileId = it) }
)
},
metadata = metadata
),
configure = ::defaultConfig
)
Expand All @@ -255,10 +262,14 @@ class AssistantThread(
.createThread(
createThreadRequest =
CreateThreadRequest(
messages.map {
CreateMessageRequest(role = CreateMessageRequest.Role.user, content = it)
},
metadata
messages =
messages.map {
CreateMessageRequest(
role = CreateMessageRequest.Role.user,
content = CreateMessageRequestContent.CaseString(it)
)
},
metadata = metadata
),
configure = ::defaultConfig
)
Expand All @@ -277,7 +288,12 @@ class AssistantThread(
api: Assistants = OpenAI(config).assistants
): AssistantThread =
AssistantThread(
api.createThread(CreateThreadRequest(messages, metadata), configure = ::defaultConfig).id,
api
.createThread(
CreateThreadRequest(messages = messages, metadata = metadata),
configure = ::defaultConfig
)
.id,
metric,
config,
api
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ sealed interface RunDelta {
/** [RunDeltaEvent.thread_run_expired] */
@JvmInline @Serializable value class RunExpired(val run: RunObject) : RunDelta

/** [RunDeltaEvent.thread_run_incomplete] */
@JvmInline @Serializable value class RunIncomplete(val run: RunObject) : RunDelta

/** [RunDeltaEvent.thread_run_step_created] */
@JvmInline @Serializable value class RunStepCreated(val runStep: RunStepObject) : RunDelta

Expand Down Expand Up @@ -246,8 +249,8 @@ sealed interface RunDelta {
val function = it.value.function
"Function: ${function.name}(${function.arguments})"
}
is RunStepDetailsToolCallsObjectToolCallsInner.CaseRunStepDetailsToolCallsRetrievalObject -> {
val retrieval = it.value.retrieval
is RunStepDetailsToolCallsObjectToolCallsInner.CaseRunStepDetailsToolCallsFileSearchObject -> {
val retrieval = it.value.fileSearch
"Retrieval: $retrieval"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ data class ModelsPricing(
// The pricing for the models was updated the May 2st, 2024
// Be sure to update the pricing for each model

// GPT-4o Input token price: $5.00, Output token price: $15.00 per 1M Tokens.
val gpt4o =
ModelsPricing(
modelName = "gpt-4o",
currency = "USD",
input = ModelsPricingItem(5.0, oneMillion),
output = ModelsPricingItem(15.0, oneMillion)
)

val gpt4Turbo =
ModelsPricing(
modelName = "gpt-4-turbo",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.xebia.functional.xef.assistants

import com.xebia.functional.openai.generated.model.CreateRunRequest
import com.xebia.functional.xef.Config
import com.xebia.functional.xef.OpenAI
import com.xebia.functional.xef.llm.assistants.Assistant
import com.xebia.functional.xef.llm.assistants.AssistantThread
Expand All @@ -14,10 +15,12 @@ suspend fun main() {

val assistant =
Assistant(
assistantId = "asst_UxczzpJkysC0l424ood87DAk",
assistantId = "asst_BwQvmWIbGUMDvCuXOtAFH8B6",
toolsConfig = listOf(Tool.toolOf(SumTool()))
)
val thread = AssistantThread(api = OpenAI().assistants, metric = metric)
val config = Config(org = null)
val api = OpenAI(config = config, logRequests = true).assistants
val thread = AssistantThread(api = api, metric = metric)
println("Welcome to the Math tutor, ask me anything about math:")
val userInput = "What is 1+1, explain all the steps and tools you used to solve it."
thread.createMessage(userInput)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ suspend fun main() {

val assistant =
Assistant(
assistantId = "asst_UxczzpJkysC0l424ood87DAk",
assistantId = "asst_BwQvmWIbGUMDvCuXOtAFH8B6",
toolsConfig = listOf(Tool.toolOf(SumTool()))
)
val thread = AssistantThread(api = OpenAI(logRequests = false).assistants, metric = metric)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ suspend fun main() {
// FilesApi.PurposeCreateFile.assistants
// )
// .body()
val fileId = "file-q77cZu6e6sC2TsYbUs8UX5Dj"
// val fileId = "file-q77cZu6e6sC2TsYbUs8UX5Dj"
// remove assistant id to create a new one
// assistant_id: "asst_ImCKap37lLHBqV1awa0kzjZ3"
// language=yaml
val yamlConfig =
"""
assistant_id: "asst_ImCKap37lLHBqV1awa0kzjZ3"
model: "gpt-4-1106-preview"
model: "gpt-4o"
name: "My Custom Test Assistant"
description: "A versatile AI assistant capable of conversational and informational tasks."
instructions: "This assistant is designed to provide informative and engaging conversations, answer queries, and execute code when necessary."
Expand All @@ -27,8 +28,6 @@ suspend fun main() {
- type: "retrieval"
- type: "function"
name: "SumTool"
file_ids:
- "$fileId"
metadata:
version: "1.0"
created_by: "OpenAI"
Expand All @@ -42,6 +41,6 @@ suspend fun main() {
val assistantInfo = assistant.get()
println("assistant: $assistantInfo")
val thread = AssistantThread()
thread.createMessage(MessageWithFiles("What does this file say?", listOf(fileId)))
thread.createMessage(MessageWithFiles("What is 1 + 1? Use SumTool", emptyList()))
thread.run(assistant).collect(RunDelta::printEvent)
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package com.xebia.functional.xef.assistants

import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.llm.assistants.Assistant
import com.xebia.functional.xef.llm.assistants.AssistantThread
import com.xebia.functional.xef.llm.assistants.RunDelta
import com.xebia.functional.xef.llm.assistants.Tool
import kotlinx.serialization.Serializable

Expand All @@ -30,7 +27,3 @@ suspend fun main() {
val toolConfig = Tool.toolOf(SumToolWithDescription()).functionObject
println(toolConfig.parameters)
}

private suspend fun runAssistantAndDisplayResults(thread: AssistantThread, assistant: Assistant) {
thread.run(assistant).collect(RunDelta::printEvent)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ sealed class Response {
}

suspend fun main() {
val response = AI<Response>("Capital of France?")
val response = AI<Response>("What is the capital of France?")
println(response) // City(name=Paris)
}

0 comments on commit 9dcf81e

Please sign in to comment.