Skip to content

Commit

Permalink
feat: enable recording datasets works in local
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Dec 31, 2023
1 parent f24cc6a commit 2fcab1c
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 1 deletion.
Expand Up @@ -3,9 +3,15 @@ package cc.unitmesh.devti.llms.azure
import cc.unitmesh.devti.custom.action.CustomPromptConfig
import cc.unitmesh.devti.gui.chat.ChatRole
import cc.unitmesh.devti.llms.LLMProvider
import cc.unitmesh.devti.recording.EmptyRecording
import cc.unitmesh.devti.recording.JsonlRecording
import cc.unitmesh.devti.recording.Recording
import cc.unitmesh.devti.recording.RecordingInstruction
import cc.unitmesh.devti.settings.AutoDevSettingsState
import cc.unitmesh.devti.settings.custom.teamPromptsSettings
import com.fasterxml.jackson.databind.ObjectMapper
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.project.Project
import com.theokanning.openai.completion.chat.ChatCompletionResult
Expand Down Expand Up @@ -53,6 +59,14 @@ class AzureOpenAIProvider(val project: Project) : LLMProvider {
private val maxTokenLength: Int
get() = AutoDevSettingsState.getInstance().fetchMaxTokenLength()

private val recording: Recording
get() {
if (project.teamPromptsSettings.state.recordingInLocal) {
return project.service<JsonlRecording>()
}
return EmptyRecording()
}


init {
val prompts = autoDevSettingsState.customPrompts
Expand Down Expand Up @@ -149,6 +163,8 @@ class AzureOpenAIProvider(val project: Project) : LLMProvider {
call.enqueue(cc.unitmesh.devti.llms.azure.ResponseBodyCallback(emitter, emitDone))
}, BackpressureStrategy.BUFFER)

var output = ""

return callbackFlow {
sseFlowable
.doOnError(Throwable::printStackTrace)
Expand All @@ -161,6 +177,8 @@ class AzureOpenAIProvider(val project: Project) : LLMProvider {
}
}

recording.write(RecordingInstruction(promptText, output))

close()
}
}
Expand Down
19 changes: 18 additions & 1 deletion src/main/kotlin/cc/unitmesh/devti/llms/openai/OpenAIProvider.kt
Expand Up @@ -2,8 +2,14 @@ package cc.unitmesh.devti.llms.openai

import cc.unitmesh.devti.gui.chat.ChatRole
import cc.unitmesh.devti.llms.LLMProvider
import cc.unitmesh.devti.recording.EmptyRecording
import cc.unitmesh.devti.recording.JsonlRecording
import cc.unitmesh.devti.recording.Recording
import cc.unitmesh.devti.recording.RecordingInstruction
import cc.unitmesh.devti.settings.AutoDevSettingsState
import cc.unitmesh.devti.settings.custom.teamPromptsSettings
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.project.Project
Expand Down Expand Up @@ -63,6 +69,14 @@ class OpenAIProvider(val project: Project) : LLMProvider {
private val messages: MutableList<ChatMessage> = ArrayList()
private var historyMessageLength: Int = 0

private val recording: Recording
get() {
if (project.teamPromptsSettings.state.recordingInLocal) {
return project.service<JsonlRecording>()
}
return EmptyRecording()
}

override fun clearMessage() {
messages.clear()
historyMessageLength = 0
Expand All @@ -89,24 +103,27 @@ class OpenAIProvider(val project: Project) : LLMProvider {
clearMessage()
}

var output = ""
val completionRequest = prepareRequest(promptText, systemPrompt)

return callbackFlow {
withContext(Dispatchers.IO) {
service.streamChatCompletion(completionRequest)
.doOnError{ error ->
.doOnError { error ->
logger.error("Error in stream", error)
trySend(error.message ?: "Error occurs")
}
.blockingForEach { response ->
if (response.choices.isNotEmpty()) {
val completion = response.choices[0].message
if (completion != null && completion.content != null) {
output += completion.content
trySend(completion.content)
}
}
}

recording.write(RecordingInstruction(promptText, output))
close()
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/main/kotlin/cc/unitmesh/devti/recording/EmptyRecording.kt
@@ -0,0 +1,7 @@
package cc.unitmesh.devti.recording

class EmptyRecording: Recording {
override fun write(instruction: RecordingInstruction) {
// do nothing
}
}
21 changes: 21 additions & 0 deletions src/main/kotlin/cc/unitmesh/devti/recording/JsonlRecording.kt
@@ -0,0 +1,21 @@
package cc.unitmesh.devti.recording

import com.intellij.openapi.components.Service
import com.intellij.openapi.project.Project
import com.intellij.openapi.project.guessProjectDir
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import java.nio.file.Path

@Service(Service.Level.PROJECT)
class JsonlRecording(val project: Project) : Recording {
private val recordingPath: Path = Path.of(project.guessProjectDir()!!.path, "recording.jsonl")
override fun write(instruction: RecordingInstruction) {
if (!recordingPath.toFile().exists()) {
recordingPath.toFile().createNewFile()
}

recordingPath.toFile().appendText(Json.encodeToString(instruction) + "\n")
}
}

5 changes: 5 additions & 0 deletions src/main/kotlin/cc/unitmesh/devti/recording/Recording.kt
@@ -0,0 +1,5 @@
package cc.unitmesh.devti.recording

interface Recording {
fun write(instruction: RecordingInstruction)
}
@@ -0,0 +1,9 @@
package cc.unitmesh.devti.recording

import kotlinx.serialization.Serializable

@Serializable
data class RecordingInstruction(
val instruction: String,
val output: String,
)

0 comments on commit 2fcab1c

Please sign in to comment.