Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce unnecessary resolve in type providers #11

Merged
merged 2 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
<idea-plugin url="https://github.com/strawberry-graphql/strawberry-pycharm-plugin" require-restart="true">
<id>rocks.strawberry</id>
<name>Strawberry GraphQL</name>
<version>0.0.10</version>
<version>0.0.11</version>
<vendor email="hello@strawberry.rocks">Strawberry GraphQL</vendor>
<change-notes><![CDATA[
<h2>version 0.0.11</h2>
<p>Bugfix</p>
<ul>
<li>Reduce unnecessary resolve in type providers</li>
</ul>
<h2>version 0.0.10</h2>
<p>Features</p>
<ul>
Expand Down
54 changes: 9 additions & 45 deletions src/rocks/strawberry/Strawberry.kt
Original file line number Diff line number Diff line change
@@ -1,51 +1,15 @@
package rocks.strawberry

import com.intellij.psi.PsiElement
import com.intellij.psi.ResolveResult
import com.intellij.psi.util.QualifiedName
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyDecoratable
import com.jetbrains.python.psi.PyReferenceExpression
import com.jetbrains.python.psi.PyUtil
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.resolve.PyResolveUtil
import com.jetbrains.python.psi.types.PyClassType
import com.jetbrains.python.psi.types.PyType
import com.jetbrains.python.psi.types.PyUnionType
import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.psi.*

val DATACLASS_QUALIFIED_NAME = QualifiedName.fromDottedString("strawberry.type")
const val DATACLASS_SHORT_NAME = "strawberry.type"
const val DATACLASS_LONG_NAME = "strawberry.object_type.type"

val DATACLASS_NAMES = listOf(
DATACLASS_SHORT_NAME,
DATACLASS_LONG_NAME
)

fun hasDecorator(pyDecoratable: PyDecoratable, refName: QualifiedName): Boolean {
return pyDecoratable.decoratorList?.decorators?.mapNotNull { it.callee as? PyReferenceExpression }?.any {
PyResolveUtil.resolveImportedElementQNameLocally(it).any { decoratorQualifiedName ->
decoratorQualifiedName == refName
}
} ?: false
}


fun isDataclass(pyClass: PyClass): Boolean {
return hasDecorator(pyClass, DATACLASS_QUALIFIED_NAME)
}

fun getResolveElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): Array<ResolveResult> {
return PyResolveContext.defaultContext(context).let {
referenceExpression.getReference(it).multiResolve(false)
}

}

fun getResolvedPsiElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): List<PsiElement> {
return getResolveElements(referenceExpression, context).let { PyUtil.filterTopPriorityResults(it) }
}


fun getPyClassTypeByPyTypes(pyType: PyType): List<PyClassType> {
return when (pyType) {
is PyUnionType -> pyType.members.mapNotNull { it }.flatMap { getPyClassTypeByPyTypes(it) }
is PyClassType -> listOf(pyType)
else -> listOf()
}
}
val DATACLASS_QUALIFIED_NAME = QualifiedName.fromDottedString(DATACLASS_SHORT_NAME)
internal val PyFunction.isDataclass: Boolean get() = qualifiedName in DATACLASS_NAMES
70 changes: 5 additions & 65 deletions src/rocks/strawberry/StrawberryDataclassTypeProvider.kt
Original file line number Diff line number Diff line change
@@ -1,74 +1,14 @@
package rocks.strawberry

import com.intellij.psi.PsiElement
import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyCallExpressionImpl
import com.jetbrains.python.psi.impl.PyCallExpressionNavigator
import com.jetbrains.python.psi.types.*

class StrawberryDataclassTypeProvider : PyTypeProviderBase() {
private val pyDataclassTypeProvider = PyDataclassTypeProvider()

override fun getReferenceExpressionType(
referenceExpression: PyReferenceExpression,
context: TypeEvalContext
): PyType? {
return getDataclass(referenceExpression, context)
}


private fun getDataclassCallableType(
referenceTarget: PsiElement,
context: TypeEvalContext,
callSite: PyCallExpression?
): PyCallableType? {
return pyDataclassTypeProvider.getReferenceType(
referenceTarget,
context,
callSite ?: PyCallExpressionImpl(referenceTarget.node)
)?.get() as? PyCallableType
}

private fun getDataclassType(
referenceTarget: PsiElement,
context: TypeEvalContext,
pyReferenceExpression: PyReferenceExpression,
definition: Boolean
): PyType? {
val callSite = PyCallExpressionNavigator.getPyCallExpressionByCallee(pyReferenceExpression)
val dataclassCallableType = getDataclassCallableType(referenceTarget, context, callSite) ?: return null
val dataclassType = (dataclassCallableType).getReturnType(context) as? PyClassType ?: return null
if (!isDataclass(dataclassType.pyClass)) return null

return when {
callSite is PyCallExpression && definition -> dataclassCallableType
definition -> dataclassType.toClass()
else -> dataclassType
override fun getCallableType(callable: PyCallable, context: TypeEvalContext): PyType? {
if (callable is PyFunction && callable.isDataclass) {
// Drop fake dataclass return type
return PyCallableTypeImpl(callable.getParameters(context), null)
}
}


private fun getDataclass(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? {
return getResolvedPsiElements(referenceExpression, context)
.asSequence()
.mapNotNull {
when {
it is PyClass && isDataclass(it) ->
getDataclassType(it, context, referenceExpression, true)
it is PyTargetExpression -> (it as? PyTypedElement)
?.let { pyTypedElement -> context.getType(pyTypedElement) }
?.let { pyType -> getPyClassTypeByPyTypes(pyType) }
?.filter { pyClassType -> isDataclass(pyClassType.pyClass) }
?.mapNotNull { pyClassType ->
getDataclassType(pyClassType.pyClass,
context,
referenceExpression,
pyClassType.isDefinition)
}
?.firstOrNull()
else -> null
}
}.firstOrNull()
return super.getCallableType(callable, context)
}
}