From 9f04d13856a94f85c11bf35dc6a7ee6b890b4648 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Fri, 14 Jul 2023 13:49:31 +0800 Subject: [PATCH] fix: fix interface error issues --- .../cc/unitmesh/devti/flow/AutoDevFlow.kt | 4 ++++ .../unitmesh/devti/flow/JavaSpringBaseCrud.kt | 11 +++++++--- .../cc/unitmesh/devti/flow/SpringBaseCrud.kt | 20 +++++++++++++++++++ .../openai/create_service_and_repository.txt | 3 ++- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/main/kotlin/cc/unitmesh/devti/flow/AutoDevFlow.kt b/src/main/kotlin/cc/unitmesh/devti/flow/AutoDevFlow.kt index 3501ce06b8..cfc3eeb12b 100644 --- a/src/main/kotlin/cc/unitmesh/devti/flow/AutoDevFlow.kt +++ b/src/main/kotlin/cc/unitmesh/devti/flow/AutoDevFlow.kt @@ -178,6 +178,10 @@ class AutoDevFlow( processor.createDto(code) } + processor.isRepository(code) -> { + processor.createRepository(code) + } + else -> { processor.createClass(code, null) } diff --git a/src/main/kotlin/cc/unitmesh/devti/flow/JavaSpringBaseCrud.kt b/src/main/kotlin/cc/unitmesh/devti/flow/JavaSpringBaseCrud.kt index c5eb10bb9a..53c89d427c 100644 --- a/src/main/kotlin/cc/unitmesh/devti/flow/JavaSpringBaseCrud.kt +++ b/src/main/kotlin/cc/unitmesh/devti/flow/JavaSpringBaseCrud.kt @@ -39,6 +39,7 @@ class JavaSpringBaseCrud(val project: Project) : SpringBaseCrud { override fun getAllEntityFiles(): List = filterFilesByFunc(::entityFilter) override fun getAllDtoFiles(): List = filterFilesByFunc(::dtoFilter) override fun getAllServiceFiles(): List = filterFilesByFunc(::serviceFilter) + override fun getAllRepositoryFiles(): List = filterFilesByFunc(::repositoryFilter) private fun filterFilesByFunc(filter: KFunction1): List { return runReadAction { @@ -141,6 +142,10 @@ class JavaSpringBaseCrud(val project: Project) : SpringBaseCrud { return createClassByCode(code, getAllDtoFiles()) } + override fun createRepository(code: String): DtClass? { + return createClassByCode(code, getAllRepositoryFiles()) + } + override fun createService(code: String): DtClass? { return createClassByCode(code, getAllServiceFiles()) } @@ -171,11 +176,11 @@ class JavaSpringBaseCrud(val project: Project) : SpringBaseCrud { runWriteAction { val newClass = psiElementFactory.createClassFromText(code, null) - val regex = Regex("public\\s+class\\s+(\\w+)") + val regex = Regex("public\\s+(class|interface)\\s+(\\w+)") val matchResult = regex.find(code) - val className = if (matchResult?.groupValues?.get(1) != null) { - matchResult.groupValues[1] + val className = if (matchResult?.groupValues?.get(2) != null) { + matchResult.groupValues[2] } else if (newClass.identifyingElement?.text != null) { newClass.identifyingElement?.text } else { diff --git a/src/main/kotlin/cc/unitmesh/devti/flow/SpringBaseCrud.kt b/src/main/kotlin/cc/unitmesh/devti/flow/SpringBaseCrud.kt index be2c1c5092..b5fa4eb77e 100644 --- a/src/main/kotlin/cc/unitmesh/devti/flow/SpringBaseCrud.kt +++ b/src/main/kotlin/cc/unitmesh/devti/flow/SpringBaseCrud.kt @@ -17,6 +17,7 @@ interface SpringBaseCrud { fun getAllEntityFiles(): List fun getAllDtoFiles(): List fun getAllServiceFiles(): List + fun getAllRepositoryFiles(): List fun createControllerOrUpdateMethod(targetController: String, code: String, isControllerExist: Boolean) @@ -24,6 +25,7 @@ interface SpringBaseCrud { fun createEntity(code: String): DtClass? fun createService(code: String): DtClass? fun createDto(code: String): DtClass? + fun createRepository(code: String): DtClass? fun createClass(code: String, packageName: String?): DtClass? fun dtoFilter(clazz: PsiClass): Boolean { @@ -46,6 +48,11 @@ interface SpringBaseCrud { it == "org.springframework.stereotype.Service" } + fun repositoryFilter(clazz: PsiClass): Boolean = clazz.annotations + .map { it.qualifiedName }.any { + it == "org.springframework.stereotype.Repository" + } + fun entityFilter(clazz: PsiClass): Boolean = clazz.annotations .map { it.qualifiedName }.any { it == "javax.persistence.Entity" @@ -103,4 +110,17 @@ interface SpringBaseCrud { return regex.containsMatchIn(code) } + fun isRepository(code: String): Boolean { + if (code.contains("@Repository")) { + return true + } + + if (code.contains("import org.springframework.stereotype.Repository")) { + return true + } + + // regex to match `public class xxRepository` + val regex = Regex("public\\s+class\\s+\\w+Repository") + return regex.containsMatchIn(code) + } } diff --git a/src/main/resources/prompts/openai/create_service_and_repository.txt b/src/main/resources/prompts/openai/create_service_and_repository.txt index 1a24623c23..097d32767c 100644 --- a/src/main/resources/prompts/openai/create_service_and_repository.txt +++ b/src/main/resources/prompts/openai/create_service_and_repository.txt @@ -22,7 +22,8 @@ public class {serviceName} { import org.springframework.data.repository.CrudRepository; import org.springframework.stereotype.Repository; -public class {xxxRepository} extends JpaRepository<{xxx}, Long> { +@Repository +public interface {xxxRepository} extends JpaRepository<{xxx}, Long> { } ```