Skip to content

Commit

Permalink
feat: make custom server works with stream
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Aug 7, 2023
1 parent 6b6e91d commit a7a9c1c
Showing 1 changed file with 32 additions and 18 deletions.
50 changes: 32 additions & 18 deletions src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt
@@ -1,15 +1,24 @@
package cc.unitmesh.devti.llms.custom

import cc.unitmesh.devti.llms.CodeCopilotProvider
import cc.unitmesh.devti.llms.azure.SimpleOpenAIBody
import cc.unitmesh.devti.prompting.model.CustomPromptConfig
import cc.unitmesh.devti.settings.AutoDevSettingsState
import com.fasterxml.jackson.databind.ObjectMapper
import com.intellij.openapi.components.Service
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.project.Project
import com.theokanning.openai.completion.chat.ChatCompletionResult
import com.theokanning.openai.service.SSE
import io.reactivex.BackpressureStrategy
import io.reactivex.Flowable
import io.reactivex.FlowableEmitter
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.withContext
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -63,30 +72,35 @@ class CustomLLMProvider(val project: Project) : CodeCopilotProvider {
.post(body)
.build()

return callbackFlow {
val listener = object : EventSourceListener() {
override fun onOpen(eventSource: EventSource, response: Response) {
println("onOpen")
}

override fun onEvent(eventSource: EventSource, id: String?, type: String?, data: String) {
println(data)
trySend(data)
}
val call = client.newCall(request)
val emitDone = false

override fun onClosed(eventSource: EventSource) {
val sseFlowable = Flowable
.create({ emitter: FlowableEmitter<SSE> ->
call.enqueue(cc.unitmesh.devti.llms.azure.ResponseBodyCallback(emitter, emitDone))
}, BackpressureStrategy.BUFFER)

}
try {
return callbackFlow {
withContext(Dispatchers.IO) {
sseFlowable
.doOnError(Throwable::printStackTrace)
.blockingForEach { sse ->
val result: ChatCompletionResult =
ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java)
val completion = result.choices[0].message
if (completion != null && completion.content != null) {
trySend(completion.content)
}
}

override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) {
close()
}
}

val eventSource = EventSources.createFactory(client).newEventSource(request, listener)

awaitClose {
eventSource.cancel()
} catch (e: Exception) {
logger.error("Failed to stream", e)
return callbackFlow {
close()
}
}
}
Expand Down

0 comments on commit a7a9c1c

Please sign in to comment.