Skip to content

Commit

Permalink
feat(doc): add basdic prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Aug 9, 2023
1 parent f85a28d commit 4fddd6d
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 44 deletions.
Expand Up @@ -93,25 +93,12 @@ open class JavaContextPrompter : ContextPrompter() {

when (action!!) {
ChatActionType.EXPLAIN -> {
val autoComment = customPromptConfig?.autoComment
if (autoComment?.instruction?.isNotEmpty() == true) {
prompt = autoComment.instruction
}
}

ChatActionType.REFACTOR -> {
val refactor = customPromptConfig?.refactor
if (refactor?.instruction?.isNotEmpty() == true) {
prompt = refactor.instruction
}
}

ChatActionType.CODE_COMPLETE -> {
val codeComplete = customPromptConfig?.autoComplete
if (codeComplete?.instruction?.isNotEmpty() == true) {
prompt = codeComplete.instruction
}

when {
MvcUtil.isController(fileName, lang) -> {
val spec = CustomPromptConfig.load().spec["controller"]
Expand All @@ -135,12 +122,7 @@ open class JavaContextPrompter : ContextPrompter() {
}
}

ChatActionType.GENERATE_TEST -> {
val writeTest = customPromptConfig?.writeTest
if (writeTest?.instruction?.isNotEmpty() == true) {
prompt = writeTest.instruction
}
}
ChatActionType.GENERATE_TEST -> {}

ChatActionType.FIX_ISSUE -> {
addFixIssueContext(selectedText)
Expand All @@ -158,6 +140,7 @@ open class JavaContextPrompter : ContextPrompter() {

ChatActionType.CUSTOM_COMPLETE -> {
}

ChatActionType.CUSTOM_ACTION -> TODO()
}

Expand Down
Expand Up @@ -2,12 +2,11 @@ package cc.unitmesh.idea.provider

import cc.unitmesh.devti.provider.LivingDocumentation
import cc.unitmesh.devti.provider.LivingDocumentationType
import com.intellij.codeInsight.daemon.impl.CollectHighlightsUtil
import com.intellij.openapi.editor.SelectionModel
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiMethod
import com.intellij.psi.PsiNameIdentifierOwner
import com.intellij.psi.*
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.util.IncorrectOperationException

class JavaLivingDocumentation : LivingDocumentation {
override fun startEndString(type: LivingDocumentationType): Pair<String, String> {
Expand All @@ -19,9 +18,23 @@ class JavaLivingDocumentation : LivingDocumentation {
}

override fun updateDoc(psiElement: PsiElement, str: String) {
TODO("Not yet implemented")
val project = psiElement.project
val psiElementFactory = JavaPsiFacade.getElementFactory(project)
val newDocComment = psiElementFactory.createDocCommentFromText(str)

if (psiElement is PsiDocCommentOwner) {
try {
psiElement.docComment?.replace(newDocComment)
} catch (e: IncorrectOperationException) {
val firstChild = psiElement.firstChild
psiElement.addBefore(newDocComment, firstChild)
}
} else {
throw IncorrectOperationException("Unable to update documentation")
}
}


override fun findExampleDoc(psiNameIdentifierOwner: PsiNameIdentifierOwner): String {
return ""
}
Expand All @@ -38,11 +51,51 @@ class JavaLivingDocumentation : LivingDocumentation {

}

fun containsElement(selectionModel: SelectionModel, element: PsiElement): Boolean {
return selectionModel.selectionStart <= element.textRange.startOffset && element.textRange.endOffset <= selectionModel.selectionEnd
}

override fun findDocTargetsInSelection(
psiElement: PsiElement,
selectionModel: SelectionModel,
): List<PsiNameIdentifierOwner> {
TODO("Not yet implemented")
val findCommonParent = CollectHighlightsUtil.findCommonParent(
psiElement,
selectionModel.selectionStart,
selectionModel.selectionEnd
) ?: return emptyList()

if (findCommonParent is PsiJavaFile) {
val classAndFieldMethods = mutableListOf<PsiNameIdentifierOwner>()
val classes = findCommonParent.classes
for (psiClass in classes) {
if (containsElement(selectionModel, psiClass)) {
classAndFieldMethods.add(psiClass)
}
}

return classAndFieldMethods
}

val target = findNearestDocumentationTarget(findCommonParent) ?: return emptyList()

if (target !is PsiClass || containsElement(selectionModel, target)) {
return listOf(target)
}

val methodsAndFieldsInRange = mutableListOf<PsiNameIdentifierOwner>()
for (psiField in target.fields) {
if (containsElement(selectionModel, psiField)) {
methodsAndFieldsInRange.add(psiField)
}
}
for (psiMethod in target.methods) {
if (containsElement(selectionModel, psiMethod)) {
methodsAndFieldsInRange.add(psiMethod)
}
}

return methodsAndFieldsInRange
}

}
@@ -1,12 +1,13 @@
package cc.unitmesh.devti.context

import cc.unitmesh.devti.context.base.LLMQueryContextProvider
import cc.unitmesh.devti.context.builder.ClassContextBuilder
import com.intellij.lang.Language
import com.intellij.lang.LanguageExtension
import com.intellij.openapi.diagnostic.logger
import com.intellij.psi.PsiElement

class ClassContextProvider(private val gatherUsages: Boolean) {
class ClassContextProvider(private val gatherUsages: Boolean) : LLMQueryContextProvider<PsiElement> {
private val languageExtension = LanguageExtension<ClassContextBuilder>("cc.unitmesh.classContextBuilder")
private val providers: List<ClassContextBuilder>

Expand All @@ -19,7 +20,7 @@ class ClassContextProvider(private val gatherUsages: Boolean) {
val logger = logger<ClassContextProvider>()
}

fun from(psiElement: PsiElement): ClassContext {
override fun from(psiElement: PsiElement): ClassContext {
for (provider in providers) {
provider.getClassContext(psiElement, gatherUsages)?.let {
return it
Expand Down
@@ -1,11 +1,12 @@
package cc.unitmesh.devti.context

import cc.unitmesh.devti.context.base.LLMQueryContextProvider
import cc.unitmesh.devti.context.builder.FileContextBuilder
import com.intellij.lang.Language
import com.intellij.lang.LanguageExtension
import com.intellij.psi.PsiFile

class FileContextProvider {
class FileContextProvider: LLMQueryContextProvider<PsiFile> {
private val languageExtension: LanguageExtension<FileContextBuilder> =
LanguageExtension("cc.unitmesh.fileContextBuilder")

Expand All @@ -16,14 +17,14 @@ class FileContextProvider {
providers = registeredLanguages.mapNotNull { languageExtension.forLanguage(it) }
}

fun from(psiFile: PsiFile): FileContext {
override fun from(psiElement: PsiFile): FileContext {
for (provider in providers) {
val fileContext = provider.getFileContext(psiFile)
val fileContext = provider.getFileContext(psiElement)
if (fileContext != null) {
return fileContext
}
}

return FileContext(psiFile, psiFile.name, psiFile.virtualFile?.path!!)
return FileContext(psiElement, psiElement.name, psiElement.virtualFile?.path!!)
}
}
@@ -1,12 +1,14 @@
package cc.unitmesh.devti.context

import cc.unitmesh.devti.context.base.LLMQueryContextProvider
import cc.unitmesh.devti.context.builder.MethodContextBuilder
import com.intellij.lang.Language
import com.intellij.lang.LanguageExtension
import com.intellij.psi.PsiElement
import org.jetbrains.annotations.NotNull

class MethodContextProvider(private val includeClassContext: Boolean, private val gatherUsages: Boolean) {
class MethodContextProvider(private val includeClassContext: Boolean, private val gatherUsages: Boolean):
LLMQueryContextProvider<PsiElement> {
@NotNull
private val languageExtension: LanguageExtension<MethodContextBuilder> =
LanguageExtension("cc.unitmesh.methodContextBuilder")
Expand All @@ -20,7 +22,7 @@ class MethodContextProvider(private val includeClassContext: Boolean, private va
}

@NotNull
fun from(@NotNull psiElement: PsiElement): MethodContext {
override fun from(@NotNull psiElement: PsiElement): MethodContext {
val iterator = providers.iterator()
while (iterator.hasNext()) {
val provider = iterator.next()
Expand Down
@@ -1,5 +1,6 @@
package cc.unitmesh.devti.context

import cc.unitmesh.devti.context.base.LLMQueryContextProvider
import cc.unitmesh.devti.context.builder.VariableContextBuilder
import com.intellij.lang.Language
import com.intellij.lang.LanguageExtension
Expand All @@ -9,7 +10,7 @@ class VariableContextProvider(
private val includeMethodContext: Boolean,
private val includeClassContext: Boolean,
private val gatherUsages: Boolean
) {
): LLMQueryContextProvider<PsiElement> {
private val languageExtension: LanguageExtension<VariableContextBuilder> =
LanguageExtension("cc.unitmesh.variableContextBuilder")

Expand All @@ -20,7 +21,7 @@ class VariableContextProvider(
providers = registeredLanguages.mapNotNull(languageExtension::forLanguage)
}

fun from(psiElement: PsiElement): VariableContext {
override fun from(psiElement: PsiElement): VariableContext {
for (provider in providers) {
val variableContext =
provider.getVariableContext(psiElement, includeMethodContext, includeClassContext, gatherUsages)
Expand Down
@@ -0,0 +1,8 @@
package cc.unitmesh.devti.context.base;

import com.intellij.psi.PsiElement

interface LLMQueryContextProvider<T : PsiElement?> {
fun from(psiElement: T): LLMQueryContext
}

Expand Up @@ -10,6 +10,7 @@ import com.intellij.openapi.progress.ProgressManager
import com.intellij.openapi.progress.Task
import com.intellij.openapi.progress.impl.BackgroundableProcessIndicator
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiFile
import com.intellij.psi.PsiManager
import com.intellij.psi.PsiNameIdentifierOwner
Expand Down Expand Up @@ -39,13 +40,30 @@ class LivingDocumentationIntention : AbstractChatIntention() {
if (selectedText != null) {
val owners: List<PsiNameIdentifierOwner> = findSelectedElementToDocument(editor, project, selectionModel)
for (identifierOwner in owners) {
val task: Task.Backgroundable = LivingDocumentationTask(editor, identifierOwner)

ProgressManager.getInstance()
.runProcessWithProgressAsynchronously(task, BackgroundableProcessIndicator(task))
writeForDocument(editor, identifierOwner)
}
}

val closestToCaretNamedElement: PsiNameIdentifierOwner? = getClosestToCaretNamedElement(editor)
if (closestToCaretNamedElement != null) {
writeForDocument(editor, closestToCaretNamedElement)
}
}

private fun writeForDocument(editor: Editor, element: PsiNameIdentifierOwner) {
val task: Task.Backgroundable = LivingDocumentationTask(editor, element)
ProgressManager.getInstance()
.runProcessWithProgressAsynchronously(task, BackgroundableProcessIndicator(task))
}

private fun getClosestToCaretNamedElement(editor: Editor): PsiNameIdentifierOwner? {
val element = PsiUtilBase.getElementAtCaret(editor) ?: return null
return getClosestNamedElement(element)
}

private fun getClosestNamedElement(element: PsiElement): PsiNameIdentifierOwner? {
val support = LivingDocumentation.forLanguage(element.language) ?: return null
return support.findNearestDocumentationTarget(element)
}

private fun findSelectedElementToDocument(
Expand Down

0 comments on commit 4fddd6d

Please sign in to comment.