Skip to content

Commit

Permalink
feat(provider): add isSpringRelated method to KotlinTestContextProvider
Browse files Browse the repository at this point in the history
This commit adds the `isSpringRelated` method to the `KotlinTestContextProvider` class. This method checks if a given element (either a `KtNamedFunction` or a `KtClassOrObject`) has any annotations that start with "org.springframework". If it does, the method returns true; otherwise, it returns false. This method is used to determine if a given element is related to Spring in the context of generating test code.
  • Loading branch information
phodal committed Jan 10, 2024
1 parent c252e13 commit 06173db
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 12 deletions.
Expand Up @@ -20,10 +20,10 @@ open class JavaTestContextProvider : ChatContextProvider {
open fun langFileSuffix() = "java"

override suspend fun collect(project: Project, creationContext: ChatCreationContext): List<ChatContextItem> {
val items = mutableListOf<ChatContextItem>()

val fileName = creationContext.sourceFile?.name

//

val isController = fileName?.let { MvcUtil.isController(it, langFileSuffix()) } ?: false
val isService = fileName?.let { MvcUtil.isService(it, langFileSuffix()) } ?: false

Expand All @@ -39,7 +39,7 @@ open class JavaTestContextProvider : ChatContextProvider {

baseTestPrompt += junitRule(project)

items += when {
val finalPrompt = when {
isController && isSpringRelated -> {
val testControllerPrompt = baseTestPrompt + """
|- Use appropriate Spring test annotations such as `@MockBean`, `@Autowired`, `@WebMvcTest`, `@DataJpaTest`, `@AutoConfigureTestDatabase`, `@AutoConfigureMockMvc`, `@SpringBootTest` etc.
Expand All @@ -61,7 +61,7 @@ open class JavaTestContextProvider : ChatContextProvider {
}
}

return items
return listOf(finalPrompt)
}

private fun junitRule(project: Project): String {
Expand All @@ -77,10 +77,10 @@ open class JavaTestContextProvider : ChatContextProvider {
return ""
}

private fun isSpringRelated(method: PsiElement): Boolean {
when (method) {
open fun isSpringRelated(element: PsiElement): Boolean {
when (element) {
is PsiMethod -> {
val annotations = method.annotations
val annotations = element.annotations
for (annotation in annotations) {
val fqn = annotation.qualifiedName
if (fqn != null && fqn.startsWith("org.springframework")) {
Expand All @@ -90,7 +90,7 @@ open class JavaTestContextProvider : ChatContextProvider {
}

is PsiClass -> {
val annotations = method.annotations
val annotations = element.annotations
for (annotation in annotations) {
val fqn = annotation.qualifiedName
if (fqn != null && fqn.startsWith("org.springframework")) {
Expand Down
Expand Up @@ -4,5 +4,4 @@ import cc.unitmesh.idea.prompting.JavaContextPrompter

class KotlinContextPrompter: JavaContextPrompter() {
override val testDataBuilder = KotlinTestDataBuilder()

}
Expand Up @@ -4,12 +4,42 @@ import cc.unitmesh.devti.gui.chat.ChatActionType
import cc.unitmesh.devti.provider.context.ChatCreationContext
import cc.unitmesh.idea.provider.JavaTestContextProvider
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElement
import org.jetbrains.kotlin.idea.KotlinLanguage
import org.jetbrains.kotlin.psi.KtClass
import org.jetbrains.kotlin.psi.KtClassOrObject
import org.jetbrains.kotlin.psi.KtNamedFunction

class KotlinTestContextProvider : JavaTestContextProvider() {
override fun langFileSuffix(): String = "kt"

override fun isApplicable(project: Project, creationContext: ChatCreationContext): Boolean {
return creationContext.action == ChatActionType.GENERATE_TEST && creationContext.sourceFile?.language is KotlinLanguage
}

override fun isSpringRelated(element: PsiElement): Boolean {
when (element) {
is KtNamedFunction -> {
val annotations = element.annotationEntries
for (annotation in annotations) {
val fqn = annotation.name
if (fqn != null && fqn.startsWith("org.springframework")) {
return true
}
}
}

is KtClassOrObject -> {
val annotations = element.annotationEntries
for (annotation in annotations) {
val fqn = annotation.name
if (fqn != null && fqn.startsWith("org.springframework")) {
return true
}
}
}
}

return false
}
}
Expand Up @@ -2,6 +2,7 @@ package cc.unitmesh.kotlin.provider

import cc.unitmesh.devti.context.ClassContext
import cc.unitmesh.devti.context.ClassContextProvider
import cc.unitmesh.devti.isInProject
import cc.unitmesh.devti.provider.context.TestFileContext
import cc.unitmesh.devti.provider.WriteTestService
import cc.unitmesh.kotlin.context.KotlinClassContextBuilder
Expand Down Expand Up @@ -53,7 +54,7 @@ class KotlinWriteTestService : WriteTestService() {
val relatedModels = lookupRelevantClass(project, element)

if (!parentDirPath?.contains("/src/main/kotlin/")!!) {
log.error("Source file is not in the src/main/java directory: ${parentDirPath}")
log.error("Source file is not in the src/main/java directory: $parentDirPath")
return null
}

Expand Down
Expand Up @@ -86,9 +86,9 @@ class TestCodeGenTask(val request: TestCodeGenRequest) :
}


prompter += "Code:"
prompter += "\nCode:\n"
prompter += testContext.imports.joinToString("\n") {
"//$it"
"// $it"
}

prompter += "\n```${lang.lowercase()}\n${request.selectText}\n```"
Expand Down

0 comments on commit 06173db

Please sign in to comment.