Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.13.2
hooks:
- id: ruff
- id: ruff-check
args: ["--fix", "--exit-non-zero-on-fix"]
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
Expand Down
17 changes: 6 additions & 11 deletions mypy_django_plugin/django/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,26 +88,23 @@ def get_field_type_from_model_type_info(info: TypeInfo | None, field_name: str)
if not isinstance(field_type, Instance):
return None
# Field declares a set and a get type arg. Fallback to `None` when we can't find any args
elif len(field_type.args) != 2:
if len(field_type.args) != 2:
return None
else:
return field_type
return field_type


def _get_field_set_type_from_model_type_info(info: TypeInfo | None, field_name: str) -> MypyType | None:
field_type = get_field_type_from_model_type_info(info, field_name)
if field_type is not None:
return field_type.args[0]
else:
return None
return None


def _get_field_get_type_from_model_type_info(info: TypeInfo | None, field_name: str) -> MypyType | None:
field_type = get_field_type_from_model_type_info(info, field_name)
if field_type is not None:
return field_type.args[1]
else:
return None
return None


class DjangoContext:
Expand Down Expand Up @@ -188,8 +185,7 @@ def get_related_target_field(
if not isinstance(rel_field, Field):
return None # Not supported
return rel_field
else:
return self.get_primary_key_field(related_model_cls)
return self.get_primary_key_field(related_model_cls)

def get_primary_key_field(self, model_cls: type[Model]) -> "Field[Any, Any]":
for field in model_cls._meta.get_fields():
Expand Down Expand Up @@ -365,8 +361,7 @@ def get_field_get_type(
return AnyType(TypeOfAny.unannotated)

return Instance(model_info, [])
else:
return helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_nullable)
return helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_nullable)

def get_field_related_model_cls(self, field: Union["RelatedField[Any, Any]", ForeignObjectRel]) -> type[Model]:
if isinstance(field, RelatedField):
Expand Down
8 changes: 3 additions & 5 deletions mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ def lookup_class_typeinfo(api: TypeChecker, klass: type | None) -> TypeInfo | No
return None

fullname = get_class_fullname(klass)
field_info = lookup_fully_qualified_typeinfo(api, fullname)
return field_info
return lookup_fully_qualified_typeinfo(api, fullname)


def get_class_fullname(klass: type) -> str:
Expand Down Expand Up @@ -536,13 +535,12 @@ def make_typeddict(
fallback_type = api.named_generic_type("typing._TypedDict", [])
else:
fallback_type = api.named_type("typing._TypedDict", [])
typed_dict_type = TypedDictType(
return TypedDictType(
fields,
required_keys=required_keys,
readonly_keys=readonly_keys,
fallback=fallback_type,
)
return typed_dict_type


def resolve_string_attribute_value(attr_expr: Expression, django_context: "DjangoContext") -> str | None:
Expand Down Expand Up @@ -644,7 +642,7 @@ def resolve_lazy_reference(
model_info = lookup_fully_qualified_typeinfo(api, fullname)
if model_info is not None:
return model_info
elif isinstance(api, SemanticAnalyzer) and not api.final_iteration:
if isinstance(api, SemanticAnalyzer) and not api.final_iteration:
# Getting this far, where Django matched the reference but we still can't
# find it, we want to defer
api.defer()
Expand Down
8 changes: 3 additions & 5 deletions mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def _get_current_form_bases(self) -> dict[str, int]:
bases[fullnames.FORM_CLASS_FULLNAME] = 1
bases[fullnames.MODELFORM_CLASS_FULLNAME] = 1
return bases
else:
return {}
return {}

def _get_typeinfo_or_none(self, class_name: str) -> TypeInfo | None:
sym = self.lookup_fully_qualified(class_name)
Expand Down Expand Up @@ -207,8 +206,7 @@ def get_customize_class_mro_hook(self, fullname: str) -> Callable[[ClassDefConte
info = self._get_typeinfo_or_none(fullname)
if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
return reparametrize_any_manager_hook
else:
return None
return None

def get_metaclass_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
if fullname == fullnames.MODEL_METACLASS_FULLNAME:
Expand Down Expand Up @@ -277,7 +275,7 @@ def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Mypy
def get_type_analyze_hook(self, fullname: str) -> Callable[[AnalyzeTypeContext], MypyType] | None:
if fullname in fullnames.ANNOTATED_TYPES_FULLNAMES:
return partial(handle_annotated_type, fullname=fullname)
elif fullname == "django.contrib.auth.models._User":
if fullname == "django.contrib.auth.models._User":
return partial(get_user_model, django_context=self.django_context)
return None

Expand Down
7 changes: 3 additions & 4 deletions mypy_django_plugin/transformers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def drop_combinable(_type: MypyType) -> MypyType | None:
_type = get_proper_type(_type)
if isinstance(_type, Instance) and _type.type.has_base(fullnames.COMBINABLE_EXPRESSION_FULLNAME):
return None
elif isinstance(_type, UnionType):
if isinstance(_type, UnionType):
items_without_combinable = []
for item in _type.items:
reduced = drop_combinable(item)
Expand All @@ -202,10 +202,9 @@ def drop_combinable(_type: MypyType) -> MypyType | None:
is_evaluated=_type.is_evaluated,
uses_pep604_syntax=_type.uses_pep604_syntax,
)
elif len(items_without_combinable) == 1:
if len(items_without_combinable) == 1:
return items_without_combinable[0]
else:
return None
return None

return _type

Expand Down
21 changes: 9 additions & 12 deletions mypy_django_plugin/transformers/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _process_dynamic_method(
def _get_funcdef_type(definition: Node | None) -> ProperType | None:
if isinstance(definition, FuncBase):
return definition.type
elif isinstance(definition, Decorator):
if isinstance(definition, Decorator):
return definition.func.type
return None

Expand Down Expand Up @@ -254,17 +254,17 @@ def _replace_type_var(ret_type: MypyType, to_replace: str, replace_by: MypyType)
return ret_type.copy_modified(
args=tuple(_replace_type_var(item, to_replace, replace_by) for item in ret_type.args)
)
elif isinstance(ret_type, TypeType):
if isinstance(ret_type, TypeType):
return TypeType.make_normalized(
_replace_type_var(ret_type.item, to_replace, replace_by),
line=ret_type.line,
column=ret_type.column,
)
elif isinstance(ret_type, TupleType):
if isinstance(ret_type, TupleType):
return ret_type.copy_modified(
items=[_replace_type_var(item, to_replace, replace_by) for item in ret_type.items]
)
elif isinstance(ret_type, UnionType):
if isinstance(ret_type, UnionType):
return UnionType.make_union(
items=[_replace_type_var(item, to_replace, replace_by) for item in ret_type.items],
line=ret_type.line,
Expand All @@ -288,7 +288,7 @@ def resolve_manager_method(ctx: AttributeContext) -> MypyType:
default_attr_type = get_proper_type(ctx.default_attr_type)
if not isinstance(default_attr_type, AnyType):
return ctx.default_attr_type
elif default_attr_type.type_of_any != TypeOfAny.implementation_artifact:
if default_attr_type.type_of_any != TypeOfAny.implementation_artifact:
return ctx.default_attr_type

# (Current state is:) We wouldn't end up here when looking up a method from a custom _manager_.
Expand All @@ -303,18 +303,15 @@ def resolve_manager_method(ctx: AttributeContext) -> MypyType:

if isinstance(ctx.type, Instance):
return resolve_manager_method_from_instance(instance=ctx.type, method_name=method_name, ctx=ctx)
elif isinstance(ctx.type, UnionType) and all(
isinstance(get_proper_type(item), Instance) for item in ctx.type.items
):
if isinstance(ctx.type, UnionType) and all(isinstance(get_proper_type(item), Instance) for item in ctx.type.items):
resolved = tuple(
resolve_manager_method_from_instance(instance=instance, method_name=method_name, ctx=ctx)
for item in ctx.type.items
if isinstance((instance := get_proper_type(item)), Instance)
)
return UnionType(resolved)
else:
ctx.api.fail(f'Unable to resolve return type of queryset/manager method "{method_name}"', ctx.context)
return AnyType(TypeOfAny.from_error)
ctx.api.fail(f'Unable to resolve return type of queryset/manager method "{method_name}"', ctx.context)
return AnyType(TypeOfAny.from_error)


def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefContext) -> None:
Expand Down Expand Up @@ -476,7 +473,7 @@ def populate_manager_from_queryset(manager_info: TypeInfo, queryset_info: TypeIn
continue
# private, magic methods are not copied
# https://github.com/django/django/blob/5.0.4/django/db/models/manager.py#L101
elif name.startswith("_"):
if name.startswith("_"):
continue
# Insert the queryset method name as a class member. Note that the type of
# the method is set as Any. Figuring out the type is the job of the
Expand Down
18 changes: 7 additions & 11 deletions mypy_django_plugin/transformers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo:

def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo:
fullname = helpers.get_class_fullname(klass)
field_info = self.lookup_typeinfo_or_incomplete_defn_error(fullname)
return field_info
return self.lookup_typeinfo_or_incomplete_defn_error(fullname)

def create_new_var(self, name: str, typ: MypyType) -> Var:
# type=: type of the variable itself
Expand All @@ -101,8 +100,7 @@ def add_new_node_to_model_class(

def add_new_class_for_current_module(self, name: str, bases: list[Instance]) -> TypeInfo:
current_module = self.api.modules[self.model_classdef.info.module_name]
new_class_info = helpers.add_new_class_for_module(current_module, name=name, bases=bases)
return new_class_info
return helpers.add_new_class_for_module(current_module, name=name, bases=bases)

def run(self) -> None:
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname)
Expand Down Expand Up @@ -355,8 +353,7 @@ def run_with_model_cls(self, model_cls: type[Model]) -> None:
except helpers.IncompleteDefnException as exc:
if not self.api.final_iteration:
raise exc
else:
continue
continue

is_nullable = self.django_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(
Expand Down Expand Up @@ -503,8 +500,7 @@ def run_with_model_cls(self, model_cls: type[Model]) -> None:
# see if another round could help figuring out the default manager type
if not self.api.final_iteration:
raise exc
else:
return None
return None
default_manager_info = generated_manager_info

default_manager = helpers.fill_manager(default_manager_info, Instance(self.model_classdef.info, []))
Expand Down Expand Up @@ -544,7 +540,7 @@ def process_relation(self, relation: ForeignObjectRel) -> None:
),
)
return
elif isinstance(relation, ManyToManyRel):
if isinstance(relation, ManyToManyRel):
if not reverse_lookup_declared:
# TODO: 'relation' should be based on `TypeInfo` instead of Django runtime.
assert relation.through is not None
Expand All @@ -558,7 +554,7 @@ def process_relation(self, relation: ForeignObjectRel) -> None:
is_classvar=True,
)
return
elif not reverse_lookup_declared:
if not reverse_lookup_declared:
# ManyToOneRel
self.add_new_node_to_model_class(
attname, Instance(self.reverse_many_to_one_descriptor, [Instance(to_model_info, [])]), is_classvar=True
Expand Down Expand Up @@ -838,7 +834,7 @@ def create_through_table_class(
) -> TypeInfo | None:
if not isinstance(m2m_args.to.model, Instance):
return None
elif m2m_args.through is not None:
if m2m_args.through is not None:
# Call has explicit 'through=', no need to create any implicit through table
return m2m_args.through.model.type if isinstance(m2m_args.through.model, Instance) else None

Expand Down
23 changes: 10 additions & 13 deletions mypy_django_plugin/transformers/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,15 @@ def get_field_type_from_lookup(

if lookup_field is None:
return AnyType(TypeOfAny.implementation_artifact)
elif (isinstance(lookup_field, RelatedField) and lookup_field.column == lookup) or isinstance(
if (isinstance(lookup_field, RelatedField) and lookup_field.column == lookup) or isinstance(
lookup_field, ForeignObjectRel
):
model_cls = django_context.get_field_related_model_cls(lookup_field)
lookup_field = django_context.get_primary_key_field(model_cls)

api = helpers.get_typechecker_api(ctx)
model_info = helpers.lookup_class_typeinfo(api, model_cls)
field_get_type = django_context.get_field_get_type(api, model_info, lookup_field, method=method)
return field_get_type
return django_context.get_field_get_type(api, model_info, lookup_field, method=method)


def get_values_list_row_type(
Expand All @@ -94,7 +93,7 @@ def get_values_list_row_type(
)
assert lookup_type is not None
return lookup_type
elif named:
if named:
column_types: dict[str, MypyType] = {}
for field in django_context.get_model_fields(model_cls):
column_type = django_context.get_field_get_type(
Expand All @@ -109,15 +108,13 @@ def get_values_list_row_type(
column_types,
extra_bases=[typechecker_api.named_generic_type(fullnames.ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])],
)
else:
return helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
else:
# flat=False, named=False, all fields
if is_annotated:
return typechecker_api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.special_form)])
field_lookups = []
for field in django_context.get_model_fields(model_cls):
field_lookups.append(field.attname)
return helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
# flat=False, named=False, all fields
if is_annotated:
return typechecker_api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.special_form)])
field_lookups = []
for field in django_context.get_model_fields(model_cls):
field_lookups.append(field.attname)

if len(field_lookups) > 1 and flat:
typechecker_api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context)
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ select = [
"RUF100", # Equivalent to flake8-noqa NQA103
"PGH004", # Equivalent to flake8-noqa NQA104
"PGH003", # Disallowed blanket `type: ignore` annotations.
"RET504", # Unnecessary assignment to {name} before return statement
"RET505", # Unnecessary {branch} after return statement
"RET506", # Unnecessary {branch} after raise statement
"RET507", # Unnecessary {branch} after continue statement
"RET508", # Unnecessary {branch} after break statement
]
ignore = [
"PYI021", # We have a few meaningful docstrings for stubs only constructs/utilities.
Expand Down
6 changes: 2 additions & 4 deletions tests/assert_type/db/models/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,14 @@ class VoidChoices(BaseEmptyChoices):
def get_suit_with_color(suit: Suit) -> str:
if suit == Suit.DIAMOND or suit == Suit.HEART:
return f"{suit.label} is red."
else:
return f"{suit.label} is black."
return f"{suit.label} is black."


# Checks a single enum literal to test that the plugin resolves types correctly.
def is_suit_a_diamond(suit: Suit) -> str:
if suit == Suit.DIAMOND:
return f"{suit.label}: Yes!"
else:
return f"{suit.label}: No!"
return f"{suit.label}: No!"


# Choice type that overrides a property and uses `super()` to test the plugin resolve types correctly.
Expand Down
5 changes: 2 additions & 3 deletions tests/test_generic_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def _get_need_generic() -> list[MPGeneric[Any]]:
from django.contrib.auth.forms import SetPasswordMixin, SetUnusablePasswordMixin

return [MPGeneric(SetPasswordMixin), MPGeneric(SetUnusablePasswordMixin), *django_stubs_ext.patch._need_generic]
else:
from django.contrib.auth.forms import AdminPasswordChangeForm, SetPasswordForm
from django.contrib.auth.forms import AdminPasswordChangeForm, SetPasswordForm

return [MPGeneric(SetPasswordForm), MPGeneric(AdminPasswordChangeForm), *django_stubs_ext.patch._need_generic]
return [MPGeneric(SetPasswordForm), MPGeneric(AdminPasswordChangeForm), *django_stubs_ext.patch._need_generic]
Loading