Skip to content

Commit

Permalink
feat(test): auto insert to code
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Jul 27, 2023
1 parent 62963e1 commit d48e704
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 128 deletions.
3 changes: 3 additions & 0 deletions README.md
Expand Up @@ -31,6 +31,9 @@ features:
- AI assistant. AutoDev will help you find bug, explain code, trace exception, generate commits, and more.
- Custom prompt. You can customize your prompt in `Settings` -> `Tools` -> `AutoDev`
- Custom LLM Server. You can customize your LLM Server in `Settings` -> `Tools` -> `AutoDev`
- Auto Testing
- [ ] auto create unit test.
- [ ] auto run unit test and try to fix test.
- [ ] Smart architecture. With ArchGuard Co-mate DSL, AutoDev will help you design your architecture.

## Usage
Expand Down
Expand Up @@ -4,55 +4,70 @@ import cc.unitmesh.devti.context.ClassContext
import cc.unitmesh.devti.context.ClassContextProvider
import cc.unitmesh.devti.provider.TestContextProvider
import cc.unitmesh.devti.provider.TestFileContext
import com.intellij.openapi.application.runReadAction
import com.intellij.openapi.application.runWriteAction
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.project.Project
import com.intellij.openapi.project.guessProjectDir
import com.intellij.openapi.vfs.LocalFileSystem
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.openapi.vfs.VirtualFileManager
import com.intellij.psi.*
import com.intellij.psi.impl.source.PsiClassReferenceType
import java.io.File

class JavaTestContextProvider : TestContextProvider() {
companion object {
val log = logger<JavaTestContextProvider>()
}

override fun findOrCreateTestFile(sourceFile: PsiFile, project: Project, element: PsiElement): TestFileContext? {
val sourceFilePath = sourceFile.virtualFile
val sourceDir = sourceFilePath.parent

val packageName = (sourceFile as PsiJavaFile).packageName

val relatedModels = lookupRelevantClass(project, element)

// Check if the source file is in the src/main/java directory
if (!sourceDir?.path?.contains("/src/main/java/")!!) {
// Not in the src/main/java directory, return null (cannot find test directory)
log.error("Source file is not in the src/main/java directory: $sourceDir")
return null
}

// Find the test directory
val testDirPath = sourceDir.path.replace("/src/main/java/", "/src/test/java/")
val testDir = LocalFileSystem.getInstance().findFileByPath(testDirPath)

// Check if the test directory exists, if not, create it
if (testDir == null || !testDir.isDirectory) {
val testDirCreated = LocalFileSystem.getInstance().refreshAndFindFileByPath(testDirPath)
return if (testDirCreated != null) {
// Successfully created the test directory
val targetFile = createTestFile(sourceFile, testDirCreated, packageName)
TestFileContext(true, targetFile, relatedModels)
} else {
// Failed to create the test directory, return null
null
// Create the test directory if it doesn't exist
val testDirFile = File(testDirPath)
if (!testDirFile.exists()) {
testDirFile.mkdirs()
// Refresh the VirtualFileManager to make sure the newly created directory is visible in IntelliJ
VirtualFileManager.getInstance().refreshWithoutFileWatcher(false)
}
}

val testDirCreated = LocalFileSystem.getInstance().refreshAndFindFileByPath(testDirPath)
if (testDirCreated == null) {
log.error("Failed to create test directory: $testDirPath")
return null
}


// Test directory already exists, find the corresponding test file
val testFilePath = testDirPath + "/" + sourceFile.name.replace(".java", "Test.java")
val testFile = LocalFileSystem.getInstance().findFileByPath(testFilePath)

// update file index
VirtualFileManager.getInstance().syncRefresh()

val relatedModels = lookupRelevantClass(project, element)

return if (testFile != null) {
TestFileContext(false, testFile, relatedModels)
} else {
val targetFile = createTestFile(sourceFile, testDir, packageName)
val targetFile = createTestFile(sourceFile, testDir!!, packageName)
TestFileContext(true, targetFile, relatedModels)
}
}
Expand All @@ -61,7 +76,12 @@ class JavaTestContextProvider : TestContextProvider() {
val models = mutableListOf<ClassContext>()
val projectPath = project.guessProjectDir()?.path

val resolvedClasses = resolveByMethod(element)
val resolvedClasses = try {
resolveByMethod(element)
} catch (e: Exception) {
log.error("Failed to resolve class by method: ${e.message}")
mutableMapOf()
}

if (element is PsiClass) {
val methods = element.methods
Expand All @@ -77,10 +97,10 @@ class JavaTestContextProvider : TestContextProvider() {
}
}

logger<JavaTestContextProvider>().warn("models: $models")
return models
}

// TODO: handle generic type
private fun resolveByMethod(element: PsiElement): MutableMap<String, PsiClass?> {
val resolvedClasses = mutableMapOf<String, PsiClass?>()
if (element is PsiMethod) {
Expand All @@ -90,6 +110,14 @@ class JavaTestContextProvider : TestContextProvider() {

val outputType = element.returnTypeElement?.type
if (outputType is PsiClassReferenceType) {
if (outputType.parameters.isNotEmpty()) {
outputType.parameters.forEach {
if (it is PsiClassReferenceType) {
resolvedClasses[it.canonicalText] = it.resolve()
}
}
}

val canonicalText = outputType.canonicalText
resolvedClasses[canonicalText] = outputType.resolve()
}
Expand All @@ -98,50 +126,69 @@ class JavaTestContextProvider : TestContextProvider() {
return resolvedClasses
}

override fun insertTestCode(sourceFile: PsiFile, project: Project, methodName: String, code: String): Boolean {
// Get the root element (usually a class) of the source file
val rootElement = sourceFile.children.find { it is PsiClass } as? PsiClass

if (rootElement != null) {
// Check if a method with the same name already exists
val existingMethod = rootElement.methods.find { it.name == methodName }
override fun insertTestMethod(sourceFile: PsiFile, project: Project, code: String): Boolean {
// Check if the provided methodCode contains @Test annotation
log.info("methodCode: $code")
if (!code.contains("@Test")) {
log.error("methodCode does not contain @Test annotation: $code")
return false
}

if (existingMethod != null) {
// Method with the same name already exists, return an error message
// return "Error: Method with name '$methodName' already exists."
return false
} else {
// Create the new test method
val psiElementFactory = PsiElementFactory.getInstance(project)
val newTestMethod =
psiElementFactory.createMethodFromText("void $methodName() {\n$code\n}", rootElement)
// if code is a class code, we need to insert
if (code.contains("public class ")) {
return insertClassCode(sourceFile, project, code)
}

// Add the new method to the class
val addedMethod = rootElement.add(newTestMethod) as PsiMethod
return runReadAction {
// Check if the root element (usually a class) of the source file is PsiClass
val rootElement = sourceFile.children.find { it is PsiClass } as? PsiClass ?: return@runReadAction false

// Format the newly inserted code
addedMethod.navigate(true)
// Create the new test method
val psiElementFactory = PsiElementFactory.getInstance(project)

// Refresh the project to make the changes visible
project.guessProjectDir()?.refresh(true, true)
val newTestMethod = psiElementFactory.createMethodFromText(code, rootElement)

// return "Success: Method '$methodName' successfully added."
return true
// Check if the method already exists in the class
if (rootElement.findMethodsByName(newTestMethod.name, false).isNotEmpty()) {
log.error("Method already exists in the class: ${newTestMethod.name}")
return@runReadAction false
}
} else {
// return "Error: Failed to find the class to insert the method."
return false

// Add the @Test annotation if it's missing
val modifierList: PsiModifierList = newTestMethod.modifierList
val testAnnotation: PsiAnnotation = psiElementFactory.createAnnotationFromText("@Test", newTestMethod)
modifierList.add(testAnnotation)

// Insert the new test method into the class
val addedMethod: PsiMethod = rootElement.add(newTestMethod) as PsiMethod

// Format the newly inserted code
addedMethod.navigate(true)

// Refresh the project to make the changes visible
project.guessProjectDir()?.refresh(true, true)

return@runReadAction true
}
}

override fun insertClassCode(sourceFile: PsiFile, project: Project, code: String): Boolean {
val psiTestFile = PsiManager.getInstance(project).findFile(sourceFile.virtualFile) ?: return false

WriteCommandAction.runWriteCommandAction(project) {
// add code to test file by string
val document = psiTestFile.viewProvider.document!!
document.insertString(document.textLength, code)
}

return true
}

private fun createTestFile(sourceFile: PsiFile, testDir: VirtualFile, packageName: String): VirtualFile {
// Create the test file content based on the source file
val sourceFileName = sourceFile.name
val testFileName = sourceFileName.replace(".java", "Test.java")
val testFileContent = """package $packageName;
|$AUTO_DEV_PLACEHOLDER""".trimMargin()
val testFileContent = "package $packageName;\n\n"

// Create the test file in the test directory
val testFile = testDir.createChildData(this, testFileName)
testFile.setBinaryContent(testFileContent.toByteArray())

Expand Down
4 changes: 3 additions & 1 deletion src/main/kotlin/cc/unitmesh/devti/context/ClassContext.kt
Expand Up @@ -36,7 +36,9 @@ class ClassContext(
override fun toQuery(): String {
val className = name ?: "_"
val classFields = getFieldNames().joinToString(separator = " ")
val classMethods = getMethodSignatures().joinToString(separator = "\n")
val classMethods = getMethodSignatures()
.filter { it.isNotBlank() }
.joinToString(separator = "\n")
return "class name: $className\nclass fields: $classFields\nclass methods: $classMethods\nsuper classes: $superClasses\n"
}

Expand Down
49 changes: 36 additions & 13 deletions src/main/kotlin/cc/unitmesh/devti/intentions/WriteTestIntention.kt
Expand Up @@ -3,19 +3,25 @@ package cc.unitmesh.devti.intentions
import cc.unitmesh.devti.AutoDevBundle
import cc.unitmesh.devti.editor.LLMCoroutineScopeService
import cc.unitmesh.devti.gui.chat.ChatActionType
import cc.unitmesh.devti.intentions.editor.sendToChat
import cc.unitmesh.devti.provider.ContextPrompter
import cc.unitmesh.devti.llms.ConnectorFactory
import cc.unitmesh.devti.parser.parseCodeFromString
import cc.unitmesh.devti.provider.TestContextProvider
import cc.unitmesh.devti.provider.context.ChatContextProvider
import cc.unitmesh.devti.provider.context.ChatCreationContext
import cc.unitmesh.devti.provider.context.ChatOrigin
import com.intellij.openapi.application.WriteAction
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.fileEditor.FileEditorManager
import com.intellij.openapi.project.Project
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.psi.PsiFile
import com.intellij.psi.PsiManager
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.jetbrains.kotlin.idea.core.util.toPsiFile

class WriteTestIntention : AbstractChatIntention() {
override fun getText(): String = AutoDevBundle.message("intentions.chat.code.test.name")
Expand All @@ -35,26 +41,27 @@ class WriteTestIntention : AbstractChatIntention() {

LLMCoroutineScopeService.scope(project).launch {
WriteAction.runAndWait<Throwable> {
val testContext = TestContextProvider.context(lang)?.findOrCreateTestFile(file, project, element)
val testContextProvider = TestContextProvider.context(lang)
val testContext = testContextProvider?.findOrCreateTestFile(file, project, element)
if (testContext == null) {
logger<WriteTestIntention>().error("Failed to create test file for: $file")
return@runAndWait
}

runBlocking {
var prompter = "Write test for ${testContext.file.name}, ${testContext.file.path}."
var prompter = "Write unit test for following code. You MUST return code only, not explain.\n\n"

val creationContext = ChatCreationContext(ChatOrigin.Intention, actionType, file)
val chatContextItems = ChatContextProvider.collectChatContextList(project, creationContext)
chatContextItems.forEach {
prompter += it.text
}

val additionContext = testContext.relatedClass.map {
val additionContext = testContext.relatedClass.joinToString("\n") {
it.toQuery()
}.joinToString("\n").lines().map {
}.lines().joinToString("\n") {
"// $it"
}.joinToString("\n")
}

prompter += additionContext

Expand All @@ -63,18 +70,34 @@ class WriteTestIntention : AbstractChatIntention() {
|```
|"""

// navigate to the test file
navigateTestFile(testContext.file, editor, project)

sendToChat(project, actionType, object : ContextPrompter() {
override fun displayPrompt(): String {
return prompter
val flow: Flow<String> = ConnectorFactory.getInstance().connector(project).stream(prompter, "")
logger<WriteTestIntention>().warn("Prompt: $prompter")
LLMCoroutineScopeService.scope(project).launch {
val suggestion = StringBuilder()
flow.collect {
suggestion.append(it)
}

override fun requestPrompt(): String {
return prompter
parseCodeFromString(suggestion.toString()).forEach {
val testFile: PsiFile = PsiManager.getInstance(project).findFile(testContext.file)!!
testContextProvider.insertTestMethod(testFile, project, it)
}
})
}
}
}
}
}

private fun navigateTestFile(testFile: VirtualFile, editor: Editor, project: Project) {
val fileEditorManager = FileEditorManager.getInstance(project)
val editors = fileEditorManager.openFile(testFile, true)

// If the file is already open in the editor, focus on the editor tab
if (editors.isNotEmpty()) {
fileEditorManager.setSelectedEditor(testFile, "text-editor")
}
}
}

0 comments on commit d48e704

Please sign in to comment.