Skip to content

Commit

Permalink
Fix/673/from queryset then custom qs method (#680)
Browse files Browse the repository at this point in the history
* Fix `MyModel.objects.filter(...).my_method()`

* Fix regression: `MyModel.objects.filter(...).my_method()` no longer worked when using from_queryset

This also fixes the self-type of the copied-over methods of the manager generated by from_queryset.
Previously it was not parameterized by the model class, but used Any.

The handling of unbound types is not tested here as I have not been able to
find a way to create a test case for it. It has been manually tested
against an internal codebase.

* Remove unneeded defer.
  • Loading branch information
syastrov committed Jul 29, 2021
1 parent 08a662e commit 8da8ab4
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 33 deletions.
45 changes: 33 additions & 12 deletions mypy_django_plugin/lib/helpers.py
Expand Up @@ -18,7 +18,6 @@
MemberExpr,
MypyFile,
NameExpr,
PlaceholderNode,
StrExpr,
SymbolNode,
SymbolTable,
Expand All @@ -33,12 +32,13 @@
DynamicClassDefContext,
FunctionContext,
MethodContext,
SemanticAnalyzerPluginInterface,
)
from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType
from mypy.types import Type as MypyType
from mypy.types import TypedDictType, TypeOfAny, UnionType
from mypy.types import TypedDictType, TypeOfAny, UnboundType, UnionType

from mypy_django_plugin.lib import fullnames
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME
Expand Down Expand Up @@ -355,8 +355,26 @@ def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument],
return prepared_arguments, return_type


def bind_or_analyze_type(t: MypyType, api: SemanticAnalyzer, module_name: Optional[str] = None) -> Optional[MypyType]:
"""Analyze a type. If an unbound type, try to look it up in the given module name.
That should hopefully give a bound type."""
if isinstance(t, UnboundType) and module_name is not None:
node = api.lookup_fully_qualified_or_none(module_name + "." + t.name)
if node is None:
return None
return node.type
else:
return api.anal_type(t)


def copy_method_to_another_class(
ctx: ClassDefContext, self_type: Instance, new_method_name: str, method_node: FuncDef
ctx: ClassDefContext,
self_type: Instance,
new_method_name: str,
method_node: FuncDef,
return_type: Optional[MypyType] = None,
original_module_name: Optional[str] = None,
) -> None:
semanal_api = get_semanal_api(ctx)
if method_node.type is None:
Expand All @@ -374,23 +392,20 @@ def copy_method_to_another_class(
semanal_api.defer()
return

arguments = []
bound_return_type = semanal_api.anal_type(method_type.ret_type, allow_placeholder=True)

assert bound_return_type is not None

if isinstance(bound_return_type, PlaceholderNode):
if return_type is None:
return_type = bind_or_analyze_type(method_type.ret_type, semanal_api, original_module_name)
if return_type is None:
return

try:
original_arguments = method_node.arguments[1:]
except AttributeError:
original_arguments = []

arguments = []
for arg_name, arg_type, original_argument in zip(
method_type.arg_names[1:], method_type.arg_types[1:], original_arguments
):
bound_arg_type = semanal_api.anal_type(arg_type)
bound_arg_type = bind_or_analyze_type(arg_type, semanal_api, original_module_name)
if bound_arg_type is None:
return

Expand All @@ -406,4 +421,10 @@ def copy_method_to_another_class(
argument.set_line(original_argument)
arguments.append(argument)

add_method(ctx, new_method_name, args=arguments, return_type=bound_return_type, self_type=self_type)
add_method(ctx, new_method_name, args=arguments, return_type=return_type, self_type=self_type)


def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None:
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
get_django_metadata(sym.node)["manager_bases"][fullname] = 1
15 changes: 9 additions & 6 deletions mypy_django_plugin/main.py
Expand Up @@ -53,10 +53,8 @@ def transform_form_class(ctx: ClassDefContext) -> None:
forms.make_meta_nested_class_inherit_from_any(ctx)


def add_new_manager_base(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
helpers.get_django_metadata(sym.node)["manager_bases"][ctx.cls.fullname] = 1
def add_new_manager_base_hook(ctx: ClassDefContext) -> None:
helpers.add_new_manager_base(ctx.api, ctx.cls.fullname)


def extract_django_settings_module(config_file_path: Optional[str]) -> str:
Expand Down Expand Up @@ -235,7 +233,12 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]:
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname:
deps.add(self._new_dependency(related_model_module))
return list(deps) + [self._new_dependency("django_stubs_ext")] # for annotate
return list(deps) + [
# for QuerySet.annotate
self._new_dependency("django_stubs_ext"),
# For BaseManager.from_queryset
self._new_dependency("django.db.models.query"),
]

def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], MypyType]]:
if fullname == "django.contrib.auth.get_user_model":
Expand Down Expand Up @@ -305,7 +308,7 @@ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefConte
return partial(transform_model_class, django_context=self.django_context)

if fullname in self._get_current_manager_bases():
return add_new_manager_base
return add_new_manager_base_hook

if fullname in self._get_current_form_bases():
return transform_form_class
Expand Down
88 changes: 79 additions & 9 deletions mypy_django_plugin/transformers/managers.py
@@ -1,6 +1,7 @@
from mypy.checker import fill_typevars
from mypy.nodes import GDEF, Decorator, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo
from mypy.plugin import ClassDefContext, DynamicClassDefContext
from mypy.types import AnyType, Instance, TypeOfAny
from mypy.types import CallableType, Instance, TypeVarType, UnboundType, get_proper_type

from mypy_django_plugin.lib import fullnames, helpers

Expand Down Expand Up @@ -29,15 +30,11 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
# But it should be analyzed again, so this isn't a problem.
return

base_manager_instance = fill_typevars(base_manager_info)
assert isinstance(base_manager_instance, Instance)
new_manager_info = semanal_api.basic_new_typeinfo(
ctx.name, basetype_or_fallback=Instance(base_manager_info, [AnyType(TypeOfAny.unannotated)]), line=ctx.call.line
ctx.name, basetype_or_fallback=base_manager_instance, line=ctx.call.line
)
new_manager_info.line = ctx.call.line
new_manager_info.defn.line = ctx.call.line
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()

current_module = semanal_api.cur_mod_node
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)

sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname)
assert sym is not None
Expand All @@ -52,6 +49,15 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
derived_queryset_info = sym.node
assert isinstance(derived_queryset_info, TypeInfo)

new_manager_info.line = ctx.call.line
new_manager_info.type_vars = base_manager_info.type_vars
new_manager_info.defn.type_vars = base_manager_info.defn.type_vars
new_manager_info.defn.line = ctx.call.line
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()

current_module = semanal_api.cur_mod_node
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)

if len(ctx.call.args) > 1:
expr = ctx.call.args[1]
assert isinstance(expr, StrExpr)
Expand All @@ -64,11 +70,19 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
base_manager_info.metadata["from_queryset_managers"] = {}
base_manager_info.metadata["from_queryset_managers"][custom_manager_generated_fullname] = new_manager_info.fullname

# So that the plugin will reparameterize the manager when it is constructed inside of a Model definition
helpers.add_new_manager_base(semanal_api, new_manager_info.fullname)

class_def_context = ClassDefContext(cls=new_manager_info.defn, reason=ctx.call, api=semanal_api)
self_type = Instance(new_manager_info, [])
self_type = fill_typevars(new_manager_info)
assert isinstance(self_type, Instance)
queryset_method_names = []

# we need to copy all methods in MRO before django.db.models.query.QuerySet
for class_mro_info in derived_queryset_info.mro:
if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME:
for name, sym in class_mro_info.names.items():
queryset_method_names.append(name)
break
for name, sym in class_mro_info.names.items():
if isinstance(sym.node, FuncDef):
Expand All @@ -80,3 +94,59 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
helpers.copy_method_to_another_class(
class_def_context, self_type, new_method_name=name, method_node=func_node
)

# Gather names of all BaseManager methods
manager_method_names = []
for manager_mro_info in new_manager_info.mro:
if manager_mro_info.fullname == fullnames.BASE_MANAGER_CLASS_FULLNAME:
for name, sym in manager_mro_info.names.items():
manager_method_names.append(name)

# Copy/alter all methods in common between BaseManager/QuerySet over to the new manager if their return type is
# the QuerySet's self-type. Alter the return type to be the custom queryset, parameterized by the manager's model
# type variable.
for class_mro_info in derived_queryset_info.mro:
if class_mro_info.fullname != fullnames.QUERYSET_CLASS_FULLNAME:
continue
for name, sym in class_mro_info.names.items():
if name not in manager_method_names:
continue

if isinstance(sym.node, FuncDef):
func_node = sym.node
elif isinstance(sym.node, Decorator):
func_node = sym.node.func
else:
continue

method_type = func_node.type
if not isinstance(method_type, CallableType):
if not semanal_api.final_iteration:
semanal_api.defer()
return None
original_return_type = method_type.ret_type
if original_return_type is None:
continue

# Skip any method that doesn't return _QS
original_return_type = get_proper_type(original_return_type)
if isinstance(original_return_type, UnboundType):
if original_return_type.name != "_QS":
continue
elif isinstance(original_return_type, TypeVarType):
if original_return_type.name != "_QS":
continue
else:
continue

# Return the custom queryset parameterized by the manager's type vars
return_type = Instance(derived_queryset_info, self_type.args)

helpers.copy_method_to_another_class(
class_def_context,
self_type,
new_method_name=name,
method_node=func_node,
return_type=return_type,
original_module_name=class_mro_info.module_name,
)
14 changes: 8 additions & 6 deletions tests/typecheck/managers/querysets/test_from_queryset.yml
@@ -1,9 +1,11 @@
- case: from_queryset_with_base_manager
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
reveal_type(MyModel.objects.filter(id=1).queryset_method()) # N: Revealed type is "builtins.str"
reveal_type(MyModel.objects.filter(id=1)) # N: Revealed type is "myapp.models.ModelQuerySet[myapp.models.MyModel*]"
installed_apps:
- myapp
files:
Expand All @@ -23,7 +25,7 @@
- case: from_queryset_with_manager
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
installed_apps:
Expand Down Expand Up @@ -97,7 +99,7 @@
- case: from_queryset_with_class_inheritance
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
installed_apps:
Expand All @@ -121,7 +123,7 @@
- case: from_queryset_with_manager_in_another_directory_and_imports
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method) # N: Revealed type is "def (param: Union[builtins.str, None] =) -> Union[builtins.str, None]"
reveal_type(MyModel().objects.queryset_method('str')) # N: Revealed type is "Union[builtins.str, None]"
Expand Down Expand Up @@ -151,7 +153,7 @@
disable_cache: true
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.base_queryset_method) # N: Revealed type is "def (param: Union[builtins.int, builtins.str]) -> <nothing>"
reveal_type(MyModel().objects.base_queryset_method(2)) # N: Revealed type is "<nothing>"
Expand Down Expand Up @@ -183,7 +185,7 @@
- case: from_queryset_with_decorated_queryset_methods
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
installed_apps:
- myapp
Expand Down

0 comments on commit 8da8ab4

Please sign in to comment.