Skip to content

Commit

Permalink
feat: make azure works with stream
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Aug 7, 2023
1 parent f83b70f commit 962b599
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 15 deletions.
97 changes: 82 additions & 15 deletions src/main/kotlin/cc/unitmesh/devti/llms/azure/AzureOpenAIProvider.kt
Expand Up @@ -10,12 +10,24 @@ import com.intellij.openapi.project.Project
import com.theokanning.openai.completion.chat.ChatCompletionResult
import com.theokanning.openai.completion.chat.ChatMessage
import com.theokanning.openai.completion.chat.ChatMessageRole
import com.theokanning.openai.service.ResponseBodyCallback
import com.theokanning.openai.service.SSE
import io.reactivex.BackpressureStrategy
import io.reactivex.Flowable
import io.reactivex.FlowableEmitter
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.withContext
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import okhttp3.*
import okhttp3.MediaType.Companion.toMediaTypeOrNull
import okhttp3.OkHttpClient
import okhttp3.Request
import java.io.IOException


@Serializable
data class SimpleOpenAIFormat(val role: String, val content: String) {
Expand All @@ -26,6 +38,9 @@ data class SimpleOpenAIFormat(val role: String, val content: String) {
}
}

@Serializable
data class SimpleOpenAIBody(val messages: List<SimpleOpenAIFormat>, val temperature: Double, val stream: Boolean)

@Service(Service.Level.PROJECT)
class AzureOpenAIProvider(val project: Project) : CodeCopilotProvider {
private val logger = logger<AzureOpenAIProvider>()
Expand All @@ -49,47 +64,99 @@ class AzureOpenAIProvider(val project: Project) : CodeCopilotProvider {
private val messages: MutableList<SimpleOpenAIFormat> = ArrayList()
private var historyMessageLength: Int = 0


fun prompt(instruction: String, input: String): String {
val promptText = "$instruction\n$input"
val systemMessage = ChatMessage(ChatMessageRole.USER.value(), promptText)

if (historyMessageLength > 8192) {
messages.clear()
}

messages.add(SimpleOpenAIFormat.fromChatMessage(systemMessage))
val requestText = Json.encodeToString<SimpleOpenAIBody>(
SimpleOpenAIBody(
messages,
0.0,
false
)
)

val builder = Request.Builder()
val requestText = """{
|"messages": ${Json.encodeToString(messages)},
|"temperature": 0.0
}""".trimMargin()

val body = okhttp3.RequestBody.create(
val body = RequestBody.create(
"application/json; charset=utf-8".toMediaTypeOrNull(),
requestText
)

val builder = Request.Builder()
val request = builder
.url(url)
.post(body)
.build()

val response = client.newCall(request).execute()

if (!response.isSuccessful) {
logger.error("$response")
return ""
}

val objectMapper = ObjectMapper()
val completion: ChatCompletionResult =
objectMapper.readValue(response.body?.string(), ChatCompletionResult::class.java)
ObjectMapper().readValue(response.body?.string(), ChatCompletionResult::class.java)

return completion.choices[0].message.content
}


// fun stream(apiCall: Call, emitDone: Boolean): Flowable<SSE?> {
// return Flowable.create({ emitter: FlowableEmitter<SSE?>? ->
// apiCall.enqueue(
//
// )
// }, BackpressureStrategy.BUFFER)
// }

@OptIn(ExperimentalCoroutinesApi::class)
override fun stream(promptText: String, systemPrompt: String): Flow<String> = callbackFlow {
val promptText1 = "$promptText\n${""}"
val systemMessage = ChatMessage(ChatMessageRole.USER.value(), promptText1)
if (historyMessageLength > 8192) {
messages.clear()
}
messages.add(SimpleOpenAIFormat.fromChatMessage(systemMessage))
val openAIBody = SimpleOpenAIBody(
messages,
0.0,
true
)

val requestText = Json.encodeToString<SimpleOpenAIBody>(openAIBody)
val body = RequestBody.create(
"application/json; charset=utf-8".toMediaTypeOrNull(),
requestText
)

val builder = Request.Builder()
val request = builder
.url(url)
.post(body)
.build()
val call = client.newCall(request)
val emitDone = false;
withContext(Dispatchers.IO) {
Flowable.create({ emitter: FlowableEmitter<SSE> ->
call.enqueue(cc.unitmesh.devti.llms.azure.ResponseBodyCallback(emitter, emitDone))
}, BackpressureStrategy.BUFFER)
.doOnError(Throwable::printStackTrace)
.blockingForEach { sse ->
val result: ChatCompletionResult =
ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java)
val completion = result.choices[0].message
if (completion != null && completion.content != null) {
trySend(completion.content)
}
}

close()
}
}


override fun autoComment(text: String): String {
val comment = customPromptConfig!!.autoComment
return prompt(comment.instruction, comment.input.replace("{code}", text))
Expand Down
103 changes: 103 additions & 0 deletions src/main/kotlin/cc/unitmesh/devti/llms/azure/ResponseBodyCallback.kt
@@ -0,0 +1,103 @@
// MIT License
//
//Copyright (c) [year] [fullname]
//
//Permission is hereby granted, free of charge, to any person obtaining a copy
//of this software and associated documentation files (the "Software"), to deal
//in the Software without restriction, including without limitation the rights
//to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
//copies of the Software, and to permit persons to whom the Software is
//furnished to do so, subject to the following conditions:
//
//The above copyright notice and this permission notice shall be included in all
//copies or substantial portions of the Software.
//
//THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
//IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
//FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
//AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
//LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
//OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
//SOFTWARE.
package cc.unitmesh.devti.llms.azure

import com.theokanning.openai.OpenAiError
import com.theokanning.openai.OpenAiHttpException
import com.theokanning.openai.service.OpenAiService
import com.theokanning.openai.service.SSE
import com.theokanning.openai.service.SSEFormatException
import io.reactivex.FlowableEmitter
import okhttp3.Call
import okhttp3.Callback
import okhttp3.Response
import org.apache.commons.httpclient.HttpException
import java.io.BufferedReader
import java.io.IOException
import java.io.InputStreamReader
import java.nio.charset.StandardCharsets

/**
* Callback to parse Server Sent Events (SSE) from raw InputStream and
* emit the events with io.reactivex.FlowableEmitter to allow streaming of
* SSE.
*/
class ResponseBodyCallback(private val emitter: FlowableEmitter<SSE>, private val emitDone: Boolean) : Callback {
override fun onResponse(call: Call, response: Response) {
var reader: BufferedReader? = null
try {
if (!response.isSuccessful) {
val e = HttpException(response.body.toString())
val errorBody = response.body
if (errorBody == null) {
throw e
} else {
val error = mapper.readValue(
errorBody.string(),
OpenAiError::class.java
)
throw OpenAiHttpException(error, e, e.reasonCode)
}
}
val inputStream = response.body!!.byteStream()
reader = BufferedReader(InputStreamReader(inputStream, StandardCharsets.UTF_8))
var line = ""
var sse: SSE? = null
while (!emitter.isCancelled && reader.readLine().also { line = it } != null) {
sse = if (line.startsWith("data:")) {
val data = line.substring(5).trim { it <= ' ' }
SSE(data)
} else if (line == "" && sse != null) {
if (sse.isDone) {
if (emitDone) {
emitter.onNext(sse)
}
break
}
emitter.onNext(sse)
null
} else {
throw SSEFormatException("Invalid sse format! $line")
}
}
emitter.onComplete()
} catch (t: Throwable) {
onFailure(call, t as IOException)
} finally {
if (reader != null) {
try {
reader.close()
} catch (e: IOException) {
// do nothing
}
}
}
}

override fun onFailure(call: Call, e: IOException) {
emitter.onError(e)
}

companion object {
private val mapper = OpenAiService.defaultObjectMapper()
}
}

0 comments on commit 962b599

Please sign in to comment.