From 20735eef361ee7b7a5e23b4b5f2917a274e1086d Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 24 Jan 2024 17:25:33 +0800 Subject: [PATCH] fix(genius): update SQL generation templates The SQL generation templates in the `genius` package were updated to improve clarity and design. - Modified the `sql-gen-clarify.vm` template to include the phrase "User requirements" instead of just "requirements". - Modified the `sql-gen-design.vm` template to use Markdown syntax for SQL and removed the unnecessary explanation. These changes were made to enhance the user experience and make the generated SQL scripts more readable. --- .../actions/GenSqlScriptBySelection.kt | 38 ++++++++++++------- .../resources/genius/sql/sql-gen-clarify.vm | 8 ++-- .../resources/genius/sql/sql-gen-design.vm | 5 +-- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/exts/database/src/main/kotlin/cc/unitmesh/database/actions/GenSqlScriptBySelection.kt b/exts/database/src/main/kotlin/cc/unitmesh/database/actions/GenSqlScriptBySelection.kt index ce6e418486..3b4b64c5f0 100644 --- a/exts/database/src/main/kotlin/cc/unitmesh/database/actions/GenSqlScriptBySelection.kt +++ b/exts/database/src/main/kotlin/cc/unitmesh/database/actions/GenSqlScriptBySelection.kt @@ -12,7 +12,7 @@ import com.intellij.database.model.ObjectKind import com.intellij.database.psi.DbPsiFacade import com.intellij.database.util.DasUtil import com.intellij.openapi.application.ApplicationManager -import com.intellij.openapi.application.ReadAction +import com.intellij.openapi.command.WriteCommandAction import com.intellij.openapi.diagnostic.logger import com.intellij.openapi.editor.Editor import com.intellij.openapi.progress.ProgressIndicator @@ -69,15 +69,15 @@ class GenSqlScriptBySelection : AbstractChatIntention() { ApplicationManager.getApplication().invokeLater { ProgressManager.getInstance() - .run(generateSqlWorkflow(project, contentPanel, prompter)) + .run(generateSqlWorkflow(project, prompter, editor)) } } } private fun generateSqlWorkflow( project: Project, - ui: ChatCodingPanel, flow: GenSqlFlow, + editor: Editor, ) = object : Task.Backgroundable(project, "Loading retained test failure", true) { override fun run(indicator: ProgressIndicator) { @@ -85,18 +85,21 @@ class GenSqlScriptBySelection : AbstractChatIntention() { indicator.text = AutoDevBundle.message("migration.database.sql.generate.clarify") - val tables = ReadAction.compute { - flow.clarify() - } + val tables = flow.clarify() + logger.info("Tables: $tables") // tables will be list in string format, like: `[table1, table2]`, we need to parse to Lists val tableNames = tables.substringAfter("[").substringBefore("]") .split(", ").map { it.trim() } indicator.fraction = 0.6 + indicator.text = AutoDevBundle.message("migration.database.sql.generate.generate") val sqlScript = flow.generate(tableNames) logger.info("SQL Script: $sqlScript") + WriteCommandAction.runWriteCommandAction(project, "Gen SQL", "cc.unitmesh.livingDoc", { + editor.document.insertString(editor.caretModel.offset, sqlScript) + }) indicator.fraction = 1.0 } @@ -113,9 +116,13 @@ class GenSqlFlow( fun clarify(): String { val stepOnePrompt = generateStepOnePrompt(dbContext, actions) - ui.addMessage(stepOnePrompt, true, stepOnePrompt) - // for answer - ui.addMessage(AutoDevBundle.message("autodev.loading")) + try { + ui.addMessage(stepOnePrompt, true, stepOnePrompt) + // for answer + ui.addMessage(AutoDevBundle.message("autodev.loading")) + } catch (e: Exception) { + logger.error("Error: $e") + } return runBlocking { val prompt = llm.stream(stepOnePrompt, "") @@ -125,9 +132,13 @@ class GenSqlFlow( fun generate(tableNames: List): String { val stepTwoPrompt = generateStepTwoPrompt(dbContext, actions, tableNames) - ui.addMessage(stepTwoPrompt, true, stepTwoPrompt) - // for answer - ui.addMessage(AutoDevBundle.message("autodev.loading")) + try { + ui.addMessage(stepTwoPrompt, true, stepTwoPrompt) + // for answer + ui.addMessage(AutoDevBundle.message("autodev.loading")) + } catch (e: Exception) { + logger.error("Error: $e") + } return runBlocking { val prompt = llm.stream(stepTwoPrompt, "") @@ -135,7 +146,6 @@ class GenSqlFlow( } } - private fun generateStepOnePrompt(context: DbContext, actions: DbContextActionProvider): String { val templateRender = TemplateRender("genius/sql") val template = templateRender.getTemplate("sql-gen-clarify.vm") @@ -155,7 +165,7 @@ class GenSqlFlow( tableInfos: List ): String { val templateRender = TemplateRender("genius/sql") - val template = templateRender.getTemplate("sql-gen-generate.vm") + val template = templateRender.getTemplate("sql-gen-design.vm") dbContext.tableInfos = actions.getTableColumns(tableInfos) diff --git a/src/main/resources/genius/sql/sql-gen-clarify.vm b/src/main/resources/genius/sql/sql-gen-clarify.vm index 5e97e92447..8b177b882b 100644 --- a/src/main/resources/genius/sql/sql-gen-clarify.vm +++ b/src/main/resources/genius/sql/sql-gen-clarify.vm @@ -8,14 +8,14 @@ According to the user's requirements, you should choose the best Tables for the For example: - Question(requirements): calculate the average trip length by subscriber type.// User tables: trips, users, subscriber_type -- Answer: [trips, subscriber_type] +- You should anwser: [trips, subscriber_type] ---- -Here are the requirements: +Here are the User requirements: -``` +```markdown ${context.requirement} -``` +```markdown Please choose the best Tables for the user, just return the table names in a list, no explain. diff --git a/src/main/resources/genius/sql/sql-gen-design.vm b/src/main/resources/genius/sql/sql-gen-design.vm index d39e0bb503..6ad4bce1e8 100644 --- a/src/main/resources/genius/sql/sql-gen-design.vm +++ b/src/main/resources/genius/sql/sql-gen-design.vm @@ -18,10 +18,9 @@ select average_trip_length from subscriber_type where subscriber_type = 'subscri Here are the requirements: -``` +```markdown ${context.requirement} ``` -Please write your SQL here: +Please write your SQL with Markdown syntax, no explanation is needed. : -```sql