Skip to content

Commit

Permalink
fix(genius): update SQL generation templates
Browse files Browse the repository at this point in the history
The SQL generation templates in the `genius` package were updated to improve clarity and design.

- Modified the `sql-gen-clarify.vm` template to include the phrase "User requirements" instead of just "requirements".
- Modified the `sql-gen-design.vm` template to use Markdown syntax for SQL and removed the unnecessary explanation.

These changes were made to enhance the user experience and make the generated SQL scripts more readable.
  • Loading branch information
phodal committed Jan 24, 2024
1 parent 0d2e6ce commit 20735ee
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
Expand Up @@ -12,7 +12,7 @@ import com.intellij.database.model.ObjectKind
import com.intellij.database.psi.DbPsiFacade
import com.intellij.database.util.DasUtil
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.ReadAction
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.progress.ProgressIndicator
Expand Down Expand Up @@ -69,34 +69,37 @@ class GenSqlScriptBySelection : AbstractChatIntention() {
ApplicationManager.getApplication().invokeLater {

ProgressManager.getInstance()
.run(generateSqlWorkflow(project, contentPanel, prompter))
.run(generateSqlWorkflow(project, prompter, editor))
}
}
}

private fun generateSqlWorkflow(
project: Project,
ui: ChatCodingPanel,
flow: GenSqlFlow,
editor: Editor,
) =
object : Task.Backgroundable(project, "Loading retained test failure", true) {
override fun run(indicator: ProgressIndicator) {
indicator.fraction = 0.2


indicator.text = AutoDevBundle.message("migration.database.sql.generate.clarify")
val tables = ReadAction.compute<String, Throwable> {
flow.clarify()
}
val tables = flow.clarify()

logger.info("Tables: $tables")
// tables will be list in string format, like: `[table1, table2]`, we need to parse to Lists
val tableNames = tables.substringAfter("[").substringBefore("]")
.split(", ").map { it.trim() }

indicator.fraction = 0.6
indicator.text = AutoDevBundle.message("migration.database.sql.generate.generate")
val sqlScript = flow.generate(tableNames)

logger.info("SQL Script: $sqlScript")
WriteCommandAction.runWriteCommandAction(project, "Gen SQL", "cc.unitmesh.livingDoc", {
editor.document.insertString(editor.caretModel.offset, sqlScript)
})

indicator.fraction = 1.0
}
Expand All @@ -113,9 +116,13 @@ class GenSqlFlow(

fun clarify(): String {
val stepOnePrompt = generateStepOnePrompt(dbContext, actions)
ui.addMessage(stepOnePrompt, true, stepOnePrompt)
// for answer
ui.addMessage(AutoDevBundle.message("autodev.loading"))
try {
ui.addMessage(stepOnePrompt, true, stepOnePrompt)
// for answer
ui.addMessage(AutoDevBundle.message("autodev.loading"))
} catch (e: Exception) {
logger.error("Error: $e")
}

return runBlocking {
val prompt = llm.stream(stepOnePrompt, "")
Expand All @@ -125,17 +132,20 @@ class GenSqlFlow(

fun generate(tableNames: List<String>): String {
val stepTwoPrompt = generateStepTwoPrompt(dbContext, actions, tableNames)
ui.addMessage(stepTwoPrompt, true, stepTwoPrompt)
// for answer
ui.addMessage(AutoDevBundle.message("autodev.loading"))
try {
ui.addMessage(stepTwoPrompt, true, stepTwoPrompt)
// for answer
ui.addMessage(AutoDevBundle.message("autodev.loading"))
} catch (e: Exception) {
logger.error("Error: $e")
}

return runBlocking {
val prompt = llm.stream(stepTwoPrompt, "")
return@runBlocking ui.updateMessage(prompt)
}
}


private fun generateStepOnePrompt(context: DbContext, actions: DbContextActionProvider): String {
val templateRender = TemplateRender("genius/sql")
val template = templateRender.getTemplate("sql-gen-clarify.vm")
Expand All @@ -155,7 +165,7 @@ class GenSqlFlow(
tableInfos: List<String>
): String {
val templateRender = TemplateRender("genius/sql")
val template = templateRender.getTemplate("sql-gen-generate.vm")
val template = templateRender.getTemplate("sql-gen-design.vm")

dbContext.tableInfos = actions.getTableColumns(tableInfos)

Expand Down
8 changes: 4 additions & 4 deletions src/main/resources/genius/sql/sql-gen-clarify.vm
Expand Up @@ -8,14 +8,14 @@ According to the user's requirements, you should choose the best Tables for the
For example:

- Question(requirements): calculate the average trip length by subscriber type.// User tables: trips, users, subscriber_type
- Answer: [trips, subscriber_type]
- You should anwser: [trips, subscriber_type]

----

Here are the requirements:
Here are the User requirements:

```
```markdown
${context.requirement}
```
```markdown

Please choose the best Tables for the user, just return the table names in a list, no explain.
5 changes: 2 additions & 3 deletions src/main/resources/genius/sql/sql-gen-design.vm
Expand Up @@ -18,10 +18,9 @@ select average_trip_length from subscriber_type where subscriber_type = 'subscri

Here are the requirements:

```
```markdown
${context.requirement}
```

Please write your SQL here:
Please write your SQL with Markdown syntax, no explanation is needed. :

```sql

0 comments on commit 20735ee

Please sign in to comment.