diff --git a/exts/database/src/main/kotlin/cc/unitmesh/database/actions/GenSqlScriptBySelection.kt b/exts/database/src/main/kotlin/cc/unitmesh/database/actions/GenSqlScriptBySelection.kt index b9c13d78ad..77f36bb039 100644 --- a/exts/database/src/main/kotlin/cc/unitmesh/database/actions/GenSqlScriptBySelection.kt +++ b/exts/database/src/main/kotlin/cc/unitmesh/database/actions/GenSqlScriptBySelection.kt @@ -1,28 +1,22 @@ package cc.unitmesh.database.actions import cc.unitmesh.database.DbContextActionProvider +import cc.unitmesh.database.flow.GenFlowContext +import cc.unitmesh.database.flow.GenSqlFlow +import cc.unitmesh.database.flow.GenSqlTask import cc.unitmesh.devti.AutoDevBundle -import cc.unitmesh.devti.gui.chat.ChatCodingPanel import cc.unitmesh.devti.gui.sendToChatPanel import cc.unitmesh.devti.intentions.action.base.AbstractChatIntention -import cc.unitmesh.devti.llms.LLMProvider import cc.unitmesh.devti.llms.LlmFactory -import cc.unitmesh.devti.template.TemplateRender -import cc.unitmesh.devti.util.LLMCoroutineScope -import cc.unitmesh.devti.util.parser.parseCodeFromString import com.intellij.database.model.ObjectKind import com.intellij.database.psi.DbPsiFacade import com.intellij.database.util.DasUtil -import com.intellij.openapi.command.WriteCommandAction import com.intellij.openapi.diagnostic.logger import com.intellij.openapi.editor.Editor -import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.progress.ProgressManager -import com.intellij.openapi.progress.Task import com.intellij.openapi.progress.impl.BackgroundableProcessIndicator import com.intellij.openapi.project.Project import com.intellij.psi.PsiFile -import kotlinx.coroutines.runBlocking class GenSqlScriptBySelection : AbstractChatIntention() { @@ -54,7 +48,7 @@ class GenSqlScriptBySelection : AbstractChatIntention() { tables.filter { table -> table.kind == ObjectKind.TABLE && table.dasParent?.name == schemaName } }.toList() - val dbContext = DbContext( + val genFlowContext = GenFlowContext( requirement = selectedText ?: "", databaseVersion = databaseVersion.let { "name: ${it.name}, version: ${it.version}" @@ -67,142 +61,12 @@ class GenSqlScriptBySelection : AbstractChatIntention() { sendToChatPanel(project) { contentPanel, _ -> val llmProvider = LlmFactory().create(project) - val prompter = GenSqlFlow(dbContext, actions, contentPanel, llmProvider, project) + val prompter = GenSqlFlow(genFlowContext, actions, contentPanel, llmProvider, project) - val task = generateSqlWorkflow(project, prompter, editor) + val task = GenSqlTask(project, prompter, editor) ProgressManager.getInstance() .runProcessWithProgressAsynchronously(task, BackgroundableProcessIndicator(task)) } } - - private fun generateSqlWorkflow( - project: Project, - flow: GenSqlFlow, - editor: Editor, - ): Task.Backgroundable { - return object : Task.Backgroundable(project, "Gen SQL", true) { - override fun run(indicator: ProgressIndicator) { - indicator.fraction = 0.2 - - indicator.text = AutoDevBundle.message("migration.database.sql.generate.clarify") - val tables = flow.clarify() - - logger.info("Tables: $tables") - // tables will be list in string format, like: `[table1, table2]`, we need to parse to Lists - val tableNames = tables.substringAfter("[").substringBefore("]") - .split(", ").map { it.trim() } - - if (tableNames.isEmpty()) { - indicator.fraction = 1.0 - val allTables = flow.getAllTables() - logger.warn("no table related: $allTables") - return - } - - indicator.fraction = 0.6 - indicator.text = AutoDevBundle.message("migration.database.sql.generate.generate") - val sqlScript = flow.generate(tableNames) - - logger.info("SQL Script: $sqlScript") - WriteCommandAction.runWriteCommandAction(project, "Gen SQL", "cc.unitmesh.livingDoc", { - // new line - editor.document.insertString(editor.caretModel.offset, "\n") - // insert sql script - val code = parseCodeFromString(sqlScript).first() - editor.document.insertString(editor.caretModel.offset + "\n".length, code) - }) - - indicator.fraction = 1.0 - } - } - } -} - -class GenSqlFlow( - val dbContext: DbContext, - val actions: DbContextActionProvider, - val ui: ChatCodingPanel, - val llm: LLMProvider, - val project: Project -) { - private val logger = logger() - - fun clarify(): String { - val stepOnePrompt = generateStepOnePrompt(dbContext, actions) - - LLMCoroutineScope.scope(project).runCatching { - ui.addMessage(stepOnePrompt, true, stepOnePrompt) - ui.addMessage(AutoDevBundle.message("autodev.loading")) - }.onFailure { - logger.warn("Error: $it") - } - - return runBlocking { - val prompt = llm.stream(stepOnePrompt, "") - return@runBlocking ui.updateMessage(prompt) - } - } - - fun generate(tableNames: List): String { - val stepTwoPrompt = generateStepTwoPrompt(dbContext, actions, tableNames) - - LLMCoroutineScope.scope(project).runCatching { - ui.addMessage(stepTwoPrompt, true, stepTwoPrompt) - ui.addMessage(AutoDevBundle.message("autodev.loading")) - }.onFailure { - logger.warn("Error: $it") - } - - return runBlocking { - val prompt = llm.stream(stepTwoPrompt, "") - return@runBlocking ui.updateMessage(prompt) - } - } - - private fun generateStepOnePrompt(context: DbContext, actions: DbContextActionProvider): String { - val templateRender = TemplateRender("genius/sql") - val template = templateRender.getTemplate("sql-gen-clarify.vm") - - templateRender.context = context - templateRender.actions = actions - - val prompter = templateRender.renderTemplate(template) - - logger.info("Prompt: $prompter") - return prompter - } - - private fun generateStepTwoPrompt( - dbContext: DbContext, - actions: DbContextActionProvider, - tableInfos: List - ): String { - val templateRender = TemplateRender("genius/sql") - val template = templateRender.getTemplate("sql-gen-design.vm") - - dbContext.tableInfos = actions.getTableColumns(tableInfos) - - templateRender.context = dbContext - templateRender.actions = actions - - val prompter = templateRender.renderTemplate(template) - - logger.info("Prompt: $prompter") - return prompter - } - - fun getAllTables(): List { - return actions.dasTables.map { it.name } - } } - -data class DbContext( - val requirement: String, - val databaseVersion: String, - val schemaName: String, - val tableNames: List, - // for step 2 - var tableInfos: List = emptyList(), -) - diff --git a/exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenFlowContext.kt b/exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenFlowContext.kt new file mode 100644 index 0000000000..3973817eb1 --- /dev/null +++ b/exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenFlowContext.kt @@ -0,0 +1,10 @@ +package cc.unitmesh.database.flow + +data class GenFlowContext( + val requirement: String, + val databaseVersion: String, + val schemaName: String, + val tableNames: List, + // for step 2 + var tableInfos: List = emptyList(), +) \ No newline at end of file diff --git a/exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenSqlFlow.kt b/exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenSqlFlow.kt new file mode 100644 index 0000000000..35f22621dd --- /dev/null +++ b/exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenSqlFlow.kt @@ -0,0 +1,89 @@ +package cc.unitmesh.database.flow + +import cc.unitmesh.database.DbContextActionProvider +import cc.unitmesh.devti.AutoDevBundle +import cc.unitmesh.devti.gui.chat.ChatCodingPanel +import cc.unitmesh.devti.llms.LLMProvider +import cc.unitmesh.devti.template.TemplateRender +import cc.unitmesh.devti.util.LLMCoroutineScope +import com.intellij.openapi.diagnostic.logger +import com.intellij.openapi.project.Project +import kotlinx.coroutines.runBlocking + +class GenSqlFlow( + val genFlowContext: GenFlowContext, + val actions: DbContextActionProvider, + val ui: ChatCodingPanel, + val llm: LLMProvider, + val project: Project +) { + private val logger = logger() + + fun clarify(): String { + val stepOnePrompt = generateStepOnePrompt(genFlowContext, actions) + + LLMCoroutineScope.scope(project).runCatching { + ui.addMessage(stepOnePrompt, true, stepOnePrompt) + ui.addMessage(AutoDevBundle.message("autodev.loading")) + }.onFailure { + logger.warn("Error: $it") + } + + return runBlocking { + val prompt = llm.stream(stepOnePrompt, "") + return@runBlocking ui.updateMessage(prompt) + } + } + + fun generate(tableNames: List): String { + val stepTwoPrompt = generateStepTwoPrompt(genFlowContext, actions, tableNames) + + LLMCoroutineScope.scope(project).runCatching { + ui.addMessage(stepTwoPrompt, true, stepTwoPrompt) + ui.addMessage(AutoDevBundle.message("autodev.loading")) + }.onFailure { + logger.warn("Error: $it") + } + + return runBlocking { + val prompt = llm.stream(stepTwoPrompt, "") + return@runBlocking ui.updateMessage(prompt) + } + } + + private fun generateStepOnePrompt(context: GenFlowContext, actions: DbContextActionProvider): String { + val templateRender = TemplateRender("genius/sql") + val template = templateRender.getTemplate("sql-gen-clarify.vm") + + templateRender.context = context + templateRender.actions = actions + + val prompter = templateRender.renderTemplate(template) + + logger.info("Prompt: $prompter") + return prompter + } + + private fun generateStepTwoPrompt( + genFlowContext: GenFlowContext, + actions: DbContextActionProvider, + tableInfos: List + ): String { + val templateRender = TemplateRender("genius/sql") + val template = templateRender.getTemplate("sql-gen-design.vm") + + genFlowContext.tableInfos = actions.getTableColumns(tableInfos) + + templateRender.context = genFlowContext + templateRender.actions = actions + + val prompter = templateRender.renderTemplate(template) + + logger.info("Prompt: $prompter") + return prompter + } + + fun getAllTables(): List { + return actions.dasTables.map { it.name } + } +} \ No newline at end of file diff --git a/exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenSqlTask.kt b/exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenSqlTask.kt new file mode 100644 index 0000000000..f3c12e0239 --- /dev/null +++ b/exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenSqlTask.kt @@ -0,0 +1,52 @@ +package cc.unitmesh.database.flow + +import cc.unitmesh.devti.AutoDevBundle +import cc.unitmesh.devti.util.parser.parseCodeFromString +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.openapi.diagnostic.logger +import com.intellij.openapi.editor.Editor +import com.intellij.openapi.progress.ProgressIndicator +import com.intellij.openapi.progress.Task +import com.intellij.openapi.project.Project + +class GenSqlTask( + private val project: Project, + private val flow: GenSqlFlow, + private val editor: Editor +) : Task.Backgroundable(project, "Gen SQL", true) { + private val logger = logger() + + override fun run(indicator: ProgressIndicator) { + indicator.fraction = 0.2 + + indicator.text = AutoDevBundle.message("migration.database.sql.generate.clarify") + val tables = flow.clarify() + + logger.info("Tables: $tables") + // tables will be list in string format, like: `[table1, table2]`, we need to parse to Lists + val tableNames = tables.substringAfter("[").substringBefore("]") + .split(", ").map { it.trim() } + + if (tableNames.isEmpty()) { + indicator.fraction = 1.0 + val allTables = flow.getAllTables() + logger.warn("no table related: $allTables") + return + } + + indicator.fraction = 0.6 + indicator.text = AutoDevBundle.message("migration.database.sql.generate.generate") + val sqlScript = flow.generate(tableNames) + + logger.info("SQL Script: $sqlScript") + WriteCommandAction.runWriteCommandAction(project, "Gen SQL", "cc.unitmesh.livingDoc", { + // new line + editor.document.insertString(editor.caretModel.offset, "\n") + // insert sql script + val code = parseCodeFromString(sqlScript).first() + editor.document.insertString(editor.caretModel.offset + "\n".length, code) + }) + + indicator.fraction = 1.0 + } +} \ No newline at end of file