Skip to content

Commit

Permalink
Fix ForeignKey queryset filters on un-swapped models (#1495)
Browse files Browse the repository at this point in the history
* Add failing test

* Fix related_model lookups for un-swapped models

* Handle possible unregistered models

* Add `get_models_foreign_key` helper

* Refactor `get_field_related_model_cls` to raise `UnregisteredModelError`

* Remove unreachable statements

* Add missing check
  • Loading branch information
UnknownPlatypus committed Jun 16, 2023
1 parent c1cb879 commit e778561
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 66 deletions.
44 changes: 29 additions & 15 deletions mypy_django_plugin/django/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mypy.types import AnyType, Instance, TypeOfAny, UnionType
from mypy.types import Type as MypyType

from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME

Expand Down Expand Up @@ -118,7 +119,19 @@ def get_model_fields(self, model_cls: Type[Model]) -> Iterator["Field[Any, Any]"
if isinstance(field, Field):
yield field

def get_model_foreign_keys(self, model_cls: Type[Model]) -> Iterator["ForeignKey[Any, Any]"]:
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignKey):
yield field

def get_model_related_fields(self, model_cls: Type[Model]) -> Iterator["RelatedField[Any, Any]"]:
"""Get model forward relations"""
for field in model_cls._meta.get_fields():
if isinstance(field, RelatedField):
yield field

def get_model_relations(self, model_cls: Type[Model]) -> Iterator[ForeignObjectRel]:
"""Get model reverse relations"""
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignObjectRel):
yield field
Expand All @@ -127,7 +140,7 @@ def get_field_lookup_exact_type(
self, api: TypeChecker, field: Union["Field[Any, Any]", ForeignObjectRel]
) -> MypyType:
if isinstance(field, (RelatedField, ForeignObjectRel)):
related_model_cls = field.related_model
related_model_cls = self.get_field_related_model_cls(field)
primary_key_field = self.get_primary_key_field(related_model_cls)
primary_key_type = self.get_field_get_type(api, primary_key_field, method="init")

Expand Down Expand Up @@ -210,9 +223,6 @@ def get_field_set_type_from_model_type_info(info: Optional[TypeInfo], field_name
continue

related_model = self.get_field_related_model_cls(field)
if related_model is None:
expected_types[field_name] = AnyType(TypeOfAny.from_error)
continue

if related_model._meta.proxy_for_model is not None:
related_model = related_model._meta.proxy_for_model
Expand Down Expand Up @@ -312,8 +322,6 @@ def get_field_get_type(
is_nullable = self.get_field_nullability(field, method)
if isinstance(field, RelatedField):
related_model_cls = self.get_field_related_model_cls(field)
if related_model_cls is None:
return AnyType(TypeOfAny.from_error)

if method in ("values", "values_list"):
primary_key_field = self.get_primary_key_field(related_model_cls)
Expand All @@ -327,9 +335,7 @@ def get_field_get_type(
else:
return helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_nullable)

def get_field_related_model_cls(
self, field: Union["RelatedField[Any, Any]", ForeignObjectRel]
) -> Optional[Type[Model]]:
def get_field_related_model_cls(self, field: Union["RelatedField[Any, Any]", ForeignObjectRel]) -> Type[Model]:
if isinstance(field, RelatedField):
related_model_cls = field.remote_field.model
else:
Expand All @@ -341,13 +347,15 @@ def get_field_related_model_cls(
related_model_cls = field.model
elif "." not in related_model_cls:
# same file model
related_model_fullname = field.model.__module__ + "." + related_model_cls
related_model_fullname = f"{field.model.__module__}.{related_model_cls}"
related_model_cls = self.get_model_class_by_fullname(related_model_fullname)
if related_model_cls is None:
raise UnregisteredModelError
else:
try:
related_model_cls = self.apps_registry.get_model(related_model_cls)
except LookupError:
return None
except LookupError as e:
raise UnregisteredModelError from e

return related_model_cls

Expand All @@ -363,13 +371,13 @@ def _resolve_field_from_parts(

field = currently_observed_model._meta.get_field(field_part)
if isinstance(field, RelatedField):
currently_observed_model = field.related_model
currently_observed_model = self.get_field_related_model_cls(field)
model_name = currently_observed_model._meta.model_name
if model_name is not None and field_part == (model_name + "_id"):
field = self.get_primary_key_field(currently_observed_model)

if isinstance(field, ForeignObjectRel):
currently_observed_model = field.related_model
currently_observed_model = self.get_field_related_model_cls(field)

# Guaranteed by `query.solve_lookup_type` before.
assert isinstance(field, (Field, ForeignObjectRel))
Expand Down Expand Up @@ -397,9 +405,15 @@ def solve_lookup_type(
field = query.get_meta().get_field(query_parts[0])
except FieldDoesNotExist:
return None

if len(query_parts) == 1:
return [], [query_parts[0]], False
sub_query = Query(field.related_model).solve_lookup_type("__")

if not isinstance(field, (RelatedField, ForeignObjectRel)):
return None

related_model = self.get_field_related_model_cls(field)
sub_query = Query(related_model).solve_lookup_type("__".join(query_parts[1:]))
entire_query_parts = [query_parts[0], *sub_query[1]]
return sub_query[0], entire_query_parts, sub_query[2]

Expand Down
2 changes: 2 additions & 0 deletions mypy_django_plugin/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class UnregisteredModelError(Exception):
"""The requested model is not registered"""
26 changes: 13 additions & 13 deletions mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import itertools
import sys
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Type

from django.db.models.fields.related import RelatedField
from mypy.modulefinder import mypy_path
from mypy.nodes import MypyFile, TypeInfo
from mypy.options import Options
Expand All @@ -20,6 +20,7 @@
import mypy_django_plugin.transformers.orm_lookups
from mypy_django_plugin.config import DjangoPluginConfig
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import fields, forms, init_create, meta, querysets, request, settings
from mypy_django_plugin.transformers.functional import resolve_str_promise_attribute
Expand Down Expand Up @@ -147,23 +148,22 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]:
if not defined_model_classes:
return []
deps = set()

for model_class in defined_model_classes:
# forward relations
for field in self.django_context.get_model_fields(model_class):
if isinstance(field, RelatedField):
for field in itertools.chain(
# forward relations
self.django_context.get_model_related_fields(model_class),
# reverse relations - `related_objects` is private API (according to docstring)
model_class._meta.related_objects, # type: ignore[attr-defined]
):
try:
related_model_cls = self.django_context.get_field_related_model_cls(field)
if related_model_cls is None:
continue
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname:
deps.add(self._new_dependency(related_model_module))
# reverse relations
# `related_objects` is private API (according to docstring)
for relation in model_class._meta.related_objects: # type: ignore[attr-defined]
related_model_cls = self.django_context.get_field_related_model_cls(relation)
except UnregisteredModelError:
continue
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname:
deps.add(self._new_dependency(related_model_module))

return list(deps) + [
# for QuerySet.annotate
self._new_dependency("django_stubs_ext"),
Expand Down
6 changes: 4 additions & 2 deletions mypy_django_plugin/transformers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mypy.types import Type as MypyType

from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,8 +60,9 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context

assert isinstance(current_field, RelatedField)

related_model_cls = django_context.get_field_related_model_cls(current_field)
if related_model_cls is None:
try:
related_model_cls = django_context.get_field_related_model_cls(current_field)
except UnregisteredModelError:
return AnyType(TypeOfAny.from_error)

default_related_field_type = set_descriptor_types_for_field(ctx)
Expand Down
64 changes: 31 additions & 33 deletions mypy_django_plugin/transformers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from django.db.models import Manager, Model
from django.db.models.fields import DateField, DateTimeField, Field
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.reverse_related import ForeignObjectRel, OneToOneRel
from mypy.checker import TypeChecker
from mypy.nodes import ARG_STAR2, Argument, AssignmentStmt, CallExpr, Context, NameExpr, TypeInfo, Var
Expand All @@ -15,6 +14,7 @@

from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.errorcodes import MANAGER_MISSING
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME
from mypy_django_plugin.transformers import fields
Expand Down Expand Up @@ -234,41 +234,41 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:

class AddRelatedModelsId(ModelClassInitializer):
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignKey):
for field in self.django_context.get_model_foreign_keys(model_cls):
try:
related_model_cls = self.django_context.get_field_related_model_cls(field)
if related_model_cls is None:
error_context: Context = self.ctx.cls
field_sym = self.ctx.cls.info.get(field.name)
if field_sym is not None and field_sym.node is not None:
error_context = field_sym.node
self.api.fail(
f"Cannot find model {field.related_model!r} referenced in field {field.name!r}",
ctx=error_context,
)
self.add_new_node_to_model_class(field.attname, AnyType(TypeOfAny.explicit))
continue
except UnregisteredModelError:
error_context: Context = self.ctx.cls
field_sym = self.ctx.cls.info.get(field.name)
if field_sym is not None and field_sym.node is not None:
error_context = field_sym.node
self.api.fail(
f"Cannot find model {field.related_model!r} referenced in field {field.name!r}",
ctx=error_context,
)
self.add_new_node_to_model_class(field.attname, AnyType(TypeOfAny.explicit))
continue

if related_model_cls._meta.abstract:
continue
if related_model_cls._meta.abstract:
continue

rel_target_field = self.django_context.get_related_target_field(related_model_cls, field)
if not rel_target_field:
continue
rel_target_field = self.django_context.get_related_target_field(related_model_cls, field)
if not rel_target_field:
continue

try:
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_target_field.__class__)
except helpers.IncompleteDefnException as exc:
if not self.api.final_iteration:
raise exc
else:
continue
try:
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_target_field.__class__)
except helpers.IncompleteDefnException as exc:
if not self.api.final_iteration:
raise exc
else:
continue

is_nullable = self.django_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(
field_info, is_set_nullable=is_nullable, is_get_nullable=is_nullable
)
self.add_new_node_to_model_class(field.attname, Instance(field_info, [set_type, get_type]))
is_nullable = self.django_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(
field_info, is_set_nullable=is_nullable, is_get_nullable=is_nullable
)
self.add_new_node_to_model_class(field.attname, Instance(field_info, [set_type, get_type]))


class AddManagers(ModelClassInitializer):
Expand Down Expand Up @@ -448,8 +448,6 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:
continue

related_model_cls = self.django_context.get_field_related_model_cls(relation)
if related_model_cls is None:
continue

try:
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls)
Expand Down
6 changes: 5 additions & 1 deletion mypy_django_plugin/transformers/orm_lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mypy.types import Type as MypyType

from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.helpers import is_annotated_model_fullname

Expand Down Expand Up @@ -36,7 +37,10 @@ def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext)
if is_annotated_model_fullname(model_cls_fullname):
lookup_type = AnyType(TypeOfAny.implementation_artifact)
else:
lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
try:
lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
except UnregisteredModelError:
lookup_type = AnyType(TypeOfAny.from_error)
# Managers as provided_type is not supported yet
if isinstance(provided_type, Instance) and helpers.has_any_of_bases(
provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME, fullnames.QUERYSET_CLASS_FULLNAME)
Expand Down
2 changes: 0 additions & 2 deletions mypy_django_plugin/transformers/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def get_field_type_from_lookup(
lookup_field, ForeignObjectRel
):
related_model_cls = django_context.get_field_related_model_cls(lookup_field)
if related_model_cls is None:
return AnyType(TypeOfAny.from_error)
lookup_field = django_context.get_primary_key_field(related_model_cls)

field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), lookup_field, method=method)
Expand Down
32 changes: 32 additions & 0 deletions tests/typecheck/fields/test_related.yml
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,7 @@
- case: test_fails_if_app_label_is_unknown_in_relation_field
main: |
from installed.models import InstalledModel
InstalledModel.objects.filter(non_installed__isnull=True)
installed_apps:
- installed
files:
Expand Down Expand Up @@ -975,3 +976,34 @@
class Book(PrintedGood):
name = models.CharField()
- case: test_foreign_key_to_as_string_filter_on_abstract
main: |
from myapp.models import Book, Publisher
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models/__init__.py
content: |
from django.db import models
from django.db.models.query import QuerySet
class Publisher(models.Model):
name = models.CharField()
class MyModel(models.Model):
pass
class PrintedGood(MyModel):
publisher = models.ForeignKey(to="myapp.Publisher", on_delete=models.CASCADE)
@property
def siblings(self) -> QuerySet['PrintedGood']:
return self.__class__.objects.filter(publisher=self.publisher)
class Meta:
abstract = True
class Book(PrintedGood):
name = models.CharField()

0 comments on commit e778561

Please sign in to comment.