Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections import OrderedDict
from typing import (
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union,
)
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union)

from typing_extensions import Protocol

from django.db.models.fields import Field
from django.db.models.fields.related import RelatedField
Expand Down Expand Up @@ -324,6 +325,59 @@ def _prepare_new_method_arguments(node: FuncDef) -> Tuple[List[Argument], MypyTy
return arguments, return_type


class SupportsNamedType(Protocol):
def named_type(self, qualified_name: str, args: Optional[List[MypyType]] = None) -> Instance:
pass


class FakeClassDefContext:
def __init__(self, cls: ClassDef, api: SupportsNamedType):
self.cls = cls
self.api = api


def _add_method(
info: TypeInfo,
api: SupportsNamedType,
old_method_node: FuncDef,
name: str,
self_type: Optional[MypyType] = None
) -> None:
fake_ctx = FakeClassDefContext(info.defn, api=api)
arguments, return_type = _prepare_new_method_arguments(old_method_node)
add_method(fake_ctx,
name,
args=arguments,
return_type=return_type,
self_type=self_type)


def basic_new_typeinfo(self, name: str, basetype_or_fallback: Instance) -> TypeInfo:
class_def = ClassDef(name, Block([]))
if self.is_func_scope() and not self.type:
# Full names of generated classes should always be prefixed with the module names
# even if they are nested in a function, since these classes will be (de-)serialized.
# (Note that the caller should append @line to the name to avoid collisions.)
# TODO: clean this up, see #6422.
class_def.fullname = self.cur_mod_id + '.' + self.qualified_name(name)
else:
class_def.fullname = self.qualified_name(name)

info = TypeInfo(SymbolTable(), class_def, self.cur_mod_id)
class_def.info = info
mro = basetype_or_fallback.type.mro
if not mro:
# Forward reference, MRO should be recalculated in third pass.
mro = [basetype_or_fallback.type, self.object_type().type]
info.mro = [info] + mro
info.bases = [basetype_or_fallback]
return info


# def copy_methods_to_another_class(source_info: TypeInfo, dest_type: Instance):
# pass


def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance,
new_method_name: str, method_node: FuncDef) -> None:
arguments, return_type = _prepare_new_method_arguments(method_node)
Expand All @@ -332,3 +386,5 @@ def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance,
args=arguments,
return_type=return_type,
self_type=self_type)


3 changes: 3 additions & 0 deletions mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def get_method_hook(self, fullname: str
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)

if method_name == 'as_manager':
return partial(querysets.create_new_class_from_as_manager_method, django_context=self.django_context)

manager_classes = self._get_current_manager_bases()
if class_fullname in manager_classes and method_name == 'create':
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
Expand Down
43 changes: 42 additions & 1 deletion mypy_django_plugin/transformers/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy.nodes import Expression, NameExpr
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance
from mypy.types import AnyType, Instance, CallableType
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny

Expand Down Expand Up @@ -190,3 +190,44 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan

row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()))
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])


def create_new_class_from_as_manager_method(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
assert isinstance(ctx.type, CallableType)
callee_type = ctx.type.bound_args[0]
if (not isinstance(callee_type, Instance)
or not callee_type.type.has_base(fullnames.QUERYSET_CLASS_FULLNAME)):
return ctx.default_return_type

callee_type_info = callee_type.type
manager_info = helpers.lookup_fully_qualified_typeinfo(ctx.api,
fullnames.MANAGER_CLASS_FULLNAME)
if manager_info is None:
return ctx.default_return_type

new_manager_name = f'{callee_type_info.name}_AsManager'
base_class = Instance(manager_info, [AnyType(TypeOfAny.unannotated)])
helpers.get_typechecker_api(ctx.api).mod
helpers.add_new_class_for_module()
new_manager_info = helpers.basic_new_typeinfo(new_manager_name,
basetype_or_fallback=base_class)
new_manager_info.line = ctx.context.line
new_manager_info.defn.line = ctx.context.line
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()

current_module = semanal_api.cur_mod_node
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info,
plugin_generated=True)

# helpers._add_method()
# class_def_context = ClassDefContext(cls=new_manager_info.defn,
# reason=ctx.call, api=semanal_api)
# self_type = Instance(new_manager_info, [])
# for name, sym in derived_queryset_info.names.items():
# if isinstance(sym.node, FuncDef):
# helpers.copy_method_to_another_class(class_def_context,
# self_type,
# new_method_name=name,
# method_node=sym.node)

print()
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
- case: query_as_manager_returns_manager_with_copied_methods
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is 'myapp.models.MyModel_ModelQuerySet_AsManager[myapp.models.MyModel]'
reveal_type(MyModel().objects.get()) # N: Revealed type is 'myapp.models.MyModel*'
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is 'builtins.str'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models

class ModelQuerySet(models.QuerySet):
def queryset_method(self) -> str:
return 'hello'
class MyModel(models.Model):
objects = ModelQuerySet.as_manager()