Skip to content

Commit

Permalink
feat: support stream mode for chatgpt usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
plateaukao committed Jun 3, 2023
1 parent ea19dbf commit 5010d5c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 11 deletions.
1 change: 1 addition & 0 deletions app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,6 @@ dependencies {

// okhttp
implementation 'com.squareup.okhttp3:okhttp:4.10.0'
implementation 'com.squareup.okhttp3:okhttp-sse:4.11.0'
implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.0'
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class ConfigManager(

var hideStatusbar by BooleanPreference(sp, K_HIDE_STATUSBAR, false)

var enableOpenAiStream by BooleanPreference(sp, K_ENABLE_OPEN_AI_STREAM, true)

var isIncognitoMode: Boolean
get() = sp.getBoolean(K_IS_INCOGNITO_MODE, false)
set(value) {
Expand Down Expand Up @@ -535,6 +537,8 @@ class ConfigManager(
const val K_ENABLE_SAVE_DATA = "sp_enable_save_data"
const val K_HIDE_STATUSBAR = "sp_hide_statusbar"

const val K_ENABLE_OPEN_AI_STREAM = "sp_enable_open_ai_stream"

const val K_SHOW_TRANSLATED_IMAGE_TO_SECOND_PANEL =
"sp_show_translated_image_to_second_panel"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import okhttp3.MediaType.Companion.toMediaType
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.RequestBody.Companion.toRequestBody
import okhttp3.sse.EventSource
import okhttp3.sse.EventSources
import org.koin.core.component.KoinComponent
import java.util.concurrent.TimeUnit
import kotlin.coroutines.resume
Expand All @@ -23,21 +25,36 @@ class OpenAiRepository(
.readTimeout(30, TimeUnit.SECONDS)
.writeTimeout(60, TimeUnit.SECONDS)
.build()
private val factory by lazy { EventSources.createFactory(client) }

private val json = Json { ignoreUnknownKeys = true }

fun chatStream(
messages: List<ChatMessage>,
appendResponseAction: (String) -> Unit,
failureAction: () -> Unit,
) {
val request = createRequest(messages, true)

factory.newEventSource(request, object : okhttp3.sse.EventSourceListener() {
override fun onEvent(
eventSource: EventSource, id: String?, type: String?, data: String
) {
if (data == null || data.isEmpty() || data == "[DONE]") return
try {
val chatCompletion = json.decodeFromString<ChatCompletionDelta>(data)
appendResponseAction(chatCompletion.choices.first().delta.content ?: "")
} catch (e: Exception) {
failureAction()
}
}
})
}

suspend fun chatCompletion(
messages: List<ChatMessage>
): ChatCompletion? = suspendCoroutine { continuation ->
val request = Request.Builder()
.url(endpoint)
.post(
json.encodeToString(ChatRequest("gpt-3.5-turbo", messages))
.toRequestBody(mediaType)
)
.header("Authorization", "Bearer $apiKey")
.build()

val request = createRequest(messages)
client.newCall(request).execute().use { response ->
if (response.code != 200 || response.body == null) {
return@use continuation.resume(null)
Expand All @@ -54,6 +71,18 @@ class OpenAiRepository(
}
}

private fun createRequest(
messages: List<ChatMessage>,
stream: Boolean = false,
): Request = Request.Builder()
.url(endpoint)
.post(
json.encodeToString(ChatRequest("gpt-3.5-turbo", messages, stream))
.toRequestBody(mediaType)
)
.header("Authorization", "Bearer $apiKey")
.build()

companion object {
private const val endpoint = "https://api.openai.com/v1/chat/completions"
private val mediaType = "application/json; charset=utf-8".toMediaType()
Expand All @@ -66,7 +95,15 @@ data class ChatCompletion(
val created: Int,
val model: String,
val choices: List<ChatChoice>,
val usage: ChatUsage,
val usage: ChatUsage = ChatUsage(0, 0, 0)
)

@Serializable
data class ChatCompletionDelta(
val id: String,
val created: Int,
val model: String,
val choices: List<ChatChoiceDelta>,
)

@Serializable
Expand All @@ -83,13 +120,26 @@ data class ChatUsage(
data class ChatRequest(
val model: String,
val messages: List<ChatMessage>,
val stream: Boolean = false,
val temperature: Double = 0.5,
)

@Serializable
data class ChatChoiceDelta(
val index: Int,
val delta: ChatDelta,
@kotlinx.serialization.Transient
@SerialName("finish_reason")
val finishReason: String? = null,
)

@Serializable
data class ChatChoice(
val index: Int,
val message: ChatMessage
val message: ChatMessage,
@kotlinx.serialization.Transient
@SerialName("finish_reason")
val finishReason: String? = null,
)

enum class ChatRole {
Expand All @@ -103,6 +153,11 @@ enum class ChatRole {
Assistant
}

@Serializable
data class ChatDelta(
val content: String? = null,
)

@Serializable
data class ChatMessage(
val content: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ class GptViewModel : ViewModel(), KoinComponent {
messages.add("${config.gptUserPromptPrefix}${_inputMessage.value}".toUserMessage())


// stream case
if (config.enableOpenAiStream) {
openaiRepository.chatStream(
messages,
appendResponseAction = { _responseMessage.value += it },
failureAction = { _responseMessage.value = "Something went wrong." }
)
return
}

// normal case: too slow!!!
viewModelScope.launch(Dispatchers.IO) {
val chatCompletion = openaiRepository.chatCompletion(messages)
if (chatCompletion == null || chatCompletion.choices.isEmpty()) {
Expand Down

0 comments on commit 5010d5c

Please sign in to comment.