diff --git a/django-stubs/template/defaultfilters.pyi b/django-stubs/template/defaultfilters.pyi index 7eacd3fa7..f6a40d4ff 100644 --- a/django-stubs/template/defaultfilters.pyi +++ b/django-stubs/template/defaultfilters.pyi @@ -2,13 +2,18 @@ from collections.abc import Callable from datetime import date as _date from datetime import datetime from datetime import time as _time -from typing import Any +from typing import Any, TypeVar +from django.template.library import Library from django.utils.safestring import SafeString +from typing_extensions import Concatenate, ParamSpec -register: Any +_P = ParamSpec("_P") +_R = TypeVar("_R") -def stringfilter(func: Callable) -> Callable: ... +register: Library + +def stringfilter(func: Callable[Concatenate[str, _P], _R]) -> Callable[Concatenate[object, _P], _R]: ... def addslashes(value: str) -> str: ... def capfirst(value: str) -> str: ... def escapejs_filter(value: str) -> SafeString: ... diff --git a/django-stubs/template/library.pyi b/django-stubs/template/library.pyi index 927918e2c..6327ef084 100644 --- a/django-stubs/template/library.pyi +++ b/django-stubs/template/library.pyi @@ -1,42 +1,90 @@ from collections.abc import Callable, Collection, Iterable, Mapping, Sequence, Sized -from typing import Any, TypeVar, overload +from typing import Any, Literal, TypeVar, overload from django.template.base import FilterExpression, Origin, Parser, Token from django.template.context import Context from django.utils.safestring import SafeString +from typing_extensions import Concatenate from .base import Node, Template class InvalidTemplateLibrary(Exception): ... _C = TypeVar("_C", bound=Callable[..., Any]) +_CompileC = TypeVar("_CompileC", bound=Callable[[Parser, Token], Node]) +_FilterC = TypeVar("_FilterC", bound=Callable[[Any], Any] | Callable[[Any, Any], Any]) +_TakesContextC = TypeVar("_TakesContextC", bound=Callable[Concatenate[Context, ...], Any]) class Library: - filters: dict[str, Callable] - tags: dict[str, Callable] + filters: dict[str, Callable[[Any], Any] | Callable[[Any, Any], Any]] + tags: dict[str, Callable[[Parser, Token], Node]] def __init__(self) -> None: ... + # @register.tag @overload - def tag(self, name: _C) -> _C: ... + def tag(self, name: _CompileC, /) -> _CompileC: ... + # register.tag("somename", somefunc) @overload - def tag(self, name: str, compile_function: _C) -> _C: ... + def tag(self, name: str, compile_function: _CompileC) -> _CompileC: ... + # @register.tag() + # @register.tag("somename") or @register.tag(name="somename") @overload - def tag(self, name: str | None = ..., compile_function: None = ...) -> Callable[[_C], _C]: ... - def tag_function(self, func: _C) -> _C: ... + def tag(self, name: str | None = ..., compile_function: None = ...) -> Callable[[_CompileC], _CompileC]: ... + def tag_function(self, func: _CompileC) -> _CompileC: ... + # @register.filter @overload - def filter(self, name: _C, filter_func: None = ..., **flags: Any) -> _C: ... + def filter(self, name: _FilterC, /) -> _FilterC: ... + # @register.filter() + # @register.filter("somename") or @register.filter(name='somename') @overload - def filter(self, name: str | None, filter_func: _C, **flags: Any) -> _C: ... + def filter( + self, + *, + name: str | None = ..., + filter_func: None = ..., + is_safe: bool = ..., + needs_autoescape: bool = ..., + expects_localtime: bool = ..., + ) -> Callable[[_FilterC], _FilterC]: ... + # register.filter("somename", somefunc) + @overload + def filter( + self, + name: str, + filter_func: _FilterC, + *, + is_safe: bool = ..., + needs_autoescape: bool = ..., + expects_localtime: bool = ..., + ) -> _FilterC: ... + # @register.simple_tag + @overload + def simple_tag(self, func: _C, /) -> _C: ... + # @register.simple_tag(takes_context=True) @overload - def filter(self, name: str | None = ..., filter_func: None = ..., **flags: Any) -> Callable[[_C], _C]: ... + def simple_tag( + self, *, takes_context: Literal[True], name: str | None = ... + ) -> Callable[[_TakesContextC], _TakesContextC]: ... + # @register.simple_tag(takes_context=False) + # @register.simple_tag(...) @overload - def simple_tag(self, func: _C) -> _C: ... + def simple_tag( + self, *, takes_context: Literal[False] | None = ..., name: str | None = ... + ) -> Callable[[_C], _C]: ... + @overload + def inclusion_tag( + self, + filename: Template | str, + func: Callable[..., Any] | None = ..., + *, + takes_context: Literal[True], + name: str | None = ..., + ) -> Callable[[_TakesContextC], _TakesContextC]: ... @overload - def simple_tag(self, takes_context: bool | None = ..., name: str | None = ...) -> Callable[[_C], _C]: ... def inclusion_tag( self, filename: Template | str, - func: Callable | None = ..., - takes_context: bool | None = ..., + func: Callable[..., Any] | None = ..., + takes_context: Literal[False] | None = ..., name: str | None = ..., ) -> Callable[[_C], _C]: ... diff --git a/tests/typecheck/template/test_library.yml b/tests/typecheck/template/test_library.yml index b3fd4dde0..0759099c1 100644 --- a/tests/typecheck/template/test_library.yml +++ b/tests/typecheck/template/test_library.yml @@ -20,6 +20,34 @@ reveal_type(lower) # N: Revealed type is "def (value: builtins.str) -> builtins.str" +- case: register_filter_no_decorator + main: | + from django import template + register = template.Library() + + def lower(value: str) -> str: + return value.lower() + + registered = register.filter("tolower", lower) + + reveal_type(registered) # N: Revealed type is "def (value: builtins.str) -> builtins.str" + +- case: register_bad_filters + main: | + from django import template + register = template.Library() + + @register.filter + def lower() -> str: + return "" + + @register.filter(name="toomanyargs") + def toomanyargs(arg1: str, arg2: str, arg3: str) -> str: + return "" + out: | + main:4: error: Value of type variable "_FilterC" of "filter" of "Library" cannot be "Callable[[], str]" [type-var] + main:8: error: Value of type variable "_FilterC" of function cannot be "Callable[[str, str, str], str]" [type-var] + - case: register_simple_tag_no_args main: | import datetime @@ -35,15 +63,16 @@ - case: register_simple_tag_context main: | from django import template + from django.template.context import Context from typing import Dict, Any register = template.Library() @register.simple_tag(takes_context=True) - def current_time(context: Dict[str, Any], format_string: str) -> str: + def current_time(context: Context, format_string: str) -> str: timezone = context['timezone'] return "test" - reveal_type(current_time) # N: Revealed type is "def (context: builtins.dict[builtins.str, Any], format_string: builtins.str) -> builtins.str" + reveal_type(current_time) # N: Revealed type is "def (context: django.template.context.Context, format_string: builtins.str) -> builtins.str" - case: register_simple_tag_named main: | @@ -94,3 +123,27 @@ return ', '.join(results) reveal_type(format_results) # N: Revealed type is "def (results: builtins.list[builtins.str]) -> builtins.str" + +- case: register_inclusion_tag_takes_context + main: | + from django import template + from django.template.context import Context + + from typing import List + register = template.Library() + + @register.inclusion_tag('results.html', takes_context=True) + def format_results(context: Context, results: List[str]) -> str: + return ', '.join(results) + + reveal_type(format_results) # N: Revealed type is "def (context: django.template.context.Context, results: builtins.list[builtins.str]) -> builtins.str" + +- case: stringfilter + main: | + from django.template.defaultfilters import stringfilter + + @stringfilter + def lower(value: str) -> str: + return value.lower() + + reveal_type(lower) # N: Revealed type is "def (builtins.object) -> builtins.str"