Skip to content

Commit

Permalink
refactor(database): split GenSqlFlow and GenFlowContext classes
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Jan 24, 2024
1 parent 927939f commit 95b3272
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 142 deletions.
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
package cc.unitmesh.database.actions

import cc.unitmesh.database.DbContextActionProvider
import cc.unitmesh.database.flow.GenFlowContext
import cc.unitmesh.database.flow.GenSqlFlow
import cc.unitmesh.database.flow.GenSqlTask
import cc.unitmesh.devti.AutoDevBundle
import cc.unitmesh.devti.gui.chat.ChatCodingPanel
import cc.unitmesh.devti.gui.sendToChatPanel
import cc.unitmesh.devti.intentions.action.base.AbstractChatIntention
import cc.unitmesh.devti.llms.LLMProvider
import cc.unitmesh.devti.llms.LlmFactory
import cc.unitmesh.devti.template.TemplateRender
import cc.unitmesh.devti.util.LLMCoroutineScope
import cc.unitmesh.devti.util.parser.parseCodeFromString
import com.intellij.database.model.ObjectKind
import com.intellij.database.psi.DbPsiFacade
import com.intellij.database.util.DasUtil
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.progress.ProgressIndicator
import com.intellij.openapi.progress.ProgressManager
import com.intellij.openapi.progress.Task
import com.intellij.openapi.progress.impl.BackgroundableProcessIndicator
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiFile
import kotlinx.coroutines.runBlocking


class GenSqlScriptBySelection : AbstractChatIntention() {
Expand Down Expand Up @@ -54,7 +48,7 @@ class GenSqlScriptBySelection : AbstractChatIntention() {
tables.filter { table -> table.kind == ObjectKind.TABLE && table.dasParent?.name == schemaName }
}.toList()

val dbContext = DbContext(
val genFlowContext = GenFlowContext(
requirement = selectedText ?: "",
databaseVersion = databaseVersion.let {
"name: ${it.name}, version: ${it.version}"
Expand All @@ -67,142 +61,12 @@ class GenSqlScriptBySelection : AbstractChatIntention() {

sendToChatPanel(project) { contentPanel, _ ->
val llmProvider = LlmFactory().create(project)
val prompter = GenSqlFlow(dbContext, actions, contentPanel, llmProvider, project)
val prompter = GenSqlFlow(genFlowContext, actions, contentPanel, llmProvider, project)

val task = generateSqlWorkflow(project, prompter, editor)
val task = GenSqlTask(project, prompter, editor)
ProgressManager.getInstance()
.runProcessWithProgressAsynchronously(task, BackgroundableProcessIndicator(task))
}
}

private fun generateSqlWorkflow(
project: Project,
flow: GenSqlFlow,
editor: Editor,
): Task.Backgroundable {
return object : Task.Backgroundable(project, "Gen SQL", true) {
override fun run(indicator: ProgressIndicator) {
indicator.fraction = 0.2

indicator.text = AutoDevBundle.message("migration.database.sql.generate.clarify")
val tables = flow.clarify()

logger.info("Tables: $tables")
// tables will be list in string format, like: `[table1, table2]`, we need to parse to Lists
val tableNames = tables.substringAfter("[").substringBefore("]")
.split(", ").map { it.trim() }

if (tableNames.isEmpty()) {
indicator.fraction = 1.0
val allTables = flow.getAllTables()
logger.warn("no table related: $allTables")
return
}

indicator.fraction = 0.6
indicator.text = AutoDevBundle.message("migration.database.sql.generate.generate")
val sqlScript = flow.generate(tableNames)

logger.info("SQL Script: $sqlScript")
WriteCommandAction.runWriteCommandAction(project, "Gen SQL", "cc.unitmesh.livingDoc", {
// new line
editor.document.insertString(editor.caretModel.offset, "\n")
// insert sql script
val code = parseCodeFromString(sqlScript).first()
editor.document.insertString(editor.caretModel.offset + "\n".length, code)
})

indicator.fraction = 1.0
}
}
}
}

class GenSqlFlow(
val dbContext: DbContext,
val actions: DbContextActionProvider,
val ui: ChatCodingPanel,
val llm: LLMProvider,
val project: Project
) {
private val logger = logger<GenSqlFlow>()

fun clarify(): String {
val stepOnePrompt = generateStepOnePrompt(dbContext, actions)

LLMCoroutineScope.scope(project).runCatching {
ui.addMessage(stepOnePrompt, true, stepOnePrompt)
ui.addMessage(AutoDevBundle.message("autodev.loading"))
}.onFailure {
logger.warn("Error: $it")
}

return runBlocking {
val prompt = llm.stream(stepOnePrompt, "")
return@runBlocking ui.updateMessage(prompt)
}
}

fun generate(tableNames: List<String>): String {
val stepTwoPrompt = generateStepTwoPrompt(dbContext, actions, tableNames)

LLMCoroutineScope.scope(project).runCatching {
ui.addMessage(stepTwoPrompt, true, stepTwoPrompt)
ui.addMessage(AutoDevBundle.message("autodev.loading"))
}.onFailure {
logger.warn("Error: $it")
}

return runBlocking {
val prompt = llm.stream(stepTwoPrompt, "")
return@runBlocking ui.updateMessage(prompt)
}
}

private fun generateStepOnePrompt(context: DbContext, actions: DbContextActionProvider): String {
val templateRender = TemplateRender("genius/sql")
val template = templateRender.getTemplate("sql-gen-clarify.vm")

templateRender.context = context
templateRender.actions = actions

val prompter = templateRender.renderTemplate(template)

logger.info("Prompt: $prompter")
return prompter
}

private fun generateStepTwoPrompt(
dbContext: DbContext,
actions: DbContextActionProvider,
tableInfos: List<String>
): String {
val templateRender = TemplateRender("genius/sql")
val template = templateRender.getTemplate("sql-gen-design.vm")

dbContext.tableInfos = actions.getTableColumns(tableInfos)

templateRender.context = dbContext
templateRender.actions = actions

val prompter = templateRender.renderTemplate(template)

logger.info("Prompt: $prompter")
return prompter
}

fun getAllTables(): List<String> {
return actions.dasTables.map { it.name }
}
}


data class DbContext(
val requirement: String,
val databaseVersion: String,
val schemaName: String,
val tableNames: List<String>,
// for step 2
var tableInfos: List<String> = emptyList(),
)

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package cc.unitmesh.database.flow

data class GenFlowContext(
val requirement: String,
val databaseVersion: String,
val schemaName: String,
val tableNames: List<String>,
// for step 2
var tableInfos: List<String> = emptyList(),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package cc.unitmesh.database.flow

import cc.unitmesh.database.DbContextActionProvider
import cc.unitmesh.devti.AutoDevBundle
import cc.unitmesh.devti.gui.chat.ChatCodingPanel
import cc.unitmesh.devti.llms.LLMProvider
import cc.unitmesh.devti.template.TemplateRender
import cc.unitmesh.devti.util.LLMCoroutineScope
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.project.Project
import kotlinx.coroutines.runBlocking

class GenSqlFlow(
val genFlowContext: GenFlowContext,
val actions: DbContextActionProvider,
val ui: ChatCodingPanel,
val llm: LLMProvider,
val project: Project
) {
private val logger = logger<GenSqlFlow>()

fun clarify(): String {
val stepOnePrompt = generateStepOnePrompt(genFlowContext, actions)

LLMCoroutineScope.scope(project).runCatching {
ui.addMessage(stepOnePrompt, true, stepOnePrompt)
ui.addMessage(AutoDevBundle.message("autodev.loading"))
}.onFailure {
logger.warn("Error: $it")
}

return runBlocking {
val prompt = llm.stream(stepOnePrompt, "")
return@runBlocking ui.updateMessage(prompt)
}
}

fun generate(tableNames: List<String>): String {
val stepTwoPrompt = generateStepTwoPrompt(genFlowContext, actions, tableNames)

LLMCoroutineScope.scope(project).runCatching {
ui.addMessage(stepTwoPrompt, true, stepTwoPrompt)
ui.addMessage(AutoDevBundle.message("autodev.loading"))
}.onFailure {
logger.warn("Error: $it")
}

return runBlocking {
val prompt = llm.stream(stepTwoPrompt, "")
return@runBlocking ui.updateMessage(prompt)
}
}

private fun generateStepOnePrompt(context: GenFlowContext, actions: DbContextActionProvider): String {
val templateRender = TemplateRender("genius/sql")
val template = templateRender.getTemplate("sql-gen-clarify.vm")

templateRender.context = context
templateRender.actions = actions

val prompter = templateRender.renderTemplate(template)

logger.info("Prompt: $prompter")
return prompter
}

private fun generateStepTwoPrompt(
genFlowContext: GenFlowContext,
actions: DbContextActionProvider,
tableInfos: List<String>
): String {
val templateRender = TemplateRender("genius/sql")
val template = templateRender.getTemplate("sql-gen-design.vm")

genFlowContext.tableInfos = actions.getTableColumns(tableInfos)

templateRender.context = genFlowContext
templateRender.actions = actions

val prompter = templateRender.renderTemplate(template)

logger.info("Prompt: $prompter")
return prompter
}

fun getAllTables(): List<String> {
return actions.dasTables.map { it.name }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package cc.unitmesh.database.flow

import cc.unitmesh.devti.AutoDevBundle
import cc.unitmesh.devti.util.parser.parseCodeFromString
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.progress.ProgressIndicator
import com.intellij.openapi.progress.Task
import com.intellij.openapi.project.Project

class GenSqlTask(
private val project: Project,
private val flow: GenSqlFlow,
private val editor: Editor
) : Task.Backgroundable(project, "Gen SQL", true) {
private val logger = logger<GenSqlTask>()

override fun run(indicator: ProgressIndicator) {
indicator.fraction = 0.2

indicator.text = AutoDevBundle.message("migration.database.sql.generate.clarify")
val tables = flow.clarify()

logger.info("Tables: $tables")
// tables will be list in string format, like: `[table1, table2]`, we need to parse to Lists
val tableNames = tables.substringAfter("[").substringBefore("]")
.split(", ").map { it.trim() }

if (tableNames.isEmpty()) {
indicator.fraction = 1.0
val allTables = flow.getAllTables()
logger.warn("no table related: $allTables")
return
}

indicator.fraction = 0.6
indicator.text = AutoDevBundle.message("migration.database.sql.generate.generate")
val sqlScript = flow.generate(tableNames)

logger.info("SQL Script: $sqlScript")
WriteCommandAction.runWriteCommandAction(project, "Gen SQL", "cc.unitmesh.livingDoc", {
// new line
editor.document.insertString(editor.caretModel.offset, "\n")
// insert sql script
val code = parseCodeFromString(sqlScript).first()
editor.document.insertString(editor.caretModel.offset + "\n".length, code)
})

indicator.fraction = 1.0
}
}

0 comments on commit 95b3272

Please sign in to comment.