Skip to content

Commit

Permalink
feat: 增加自定义请求。可修改请求 Header 及 reqeust body
Browse files Browse the repository at this point in the history
  • Loading branch information
iptton committed Dec 14, 2023
1 parent e0cca51 commit a72f085
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 38 deletions.
8 changes: 6 additions & 2 deletions src/main/kotlin/cc/unitmesh/devti/gui/chat/ChatCodingPanel.kt
Expand Up @@ -27,6 +27,7 @@ import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.withContext
import java.awt.event.ActionListener
import java.awt.event.MouseAdapter
Expand Down Expand Up @@ -152,7 +153,7 @@ class ChatCodingPanel(private val chatCodingService: ChatCodingService, val disp

suspend fun updateMessage(content: Flow<String>): String {
if (myList.componentCount > 0) {
myList.remove(myList.componentCount - 1)
myList.remove(myList.componentCount - 1)
}

progressBar.isVisible = true
Expand Down Expand Up @@ -200,9 +201,12 @@ class ChatCodingPanel(private val chatCodingService: ChatCodingService, val disp
val startTime = System.currentTimeMillis() // 记录代码开始执行的时间

var text = ""
content.catch {
content.onCompletion {
println("onCompletion ${it?.message}")
}.catch {
it.printStackTrace()
}.collect {
println("got message $it")
text += it

// 以下两个 API 设计不合理,如果必须要同时调用,那就只提供一个就好了
Expand Down
89 changes: 54 additions & 35 deletions src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt
@@ -1,9 +1,9 @@
package cc.unitmesh.devti.llms.custom

import cc.unitmesh.devti.custom.action.CustomPromptConfig
import cc.unitmesh.devti.gui.chat.ChatRole
import cc.unitmesh.devti.llms.LLMProvider
import cc.unitmesh.devti.settings.AutoDevSettingsState
import cc.unitmesh.devti.settings.ResponseType
import com.fasterxml.jackson.databind.ObjectMapper
import com.intellij.openapi.components.Service
import com.intellij.openapi.diagnostic.logger
Expand All @@ -17,12 +17,15 @@ import io.reactivex.Flowable
import io.reactivex.FlowableEmitter
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.withContext
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.*
import okhttp3.Call
import okhttp3.MediaType.Companion.toMediaTypeOrNull
import okhttp3.OkHttpClient
import okhttp3.Request
Expand All @@ -39,15 +42,15 @@ data class CustomRequest(val messages: List<Message>)
@Service(Service.Level.PROJECT)
class CustomLLMProvider(val project: Project) : LLMProvider {
private val autoDevSettingsState = AutoDevSettingsState.getInstance()
private val url get() = autoDevSettingsState.customEngineServer
private val key get() = autoDevSettingsState.customEngineToken
private val requestFormat: String get() = autoDevSettingsState.customEngineRequestFormat
private val responseFormat get() = autoDevSettingsState.customEngineResponseFormat
private val customPromptConfig: CustomPromptConfig
get() {
val prompts = autoDevSettingsState.customPrompts
return CustomPromptConfig.tryParse(prompts)
}
private val url
get() = autoDevSettingsState.customEngineServer
private val key
get() = autoDevSettingsState.customEngineToken
private val requestFormat: String
get() = autoDevSettingsState.customEngineRequestFormat
private val responseFormat
get() = autoDevSettingsState.customEngineResponseType

private var client = OkHttpClient()
private val timeout = Duration.ofSeconds(600)
private val messages: MutableList<Message> = ArrayList()
Expand All @@ -66,7 +69,6 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
return this.prompt(promptText, "")
}

@OptIn(ExperimentalCoroutinesApi::class)
override fun stream(promptText: String, systemPrompt: String, keepHistory: Boolean): Flow<String> {
if (!keepHistory) {
clearMessage()
Expand All @@ -86,15 +88,32 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
}
builder.appendCustomHeaders(requestFormat)

client = client.newBuilder()
.readTimeout(timeout)
.build()
val request = builder
.url(url)
.post(body)
.build()
client = client.newBuilder().readTimeout(timeout).build()
val call = client.newCall(builder.url(url).post(body).build())

if (autoDevSettingsState.customEngineResponseType == ResponseType.SSE.name) {
return streamSSE(call)
} else {
return streamJson(call)
}
}


private val _responseFlow = MutableSharedFlow<String>()

val call = client.newCall(request)
@OptIn(ExperimentalCoroutinesApi::class)
private fun streamJson(call: Call): Flow<String> = callbackFlow {
call.enqueue(JSONBodyResponseCallback(responseFormat) {
withContext(Dispatchers.IO) {
send(it)
}
close()
})
awaitClose()
}

@OptIn(ExperimentalCoroutinesApi::class)
private fun streamSSE(call: Call): Flow<String> {
val emitDone = false

val sseFlowable = Flowable
Expand All @@ -106,29 +125,29 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
return callbackFlow {
withContext(Dispatchers.IO) {
sseFlowable
.doOnError{
it.printStackTrace()
close()
}
.blockingForEach { sse ->
if (responseFormat.isNotEmpty()) {
val chunk: String = JsonPath.parse(sse!!.data)?.read(responseFormat)
?: throw Exception("Failed to parse chunk")
logger.warn("got msg: $chunk")
trySend(chunk)
} else {
val result: ChatCompletionResult =
ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java)
.doOnError {
it.printStackTrace()
close()
}
.blockingForEach { sse ->
if (responseFormat.isNotEmpty()) {
val chunk: String = JsonPath.parse(sse!!.data)?.read(responseFormat)
?: throw Exception("Failed to parse chunk")
logger.warn("got msg: $chunk")
trySend(chunk)
} else {
val result: ChatCompletionResult =
ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java)

val completion = result.choices[0].message
if (completion != null && completion.content != null) {
trySend(completion.content)
}
}
}

close()
}
awaitClose()
}
} catch (e: Exception) {
logger.error("Failed to stream", e)
Expand Down Expand Up @@ -174,7 +193,7 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
fun Request.Builder.appendCustomHeaders(customRequestHeader: String): Request.Builder = apply {
runCatching {
Json.parseToJsonElement(customRequestHeader)
.jsonObject["customHeaders"].let { customFields ->
.jsonObject["customHeaders"].let { customFields ->
customFields?.jsonObject?.forEach { (key, value) ->
header(key, value.jsonPrimitive.content)
}
Expand Down Expand Up @@ -232,5 +251,5 @@ fun JsonObject.updateCustomBody(customRequest: String): JsonObject {
fun CustomRequest.updateCustomFormat(format: String): String {
val requestContentOri = Json.encodeToString<CustomRequest>(this)
return Json.parseToJsonElement(requestContentOri)
.jsonObject.updateCustomBody(format).toString()
.jsonObject.updateCustomBody(format).toString()
}
@@ -0,0 +1,26 @@
package cc.unitmesh.devti.llms.custom

import com.nfeld.jsonpathkt.JsonPath
import com.nfeld.jsonpathkt.extension.read
import io.kotest.common.runBlocking
import okhttp3.Call
import okhttp3.Callback
import okhttp3.Response
import java.io.IOException

class JSONBodyResponseCallback(private val responseFormat: String,private val callback: suspend (String)->Unit): Callback {
override fun onFailure(call: Call, e: IOException) {
runBlocking {
callback("error. ${e.message}")
}
}

override fun onResponse(call: Call, response: Response) {
val responseContent: String = JsonPath.parse(response.body?.string())?.read(responseFormat) ?: ""

runBlocking() {
callback(responseContent)
}

}
}
Empty file.
Expand Up @@ -28,6 +28,11 @@ class AutoDevSettingsState : PersistentStateComponent<AutoDevSettingsState> {
var xingHuoApiSecrect = ""
var xingHuoApiKey = ""


/**
* 自定义引擎返回的数据格式是否是 [SSE](https://www.ruanyifeng.com/blog/2017/05/server-sent_events.html) 格式
*/
var customEngineResponseType = ResponseType.SSE.name
/**
* should be a json path
*/
Expand Down
12 changes: 12 additions & 0 deletions src/main/kotlin/cc/unitmesh/devti/settings/Constants.kt
Expand Up @@ -22,6 +22,18 @@ enum class XingHuoApiVersion(val value: Int) {
}
}

enum class ResponseType {
SSE, JSON;

companion object {
fun of(str: String): ResponseType = when (str) {
"SSE" -> SSE
"JSON" -> JSON
else -> JSON
}
}
}


val DEFAULT_AI_ENGINE = AI_ENGINES[0]

Expand Down
Expand Up @@ -39,6 +39,8 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
private val xingHuoAppIDParam by LLMParam.creating { Editable(settings.xingHuoAppId) }
private val xingHuoApiKeyParam by LLMParam.creating { Password(settings.xingHuoApiKey) }
private val xingHuoApiSecretParam by LLMParam.creating { Password(settings.xingHuoApiSecrect) }

private val customEngineResponseTypeParam by LLMParam.creating { ComboBox(ResponseType.of(settings.customEngineResponseType).name, ResponseType.values().map { it.name }.toList()) }
private val customEngineResponseFormatParam by LLMParam.creating { Editable(settings.customEngineResponseFormat) }
private val customEngineRequestBodyFormatParam by LLMParam.creating { Editable(settings.customEngineRequestFormat) }

Expand Down Expand Up @@ -79,6 +81,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
customOpenAIHostParam,
),
AIEngines.Custom to listOf(
customEngineResponseTypeParam,
customEngineServerParam,
customEngineTokenParam,
customEngineResponseFormatParam,
Expand Down Expand Up @@ -185,6 +188,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
openAIKeyParam.value = openAiKey
customOpenAIHostParam.value = customOpenAiHost
customEngineServerParam.value = customEngineServer
customEngineResponseTypeParam.value = customEngineResponseType
customEngineTokenParam.value = customEngineToken
openAIModelsParam.value = openAiModel
xingHuoApiVersionParam.value = xingHuoApiVersion.toString()
Expand Down Expand Up @@ -216,10 +220,11 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
aiEngine = aiEngineParam.value
language = languageParam.value
customEngineServer = customEngineServerParam.value
customEngineResponseType = customEngineResponseTypeParam.value
customEngineToken = customEngineTokenParam.value
customPrompts = customEnginePrompt.text
openAiModel = openAIModelsParam.value
customEngineResponseFormat = customEngineResponseFormatParam.value
customEngineResponseType = customEngineResponseFormatParam.value
customEngineRequestFormat = customEngineRequestBodyFormatParam.value
delaySeconds = delaySecondsParam.value
}
Expand All @@ -239,6 +244,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
settings.aiEngine != aiEngineParam.value ||
settings.language != languageParam.value ||
settings.customEngineServer != customEngineServerParam.value ||
settings.customEngineResponseType != customEngineResponseTypeParam.value ||
settings.customEngineToken != customEngineTokenParam.value ||
settings.customPrompts != customEnginePrompt.text ||
settings.openAiModel != openAIModelsParam.value ||
Expand Down
1 change: 1 addition & 0 deletions src/main/resources/messages/AutoDevBundle.properties
Expand Up @@ -66,6 +66,7 @@ settings.xingHuoApiVersionParam=XingHuo API Version

settings.delaySecondsParam=Quest Delay Seconds
settings.customEngineResponseFormatParam=Custom Response Format (Json Path)
settings.customEngineResponseTypeParam=Custom Response Type
settings.customEngineRequestBodyFormatParam=Custom Request Body Format (Json Path)
settings.customEngineRequestHeaderFormatParam=Custom Request Header Format (Json Path)
settings.external.counit.enable.label=Enable CoUnit (Experimental)
Expand Down

0 comments on commit a72f085

Please sign in to comment.