generated from JetBrains/intellij-platform-plugin-template
-
Notifications
You must be signed in to change notification settings - Fork 295
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(database): split GenSqlFlow and GenFlowContext classes
- Loading branch information
Showing
4 changed files
with
157 additions
and
142 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
10 changes: 10 additions & 0 deletions
10
exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenFlowContext.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package cc.unitmesh.database.flow | ||
|
||
data class GenFlowContext( | ||
val requirement: String, | ||
val databaseVersion: String, | ||
val schemaName: String, | ||
val tableNames: List<String>, | ||
// for step 2 | ||
var tableInfos: List<String> = emptyList(), | ||
) |
89 changes: 89 additions & 0 deletions
89
exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenSqlFlow.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
package cc.unitmesh.database.flow | ||
|
||
import cc.unitmesh.database.DbContextActionProvider | ||
import cc.unitmesh.devti.AutoDevBundle | ||
import cc.unitmesh.devti.gui.chat.ChatCodingPanel | ||
import cc.unitmesh.devti.llms.LLMProvider | ||
import cc.unitmesh.devti.template.TemplateRender | ||
import cc.unitmesh.devti.util.LLMCoroutineScope | ||
import com.intellij.openapi.diagnostic.logger | ||
import com.intellij.openapi.project.Project | ||
import kotlinx.coroutines.runBlocking | ||
|
||
class GenSqlFlow( | ||
val genFlowContext: GenFlowContext, | ||
val actions: DbContextActionProvider, | ||
val ui: ChatCodingPanel, | ||
val llm: LLMProvider, | ||
val project: Project | ||
) { | ||
private val logger = logger<GenSqlFlow>() | ||
|
||
fun clarify(): String { | ||
val stepOnePrompt = generateStepOnePrompt(genFlowContext, actions) | ||
|
||
LLMCoroutineScope.scope(project).runCatching { | ||
ui.addMessage(stepOnePrompt, true, stepOnePrompt) | ||
ui.addMessage(AutoDevBundle.message("autodev.loading")) | ||
}.onFailure { | ||
logger.warn("Error: $it") | ||
} | ||
|
||
return runBlocking { | ||
val prompt = llm.stream(stepOnePrompt, "") | ||
return@runBlocking ui.updateMessage(prompt) | ||
} | ||
} | ||
|
||
fun generate(tableNames: List<String>): String { | ||
val stepTwoPrompt = generateStepTwoPrompt(genFlowContext, actions, tableNames) | ||
|
||
LLMCoroutineScope.scope(project).runCatching { | ||
ui.addMessage(stepTwoPrompt, true, stepTwoPrompt) | ||
ui.addMessage(AutoDevBundle.message("autodev.loading")) | ||
}.onFailure { | ||
logger.warn("Error: $it") | ||
} | ||
|
||
return runBlocking { | ||
val prompt = llm.stream(stepTwoPrompt, "") | ||
return@runBlocking ui.updateMessage(prompt) | ||
} | ||
} | ||
|
||
private fun generateStepOnePrompt(context: GenFlowContext, actions: DbContextActionProvider): String { | ||
val templateRender = TemplateRender("genius/sql") | ||
val template = templateRender.getTemplate("sql-gen-clarify.vm") | ||
|
||
templateRender.context = context | ||
templateRender.actions = actions | ||
|
||
val prompter = templateRender.renderTemplate(template) | ||
|
||
logger.info("Prompt: $prompter") | ||
return prompter | ||
} | ||
|
||
private fun generateStepTwoPrompt( | ||
genFlowContext: GenFlowContext, | ||
actions: DbContextActionProvider, | ||
tableInfos: List<String> | ||
): String { | ||
val templateRender = TemplateRender("genius/sql") | ||
val template = templateRender.getTemplate("sql-gen-design.vm") | ||
|
||
genFlowContext.tableInfos = actions.getTableColumns(tableInfos) | ||
|
||
templateRender.context = genFlowContext | ||
templateRender.actions = actions | ||
|
||
val prompter = templateRender.renderTemplate(template) | ||
|
||
logger.info("Prompt: $prompter") | ||
return prompter | ||
} | ||
|
||
fun getAllTables(): List<String> { | ||
return actions.dasTables.map { it.name } | ||
} | ||
} |
52 changes: 52 additions & 0 deletions
52
exts/database/src/main/kotlin/cc/unitmesh/database/flow/GenSqlTask.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package cc.unitmesh.database.flow | ||
|
||
import cc.unitmesh.devti.AutoDevBundle | ||
import cc.unitmesh.devti.util.parser.parseCodeFromString | ||
import com.intellij.openapi.command.WriteCommandAction | ||
import com.intellij.openapi.diagnostic.logger | ||
import com.intellij.openapi.editor.Editor | ||
import com.intellij.openapi.progress.ProgressIndicator | ||
import com.intellij.openapi.progress.Task | ||
import com.intellij.openapi.project.Project | ||
|
||
class GenSqlTask( | ||
private val project: Project, | ||
private val flow: GenSqlFlow, | ||
private val editor: Editor | ||
) : Task.Backgroundable(project, "Gen SQL", true) { | ||
private val logger = logger<GenSqlTask>() | ||
|
||
override fun run(indicator: ProgressIndicator) { | ||
indicator.fraction = 0.2 | ||
|
||
indicator.text = AutoDevBundle.message("migration.database.sql.generate.clarify") | ||
val tables = flow.clarify() | ||
|
||
logger.info("Tables: $tables") | ||
// tables will be list in string format, like: `[table1, table2]`, we need to parse to Lists | ||
val tableNames = tables.substringAfter("[").substringBefore("]") | ||
.split(", ").map { it.trim() } | ||
|
||
if (tableNames.isEmpty()) { | ||
indicator.fraction = 1.0 | ||
val allTables = flow.getAllTables() | ||
logger.warn("no table related: $allTables") | ||
return | ||
} | ||
|
||
indicator.fraction = 0.6 | ||
indicator.text = AutoDevBundle.message("migration.database.sql.generate.generate") | ||
val sqlScript = flow.generate(tableNames) | ||
|
||
logger.info("SQL Script: $sqlScript") | ||
WriteCommandAction.runWriteCommandAction(project, "Gen SQL", "cc.unitmesh.livingDoc", { | ||
// new line | ||
editor.document.insertString(editor.caretModel.offset, "\n") | ||
// insert sql script | ||
val code = parseCodeFromString(sqlScript).first() | ||
editor.document.insertString(editor.caretModel.offset + "\n".length, code) | ||
}) | ||
|
||
indicator.fraction = 1.0 | ||
} | ||
} |