Skip to content

Commit

Permalink
feat: update for test code
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Jul 27, 2023
1 parent 5f38256 commit fdd3943
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 23 deletions.
Expand Up @@ -28,6 +28,7 @@ class JavaTestContextProvider : TestContextProvider() {
val sourceDir = sourceFilePath.parent

val packageName = (sourceFile as PsiJavaFile).packageName
var isNewFile = false

// Check if the source file is in the src/main/java directory
if (!sourceDir?.path?.contains("/src/main/java/")!!) {
Expand All @@ -40,6 +41,7 @@ class JavaTestContextProvider : TestContextProvider() {
val testDir = LocalFileSystem.getInstance().findFileByPath(testDirPath)

if (testDir == null || !testDir.isDirectory) {
isNewFile = true
// Create the test directory if it doesn't exist
val testDirFile = File(testDirPath)
if (!testDirFile.exists()) {
Expand All @@ -55,7 +57,6 @@ class JavaTestContextProvider : TestContextProvider() {
return null
}


// Test directory already exists, find the corresponding test file
val testFilePath = testDirPath + "/" + sourceFile.name.replace(".java", "Test.java")
val testFile = LocalFileSystem.getInstance().findFileByPath(testFilePath)
Expand All @@ -66,10 +67,10 @@ class JavaTestContextProvider : TestContextProvider() {
val relatedModels = lookupRelevantClass(project, element)

return if (testFile != null) {
TestFileContext(false, testFile, relatedModels)
TestFileContext(isNewFile, testFile, relatedModels)
} else {
val targetFile = createTestFile(sourceFile, testDir!!, packageName)
TestFileContext(true, targetFile, relatedModels)
TestFileContext(isNewFile = true, targetFile, relatedModels)
}
}

Expand Down Expand Up @@ -127,7 +128,7 @@ class JavaTestContextProvider : TestContextProvider() {
return resolvedClasses
}

override fun insertTestMethod(sourceFile: PsiFile, project: Project, code: String): Boolean {
override fun insertTestCode(sourceFile: PsiFile, project: Project, code: String): Boolean {
// Check if the provided methodCode contains @Test annotation
log.info("methodCode: $code")
if (!code.contains("@Test")) {
Expand Down Expand Up @@ -166,13 +167,12 @@ class JavaTestContextProvider : TestContextProvider() {
}

override fun insertClassCode(sourceFile: PsiFile, project: Project, code: String): Boolean {
log.info("start insertClassCode: $code")
val psiTestFile = PsiManager.getInstance(project).findFile(sourceFile.virtualFile) ?: return false

ApplicationManager.getApplication().invokeLater {
WriteCommandAction.runWriteCommandAction(project) {
val document = psiTestFile.viewProvider.document!!
document.insertString(document.textLength, code)
}
WriteCommandAction.runWriteCommandAction(project) {
val document = psiTestFile.viewProvider.document!!
document.insertString(document.textLength, code)
}

return true
Expand Down
36 changes: 23 additions & 13 deletions src/main/kotlin/cc/unitmesh/devti/intentions/WriteTestIntention.kt
Expand Up @@ -6,6 +6,7 @@ import cc.unitmesh.devti.gui.chat.ChatActionType
import cc.unitmesh.devti.llms.ConnectorFactory
import cc.unitmesh.devti.parser.parseCodeFromString
import cc.unitmesh.devti.provider.TestContextProvider
import cc.unitmesh.devti.provider.TestFileContext
import cc.unitmesh.devti.provider.context.ChatContextProvider
import cc.unitmesh.devti.provider.context.ChatCreationContext
import cc.unitmesh.devti.provider.context.ChatOrigin
Expand Down Expand Up @@ -93,19 +94,28 @@ class WriteTestIntention : AbstractChatIntention() {

val flow: Flow<String> = ConnectorFactory.getInstance().connector(project).stream(prompter, "")
logger<WriteTestIntention>().warn("Prompt: $prompter")
LLMCoroutineScopeService.scope(project).launch {
val suggestion = StringBuilder()
flow.collect {
suggestion.append(it)
}

runReadAction {
parseCodeFromString(suggestion.toString()).forEach {
val testFile: PsiFile = PsiManager.getInstance(project).findFile(testContext.file)!!
testContextProvider.insertTestMethod(testFile, project, it)
}
}
}
writeTestToFile(project, flow, testContext, testContextProvider)
}
}
}
}

private fun writeTestToFile(
project: Project,
flow: Flow<String>,
context: TestFileContext,
contextProvider: TestContextProvider
) {
LLMCoroutineScopeService.scope(project).launch {
val suggestion = StringBuilder()
flow.collect {
suggestion.append(it)
}

runReadAction {
parseCodeFromString(suggestion.toString()).forEach {
val testFile: PsiFile = PsiManager.getInstance(project).findFile(context.file)!!
contextProvider.insertTestCode(testFile, project, it)
}
}
}
Expand Down
Expand Up @@ -32,7 +32,7 @@ abstract class TestContextProvider : LazyExtensionInstance<TestContextProvider>(

abstract fun lookupRelevantClass(project: Project, element: PsiElement): List<ClassContext>

abstract fun insertTestMethod(sourceFile: PsiFile, project: Project, methodCode: String): Boolean
abstract fun insertTestCode(sourceFile: PsiFile, project: Project, methodCode: String): Boolean
abstract fun insertClassCode(sourceFile: PsiFile, project: Project, code: String): Boolean

companion object {
Expand Down

0 comments on commit fdd3943

Please sign in to comment.