Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from mypy.types import (
AnyType,
CallableType,
ExtraAttrs,
Instance,
LiteralType,
NoneTyp,
Expand Down Expand Up @@ -694,3 +695,31 @@ def get_model_from_expression(

def fill_manager(manager: TypeInfo, typ: MypyType) -> Instance:
return Instance(manager, [typ] if manager.is_generic() else [])


def merge_extra_attrs(
base_extra_attrs: ExtraAttrs | None,
*,
new_attrs: dict[str, MypyType] | None = None,
new_immutable: set[str] | None = None,
) -> ExtraAttrs:
"""
Create a new `ExtraAttrs` by merging new attributes/immutable fields into a base.

If base_extra_attrs is None, creates a fresh ExtraAttrs with only the new values.
"""
if base_extra_attrs:
return ExtraAttrs(
attrs={**base_extra_attrs.attrs, **new_attrs} if new_attrs is not None else base_extra_attrs.attrs.copy(),
immutable=(
base_extra_attrs.immutable | new_immutable
if new_immutable is not None
else base_extra_attrs.immutable.copy()
),
mod_name=None,
)
return ExtraAttrs(
attrs=new_attrs or {},
immutable=new_immutable,
mod_name=None,
)
25 changes: 2 additions & 23 deletions mypy_django_plugin/transformers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,7 @@
from mypy.plugins import common
from mypy.semanal import SemanticAnalyzer
from mypy.typeanal import TypeAnalyser
from mypy.types import (
AnyType,
ExtraAttrs,
Instance,
ProperType,
TypedDictType,
TypeOfAny,
TypeType,
TypeVarType,
get_proper_type,
)
from mypy.types import AnyType, Instance, ProperType, TypedDictType, TypeOfAny, TypeType, TypeVarType, get_proper_type
from mypy.types import Type as MypyType
from mypy.typevars import fill_typevars, fill_typevars_with_any

Expand Down Expand Up @@ -1166,18 +1156,7 @@ def get_annotated_type(
"""
Get a model type that can be used to represent an annotated model
"""
if model_type.extra_attrs:
extra_attrs = ExtraAttrs(
attrs={**model_type.extra_attrs.attrs, **fields_dict.items},
immutable=model_type.extra_attrs.immutable.copy(),
mod_name=None,
)
else:
extra_attrs = ExtraAttrs(
attrs=fields_dict.items,
immutable=None,
mod_name=None,
)
extra_attrs = helpers.merge_extra_attrs(model_type.extra_attrs, new_attrs=fields_dict.items)

annotated_model: TypeInfo | None
if helpers.is_annotated_model(model_type.type):
Expand Down
67 changes: 60 additions & 7 deletions mypy_django_plugin/transformers/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
row_type = get_values_list_row_type(
ctx, django_context, django_model.cls, is_annotated=django_model.is_annotated, flat=flat, named=named
)
return default_return_type.copy_modified(args=[django_model.typ, row_type])
ret = default_return_type.copy_modified(args=[django_model.typ, row_type])
if not named and (field_lookups := resolve_field_lookups(ctx.args[0], django_context)):
# For non-named values_list, the row type does not encode column names.
# Attach selected field names to the returned QuerySet instance so that
# subsequent annotate() can make an informed decision about name conflicts.
ret.extra_attrs = helpers.merge_extra_attrs(ret.extra_attrs, new_immutable=set(field_lookups))
return ret


def gather_kwargs(ctx: MethodContext) -> dict[str, MypyType] | None:
Expand Down Expand Up @@ -233,10 +239,11 @@ def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: Dj
return ctx.default_return_type

api = helpers.get_typechecker_api(ctx)

expression_types = {
attr_name: typ
for attr_name, typ in gather_expression_types(ctx).items()
if check_valid_attr_value(ctx, django_model, attr_name)
if check_valid_attr_value(ctx, django_context, django_model, attr_name)
}

annotated_type: ProperType = django_model.typ
Expand Down Expand Up @@ -426,22 +433,66 @@ def gather_flat_args(ctx: MethodContext) -> list[tuple[Expression | None, Proper
return lookups


def _get_selected_fields_from_queryset_type(qs_type: Instance) -> set[str] | None:
"""
Derive selected field names from a QuerySet type.

Sources:
- values(): encoded in the row TypedDict keys
- values_list(named=True): row is a NamedTuple; extract field names from fallback TypeInfo
- values_list(named=False): stored in qs_type.extra_attrs.immutable
"""
if len(qs_type.args) > 1:
row_type = get_proper_type(qs_type.args[1])
if isinstance(row_type, Instance) and helpers.is_model_type(row_type.type):
return None
if isinstance(row_type, TypedDictType):
return set(row_type.items.keys())
if isinstance(row_type, TupleType):
if row_type.partial_fallback.type.has_base("typing.NamedTuple"):
return {name for name, sym in row_type.partial_fallback.type.names.items() if sym.plugin_generated}
else:
return set()
return set()

# Fallback to explicit metadata attached to the QuerySet Instance
if qs_type.extra_attrs and qs_type.extra_attrs.immutable and isinstance(qs_type.extra_attrs.immutable, set):
return qs_type.extra_attrs.immutable

return None


def check_valid_attr_value(
ctx: MethodContext, model: DjangoModel, attr_name: str, new_attrs: dict[str, MypyType] | None = None
ctx: MethodContext,
django_context: DjangoContext,
model: DjangoModel,
attr_name: str,
*,
new_attr_names: set[str] | None = None,
) -> bool:
"""
Check if adding `attr_name` would conflict with existing symbols on `model`.

Args:
- model: The Django model being analyzed
- attr_name: The name of the attribute to be added
- new_attrs: A mapping of field names to types currently being added to the model
- new_attr_names: A mapping of field names to types currently being added to the model
"""
deselected_fields: set[str] | None = None
if isinstance(ctx.type, Instance):
selected_fields = _get_selected_fields_from_queryset_type(ctx.type)
if selected_fields is not None:
model_field_names = {f.name for f in django_context.get_model_fields(model.cls)}
deselected_fields = model_field_names - selected_fields
new_attr_names = new_attr_names or set()
new_attr_names.update(selected_fields - model_field_names)

is_conflicting_attr_value = bool(
# 1. Conflict with another symbol on the model.
# 1. Conflict with another symbol on the model (If not de-selected via a prior .values/.values_list call).
# Ex:
# User.objects.prefetch_related(Prefetch(..., to_attr="id"))
model.typ.type.get(attr_name)
and (deselected_fields is None or attr_name not in deselected_fields)
# 2. Conflict with a previous annotation.
# Ex:
# User.objects.annotate(foo=...).prefetch_related(Prefetch(...,to_attr="foo"))
Expand All @@ -453,7 +504,7 @@ def check_valid_attr_value(
# Prefetch("groups", Group.objects.filter(name="test"), to_attr="new_attr"),
# Prefetch("groups", Group.objects.all(), to_attr="new_attr"), # E: Not OK!
# )
or (new_attrs is not None and attr_name in new_attrs)
or (new_attr_names is not None and attr_name in new_attr_names)
)
if is_conflicting_attr_value:
ctx.api.fail(
Expand Down Expand Up @@ -585,7 +636,9 @@ def extract_prefetch_related_annotations(ctx: MethodContext, django_context: Dja
except (FieldError, LookupsAreUnsupported):
pass

if to_attr and check_valid_attr_value(ctx, qs_model, to_attr, new_attrs):
if to_attr and check_valid_attr_value(
ctx, django_context, qs_model, to_attr, new_attr_names=set(new_attrs.keys())
):
new_attrs[to_attr] = api.named_generic_type(
"builtins.list",
[elem_model if elem_model is not None else AnyType(TypeOfAny.special_form)],
Expand Down
69 changes: 69 additions & 0 deletions tests/typecheck/managers/querysets/test_annotate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,72 @@
content: |
from django.db import models
class Blog(models.Model): pass

- case: test_annotate_existing_field
installed_apps:
- django.contrib.auth
main: |
from typing import Any
from typing_extensions import TypedDict
from django.db import models
from django.db.models import Prefetch, F, QuerySet
from django.contrib.auth.models import User, Group
from django_stubs_ext.annotations import WithAnnotations

# Error on existing field / remapped field / previous annotation
User.objects.annotate(username=F("username")) # E: Attribute "username" already defined on "django.contrib.auth.models.User" [no-redef]
User.objects.values(foo=F("id")).annotate(foo=F("username")) # E: Attribute "foo" already defined on "django.contrib.auth.models.User" [no-redef]
User.objects.annotate(computed=F('id')).annotate(computed=F('username')) # E: Attribute "computed" already defined on "django.contrib.auth.models.User@AnnotatedWith[TypedDict({'computed': Any})]" [no-redef]

# Should be ok if filtered with `values` / `values_list`
User.objects.values_list("id").annotate(username=F("username"))
User.objects.values("id").annotate(username=F("username"))
User.objects.values_list("id", named=True).annotate(username=F("username"))

def get_with_values_list() -> QuerySet[User, tuple[int]]:
return User.objects.values_list("id")

get_with_values_list().annotate(username=F("username"))

class OnlyIdDict(TypedDict):
id: int

def get_with_values() -> QuerySet[User, OnlyIdDict]:
return User.objects.values("id")

get_with_values().annotate(username=F("username"))

# But still cause issue if overlapping with other symbols (methods, ...)
User.objects.values_list("id").annotate(get_full_name=F("username")) # E: Attribute "get_full_name" already defined on "django.contrib.auth.models.User" [no-redef]
User.objects.values("id").annotate(get_full_name=F("username")) # E: Attribute "get_full_name" already defined on "django.contrib.auth.models.User" [no-redef]

# No false positive on approximative row types
tuple_any_row: models.QuerySet[User, tuple[Any, ...]]
tuple_any_row.annotate(username=F("username"))

any_row: models.QuerySet[User, Any]
any_row.annotate(username=F("username"))

dict_row: models.QuerySet[User, dict]
dict_row.annotate(username=F("username"))

# Ensure collisions with methods are still errors in approximate contexts
get_with_values_list().annotate(get_full_name=F("username")) # E: Attribute "get_full_name" already defined on "django.contrib.auth.models.User" [no-redef]
tuple_any_row.annotate(get_full_name=F("username")) # E: Attribute "get_full_name" already defined on "django.contrib.auth.models.User" [no-redef]
any_row.annotate(get_full_name=F("username")) # E: Attribute "get_full_name" already defined on "django.contrib.auth.models.User" [no-redef]
dict_row.annotate(get_full_name=F("username")) # E: Attribute "get_full_name" already defined on "django.contrib.auth.models.User" [no-redef]

# Named values_list

# Test name collision with model methods in named values_list - should still error
User.objects.values_list("id", named=True).annotate(get_full_name=F("username")) # E: Attribute "get_full_name" already defined on "django.contrib.auth.models.User" [no-redef]
User.objects.values_list("username", named=True).annotate(get_full_name=F("first_name")) # E: Attribute "get_full_name" already defined on "django.contrib.auth.models.User" [no-redef]
User.objects.values_list("id", named=True).annotate(is_anonymous=F("username")) # E: Attribute "is_anonymous" already defined on "django.contrib.auth.models.User" [no-redef]

# Test name collision with model fields when field is NOT in values_list - should be OK
User.objects.values_list("id", named=True).annotate(username=F("first_name"))
User.objects.values_list("first_name", named=True).annotate(username=F("last_name"))

# Test name collision with model fields when field IS in values_list - should error
User.objects.values_list("username", named=True).annotate(username=F("first_name")) # E: Attribute "username" already defined on "django.contrib.auth.models.User" [no-redef]
User.objects.values_list("id", "username", named=True).annotate(username=F("first_name")) # E: Attribute "username" already defined on "django.contrib.auth.models.User" [no-redef]
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading