diff --git a/docs/custom-llm-server.md b/docs/custom-llm-server.md index bd278a3aa6..aa7198a5e1 100644 --- a/docs/custom-llm-server.md +++ b/docs/custom-llm-server.md @@ -83,9 +83,7 @@ And custom the origin keys for `role`, `messsage` { "customHeaders": { "my header": "my value" }, "customFields": {"user": "userid", "date": "2012"}, - "messageKeys": [ - {"role": "role", "content": "message"} - ] + "messageKeys": {"role": "role", "content": "message"} } ``` @@ -94,6 +92,6 @@ and the request body will be: ```json { "user": "userid", - "messages": [{"role": "user", "content": "..."}] + "messages": [{"role": "user", "message": "..."}] } ``` diff --git a/src/main/kotlin/cc/unitmesh/devti/llms/azure/ResponseBodyCallback.kt b/src/main/kotlin/cc/unitmesh/devti/llms/azure/ResponseBodyCallback.kt index 952391d60f..05e4e3776a 100644 --- a/src/main/kotlin/cc/unitmesh/devti/llms/azure/ResponseBodyCallback.kt +++ b/src/main/kotlin/cc/unitmesh/devti/llms/azure/ResponseBodyCallback.kt @@ -35,6 +35,9 @@ import java.io.InputStreamReader import java.nio.charset.StandardCharsets class AutoDevHttpException(error: String, val statusCode: Int) : RuntimeException(error) { + override fun toString(): String { + return "AutoDevHttpException(statusCode=$statusCode, message=$message)" + } } /** @@ -80,6 +83,11 @@ class ResponseBodyCallback(private val emitter: FlowableEmitter, private va null } + line!!.startsWith("{") -> { + logger().warn("msg starts with { $line") + emitter.onNext(SSE(line!!)) + null + } else -> { throw SSEFormatException("Invalid sse format! $line") } diff --git a/src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt b/src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt index 23f216806e..1c95462a3e 100644 --- a/src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt +++ b/src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt @@ -28,7 +28,6 @@ import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.RequestBody import org.jetbrains.annotations.VisibleForTesting -import org.jetbrains.kotlin.idea.gradleTooling.get import java.time.Duration @Serializable @@ -42,9 +41,7 @@ class CustomLLMProvider(val project: Project) : LLMProvider { private val autoDevSettingsState = AutoDevSettingsState.getInstance() private val url get() = autoDevSettingsState.customEngineServer private val key get() = autoDevSettingsState.customEngineToken - - private val requestHeaderFormat: String get() = autoDevSettingsState.customEngineRequestHeaderFormat - private val requestBodyFormat: String get() = autoDevSettingsState.customEngineRequestBodyFormat + private val requestFormat: String get() = autoDevSettingsState.customEngineRequestFormat private val responseFormat get() = autoDevSettingsState.customEngineResponseFormat private val customPromptConfig: CustomPromptConfig get() { @@ -78,18 +75,24 @@ class CustomLLMProvider(val project: Project) : LLMProvider { messages += Message("user", promptText) val customRequest = CustomRequest(messages) - val requestContent = Json.encodeToString(customRequest) + val requestContent = customRequest.updateCustomFormat(requestFormat) val body = RequestBody.create("application/json; charset=utf-8".toMediaTypeOrNull(), requestContent) - logger.info("Requesting from $body") val builder = Request.Builder() if (key.isNotEmpty()) { builder.addHeader("Authorization", "Bearer $key") + builder.addHeader("Content-Type", "application/json") } + builder.appendCustomHeaders(requestFormat) - client = client.newBuilder().readTimeout(timeout).build() - val request = builder.url(url).post(body).build() + client = client.newBuilder() + .readTimeout(timeout) + .build() + val request = builder + .url(url) + .post(body) + .build() val call = client.newCall(request) val emitDone = false @@ -100,7 +103,6 @@ class CustomLLMProvider(val project: Project) : LLMProvider { }, BackpressureStrategy.BUFFER) try { - logger.info("Starting to stream:") return callbackFlow { withContext(Dispatchers.IO) { sseFlowable @@ -139,7 +141,7 @@ class CustomLLMProvider(val project: Project) : LLMProvider { val body = RequestBody.create("application/json; charset=utf-8".toMediaTypeOrNull(), requestContent) - logger.info("Requesting from $body") + logger.info("Requesting form: $requestContent ${body.toString()}") val builder = Request.Builder() if (key.isNotEmpty()) { builder.addHeader("Authorization", "Bearer $key") @@ -174,7 +176,8 @@ fun Request.Builder.appendCustomHeaders(customRequestHeader: String): Request.Bu } } }.onFailure { - logger().error("Failed to parse custom request header", it) + // should I warn user? + println("Failed to parse custom request header ${it.message}") } } @@ -194,10 +197,9 @@ fun JsonObject.updateCustomBody(customRequest: String): JsonObject { } - // TODO clean code with magic literals var roleKey = "role" - var contentKey = "message" + var contentKey = "message" customRequestJson.jsonObject["messageKeys"]?.let { roleKey = it.jsonObject["role"]?.jsonPrimitive?.content ?: "role" contentKey = it.jsonObject["content"]?.jsonPrimitive?.content ?: "message" @@ -222,3 +224,9 @@ fun JsonObject.updateCustomBody(customRequest: String): JsonObject { this } } + +fun CustomRequest.updateCustomFormat(format: String): String { + val requestContentOri = Json.encodeToString(this) + return Json.parseToJsonElement(requestContentOri) + .jsonObject.updateCustomBody(format).toString() +} diff --git a/src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsState.kt b/src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsState.kt index 2bbe9a5deb..cf2e45ea58 100644 --- a/src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsState.kt +++ b/src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsState.kt @@ -30,8 +30,17 @@ class AutoDevSettingsState : PersistentStateComponent { * should be a json path */ var customEngineResponseFormat = "" - var customEngineRequestHeaderFormat = "" - var customEngineRequestBodyFormat = "" + /** + * should be a json + * { + * 'customHeaders': { 'headerName': 'headerValue', 'headerName2': 'headerValue2' ... }, + * 'customFields' : { 'bodyFieldName': 'bodyFieldValue', 'bodyFieldName2': 'bodyFieldValue2' ... } + * 'messageKey': {'role': 'roleKeyName', 'content': 'contentKeyName'} + * } + * + * @see docs/custom-llm-server.md + */ + var customEngineRequestFormat = "" var language = DEFAULT_HUMAN_LANGUAGE diff --git a/src/main/kotlin/cc/unitmesh/devti/settings/LLMSettingComponent.kt b/src/main/kotlin/cc/unitmesh/devti/settings/LLMSettingComponent.kt index b5d391b538..0177724ea6 100644 --- a/src/main/kotlin/cc/unitmesh/devti/settings/LLMSettingComponent.kt +++ b/src/main/kotlin/cc/unitmesh/devti/settings/LLMSettingComponent.kt @@ -36,8 +36,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) { private val xingHuoApiKeyParam by LLMParam.creating { Password(settings.xingHuoApiKey) } private val xingHuoApiSecretParam by LLMParam.creating { Password(settings.xingHuoApiSecrect) } private val customEngineResponseFormatParam by LLMParam.creating { Editable(settings.customEngineResponseFormat) } - private val customEngineRequestBodyFormatParam by LLMParam.creating { Editable(settings.customEngineRequestBodyFormat) } - private val customEngineRequestHeaderFormatParam by LLMParam.creating { Editable(settings.customEngineRequestHeaderFormat) } + private val customEngineRequestBodyFormatParam by LLMParam.creating { Editable(settings.customEngineRequestFormat) } val project = ProjectManager.getInstance().openProjects.firstOrNull() @@ -80,7 +79,6 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) { customEngineTokenParam, customEngineResponseFormatParam, customEngineRequestBodyFormatParam, - customEngineRequestHeaderFormatParam, ), AIEngines.XingHuo to listOf( xingHuoAppIDParam, @@ -191,8 +189,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) { aiEngineParam.value = aiEngine customEnginePrompt.text = customPrompts customEngineResponseFormatParam.value = customEngineResponseFormat - customEngineRequestBodyFormatParam.value = customEngineRequestBodyFormat - customEngineRequestHeaderFormatParam.value = customEngineRequestHeaderFormat + customEngineRequestBodyFormatParam.value = customEngineRequestFormat delaySecondsParam.value = delaySeconds } } @@ -216,8 +213,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) { customPrompts = customEnginePrompt.text openAiModel = openAIModelsParam.value customEngineResponseFormat = customEngineResponseFormatParam.value - customEngineRequestBodyFormat = customEngineRequestBodyFormatParam.value - customEngineRequestHeaderFormat = customEngineRequestHeaderFormatParam.value + customEngineRequestFormat = customEngineRequestBodyFormatParam.value delaySeconds = delaySecondsParam.value } } @@ -240,8 +236,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) { settings.openAiModel != openAIModelsParam.value || settings.customOpenAiHost != customOpenAIHostParam.value || settings.customEngineResponseFormat != customEngineResponseFormatParam.value || - settings.customEngineRequestBodyFormat != customEngineRequestBodyFormatParam.value || - settings.customEngineRequestHeaderFormat != customEngineRequestHeaderFormatParam.value || + settings.customEngineRequestFormat != customEngineRequestBodyFormatParam.value || settings.delaySeconds != delaySecondsParam.value } diff --git a/src/test/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProviderTest.kt b/src/test/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProviderTest.kt index 41b9fbbc89..b0e5eff40e 100644 --- a/src/test/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProviderTest.kt +++ b/src/test/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProviderTest.kt @@ -45,7 +45,7 @@ class CustomLLMProviderTest { val customRequest = """ { "messageKeys": - {"content": "message"} + {"content": "content"} } """.trimIndent() val request = CustomRequest(listOf( @@ -58,8 +58,28 @@ class CustomLLMProviderTest { val messageObj = newObj.jsonObject["messages"]!!.jsonArray assertEquals(1, messageObj.size) - assertEquals("this is message", messageObj[0].jsonObject["message"]!!.jsonPrimitive.content) + assertEquals("this is message", messageObj[0].jsonObject["content"]!!.jsonPrimitive.content) } + @Test + fun testCustomRequestUpdate() { + val customRequestFormat = """ + { + "customFields": + {"user": "userid", "date": "2012"}, + "messageKeys": + {"content": "anyContentKey", "role": "anyRoleKey"} + } + """.trimIndent() + + val customRequest = CustomRequest(listOf( + Message("robot", "hello") + )) + + val request = customRequest.updateCustomFormat(customRequestFormat) + assertEquals(""" + {"messages":[{"anyRoleKey":"robot","anyContentKey":"hello"}],"user":"userid","date":"2012"} + """.trimIndent(), request) + } }