Skip to content

Commit

Permalink
feat(rust): add relevant classes to TestFileContext
Browse files Browse the repository at this point in the history
This commit adds the ability to include relevant classes in the TestFileContext of the RustTestService. The relevant classes are looked up based on the element passed to the lookupRelevantClass function. If the element is a RsFunction, the return type and input parameters are extracted and resolved to obtain the corresponding RustClassContext. These relevant classes are then included in the TestFileContext. Additionally, a new private function resolveReferenceTypes is added to resolve the reference types of RsTypeReference.
  • Loading branch information
phodal committed Jan 19, 2024
1 parent 8e92028 commit 3e0c992
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 12 deletions.
Expand Up @@ -4,6 +4,7 @@ import cc.unitmesh.devti.context.MethodContext
import cc.unitmesh.devti.context.builder.MethodContextBuilder
import com.intellij.openapi.application.runReadAction
import com.intellij.psi.PsiElement
import com.intellij.psi.util.PsiTreeUtil
import org.rust.ide.presentation.presentationInfo
import org.rust.lang.core.psi.*

Expand All @@ -20,15 +21,21 @@ class RustMethodContextBuilder : MethodContextBuilder {
val language = psiElement.language.displayName

val signature = psiElement.presentationInfo?.signatureText
val paramsName = psiElement.valueParameterList?.valueParameterList?.map {
it.text
} ?: emptyList()

val enclosingClass = PsiTreeUtil.getParentOfType(psiElement, RsImplItem::class.java)

return MethodContext(
psiElement,
text,
psiElement.name,
signature.toString(),
null,
enclosingClass,
language,
returnType,
emptyList(),
paramsName,
includeClassContext,
emptyList()
)
Expand Down
Expand Up @@ -3,6 +3,7 @@ package cc.unitmesh.rust.provider
import cc.unitmesh.devti.context.ClassContext
import cc.unitmesh.devti.provider.WriteTestService
import cc.unitmesh.devti.provider.context.TestFileContext
import cc.unitmesh.rust.context.RustClassContextBuilder
import cc.unitmesh.rust.context.RustMethodContextBuilder
import com.intellij.execution.configurations.RunProfile
import com.intellij.openapi.application.runReadAction
Expand All @@ -13,6 +14,7 @@ import com.intellij.psi.util.PsiTreeUtil
import org.rust.cargo.runconfig.command.CargoCommandConfiguration
import org.rust.lang.RsLanguage
import org.rust.lang.core.psi.RsFunction
import org.rust.lang.core.psi.RsTypeReference
import org.rust.lang.core.psi.RsUseItem

class RustTestService : WriteTestService() {
Expand All @@ -37,10 +39,12 @@ class RustTestService : WriteTestService() {
it.text
}

val relevantClasses = lookupRelevantClass(project, element)

return TestFileContext(
false,
sourceFile.virtualFile,
listOf(),
relevantClasses,
"",
RsLanguage,
currentObject,
Expand All @@ -49,7 +53,30 @@ class RustTestService : WriteTestService() {
}

override fun lookupRelevantClass(project: Project, element: PsiElement): List<ClassContext> {
when (element) {
is RsFunction -> {
val returnType = element.retType?.typeReference
val input = element.valueParameterList?.valueParameterList?.map {
it.typeReference
} ?: emptyList()

val refs = (listOf(returnType) + input).filterNotNull()
val types = resolveReferenceTypes(project, refs)

return types.mapNotNull {
RustClassContextBuilder().getClassContext(it, false)
}
}
}

return listOf()
}

private fun resolveReferenceTypes(project: Project, rsTypeReferences: List<RsTypeReference>): List<PsiElement> {
val mapNotNull = rsTypeReferences.mapNotNull {
it.reference?.resolve()
}

return mapNotNull
}
}
Expand Up @@ -22,7 +22,6 @@ import com.intellij.openapi.progress.Task
import com.intellij.openapi.project.DumbService
import com.intellij.openapi.project.Project
import com.intellij.openapi.vfs.VirtualFile
import kotlinx.coroutines.InternalCoroutinesApi
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.runBlocking

Expand Down Expand Up @@ -132,7 +131,6 @@ class TestCodeGenTask(val request: TestCodeGenRequest) :
}
}

@OptIn(InternalCoroutinesApi::class)
private suspend fun writeTestToFile(
project: Project,
flow: Flow<String>,
Expand All @@ -148,7 +146,8 @@ class TestCodeGenTask(val request: TestCodeGenRequest) :
val modifier = CodeModifierProvider().modifier(context.language)
?: throw IllegalStateException("Unsupported language: ${context.language}")

parseCodeFromString(suggestion.toString()).forEach {
val codeBlocks = parseCodeFromString(suggestion.toString())
codeBlocks.forEach {
modifier.insertTestCode(context.outputFile, project, it)
}
}
Expand Down
6 changes: 0 additions & 6 deletions src/main/kotlin/cc/unitmesh/devti/util/parser/Markdown.kt
Expand Up @@ -19,12 +19,6 @@ fun parseCodeFromString(markdown: String): List<String> {
node.accept(visitor)

if (visitor.code.isEmpty()) {
// TODO: we need to add multiple code blocks support
val isJavaMethod = markdown.contains("public ") || markdown.contains("private ") || markdown.contains("protected ")
if (isJavaMethod) {
return listOf(markdown)
}

return listOf(markdown)
}

Expand Down

0 comments on commit 3e0c992

Please sign in to comment.