diff --git a/src/main/kotlin/cc/unitmesh/devti/intentions/action/task/TestCodeGenTask.kt b/src/main/kotlin/cc/unitmesh/devti/intentions/action/task/TestCodeGenTask.kt index 318c199288..ff59c3bea0 100644 --- a/src/main/kotlin/cc/unitmesh/devti/intentions/action/task/TestCodeGenTask.kt +++ b/src/main/kotlin/cc/unitmesh/devti/intentions/action/task/TestCodeGenTask.kt @@ -10,6 +10,7 @@ import cc.unitmesh.devti.provider.WriteTestService import cc.unitmesh.devti.provider.context.* import cc.unitmesh.devti.statusbar.AutoDevStatus import cc.unitmesh.devti.statusbar.AutoDevStatusService +import cc.unitmesh.devti.template.TemplateRender import com.intellij.lang.LanguageCommenters import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.application.ReadAction @@ -23,9 +24,18 @@ import com.intellij.openapi.project.Project import com.intellij.openapi.vfs.VirtualFile import kotlinx.coroutines.InternalCoroutinesApi import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.collect import kotlinx.coroutines.runBlocking +data class TestGenPromptContext( + var language: String = "", + var imports: String = "", + var frameworkedContext: String = "", + var currentClass: String = "", + var relatedClasses: String = "", + var testClassName: String = "", + var isNewFile: Boolean = true, +) + class TestCodeGenTask(val request: TestCodeGenRequest) : Task.Backgroundable(request.project, AutoDevBundle.message("intentions.chat.code.test.name")) { @@ -36,6 +46,9 @@ class TestCodeGenTask(val request: TestCodeGenRequest) : val commenter = LanguageCommenters.INSTANCE.forLanguage(request.file.language) ?: null val comment = commenter?.lineCommentPrefix ?: "//" + val templateRender = TemplateRender("genius/code") + val template = templateRender.getTemplate("test-gen.vm") + override fun run(indicator: ProgressIndicator) { indicator.isIndeterminate = true indicator.fraction = 0.1 @@ -55,11 +68,11 @@ class TestCodeGenTask(val request: TestCodeGenRequest) : return } - var prompter = "Write unit test for following $lang code." - indicator.text = AutoDevBundle.message("intentions.chat.code.test.step.collect-context") indicator.fraction = 0.3 + val testPromptContext = TestGenPromptContext() + val creationContext = ChatCreationContext(ChatOrigin.Intention, actionType, request.file, listOf(), element = request.element) @@ -67,47 +80,30 @@ class TestCodeGenTask(val request: TestCodeGenRequest) : return@runBlocking ChatContextProvider.collectChatContextList(request.project, creationContext) } - contextItems.forEach { - prompter += it.text + "\n" - } - - prompter += "\n" - prompter += ReadAction.compute { - if (testContext.relatedClasses.isEmpty()) { - return@compute "" - } - - val relatedClasses = testContext.relatedClasses.joinToString("\n") { + testPromptContext.frameworkedContext = contextItems.joinToString("\n", transform = ChatContextItem::text) + ReadAction.compute { + testPromptContext.relatedClasses = testContext.relatedClasses.joinToString("\n") { it.format() }.lines().joinToString("\n") { "$comment $it" } - "$comment here are related classes:\n$relatedClasses\n" + testPromptContext.currentClass = + runReadAction { testContext.currentClass?.format() }?.lines()?.joinToString("\n") { + "$comment $it" + } ?: "" } - if (testContext.currentClass != null) { - val currentClassInfo = runReadAction { testContext.currentClass.format() }.lines().joinToString("\n") { - "$comment $it" - } - prompter += "\n$comment here is current class information:\n$currentClassInfo\n" - } - - val importString = testContext.imports.joinToString("\n") { + testPromptContext.imports = testContext.imports.joinToString("\n") { "$comment $it" } + testPromptContext.isNewFile = testContext.isNewFile - prompter += "\nCode:\n$importString\n```${lang.lowercase()}\n${request.selectText}\n```\n" - - prompter += if (!testContext.isNewFile) { - "\nStart test code with `@Test` syntax here: \n" - } else { - "\nStart ${testContext.testClassName} with `import` syntax here: \n" - } + templateRender.context = testPromptContext + val prompter = templateRender.renderTemplate(template) logger().info("Prompt: $prompter") - indicator.fraction = 0.8 indicator.text = AutoDevBundle.message("intentions.request.background.process.title") diff --git a/src/main/kotlin/cc/unitmesh/devti/template/TemplateRender.kt b/src/main/kotlin/cc/unitmesh/devti/template/TemplateRender.kt index 2dcfbfb57e..c8cb8f8e28 100644 --- a/src/main/kotlin/cc/unitmesh/devti/template/TemplateRender.kt +++ b/src/main/kotlin/cc/unitmesh/devti/template/TemplateRender.kt @@ -45,6 +45,20 @@ class TemplateRender(pathPrefix: String) { return messages } + + fun renderTemplate(template: String): String { + val oldContextClassLoader = Thread.currentThread().getContextClassLoader() + Thread.currentThread().setContextClassLoader(TemplateRender::class.java.getClassLoader()) + + velocityContext.put("context", context) + val sw = StringWriter() + Velocity.evaluate(velocityContext, sw, "#" + this.javaClass.name, template) + val result = sw.toString() + + Thread.currentThread().setContextClassLoader(oldContextClassLoader) + + return result + } } class TemplateNotFoundError(path: String) : Exception("Prompt not found at path: $path") diff --git a/src/main/kotlin/cc/unitmesh/genius/actions/GenerateGitHubActionsAction.kt b/src/main/kotlin/cc/unitmesh/genius/actions/GenerateGitHubActionsAction.kt index 968cdcf004..797caaacfd 100644 --- a/src/main/kotlin/cc/unitmesh/genius/actions/GenerateGitHubActionsAction.kt +++ b/src/main/kotlin/cc/unitmesh/genius/actions/GenerateGitHubActionsAction.kt @@ -44,7 +44,7 @@ class GenerateGitHubActionsAction : AnAction(AutoDevBundle.message("action.new.g templateRender.context = DevOpsContext.from(githubActions) val template = templateRender.getTemplate("generate-github-action.vm") - val dir = project.guessProjectDir()!!.toNioPath().resolve(".github").resolve("workflows") + project.guessProjectDir()!!.toNioPath().resolve(".github").resolve("workflows") .createDirectories() val msgs = templateRender.buildMsgs(template) diff --git a/src/main/resources/genius/code/test-gen.vm b/src/main/resources/genius/code/test-gen.vm index e69de29bb2..548960633b 100644 --- a/src/main/resources/genius/code/test-gen.vm +++ b/src/main/resources/genius/code/test-gen.vm @@ -0,0 +1,21 @@ +Write unit test for following ${context.lang} code. + +${context.frameworkedContext} + +${context.relatedClasses} + +#if( $context.currentClass.length() > 0 ) +Here is current class information: +${context.currentClass} +#end + +${context.imports} + +## if newFile +#if( $context.isNewFile ) +Start method test code here: +#else +Start ${context.testClassName} with `import` syntax here: +#end + +