diff --git a/README.md b/README.md index 1a18b809fc..d59fd3e8d2 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,9 @@ features: - AI assistant. AutoDev will help you find bug, explain code, trace exception, generate commits, and more. - Custom prompt. You can customize your prompt in `Settings` -> `Tools` -> `AutoDev` - Custom LLM Server. You can customize your LLM Server in `Settings` -> `Tools` -> `AutoDev` +- Auto Testing + - [ ] auto create unit test. + - [ ] auto run unit test and try to fix test. - [ ] Smart architecture. With ArchGuard Co-mate DSL, AutoDev will help you design your architecture. ## Usage diff --git a/java/src/main/kotlin/cc/unitmesh/idea/provider/JavaTestContextProvider.kt b/java/src/main/kotlin/cc/unitmesh/idea/provider/JavaTestContextProvider.kt index 22554c5936..91b01de2ef 100644 --- a/java/src/main/kotlin/cc/unitmesh/idea/provider/JavaTestContextProvider.kt +++ b/java/src/main/kotlin/cc/unitmesh/idea/provider/JavaTestContextProvider.kt @@ -4,15 +4,23 @@ import cc.unitmesh.devti.context.ClassContext import cc.unitmesh.devti.context.ClassContextProvider import cc.unitmesh.devti.provider.TestContextProvider import cc.unitmesh.devti.provider.TestFileContext +import com.intellij.openapi.application.runReadAction +import com.intellij.openapi.application.runWriteAction +import com.intellij.openapi.command.WriteCommandAction import com.intellij.openapi.diagnostic.logger import com.intellij.openapi.project.Project import com.intellij.openapi.project.guessProjectDir import com.intellij.openapi.vfs.LocalFileSystem import com.intellij.openapi.vfs.VirtualFile +import com.intellij.openapi.vfs.VirtualFileManager import com.intellij.psi.* import com.intellij.psi.impl.source.PsiClassReferenceType +import java.io.File class JavaTestContextProvider : TestContextProvider() { + companion object { + val log = logger() + } override fun findOrCreateTestFile(sourceFile: PsiFile, project: Project, element: PsiElement): TestFileContext? { val sourceFilePath = sourceFile.virtualFile @@ -20,11 +28,9 @@ class JavaTestContextProvider : TestContextProvider() { val packageName = (sourceFile as PsiJavaFile).packageName - val relatedModels = lookupRelevantClass(project, element) - // Check if the source file is in the src/main/java directory if (!sourceDir?.path?.contains("/src/main/java/")!!) { - // Not in the src/main/java directory, return null (cannot find test directory) + log.error("Source file is not in the src/main/java directory: $sourceDir") return null } @@ -32,27 +38,36 @@ class JavaTestContextProvider : TestContextProvider() { val testDirPath = sourceDir.path.replace("/src/main/java/", "/src/test/java/") val testDir = LocalFileSystem.getInstance().findFileByPath(testDirPath) - // Check if the test directory exists, if not, create it if (testDir == null || !testDir.isDirectory) { - val testDirCreated = LocalFileSystem.getInstance().refreshAndFindFileByPath(testDirPath) - return if (testDirCreated != null) { - // Successfully created the test directory - val targetFile = createTestFile(sourceFile, testDirCreated, packageName) - TestFileContext(true, targetFile, relatedModels) - } else { - // Failed to create the test directory, return null - null + // Create the test directory if it doesn't exist + val testDirFile = File(testDirPath) + if (!testDirFile.exists()) { + testDirFile.mkdirs() + // Refresh the VirtualFileManager to make sure the newly created directory is visible in IntelliJ + VirtualFileManager.getInstance().refreshWithoutFileWatcher(false) } } + val testDirCreated = LocalFileSystem.getInstance().refreshAndFindFileByPath(testDirPath) + if (testDirCreated == null) { + log.error("Failed to create test directory: $testDirPath") + return null + } + + // Test directory already exists, find the corresponding test file val testFilePath = testDirPath + "/" + sourceFile.name.replace(".java", "Test.java") val testFile = LocalFileSystem.getInstance().findFileByPath(testFilePath) + // update file index + VirtualFileManager.getInstance().syncRefresh() + + val relatedModels = lookupRelevantClass(project, element) + return if (testFile != null) { TestFileContext(false, testFile, relatedModels) } else { - val targetFile = createTestFile(sourceFile, testDir, packageName) + val targetFile = createTestFile(sourceFile, testDir!!, packageName) TestFileContext(true, targetFile, relatedModels) } } @@ -61,7 +76,12 @@ class JavaTestContextProvider : TestContextProvider() { val models = mutableListOf() val projectPath = project.guessProjectDir()?.path - val resolvedClasses = resolveByMethod(element) + val resolvedClasses = try { + resolveByMethod(element) + } catch (e: Exception) { + log.error("Failed to resolve class by method: ${e.message}") + mutableMapOf() + } if (element is PsiClass) { val methods = element.methods @@ -77,10 +97,10 @@ class JavaTestContextProvider : TestContextProvider() { } } - logger().warn("models: $models") return models } + // TODO: handle generic type private fun resolveByMethod(element: PsiElement): MutableMap { val resolvedClasses = mutableMapOf() if (element is PsiMethod) { @@ -90,6 +110,14 @@ class JavaTestContextProvider : TestContextProvider() { val outputType = element.returnTypeElement?.type if (outputType is PsiClassReferenceType) { + if (outputType.parameters.isNotEmpty()) { + outputType.parameters.forEach { + if (it is PsiClassReferenceType) { + resolvedClasses[it.canonicalText] = it.resolve() + } + } + } + val canonicalText = outputType.canonicalText resolvedClasses[canonicalText] = outputType.resolve() } @@ -98,50 +126,69 @@ class JavaTestContextProvider : TestContextProvider() { return resolvedClasses } - override fun insertTestCode(sourceFile: PsiFile, project: Project, methodName: String, code: String): Boolean { - // Get the root element (usually a class) of the source file - val rootElement = sourceFile.children.find { it is PsiClass } as? PsiClass - - if (rootElement != null) { - // Check if a method with the same name already exists - val existingMethod = rootElement.methods.find { it.name == methodName } + override fun insertTestMethod(sourceFile: PsiFile, project: Project, code: String): Boolean { + // Check if the provided methodCode contains @Test annotation + log.info("methodCode: $code") + if (!code.contains("@Test")) { + log.error("methodCode does not contain @Test annotation: $code") + return false + } - if (existingMethod != null) { - // Method with the same name already exists, return an error message -// return "Error: Method with name '$methodName' already exists." - return false - } else { - // Create the new test method - val psiElementFactory = PsiElementFactory.getInstance(project) - val newTestMethod = - psiElementFactory.createMethodFromText("void $methodName() {\n$code\n}", rootElement) + // if code is a class code, we need to insert + if (code.contains("public class ")) { + return insertClassCode(sourceFile, project, code) + } - // Add the new method to the class - val addedMethod = rootElement.add(newTestMethod) as PsiMethod + return runReadAction { + // Check if the root element (usually a class) of the source file is PsiClass + val rootElement = sourceFile.children.find { it is PsiClass } as? PsiClass ?: return@runReadAction false - // Format the newly inserted code - addedMethod.navigate(true) + // Create the new test method + val psiElementFactory = PsiElementFactory.getInstance(project) - // Refresh the project to make the changes visible - project.guessProjectDir()?.refresh(true, true) + val newTestMethod = psiElementFactory.createMethodFromText(code, rootElement) -// return "Success: Method '$methodName' successfully added." - return true + // Check if the method already exists in the class + if (rootElement.findMethodsByName(newTestMethod.name, false).isNotEmpty()) { + log.error("Method already exists in the class: ${newTestMethod.name}") + return@runReadAction false } - } else { -// return "Error: Failed to find the class to insert the method." - return false + + // Add the @Test annotation if it's missing + val modifierList: PsiModifierList = newTestMethod.modifierList + val testAnnotation: PsiAnnotation = psiElementFactory.createAnnotationFromText("@Test", newTestMethod) + modifierList.add(testAnnotation) + + // Insert the new test method into the class + val addedMethod: PsiMethod = rootElement.add(newTestMethod) as PsiMethod + + // Format the newly inserted code + addedMethod.navigate(true) + + // Refresh the project to make the changes visible + project.guessProjectDir()?.refresh(true, true) + + return@runReadAction true } } + override fun insertClassCode(sourceFile: PsiFile, project: Project, code: String): Boolean { + val psiTestFile = PsiManager.getInstance(project).findFile(sourceFile.virtualFile) ?: return false + + WriteCommandAction.runWriteCommandAction(project) { + // add code to test file by string + val document = psiTestFile.viewProvider.document!! + document.insertString(document.textLength, code) + } + + return true + } + private fun createTestFile(sourceFile: PsiFile, testDir: VirtualFile, packageName: String): VirtualFile { - // Create the test file content based on the source file val sourceFileName = sourceFile.name val testFileName = sourceFileName.replace(".java", "Test.java") - val testFileContent = """package $packageName; - |$AUTO_DEV_PLACEHOLDER""".trimMargin() + val testFileContent = "package $packageName;\n\n" - // Create the test file in the test directory val testFile = testDir.createChildData(this, testFileName) testFile.setBinaryContent(testFileContent.toByteArray()) diff --git a/src/main/kotlin/cc/unitmesh/devti/context/ClassContext.kt b/src/main/kotlin/cc/unitmesh/devti/context/ClassContext.kt index 992d77c8f3..9febf9f735 100644 --- a/src/main/kotlin/cc/unitmesh/devti/context/ClassContext.kt +++ b/src/main/kotlin/cc/unitmesh/devti/context/ClassContext.kt @@ -36,7 +36,9 @@ class ClassContext( override fun toQuery(): String { val className = name ?: "_" val classFields = getFieldNames().joinToString(separator = " ") - val classMethods = getMethodSignatures().joinToString(separator = "\n") + val classMethods = getMethodSignatures() + .filter { it.isNotBlank() } + .joinToString(separator = "\n") return "class name: $className\nclass fields: $classFields\nclass methods: $classMethods\nsuper classes: $superClasses\n" } diff --git a/src/main/kotlin/cc/unitmesh/devti/intentions/WriteTestIntention.kt b/src/main/kotlin/cc/unitmesh/devti/intentions/WriteTestIntention.kt index d0e3a2a6d5..2c623438e4 100644 --- a/src/main/kotlin/cc/unitmesh/devti/intentions/WriteTestIntention.kt +++ b/src/main/kotlin/cc/unitmesh/devti/intentions/WriteTestIntention.kt @@ -3,8 +3,8 @@ package cc.unitmesh.devti.intentions import cc.unitmesh.devti.AutoDevBundle import cc.unitmesh.devti.editor.LLMCoroutineScopeService import cc.unitmesh.devti.gui.chat.ChatActionType -import cc.unitmesh.devti.intentions.editor.sendToChat -import cc.unitmesh.devti.provider.ContextPrompter +import cc.unitmesh.devti.llms.ConnectorFactory +import cc.unitmesh.devti.parser.parseCodeFromString import cc.unitmesh.devti.provider.TestContextProvider import cc.unitmesh.devti.provider.context.ChatContextProvider import cc.unitmesh.devti.provider.context.ChatCreationContext @@ -12,10 +12,16 @@ import cc.unitmesh.devti.provider.context.ChatOrigin import com.intellij.openapi.application.WriteAction import com.intellij.openapi.diagnostic.logger import com.intellij.openapi.editor.Editor +import com.intellij.openapi.fileEditor.FileEditorManager import com.intellij.openapi.project.Project +import com.intellij.openapi.vfs.VirtualFile import com.intellij.psi.PsiFile +import com.intellij.psi.PsiManager +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.collect import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking +import org.jetbrains.kotlin.idea.core.util.toPsiFile class WriteTestIntention : AbstractChatIntention() { override fun getText(): String = AutoDevBundle.message("intentions.chat.code.test.name") @@ -35,14 +41,15 @@ class WriteTestIntention : AbstractChatIntention() { LLMCoroutineScopeService.scope(project).launch { WriteAction.runAndWait { - val testContext = TestContextProvider.context(lang)?.findOrCreateTestFile(file, project, element) + val testContextProvider = TestContextProvider.context(lang) + val testContext = testContextProvider?.findOrCreateTestFile(file, project, element) if (testContext == null) { logger().error("Failed to create test file for: $file") return@runAndWait } runBlocking { - var prompter = "Write test for ${testContext.file.name}, ${testContext.file.path}." + var prompter = "Write unit test for following code. You MUST return code only, not explain.\n\n" val creationContext = ChatCreationContext(ChatOrigin.Intention, actionType, file) val chatContextItems = ChatContextProvider.collectChatContextList(project, creationContext) @@ -50,11 +57,11 @@ class WriteTestIntention : AbstractChatIntention() { prompter += it.text } - val additionContext = testContext.relatedClass.map { + val additionContext = testContext.relatedClass.joinToString("\n") { it.toQuery() - }.joinToString("\n").lines().map { + }.lines().joinToString("\n") { "// $it" - }.joinToString("\n") + } prompter += additionContext @@ -63,18 +70,34 @@ class WriteTestIntention : AbstractChatIntention() { |``` |""" + // navigate to the test file + navigateTestFile(testContext.file, editor, project) - sendToChat(project, actionType, object : ContextPrompter() { - override fun displayPrompt(): String { - return prompter + val flow: Flow = ConnectorFactory.getInstance().connector(project).stream(prompter, "") + logger().warn("Prompt: $prompter") + LLMCoroutineScopeService.scope(project).launch { + val suggestion = StringBuilder() + flow.collect { + suggestion.append(it) } - override fun requestPrompt(): String { - return prompter + parseCodeFromString(suggestion.toString()).forEach { + val testFile: PsiFile = PsiManager.getInstance(project).findFile(testContext.file)!! + testContextProvider.insertTestMethod(testFile, project, it) } - }) + } } } } } + + private fun navigateTestFile(testFile: VirtualFile, editor: Editor, project: Project) { + val fileEditorManager = FileEditorManager.getInstance(project) + val editors = fileEditorManager.openFile(testFile, true) + + // If the file is already open in the editor, focus on the editor tab + if (editors.isNotEmpty()) { + fileEditorManager.setSelectedEditor(testFile, "text-editor") + } + } } diff --git a/src/main/kotlin/cc/unitmesh/devti/intentions/task/CodeCompletionTask.kt b/src/main/kotlin/cc/unitmesh/devti/intentions/task/CodeCompletionTask.kt index 28dd317753..cfb42c2e90 100644 --- a/src/main/kotlin/cc/unitmesh/devti/intentions/task/CodeCompletionTask.kt +++ b/src/main/kotlin/cc/unitmesh/devti/intentions/task/CodeCompletionTask.kt @@ -6,22 +6,16 @@ import cc.unitmesh.devti.llms.ConnectorFactory import cc.unitmesh.devti.editor.LLMCoroutineScopeService import cc.unitmesh.devti.intentions.CodeCompletionIntention import com.intellij.lang.LanguageCommenters -import com.intellij.openapi.Disposable import com.intellij.openapi.application.invokeLater import com.intellij.openapi.command.WriteCommandAction import com.intellij.openapi.diagnostic.logger import com.intellij.openapi.editor.Document -import com.intellij.openapi.editor.Editor import com.intellij.openapi.editor.ScrollType -import com.intellij.openapi.editor.ex.DocumentEx import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.progress.Task import com.intellij.openapi.project.Project -import com.intellij.openapi.util.Disposer import com.intellij.openapi.util.TextRange -import com.intellij.openapi.vfs.VirtualFile import com.intellij.psi.PsiDocumentManager -import com.intellij.psi.PsiElement import com.intellij.psi.codeStyle.CodeStyleManager import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.collect @@ -30,64 +24,6 @@ import java.util.function.Consumer import kotlin.jvm.internal.Ref -class CompletionTaskRequest( - val project: Project, - val useTabIndents: Boolean, - val tabWidth: Int, - val fileUri: VirtualFile, - val documentContent: String, - val offset: Int, - val documentVersion: Long, - val element: PsiElement, - val editor: Editor -) : Disposable { - companion object { - fun create(editor: Editor, offset: Int, element: PsiElement, prefix: String?): CompletionTaskRequest? { - val project = editor.project ?: return null - - val document = editor.document - val file = PsiDocumentManager.getInstance(project).getPsiFile(document) ?: return null - - val useTabs = editor.settings.isUseTabCharacter(project) - val tabWidth = editor.settings.getTabSize(project) - val uri = file.virtualFile - val documentVersion = if (document is DocumentEx) { - document.modificationSequence.toLong() - } else { - document.modificationStamp - } - - return CompletionTaskRequest( - project, - useTabs, - tabWidth, - uri, - prefix ?: document.text, - offset, - documentVersion, - element, - editor - ) - - } - } - - @Volatile - var isCancelled = false - - fun cancel() { - if (isCancelled) { - return - } - isCancelled = true - Disposer.dispose(this) - } - - override fun dispose() { - isCancelled = true - } -} - class CodeCompletionTask( private val request: CompletionTaskRequest, ) : Task.Backgroundable(request.project, AutoDevBundle.message("intentions.chat.code.complete.name")) { @@ -107,6 +43,7 @@ class CodeCompletionTask( val flow: Flow = connectorFactory.connector(request.project).stream(prompt, "") logger.warn("Prompt: $prompt") + val editor = request.editor LLMCoroutineScopeService.scope(request.project).launch { val currentOffset = Ref.IntRef() currentOffset.element = request.offset @@ -125,7 +62,7 @@ class CodeCompletionTask( insertStringAndSaveChange( project, it, - request.editor.document, + editor.document, currentOffset.element, false ) @@ -133,8 +70,8 @@ class CodeCompletionTask( ) currentOffset.element += it.length - request.editor.caretModel.moveToOffset(currentOffset.element) - request.editor.scrollingModel.scrollToCaret(ScrollType.MAKE_VISIBLE) + editor.caretModel.moveToOffset(currentOffset.element) + editor.scrollingModel.scrollToCaret(ScrollType.MAKE_VISIBLE) } } diff --git a/src/main/kotlin/cc/unitmesh/devti/intentions/task/CompletionTaskRequest.kt b/src/main/kotlin/cc/unitmesh/devti/intentions/task/CompletionTaskRequest.kt new file mode 100644 index 0000000000..db4d09b3fb --- /dev/null +++ b/src/main/kotlin/cc/unitmesh/devti/intentions/task/CompletionTaskRequest.kt @@ -0,0 +1,68 @@ +package cc.unitmesh.devti.intentions.task + +import com.intellij.openapi.Disposable +import com.intellij.openapi.editor.Editor +import com.intellij.openapi.editor.ex.DocumentEx +import com.intellij.openapi.project.Project +import com.intellij.openapi.util.Disposer +import com.intellij.openapi.vfs.VirtualFile +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiElement + +class CompletionTaskRequest( + val project: Project, + val useTabIndents: Boolean, + val tabWidth: Int, + val fileUri: VirtualFile, + val documentContent: String, + val offset: Int, + val documentVersion: Long, + val element: PsiElement, + val editor: Editor +) : Disposable { + companion object { + fun create(editor: Editor, offset: Int, element: PsiElement, prefix: String?): CompletionTaskRequest? { + val project = editor.project ?: return null + + val document = editor.document + val file = PsiDocumentManager.getInstance(project).getPsiFile(document) ?: return null + + val useTabs = editor.settings.isUseTabCharacter(project) + val tabWidth = editor.settings.getTabSize(project) + val uri = file.virtualFile + val documentVersion = if (document is DocumentEx) { + document.modificationSequence.toLong() + } else { + document.modificationStamp + } + + return CompletionTaskRequest( + project, + useTabs, + tabWidth, + uri, + prefix ?: document.text, + offset, + documentVersion, + element, + editor + ) + + } + } + + @Volatile + var isCancelled = false + + fun cancel() { + if (isCancelled) { + return + } + isCancelled = true + Disposer.dispose(this) + } + + override fun dispose() { + isCancelled = true + } +} \ No newline at end of file diff --git a/src/main/kotlin/cc/unitmesh/devti/provider/TestContextProvider.kt b/src/main/kotlin/cc/unitmesh/devti/provider/TestContextProvider.kt index 3242803f79..f4609ffa2c 100644 --- a/src/main/kotlin/cc/unitmesh/devti/provider/TestContextProvider.kt +++ b/src/main/kotlin/cc/unitmesh/devti/provider/TestContextProvider.kt @@ -32,7 +32,8 @@ abstract class TestContextProvider : LazyExtensionInstance( abstract fun lookupRelevantClass(project: Project, element: PsiElement): List - abstract fun insertTestCode(sourceFile: PsiFile, project: Project, methodName: String, code: String): Boolean + abstract fun insertTestMethod(sourceFile: PsiFile, project: Project, methodCode: String): Boolean + abstract fun insertClassCode(sourceFile: PsiFile, project: Project, code: String): Boolean companion object { private val EP_NAME: ExtensionPointName =