From 2fcab1c532933f461627504eafa11de48098d27a Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Sun, 31 Dec 2023 16:23:46 +0800 Subject: [PATCH] feat: enable recording datasets works in local --- .../devti/llms/azure/AzureOpenAIProvider.kt | 18 ++++++++++++++++ .../devti/llms/openai/OpenAIProvider.kt | 19 ++++++++++++++++- .../devti/recording/EmptyRecording.kt | 7 +++++++ .../devti/recording/JsonlRecording.kt | 21 +++++++++++++++++++ .../cc/unitmesh/devti/recording/Recording.kt | 5 +++++ .../devti/recording/RecordingInstruction.kt | 9 ++++++++ 6 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 src/main/kotlin/cc/unitmesh/devti/recording/EmptyRecording.kt create mode 100644 src/main/kotlin/cc/unitmesh/devti/recording/JsonlRecording.kt create mode 100644 src/main/kotlin/cc/unitmesh/devti/recording/Recording.kt create mode 100644 src/main/kotlin/cc/unitmesh/devti/recording/RecordingInstruction.kt diff --git a/src/main/kotlin/cc/unitmesh/devti/llms/azure/AzureOpenAIProvider.kt b/src/main/kotlin/cc/unitmesh/devti/llms/azure/AzureOpenAIProvider.kt index 2201630ac1..587c326f39 100644 --- a/src/main/kotlin/cc/unitmesh/devti/llms/azure/AzureOpenAIProvider.kt +++ b/src/main/kotlin/cc/unitmesh/devti/llms/azure/AzureOpenAIProvider.kt @@ -3,9 +3,15 @@ package cc.unitmesh.devti.llms.azure import cc.unitmesh.devti.custom.action.CustomPromptConfig import cc.unitmesh.devti.gui.chat.ChatRole import cc.unitmesh.devti.llms.LLMProvider +import cc.unitmesh.devti.recording.EmptyRecording +import cc.unitmesh.devti.recording.JsonlRecording +import cc.unitmesh.devti.recording.Recording +import cc.unitmesh.devti.recording.RecordingInstruction import cc.unitmesh.devti.settings.AutoDevSettingsState +import cc.unitmesh.devti.settings.custom.teamPromptsSettings import com.fasterxml.jackson.databind.ObjectMapper import com.intellij.openapi.components.Service +import com.intellij.openapi.components.service import com.intellij.openapi.diagnostic.logger import com.intellij.openapi.project.Project import com.theokanning.openai.completion.chat.ChatCompletionResult @@ -53,6 +59,14 @@ class AzureOpenAIProvider(val project: Project) : LLMProvider { private val maxTokenLength: Int get() = AutoDevSettingsState.getInstance().fetchMaxTokenLength() + private val recording: Recording + get() { + if (project.teamPromptsSettings.state.recordingInLocal) { + return project.service() + } + return EmptyRecording() + } + init { val prompts = autoDevSettingsState.customPrompts @@ -149,6 +163,8 @@ class AzureOpenAIProvider(val project: Project) : LLMProvider { call.enqueue(cc.unitmesh.devti.llms.azure.ResponseBodyCallback(emitter, emitDone)) }, BackpressureStrategy.BUFFER) + var output = "" + return callbackFlow { sseFlowable .doOnError(Throwable::printStackTrace) @@ -161,6 +177,8 @@ class AzureOpenAIProvider(val project: Project) : LLMProvider { } } + recording.write(RecordingInstruction(promptText, output)) + close() } } diff --git a/src/main/kotlin/cc/unitmesh/devti/llms/openai/OpenAIProvider.kt b/src/main/kotlin/cc/unitmesh/devti/llms/openai/OpenAIProvider.kt index b3c929860e..c597e65284 100644 --- a/src/main/kotlin/cc/unitmesh/devti/llms/openai/OpenAIProvider.kt +++ b/src/main/kotlin/cc/unitmesh/devti/llms/openai/OpenAIProvider.kt @@ -2,8 +2,14 @@ package cc.unitmesh.devti.llms.openai import cc.unitmesh.devti.gui.chat.ChatRole import cc.unitmesh.devti.llms.LLMProvider +import cc.unitmesh.devti.recording.EmptyRecording +import cc.unitmesh.devti.recording.JsonlRecording +import cc.unitmesh.devti.recording.Recording +import cc.unitmesh.devti.recording.RecordingInstruction import cc.unitmesh.devti.settings.AutoDevSettingsState +import cc.unitmesh.devti.settings.custom.teamPromptsSettings import com.intellij.openapi.components.Service +import com.intellij.openapi.components.service import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.diagnostic.logger import com.intellij.openapi.project.Project @@ -63,6 +69,14 @@ class OpenAIProvider(val project: Project) : LLMProvider { private val messages: MutableList = ArrayList() private var historyMessageLength: Int = 0 + private val recording: Recording + get() { + if (project.teamPromptsSettings.state.recordingInLocal) { + return project.service() + } + return EmptyRecording() + } + override fun clearMessage() { messages.clear() historyMessageLength = 0 @@ -89,12 +103,13 @@ class OpenAIProvider(val project: Project) : LLMProvider { clearMessage() } + var output = "" val completionRequest = prepareRequest(promptText, systemPrompt) return callbackFlow { withContext(Dispatchers.IO) { service.streamChatCompletion(completionRequest) - .doOnError{ error -> + .doOnError { error -> logger.error("Error in stream", error) trySend(error.message ?: "Error occurs") } @@ -102,11 +117,13 @@ class OpenAIProvider(val project: Project) : LLMProvider { if (response.choices.isNotEmpty()) { val completion = response.choices[0].message if (completion != null && completion.content != null) { + output += completion.content trySend(completion.content) } } } + recording.write(RecordingInstruction(promptText, output)) close() } } diff --git a/src/main/kotlin/cc/unitmesh/devti/recording/EmptyRecording.kt b/src/main/kotlin/cc/unitmesh/devti/recording/EmptyRecording.kt new file mode 100644 index 0000000000..42bf01e4c0 --- /dev/null +++ b/src/main/kotlin/cc/unitmesh/devti/recording/EmptyRecording.kt @@ -0,0 +1,7 @@ +package cc.unitmesh.devti.recording + +class EmptyRecording: Recording { + override fun write(instruction: RecordingInstruction) { + // do nothing + } +} \ No newline at end of file diff --git a/src/main/kotlin/cc/unitmesh/devti/recording/JsonlRecording.kt b/src/main/kotlin/cc/unitmesh/devti/recording/JsonlRecording.kt new file mode 100644 index 0000000000..6afa4fa137 --- /dev/null +++ b/src/main/kotlin/cc/unitmesh/devti/recording/JsonlRecording.kt @@ -0,0 +1,21 @@ +package cc.unitmesh.devti.recording + +import com.intellij.openapi.components.Service +import com.intellij.openapi.project.Project +import com.intellij.openapi.project.guessProjectDir +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import java.nio.file.Path + +@Service(Service.Level.PROJECT) +class JsonlRecording(val project: Project) : Recording { + private val recordingPath: Path = Path.of(project.guessProjectDir()!!.path, "recording.jsonl") + override fun write(instruction: RecordingInstruction) { + if (!recordingPath.toFile().exists()) { + recordingPath.toFile().createNewFile() + } + + recordingPath.toFile().appendText(Json.encodeToString(instruction) + "\n") + } +} + diff --git a/src/main/kotlin/cc/unitmesh/devti/recording/Recording.kt b/src/main/kotlin/cc/unitmesh/devti/recording/Recording.kt new file mode 100644 index 0000000000..71ad45bbe5 --- /dev/null +++ b/src/main/kotlin/cc/unitmesh/devti/recording/Recording.kt @@ -0,0 +1,5 @@ +package cc.unitmesh.devti.recording + +interface Recording { + fun write(instruction: RecordingInstruction) +} diff --git a/src/main/kotlin/cc/unitmesh/devti/recording/RecordingInstruction.kt b/src/main/kotlin/cc/unitmesh/devti/recording/RecordingInstruction.kt new file mode 100644 index 0000000000..dee5561d4c --- /dev/null +++ b/src/main/kotlin/cc/unitmesh/devti/recording/RecordingInstruction.kt @@ -0,0 +1,9 @@ +package cc.unitmesh.devti.recording + +import kotlinx.serialization.Serializable + +@Serializable +data class RecordingInstruction( + val instruction: String, + val output: String, +) \ No newline at end of file