Skip to content

Commit

Permalink
feat: add request format complete
Browse files Browse the repository at this point in the history
  • Loading branch information
ippan authored and iptton committed Oct 25, 2023
1 parent 054b587 commit 07bfced
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 30 deletions.
6 changes: 2 additions & 4 deletions docs/custom-llm-server.md
Expand Up @@ -83,9 +83,7 @@ And custom the origin keys for `role`, `messsage`
{
"customHeaders": { "my header": "my value" },
"customFields": {"user": "userid", "date": "2012"},
"messageKeys": [
{"role": "role", "content": "message"}
]
"messageKeys": {"role": "role", "content": "message"}
}
```

Expand All @@ -94,6 +92,6 @@ and the request body will be:
```json
{
"user": "userid",
"messages": [{"role": "user", "content": "..."}]
"messages": [{"role": "user", "message": "..."}]
}
```
Expand Up @@ -35,6 +35,9 @@ import java.io.InputStreamReader
import java.nio.charset.StandardCharsets

class AutoDevHttpException(error: String, val statusCode: Int) : RuntimeException(error) {
override fun toString(): String {
return "AutoDevHttpException(statusCode=$statusCode, message=$message)"
}
}

/**
Expand Down Expand Up @@ -80,6 +83,11 @@ class ResponseBodyCallback(private val emitter: FlowableEmitter<SSE>, private va
null
}

line!!.startsWith("{") -> {
logger<ResponseBodyCallback>().warn("msg starts with { $line")
emitter.onNext(SSE(line!!))
null
}
else -> {
throw SSEFormatException("Invalid sse format! $line")
}
Expand Down
34 changes: 21 additions & 13 deletions src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt
Expand Up @@ -28,7 +28,6 @@ import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.RequestBody
import org.jetbrains.annotations.VisibleForTesting
import org.jetbrains.kotlin.idea.gradleTooling.get
import java.time.Duration

@Serializable
Expand All @@ -42,9 +41,7 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
private val autoDevSettingsState = AutoDevSettingsState.getInstance()
private val url get() = autoDevSettingsState.customEngineServer
private val key get() = autoDevSettingsState.customEngineToken

private val requestHeaderFormat: String get() = autoDevSettingsState.customEngineRequestHeaderFormat
private val requestBodyFormat: String get() = autoDevSettingsState.customEngineRequestBodyFormat
private val requestFormat: String get() = autoDevSettingsState.customEngineRequestFormat
private val responseFormat get() = autoDevSettingsState.customEngineResponseFormat
private val customPromptConfig: CustomPromptConfig
get() {
Expand Down Expand Up @@ -78,18 +75,24 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
messages += Message("user", promptText)

val customRequest = CustomRequest(messages)
val requestContent = Json.encodeToString<CustomRequest>(customRequest)
val requestContent = customRequest.updateCustomFormat(requestFormat)

val body = RequestBody.create("application/json; charset=utf-8".toMediaTypeOrNull(), requestContent)
logger.info("Requesting from $body")

val builder = Request.Builder()
if (key.isNotEmpty()) {
builder.addHeader("Authorization", "Bearer $key")
builder.addHeader("Content-Type", "application/json")
}
builder.appendCustomHeaders(requestFormat)

client = client.newBuilder().readTimeout(timeout).build()
val request = builder.url(url).post(body).build()
client = client.newBuilder()
.readTimeout(timeout)
.build()
val request = builder
.url(url)
.post(body)
.build()

val call = client.newCall(request)
val emitDone = false
Expand All @@ -100,7 +103,6 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
}, BackpressureStrategy.BUFFER)

try {
logger.info("Starting to stream:")
return callbackFlow {
withContext(Dispatchers.IO) {
sseFlowable
Expand Down Expand Up @@ -139,7 +141,7 @@ class CustomLLMProvider(val project: Project) : LLMProvider {

val body = RequestBody.create("application/json; charset=utf-8".toMediaTypeOrNull(), requestContent)

logger.info("Requesting from $body")
logger.info("Requesting form: $requestContent ${body.toString()}")
val builder = Request.Builder()
if (key.isNotEmpty()) {
builder.addHeader("Authorization", "Bearer $key")
Expand Down Expand Up @@ -174,7 +176,8 @@ fun Request.Builder.appendCustomHeaders(customRequestHeader: String): Request.Bu
}
}
}.onFailure {
logger<CustomLLMProvider>().error("Failed to parse custom request header", it)
// should I warn user?
println("Failed to parse custom request header ${it.message}")
}
}

Expand All @@ -194,10 +197,9 @@ fun JsonObject.updateCustomBody(customRequest: String): JsonObject {
}



// TODO clean code with magic literals
var roleKey = "role"
var contentKey = "message"
var contentKey = "message"
customRequestJson.jsonObject["messageKeys"]?.let {
roleKey = it.jsonObject["role"]?.jsonPrimitive?.content ?: "role"
contentKey = it.jsonObject["content"]?.jsonPrimitive?.content ?: "message"
Expand All @@ -222,3 +224,9 @@ fun JsonObject.updateCustomBody(customRequest: String): JsonObject {
this
}
}

fun CustomRequest.updateCustomFormat(format: String): String {
val requestContentOri = Json.encodeToString<CustomRequest>(this)
return Json.parseToJsonElement(requestContentOri)
.jsonObject.updateCustomBody(format).toString()
}
13 changes: 11 additions & 2 deletions src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsState.kt
Expand Up @@ -30,8 +30,17 @@ class AutoDevSettingsState : PersistentStateComponent<AutoDevSettingsState> {
* should be a json path
*/
var customEngineResponseFormat = ""
var customEngineRequestHeaderFormat = ""
var customEngineRequestBodyFormat = ""
/**
* should be a json
* {
* 'customHeaders': { 'headerName': 'headerValue', 'headerName2': 'headerValue2' ... },
* 'customFields' : { 'bodyFieldName': 'bodyFieldValue', 'bodyFieldName2': 'bodyFieldValue2' ... }
* 'messageKey': {'role': 'roleKeyName', 'content': 'contentKeyName'}
* }
*
* @see docs/custom-llm-server.md
*/
var customEngineRequestFormat = ""


var language = DEFAULT_HUMAN_LANGUAGE
Expand Down
Expand Up @@ -36,8 +36,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
private val xingHuoApiKeyParam by LLMParam.creating { Password(settings.xingHuoApiKey) }
private val xingHuoApiSecretParam by LLMParam.creating { Password(settings.xingHuoApiSecrect) }
private val customEngineResponseFormatParam by LLMParam.creating { Editable(settings.customEngineResponseFormat) }
private val customEngineRequestBodyFormatParam by LLMParam.creating { Editable(settings.customEngineRequestBodyFormat) }
private val customEngineRequestHeaderFormatParam by LLMParam.creating { Editable(settings.customEngineRequestHeaderFormat) }
private val customEngineRequestBodyFormatParam by LLMParam.creating { Editable(settings.customEngineRequestFormat) }


val project = ProjectManager.getInstance().openProjects.firstOrNull()
Expand Down Expand Up @@ -80,7 +79,6 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
customEngineTokenParam,
customEngineResponseFormatParam,
customEngineRequestBodyFormatParam,
customEngineRequestHeaderFormatParam,
),
AIEngines.XingHuo to listOf(
xingHuoAppIDParam,
Expand Down Expand Up @@ -191,8 +189,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
aiEngineParam.value = aiEngine
customEnginePrompt.text = customPrompts
customEngineResponseFormatParam.value = customEngineResponseFormat
customEngineRequestBodyFormatParam.value = customEngineRequestBodyFormat
customEngineRequestHeaderFormatParam.value = customEngineRequestHeaderFormat
customEngineRequestBodyFormatParam.value = customEngineRequestFormat
delaySecondsParam.value = delaySeconds
}
}
Expand All @@ -216,8 +213,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
customPrompts = customEnginePrompt.text
openAiModel = openAIModelsParam.value
customEngineResponseFormat = customEngineResponseFormatParam.value
customEngineRequestBodyFormat = customEngineRequestBodyFormatParam.value
customEngineRequestHeaderFormat = customEngineRequestHeaderFormatParam.value
customEngineRequestFormat = customEngineRequestBodyFormatParam.value
delaySeconds = delaySecondsParam.value
}
}
Expand All @@ -240,8 +236,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
settings.openAiModel != openAIModelsParam.value ||
settings.customOpenAiHost != customOpenAIHostParam.value ||
settings.customEngineResponseFormat != customEngineResponseFormatParam.value ||
settings.customEngineRequestBodyFormat != customEngineRequestBodyFormatParam.value ||
settings.customEngineRequestHeaderFormat != customEngineRequestHeaderFormatParam.value ||
settings.customEngineRequestFormat != customEngineRequestBodyFormatParam.value ||
settings.delaySeconds != delaySecondsParam.value
}

Expand Down
Expand Up @@ -45,7 +45,7 @@ class CustomLLMProviderTest {
val customRequest = """
{
"messageKeys":
{"content": "message"}
{"content": "content"}
}
""".trimIndent()
val request = CustomRequest(listOf(
Expand All @@ -58,8 +58,28 @@ class CustomLLMProviderTest {

val messageObj = newObj.jsonObject["messages"]!!.jsonArray
assertEquals(1, messageObj.size)
assertEquals("this is message", messageObj[0].jsonObject["message"]!!.jsonPrimitive.content)
assertEquals("this is message", messageObj[0].jsonObject["content"]!!.jsonPrimitive.content)
}

@Test
fun testCustomRequestUpdate() {
val customRequestFormat = """
{
"customFields":
{"user": "userid", "date": "2012"},
"messageKeys":
{"content": "anyContentKey", "role": "anyRoleKey"}
}
""".trimIndent()

val customRequest = CustomRequest(listOf(
Message("robot", "hello")
))

val request = customRequest.updateCustomFormat(customRequestFormat)
assertEquals("""
{"messages":[{"anyRoleKey":"robot","anyContentKey":"hello"}],"user":"userid","date":"2012"}
""".trimIndent(), request)
}

}

0 comments on commit 07bfced

Please sign in to comment.