From 3831ffbd29dd3eea21ec6cee48f0ba9c6e17cc76 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Thu, 27 Jul 2023 08:19:29 +0800 Subject: [PATCH] refactor(test): extract relevant class --- .../unitmesh/idea/provider/JavaTestContextProvider.kt | 11 ++++++----- .../unitmesh/devti/intentions/WriteTestIntention.kt | 2 +- .../cc/unitmesh/devti/provider/TestContextProvider.kt | 4 +++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/java/src/main/kotlin/cc/unitmesh/idea/provider/JavaTestContextProvider.kt b/java/src/main/kotlin/cc/unitmesh/idea/provider/JavaTestContextProvider.kt index 6aff58803b..533bb9d45b 100644 --- a/java/src/main/kotlin/cc/unitmesh/idea/provider/JavaTestContextProvider.kt +++ b/java/src/main/kotlin/cc/unitmesh/idea/provider/JavaTestContextProvider.kt @@ -5,7 +5,6 @@ import cc.unitmesh.devti.context.ClassContextProvider import cc.unitmesh.devti.provider.TestContextProvider import cc.unitmesh.devti.provider.TestFileContext import com.intellij.openapi.diagnostic.logger -import com.intellij.openapi.externalSystem.service.project.ProjectDataManager import com.intellij.openapi.project.Project import com.intellij.openapi.project.guessProjectDir import com.intellij.openapi.vfs.LocalFileSystem @@ -14,13 +13,13 @@ import com.intellij.psi.* class JavaTestContextProvider : TestContextProvider() { - override fun prepareTestFile(sourceFile: PsiFile, project: Project, element: PsiElement): TestFileContext? { + override fun findOrCreateTestFile(sourceFile: PsiFile, project: Project, element: PsiElement): TestFileContext? { val sourceFilePath = sourceFile.virtualFile val sourceDir = sourceFilePath.parent val packageName = (sourceFile as PsiJavaFile).packageName - val relatedModels = prepareModels(sourceFile, project, element) + val relatedModels = lookupRelevantClass(project, element) // Check if the source file is in the src/main/java directory if (!sourceDir?.path?.contains("/src/main/java/")!!) { @@ -57,13 +56,15 @@ class JavaTestContextProvider : TestContextProvider() { } } - private fun prepareModels(sourceFile: PsiJavaFile, project: Project, element: PsiElement): List { + override fun lookupRelevantClass(project: Project, element: PsiElement): List { val models = mutableListOf() + val projectBash = project.guessProjectDir()?.path + if (element is PsiMethod) { val inputTypes = element.parameterList.parameters.map { it.type is PsiClassType }.filterIsInstance().filter { - it.resolve()?.containingFile?.virtualFile?.path?.contains(project.guessProjectDir()?.path!!)!! + it.resolve()?.containingFile?.virtualFile?.path?.contains(projectBash!!)!! }.map { it.resolve()!! } // find input class from inputTypes diff --git a/src/main/kotlin/cc/unitmesh/devti/intentions/WriteTestIntention.kt b/src/main/kotlin/cc/unitmesh/devti/intentions/WriteTestIntention.kt index 08e9e6492c..d0e3a2a6d5 100644 --- a/src/main/kotlin/cc/unitmesh/devti/intentions/WriteTestIntention.kt +++ b/src/main/kotlin/cc/unitmesh/devti/intentions/WriteTestIntention.kt @@ -35,7 +35,7 @@ class WriteTestIntention : AbstractChatIntention() { LLMCoroutineScopeService.scope(project).launch { WriteAction.runAndWait { - val testContext = TestContextProvider.context(lang)?.prepareTestFile(file, project, element) + val testContext = TestContextProvider.context(lang)?.findOrCreateTestFile(file, project, element) if (testContext == null) { logger().error("Failed to create test file for: $file") return@runAndWait diff --git a/src/main/kotlin/cc/unitmesh/devti/provider/TestContextProvider.kt b/src/main/kotlin/cc/unitmesh/devti/provider/TestContextProvider.kt index d6ce68be65..3c1c93e972 100644 --- a/src/main/kotlin/cc/unitmesh/devti/provider/TestContextProvider.kt +++ b/src/main/kotlin/cc/unitmesh/devti/provider/TestContextProvider.kt @@ -28,7 +28,9 @@ abstract class TestContextProvider : LazyExtensionInstance( return implementationClass } - abstract fun prepareTestFile(sourceFile: PsiFile, project: Project, element: PsiElement): TestFileContext? + abstract fun findOrCreateTestFile(sourceFile: PsiFile, project: Project, element: PsiElement): TestFileContext? + + abstract fun lookupRelevantClass(project: Project, element: PsiElement): List abstract fun insertTestMethod(sourceFile: PsiFile, project: Project, methodName: String, code: String): Boolean