diff --git a/src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt b/src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt index 79eb9b0ae7..f69e1a1204 100644 --- a/src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt +++ b/src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt @@ -1,15 +1,24 @@ package cc.unitmesh.devti.llms.custom import cc.unitmesh.devti.llms.CodeCopilotProvider +import cc.unitmesh.devti.llms.azure.SimpleOpenAIBody import cc.unitmesh.devti.prompting.model.CustomPromptConfig import cc.unitmesh.devti.settings.AutoDevSettingsState +import com.fasterxml.jackson.databind.ObjectMapper import com.intellij.openapi.components.Service import com.intellij.openapi.diagnostic.logger import com.intellij.openapi.project.Project +import com.theokanning.openai.completion.chat.ChatCompletionResult +import com.theokanning.openai.service.SSE +import io.reactivex.BackpressureStrategy +import io.reactivex.Flowable +import io.reactivex.FlowableEmitter +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.callbackFlow +import kotlinx.coroutines.withContext import kotlinx.serialization.Serializable import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json @@ -63,30 +72,35 @@ class CustomLLMProvider(val project: Project) : CodeCopilotProvider { .post(body) .build() - return callbackFlow { - val listener = object : EventSourceListener() { - override fun onOpen(eventSource: EventSource, response: Response) { - println("onOpen") - } - - override fun onEvent(eventSource: EventSource, id: String?, type: String?, data: String) { - println(data) - trySend(data) - } + val call = client.newCall(request) + val emitDone = false - override fun onClosed(eventSource: EventSource) { + val sseFlowable = Flowable + .create({ emitter: FlowableEmitter -> + call.enqueue(cc.unitmesh.devti.llms.azure.ResponseBodyCallback(emitter, emitDone)) + }, BackpressureStrategy.BUFFER) - } + try { + return callbackFlow { + withContext(Dispatchers.IO) { + sseFlowable + .doOnError(Throwable::printStackTrace) + .blockingForEach { sse -> + val result: ChatCompletionResult = + ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java) + val completion = result.choices[0].message + if (completion != null && completion.content != null) { + trySend(completion.content) + } + } - override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { close() } } - - val eventSource = EventSources.createFactory(client).newEventSource(request, listener) - - awaitClose { - eventSource.cancel() + } catch (e: Exception) { + logger.error("Failed to stream", e) + return callbackFlow { + close() } } }