Skip to content

Commit

Permalink
feat(database): add SQL generation functionality #80
Browse files Browse the repository at this point in the history
This commit adds functionality to generate SQL scripts based on selected tables. It introduces new classes and methods to handle the SQL generation process. The `GenSqlScriptBySelection` class now includes a `generateSqlWorkflow` method that runs the SQL generation process in the background. It also introduces the `GenSqlFlow` class, which handles the step-by-step process of generating SQL scripts. The `clarify` method prompts the user to clarify the requirements, while the `generate` method generates the SQL script based on the clarified requirements and selected tables. This new functionality enhances the database migration capabilities of the application.
  • Loading branch information
phodal committed Jan 24, 2024
1 parent 26227ba commit 0d2e6ce
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 14 deletions.
@@ -1,27 +1,32 @@
package cc.unitmesh.database.actions

import cc.unitmesh.devti.AutoDevBundle
import cc.unitmesh.devti.gui.sendToChatWindow
import cc.unitmesh.devti.gui.chat.ChatCodingPanel
import cc.unitmesh.devti.gui.sendToChatPanel
import cc.unitmesh.devti.intentions.action.base.AbstractChatIntention
import cc.unitmesh.devti.provider.ContextPrompter
import cc.unitmesh.devti.llms.LLMProvider
import cc.unitmesh.devti.llms.LlmFactory
import cc.unitmesh.devti.template.TemplateRender
import com.intellij.database.model.DasTable
import com.intellij.database.model.ObjectKind
import com.intellij.database.psi.DbPsiFacade
import com.intellij.database.util.DasUtil
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.ReadAction
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.progress.ProgressIndicator
import com.intellij.openapi.progress.ProgressManager
import com.intellij.openapi.progress.Task
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiFile
import kotlinx.coroutines.runBlocking


class GenSqlScriptBySelection : AbstractChatIntention() {
override fun priority(): Int = 1001

override fun startInWriteAction(): Boolean = false

override fun getFamilyName(): String = AutoDevBundle.message("migration.database.plsql")

override fun getText(): String = AutoDevBundle.message("migration.database.sql.generate")

override fun isAvailable(project: Project, editor: Editor?, file: PsiFile?): Boolean {
Expand All @@ -32,10 +37,12 @@ class GenSqlScriptBySelection : AbstractChatIntention() {
private val logger = logger<GenSqlScriptBySelection>()

override fun invoke(project: Project, editor: Editor?, file: PsiFile?) {
if (editor == null || file == null) return

val dbPsiFacade = DbPsiFacade.getInstance(project)
val dataSource = dbPsiFacade.dataSources.firstOrNull() ?: return

val selectedText = editor?.selectionModel?.selectedText
val selectedText = editor.selectionModel.selectedText

val rawDataSource = dbPsiFacade.getDataSourceManager(dataSource).dataSources.firstOrNull() ?: return
val databaseVersion = rawDataSource.databaseVersion
Expand All @@ -55,16 +62,80 @@ class GenSqlScriptBySelection : AbstractChatIntention() {
)

val actions = DbContextActionProvider(dasTables)
val prompter = generateStepOnePrompt(dbContext, actions)

sendToChatWindow(project, getActionType()) { panel, service ->
service.handlePromptAndResponse(panel, object : ContextPrompter() {
override fun displayPrompt(): String = prompter
override fun requestPrompt(): String = prompter
}, null, false)
sendToChatPanel(project) { contentPanel, _ ->
val llmProvider = LlmFactory().create(project)
val prompter = GenSqlFlow(dbContext, actions, contentPanel, llmProvider)
ApplicationManager.getApplication().invokeLater {

ProgressManager.getInstance()
.run(generateSqlWorkflow(project, contentPanel, prompter))
}
}
}

private fun generateSqlWorkflow(
project: Project,
ui: ChatCodingPanel,
flow: GenSqlFlow,
) =
object : Task.Backgroundable(project, "Loading retained test failure", true) {
override fun run(indicator: ProgressIndicator) {
indicator.fraction = 0.2


indicator.text = AutoDevBundle.message("migration.database.sql.generate.clarify")
val tables = ReadAction.compute<String, Throwable> {
flow.clarify()
}

// tables will be list in string format, like: `[table1, table2]`, we need to parse to Lists
val tableNames = tables.substringAfter("[").substringBefore("]")
.split(", ").map { it.trim() }

indicator.fraction = 0.6
val sqlScript = flow.generate(tableNames)

logger.info("SQL Script: $sqlScript")

indicator.fraction = 1.0
}
}
}

class GenSqlFlow(
val dbContext: DbContext,
val actions: DbContextActionProvider,
val ui: ChatCodingPanel,
val llm: LLMProvider
) {
private val logger = logger<GenSqlFlow>()

fun clarify(): String {
val stepOnePrompt = generateStepOnePrompt(dbContext, actions)
ui.addMessage(stepOnePrompt, true, stepOnePrompt)
// for answer
ui.addMessage(AutoDevBundle.message("autodev.loading"))

return runBlocking {
val prompt = llm.stream(stepOnePrompt, "")
return@runBlocking ui.updateMessage(prompt)
}
}

fun generate(tableNames: List<String>): String {
val stepTwoPrompt = generateStepTwoPrompt(dbContext, actions, tableNames)
ui.addMessage(stepTwoPrompt, true, stepTwoPrompt)
// for answer
ui.addMessage(AutoDevBundle.message("autodev.loading"))

return runBlocking {
val prompt = llm.stream(stepTwoPrompt, "")
return@runBlocking ui.updateMessage(prompt)
}
}


private fun generateStepOnePrompt(context: DbContext, actions: DbContextActionProvider): String {
val templateRender = TemplateRender("genius/sql")
val template = templateRender.getTemplate("sql-gen-clarify.vm")
Expand All @@ -77,15 +148,35 @@ class GenSqlScriptBySelection : AbstractChatIntention() {
logger.info("Prompt: $prompter")
return prompter
}

private fun generateStepTwoPrompt(
dbContext: DbContext,
actions: DbContextActionProvider,
tableInfos: List<String>
): String {
val templateRender = TemplateRender("genius/sql")
val template = templateRender.getTemplate("sql-gen-generate.vm")

dbContext.tableInfos = actions.getTableColumns(tableInfos)

templateRender.context = dbContext
templateRender.actions = actions

val prompter = templateRender.renderTemplate(template)

logger.info("Prompt: $prompter")
return prompter
}
}


data class DbContext(
val requirement: String,
val databaseVersion: String,
val schemaName: String,
val tableNames: List<String>,
// for step 2
val tableInfos: List<String> = emptyList(),
var tableInfos: List<String> = emptyList(),
)

data class DbContextActionProvider(val dasTables: List<DasTable>) {
Expand Down
Expand Up @@ -105,7 +105,9 @@ class CodeBlockView(
document: Document,
disposable: Disposable
): EditorEx {
val editor: Editor = EditorFactory.getInstance().createViewer(document, project, EditorKind.PREVIEW)
val editor: Editor = EditorFactory.getInstance()
.createViewer(document, project, EditorKind.PREVIEW)

disposable.whenDisposed(disposable) {
EditorFactory.getInstance().releaseEditor(editor)
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/resources/messages/AutoDevBundle.properties
Expand Up @@ -120,3 +120,5 @@ migration.database.plsql.generate.entity=Generate Entity
migration.database.plsql.visual=Visualize PL/SQL
migration.database.plsql.modular.design=Modular Code
migration.database.sql.generate=Generate SQL (by selection)
migration.database.sql.generate.clarify=Clarify Requiements
migration.database.sql.generate.generate=Generate SQL

0 comments on commit 0d2e6ce

Please sign in to comment.