From 6000e80acf931b05b4911c3fa7f3ba33e2b80e64 Mon Sep 17 00:00:00 2001 From: David Lord Date: Thu, 14 Dec 2023 08:58:13 -0800 Subject: [PATCH] address mypy strict findings --- CHANGES.rst | 1 + pyproject.toml | 16 +--- requirements/typing.in | 2 + requirements/typing.txt | 10 ++- src/flask/app.py | 45 ++++++----- src/flask/blueprints.py | 2 +- src/flask/cli.py | 137 +++++++++++++++++++++------------ src/flask/config.py | 45 ++++++++--- src/flask/ctx.py | 29 ++++--- src/flask/debughelpers.py | 34 ++++++-- src/flask/helpers.py | 21 +++-- src/flask/json/__init__.py | 2 +- src/flask/json/provider.py | 5 +- src/flask/json/tag.py | 10 +-- src/flask/logging.py | 5 +- src/flask/sansio/app.py | 31 ++++---- src/flask/sansio/blueprints.py | 34 ++++---- src/flask/sansio/scaffold.py | 70 +++++++++-------- src/flask/sessions.py | 25 +++--- src/flask/templating.py | 18 ++--- src/flask/testing.py | 13 ++-- src/flask/typing.py | 6 +- src/flask/views.py | 10 ++- src/flask/wrappers.py | 9 ++- 24 files changed, 346 insertions(+), 234 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 27c04f7f3e..6e4b545dbb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -8,6 +8,7 @@ Unreleased - Session data is untagged without relying on the built-in ``json.loads`` ``object_hook``. This allows other JSON providers that don't implement that. :issue:`5381` +- Address more type findings when using mypy strict mode. :pr:`5383` Version 3.0.0 diff --git a/pyproject.toml b/pyproject.toml index 9031e0596a..537aae46e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,21 +82,7 @@ python_version = "3.8" files = ["src/flask", "tests/typing"] show_error_codes = true pretty = true -#strict = true -allow_redefinition = true -disallow_subclassing_any = true -#disallow_untyped_calls = true -#disallow_untyped_defs = true -#disallow_incomplete_defs = true -no_implicit_optional = true -local_partial_types = true -#no_implicit_reexport = true -strict_equality = true -warn_redundant_casts = true -warn_unused_configs = true -warn_unused_ignores = true -#warn_return_any = true -#warn_unreachable = true +strict = true [[tool.mypy.overrides]] module = [ diff --git a/requirements/typing.in b/requirements/typing.in index d231307988..211e0bd735 100644 --- a/requirements/typing.in +++ b/requirements/typing.in @@ -1,4 +1,6 @@ mypy types-contextvars types-dataclasses +asgiref cryptography +python-dotenv diff --git a/requirements/typing.txt b/requirements/typing.txt index 990ff6bb79..adbef1ab16 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -4,19 +4,23 @@ # # pip-compile typing.in # +asgiref==3.7.2 + # via -r typing.in cffi==1.16.0 # via cryptography -cryptography==41.0.5 +cryptography==41.0.7 # via -r typing.in -mypy==1.6.1 +mypy==1.8.0 # via -r typing.in mypy-extensions==1.0.0 # via mypy pycparser==2.21 # via cffi +python-dotenv==1.0.0 + # via -r typing.in types-contextvars==2.4.7.3 # via -r typing.in types-dataclasses==0.6.6 # via -r typing.in -typing-extensions==4.8.0 +typing-extensions==4.9.0 # via mypy diff --git a/src/flask/app.py b/src/flask/app.py index 7a1cf4d48b..12ac50d49b 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -1,10 +1,10 @@ from __future__ import annotations +import collections.abc as cabc import os import sys import typing as t import weakref -from collections.abc import Iterator as _abc_Iterator from datetime import timedelta from inspect import iscoroutinefunction from itertools import chain @@ -54,6 +54,9 @@ from .wrappers import Response if t.TYPE_CHECKING: # pragma: no cover + from _typeshed.wsgi import StartResponse + from _typeshed.wsgi import WSGIEnvironment + from .testing import FlaskClient from .testing import FlaskCliRunner @@ -200,11 +203,11 @@ class Flask(App): #: The class that is used for request objects. See :class:`~flask.Request` #: for more information. - request_class = Request + request_class: type[Request] = Request #: The class that is used for response objects. See #: :class:`~flask.Response` for more information. - response_class = Response + response_class: type[Response] = Response #: the session interface to use. By default an instance of #: :class:`~flask.sessions.SecureCookieSessionInterface` is used here. @@ -216,11 +219,11 @@ def __init__( self, import_name: str, static_url_path: str | None = None, - static_folder: str | os.PathLike | None = "static", + static_folder: str | os.PathLike[str] | None = "static", static_host: str | None = None, host_matching: bool = False, subdomain_matching: bool = False, - template_folder: str | os.PathLike | None = "templates", + template_folder: str | os.PathLike[str] | None = "templates", instance_path: str | None = None, instance_relative_config: bool = False, root_path: str | None = None, @@ -282,7 +285,7 @@ def get_send_file_max_age(self, filename: str | None) -> int | None: if isinstance(value, timedelta): return int(value.total_seconds()) - return value + return value # type: ignore[no-any-return] def send_static_file(self, filename: str) -> Response: """The view function used to serve files from @@ -447,13 +450,13 @@ def raise_routing_exception(self, request: Request) -> t.NoReturn: or request.routing_exception.code in {307, 308} or request.method in {"GET", "HEAD", "OPTIONS"} ): - raise request.routing_exception # type: ignore + raise request.routing_exception # type: ignore[misc] from .debughelpers import FormDataRoutingRedirect raise FormDataRoutingRedirect(request) - def update_template_context(self, context: dict) -> None: + def update_template_context(self, context: dict[str, t.Any]) -> None: """Update the template context with some commonly used variables. This injects request, session, config and g into the template context as well as everything template context processors want @@ -481,7 +484,7 @@ def update_template_context(self, context: dict) -> None: context.update(orig_ctx) - def make_shell_context(self) -> dict: + def make_shell_context(self) -> dict[str, t.Any]: """Returns the shell context for an interactive shell for this application. This runs all the registered shell context processors. @@ -724,7 +727,7 @@ def handle_http_exception( handler = self._find_error_handler(e, request.blueprints) if handler is None: return e - return self.ensure_sync(handler)(e) + return self.ensure_sync(handler)(e) # type: ignore[no-any-return] def handle_user_exception( self, e: Exception @@ -756,7 +759,7 @@ def handle_user_exception( if handler is None: raise - return self.ensure_sync(handler)(e) + return self.ensure_sync(handler)(e) # type: ignore[no-any-return] def handle_exception(self, e: Exception) -> Response: """Handle an exception that did not have an error handler @@ -849,7 +852,7 @@ def dispatch_request(self) -> ft.ResponseReturnValue: return self.make_default_options_response() # otherwise dispatch to the handler for that endpoint view_args: dict[str, t.Any] = req.view_args # type: ignore[assignment] - return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) + return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) # type: ignore[no-any-return] def full_dispatch_request(self) -> Response: """Dispatches the request and on top of that performs request @@ -913,7 +916,7 @@ def make_default_options_response(self) -> Response: rv.allow.update(methods) return rv - def ensure_sync(self, func: t.Callable) -> t.Callable: + def ensure_sync(self, func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: """Ensure that the function is synchronous for WSGI workers. Plain ``def`` functions are returned as-is. ``async def`` functions are wrapped to run and wait for the response. @@ -928,7 +931,7 @@ def ensure_sync(self, func: t.Callable) -> t.Callable: return func def async_to_sync( - self, func: t.Callable[..., t.Coroutine] + self, func: t.Callable[..., t.Coroutine[t.Any, t.Any, t.Any]] ) -> t.Callable[..., t.Any]: """Return a sync function that will run the coroutine function. @@ -1166,7 +1169,7 @@ def make_response(self, rv: ft.ResponseReturnValue) -> Response: # make sure the body is an instance of the response class if not isinstance(rv, self.response_class): - if isinstance(rv, (str, bytes, bytearray)) or isinstance(rv, _abc_Iterator): + if isinstance(rv, (str, bytes, bytearray)) or isinstance(rv, cabc.Iterator): # let the response class set the status and headers instead of # waiting to do it manually, so that the class can handle any # special logic @@ -1240,7 +1243,7 @@ def preprocess_request(self) -> ft.ResponseReturnValue | None: rv = self.ensure_sync(before_func)() if rv is not None: - return rv + return rv # type: ignore[no-any-return] return None @@ -1353,7 +1356,7 @@ def app_context(self) -> AppContext: """ return AppContext(self) - def request_context(self, environ: dict) -> RequestContext: + def request_context(self, environ: WSGIEnvironment) -> RequestContext: """Create a :class:`~flask.ctx.RequestContext` representing a WSGI environment. Use a ``with`` block to push the context, which will make :data:`request` point at this request. @@ -1425,7 +1428,9 @@ def test_request_context(self, *args: t.Any, **kwargs: t.Any) -> RequestContext: finally: builder.close() - def wsgi_app(self, environ: dict, start_response: t.Callable) -> t.Any: + def wsgi_app( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> cabc.Iterable[bytes]: """The actual WSGI application. This is not implemented in :meth:`__call__` so that middlewares can be applied without losing a reference to the app object. Instead of doing this:: @@ -1473,7 +1478,9 @@ def wsgi_app(self, environ: dict, start_response: t.Callable) -> t.Any: ctx.pop(error) - def __call__(self, environ: dict, start_response: t.Callable) -> t.Any: + def __call__( + self, environ: WSGIEnvironment, start_response: StartResponse + ) -> cabc.Iterable[bytes]: """The WSGI server calls the Flask application object as the WSGI application. This calls :meth:`wsgi_app`, which can be wrapped to apply middleware. diff --git a/src/flask/blueprints.py b/src/flask/blueprints.py index 3a37a2c422..52859b855a 100644 --- a/src/flask/blueprints.py +++ b/src/flask/blueprints.py @@ -39,7 +39,7 @@ def get_send_file_max_age(self, filename: str | None) -> int | None: if isinstance(value, timedelta): return int(value.total_seconds()) - return value + return value # type: ignore[no-any-return] def send_static_file(self, filename: str) -> Response: """The view function used to serve files from diff --git a/src/flask/cli.py b/src/flask/cli.py index 751dfd1f60..ffdcb182a7 100644 --- a/src/flask/cli.py +++ b/src/flask/cli.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import collections.abc as cabc import importlib.metadata import inspect import os @@ -11,6 +12,7 @@ import typing as t from functools import update_wrapper from operator import itemgetter +from types import ModuleType import click from click.core import ParameterSource @@ -23,6 +25,12 @@ from .helpers import get_load_dotenv if t.TYPE_CHECKING: + import ssl + + from _typeshed.wsgi import StartResponse + from _typeshed.wsgi import WSGIApplication + from _typeshed.wsgi import WSGIEnvironment + from .app import Flask @@ -30,7 +38,7 @@ class NoAppException(click.UsageError): """Raised if an application cannot be found or loaded.""" -def find_best_app(module): +def find_best_app(module: ModuleType) -> Flask: """Given a module instance this tries to find the best possible application in the module or raises an exception. """ @@ -83,7 +91,7 @@ def find_best_app(module): ) -def _called_with_wrong_args(f): +def _called_with_wrong_args(f: t.Callable[..., Flask]) -> bool: """Check whether calling a function raised a ``TypeError`` because the call failed or because something in the factory raised the error. @@ -109,7 +117,7 @@ def _called_with_wrong_args(f): del tb -def find_app_by_string(module, app_name): +def find_app_by_string(module: ModuleType, app_name: str) -> Flask: """Check if the given string is a variable name or a function. Call a function to get the app instance, or return the variable directly. """ @@ -140,7 +148,11 @@ def find_app_by_string(module, app_name): # Parse the positional and keyword arguments as literals. try: args = [ast.literal_eval(arg) for arg in expr.args] - kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in expr.keywords} + kwargs = { + kw.arg: ast.literal_eval(kw.value) + for kw in expr.keywords + if kw.arg is not None + } except ValueError: # literal_eval gives cryptic error messages, show a generic # message with the full expression instead. @@ -185,7 +197,7 @@ def find_app_by_string(module, app_name): ) -def prepare_import(path): +def prepare_import(path: str) -> str: """Given a filename this will try to calculate the python path, add it to the search path and return the actual module name that is expected. """ @@ -214,13 +226,29 @@ def prepare_import(path): return ".".join(module_name[::-1]) -def locate_app(module_name, app_name, raise_if_not_found=True): +@t.overload +def locate_app( + module_name: str, app_name: str | None, raise_if_not_found: t.Literal[True] = True +) -> Flask: + ... + + +@t.overload +def locate_app( + module_name: str, app_name: str | None, raise_if_not_found: t.Literal[False] = ... +) -> Flask | None: + ... + + +def locate_app( + module_name: str, app_name: str | None, raise_if_not_found: bool = True +) -> Flask | None: try: __import__(module_name) except ImportError: # Reraise the ImportError if it occurred within the imported module. # Determine this by checking whether the trace has a depth > 1. - if sys.exc_info()[2].tb_next: + if sys.exc_info()[2].tb_next: # type: ignore[union-attr] raise NoAppException( f"While importing {module_name!r}, an ImportError was" f" raised:\n\n{traceback.format_exc()}" @@ -228,7 +256,7 @@ def locate_app(module_name, app_name, raise_if_not_found=True): elif raise_if_not_found: raise NoAppException(f"Could not import {module_name!r}.") from None else: - return + return None module = sys.modules[module_name] @@ -238,7 +266,7 @@ def locate_app(module_name, app_name, raise_if_not_found=True): return find_app_by_string(module, app_name) -def get_version(ctx, param, value): +def get_version(ctx: click.Context, param: click.Parameter, value: t.Any) -> None: if not value or ctx.resilient_parsing: return @@ -299,7 +327,7 @@ def load_app(self) -> Flask: return self._loaded_app if self.create_app is not None: - app = self.create_app() + app: Flask | None = self.create_app() else: if self.app_import_path: path, name = ( @@ -312,10 +340,10 @@ def load_app(self) -> Flask: import_name = prepare_import(path) app = locate_app(import_name, None, raise_if_not_found=False) - if app: + if app is not None: break - if not app: + if app is None: raise NoAppException( "Could not locate a Flask application. Use the" " 'flask --app' option, 'FLASK_APP' environment" @@ -334,8 +362,10 @@ def load_app(self) -> Flask: pass_script_info = click.make_pass_decorator(ScriptInfo, ensure=True) +F = t.TypeVar("F", bound=t.Callable[..., t.Any]) -def with_appcontext(f): + +def with_appcontext(f: F) -> F: """Wraps a callback so that it's guaranteed to be executed with the script's application context. @@ -350,14 +380,14 @@ def with_appcontext(f): """ @click.pass_context - def decorator(__ctx, *args, **kwargs): + def decorator(ctx: click.Context, /, *args: t.Any, **kwargs: t.Any) -> t.Any: if not current_app: - app = __ctx.ensure_object(ScriptInfo).load_app() - __ctx.with_resource(app.app_context()) + app = ctx.ensure_object(ScriptInfo).load_app() + ctx.with_resource(app.app_context()) - return __ctx.invoke(f, *args, **kwargs) + return ctx.invoke(f, *args, **kwargs) - return update_wrapper(decorator, f) + return update_wrapper(decorator, f) # type: ignore[return-value] class AppGroup(click.Group): @@ -368,27 +398,31 @@ class AppGroup(click.Group): Not to be confused with :class:`FlaskGroup`. """ - def command(self, *args, **kwargs): + def command( # type: ignore[override] + self, *args: t.Any, **kwargs: t.Any + ) -> t.Callable[[t.Callable[..., t.Any]], click.Command]: """This works exactly like the method of the same name on a regular :class:`click.Group` but it wraps callbacks in :func:`with_appcontext` unless it's disabled by passing ``with_appcontext=False``. """ wrap_for_ctx = kwargs.pop("with_appcontext", True) - def decorator(f): + def decorator(f: t.Callable[..., t.Any]) -> click.Command: if wrap_for_ctx: f = with_appcontext(f) - return click.Group.command(self, *args, **kwargs)(f) + return super(AppGroup, self).command(*args, **kwargs)(f) # type: ignore[no-any-return] return decorator - def group(self, *args, **kwargs): + def group( # type: ignore[override] + self, *args: t.Any, **kwargs: t.Any + ) -> t.Callable[[t.Callable[..., t.Any]], click.Group]: """This works exactly like the method of the same name on a regular :class:`click.Group` but it defaults the group class to :class:`AppGroup`. """ kwargs.setdefault("cls", AppGroup) - return click.Group.group(self, *args, **kwargs) + return super().group(*args, **kwargs) # type: ignore[no-any-return] def _set_app(ctx: click.Context, param: click.Option, value: str | None) -> str | None: @@ -545,7 +579,7 @@ def __init__( self._loaded_plugin_commands = False - def _load_plugin_commands(self): + def _load_plugin_commands(self) -> None: if self._loaded_plugin_commands: return @@ -562,7 +596,7 @@ def _load_plugin_commands(self): self._loaded_plugin_commands = True - def get_command(self, ctx, name): + def get_command(self, ctx: click.Context, name: str) -> click.Command | None: self._load_plugin_commands() # Look up built-in and plugin commands, which should be # available even if the app fails to load. @@ -584,12 +618,12 @@ def get_command(self, ctx, name): # Push an app context for the loaded app unless it is already # active somehow. This makes the context available to parameter # and command callbacks without needing @with_appcontext. - if not current_app or current_app._get_current_object() is not app: + if not current_app or current_app._get_current_object() is not app: # type: ignore[attr-defined] ctx.with_resource(app.app_context()) return app.cli.get_command(ctx, name) - def list_commands(self, ctx): + def list_commands(self, ctx: click.Context) -> list[str]: self._load_plugin_commands() # Start with the built-in and plugin commands. rv = set(super().list_commands(ctx)) @@ -645,14 +679,14 @@ def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]: return super().parse_args(ctx, args) -def _path_is_ancestor(path, other): +def _path_is_ancestor(path: str, other: str) -> bool: """Take ``other`` and remove the length of ``path`` from it. Then join it to ``path``. If it is the original value, ``path`` is an ancestor of ``other``.""" return os.path.join(path, other[len(path) :].lstrip(os.sep)) == other -def load_dotenv(path: str | os.PathLike | None = None) -> bool: +def load_dotenv(path: str | os.PathLike[str] | None = None) -> bool: """Load "dotenv" files in order of precedence to set environment variables. If an env var is already set it is not overwritten, so earlier files in the @@ -713,7 +747,7 @@ def load_dotenv(path: str | os.PathLike | None = None) -> bool: return loaded # True if at least one file was located and loaded. -def show_server_banner(debug, app_import_path): +def show_server_banner(debug: bool, app_import_path: str | None) -> None: """Show extra startup messages the first time the server is run, ignoring the reloader. """ @@ -735,10 +769,12 @@ class CertParamType(click.ParamType): name = "path" - def __init__(self): + def __init__(self) -> None: self.path_type = click.Path(exists=True, dir_okay=False, resolve_path=True) - def convert(self, value, param, ctx): + def convert( + self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None + ) -> t.Any: try: import ssl except ImportError: @@ -773,7 +809,7 @@ def convert(self, value, param, ctx): raise -def _validate_key(ctx, param, value): +def _validate_key(ctx: click.Context, param: click.Parameter, value: t.Any) -> t.Any: """The ``--key`` option must be specified when ``--cert`` is a file. Modifies the ``cert`` param to be a ``(cert, key)`` pair if needed. """ @@ -818,10 +854,11 @@ class SeparatedPathType(click.Path): validated as a :class:`click.Path` type. """ - def convert(self, value, param, ctx): + def convert( + self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None + ) -> t.Any: items = self.split_envvar_value(value) - super_convert = super().convert - return [super_convert(item, param, ctx) for item in items] + return [super().convert(item, param, ctx) for item in items] @click.command("run", short_help="Run a development server.") @@ -878,16 +915,16 @@ def convert(self, value, param, ctx): ) @pass_script_info def run_command( - info, - host, - port, - reload, - debugger, - with_threads, - cert, - extra_files, - exclude_patterns, -): + info: ScriptInfo, + host: str, + port: int, + reload: bool, + debugger: bool, + with_threads: bool, + cert: ssl.SSLContext | tuple[str, str | None] | t.Literal["adhoc"] | None, + extra_files: list[str] | None, + exclude_patterns: list[str] | None, +) -> None: """Run a local development server. This server is for development purposes only. It does not provide @@ -897,7 +934,7 @@ def run_command( option. """ try: - app = info.load_app() + app: WSGIApplication = info.load_app() except Exception as e: if is_running_from_reloader(): # When reloading, print out the error immediately, but raise @@ -905,7 +942,9 @@ def run_command( traceback.print_exc() err = e - def app(environ, start_response): + def app( + environ: WSGIEnvironment, start_response: StartResponse + ) -> cabc.Iterable[bytes]: raise err from None else: @@ -956,7 +995,7 @@ def shell_command() -> None: f"App: {current_app.import_name}\n" f"Instance: {current_app.instance_path}" ) - ctx: dict = {} + ctx: dict[str, t.Any] = {} # Support the regular Python interpreter startup script if someone # is using it. diff --git a/src/flask/config.py b/src/flask/config.py index 5f921b4dff..f2f4147806 100644 --- a/src/flask/config.py +++ b/src/flask/config.py @@ -8,27 +8,48 @@ from werkzeug.utils import import_string +if t.TYPE_CHECKING: + import typing_extensions as te -class ConfigAttribute: + from .sansio.app import App + + +T = t.TypeVar("T") + + +class ConfigAttribute(t.Generic[T]): """Makes an attribute forward to the config""" - def __init__(self, name: str, get_converter: t.Callable | None = None) -> None: + def __init__( + self, name: str, get_converter: t.Callable[[t.Any], T] | None = None + ) -> None: self.__name__ = name self.get_converter = get_converter - def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: + @t.overload + def __get__(self, obj: None, owner: None) -> te.Self: + ... + + @t.overload + def __get__(self, obj: App, owner: type[App]) -> T: + ... + + def __get__(self, obj: App | None, owner: type[App] | None = None) -> T | te.Self: if obj is None: return self + rv = obj.config[self.__name__] + if self.get_converter is not None: rv = self.get_converter(rv) - return rv - def __set__(self, obj: t.Any, value: t.Any) -> None: + return rv # type: ignore[no-any-return] + + def __set__(self, obj: App, value: t.Any) -> None: obj.config[self.__name__] = value -class Config(dict): +class Config(dict): # type: ignore[type-arg] """Works exactly like a dict but provides ways to fill it from files or special dictionaries. There are two common patterns to populate the config. @@ -73,7 +94,9 @@ class Config(dict): """ def __init__( - self, root_path: str | os.PathLike, defaults: dict | None = None + self, + root_path: str | os.PathLike[str], + defaults: dict[str, t.Any] | None = None, ) -> None: super().__init__(defaults or {}) self.root_path = root_path @@ -166,7 +189,9 @@ def from_prefixed_env( return True - def from_pyfile(self, filename: str | os.PathLike, silent: bool = False) -> bool: + def from_pyfile( + self, filename: str | os.PathLike[str], silent: bool = False + ) -> bool: """Updates the values in the config from a Python file. This function behaves as if the file was imported as module with the :meth:`from_object` function. @@ -235,8 +260,8 @@ class and has ``@property`` attributes, it needs to be def from_file( self, - filename: str | os.PathLike, - load: t.Callable[[t.IO[t.Any]], t.Mapping], + filename: str | os.PathLike[str], + load: t.Callable[[t.IO[t.Any]], t.Mapping[str, t.Any]], silent: bool = False, text: bool = True, ) -> bool: diff --git a/src/flask/ctx.py b/src/flask/ctx.py index b37e4e04a6..4640052142 100644 --- a/src/flask/ctx.py +++ b/src/flask/ctx.py @@ -15,6 +15,8 @@ from .signals import appcontext_pushed if t.TYPE_CHECKING: # pragma: no cover + from _typeshed.wsgi import WSGIEnvironment + from .app import Flask from .sessions import SessionMixin from .wrappers import Request @@ -112,7 +114,9 @@ def __repr__(self) -> str: return object.__repr__(self) -def after_this_request(f: ft.AfterRequestCallable) -> ft.AfterRequestCallable: +def after_this_request( + f: ft.AfterRequestCallable[t.Any] +) -> ft.AfterRequestCallable[t.Any]: """Executes a function after this request. This is useful to modify response objects. The function is passed the response object and has to return the same or a new one. @@ -145,7 +149,10 @@ def add_header(response): return f -def copy_current_request_context(f: t.Callable) -> t.Callable: +F = t.TypeVar("F", bound=t.Callable[..., t.Any]) + + +def copy_current_request_context(f: F) -> F: """A helper function that decorates a function to retain the current request context. This is useful when working with greenlets. The moment the function is decorated a copy of the request context is created and @@ -179,11 +186,11 @@ def do_some_work(): ctx = ctx.copy() - def wrapper(*args, **kwargs): - with ctx: - return ctx.app.ensure_sync(f)(*args, **kwargs) + def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: + with ctx: # type: ignore[union-attr] + return ctx.app.ensure_sync(f)(*args, **kwargs) # type: ignore[union-attr] - return update_wrapper(wrapper, f) + return update_wrapper(wrapper, f) # type: ignore[return-value] def has_request_context() -> bool: @@ -239,7 +246,7 @@ def __init__(self, app: Flask) -> None: self.app = app self.url_adapter = app.create_url_adapter(None) self.g: _AppCtxGlobals = app.app_ctx_globals_class() - self._cv_tokens: list[contextvars.Token] = [] + self._cv_tokens: list[contextvars.Token[AppContext]] = [] def push(self) -> None: """Binds the app context to the current context.""" @@ -302,7 +309,7 @@ class RequestContext: def __init__( self, app: Flask, - environ: dict, + environ: WSGIEnvironment, request: Request | None = None, session: SessionMixin | None = None, ) -> None: @@ -321,9 +328,11 @@ def __init__( # Functions that should be executed after the request on the response # object. These will be called before the regular "after_request" # functions. - self._after_request_functions: list[ft.AfterRequestCallable] = [] + self._after_request_functions: list[ft.AfterRequestCallable[t.Any]] = [] - self._cv_tokens: list[tuple[contextvars.Token, AppContext | None]] = [] + self._cv_tokens: list[ + tuple[contextvars.Token[RequestContext], AppContext | None] + ] = [] def copy(self) -> RequestContext: """Creates a copy of this request context with the same request object. diff --git a/src/flask/debughelpers.py b/src/flask/debughelpers.py index e8360043d1..2c8c4c4836 100644 --- a/src/flask/debughelpers.py +++ b/src/flask/debughelpers.py @@ -2,10 +2,17 @@ import typing as t +from jinja2.loaders import BaseLoader +from werkzeug.routing import RequestRedirect + from .blueprints import Blueprint from .globals import request_ctx from .sansio.app import App +if t.TYPE_CHECKING: + from .sansio.scaffold import Scaffold + from .wrappers import Request + class UnexpectedUnicodeError(AssertionError, UnicodeError): """Raised in places where we want some better error reporting for @@ -18,7 +25,7 @@ class DebugFilesKeyError(KeyError, AssertionError): provide a better error message than just a generic KeyError/BadRequest. """ - def __init__(self, request, key): + def __init__(self, request: Request, key: str) -> None: form_matches = request.form.getlist(key) buf = [ f"You tried to access the file {key!r} in the request.files" @@ -36,7 +43,7 @@ def __init__(self, request, key): ) self.msg = "".join(buf) - def __str__(self): + def __str__(self) -> str: return self.msg @@ -47,8 +54,9 @@ class FormDataRoutingRedirect(AssertionError): 307 or 308. """ - def __init__(self, request): + def __init__(self, request: Request) -> None: exc = request.routing_exception + assert isinstance(exc, RequestRedirect) buf = [ f"A request was sent to '{request.url}', but routing issued" f" a redirect to the canonical URL '{exc.new_url}'." @@ -70,7 +78,7 @@ def __init__(self, request): super().__init__("".join(buf)) -def attach_enctype_error_multidict(request): +def attach_enctype_error_multidict(request: Request) -> None: """Patch ``request.files.__getitem__`` to raise a descriptive error about ``enctype=multipart/form-data``. @@ -79,8 +87,8 @@ def attach_enctype_error_multidict(request): """ oldcls = request.files.__class__ - class newcls(oldcls): - def __getitem__(self, key): + class newcls(oldcls): # type: ignore[valid-type, misc] + def __getitem__(self, key: str) -> t.Any: try: return super().__getitem__(key) except KeyError as e: @@ -96,7 +104,7 @@ def __getitem__(self, key): request.files.__class__ = newcls -def _dump_loader_info(loader) -> t.Generator: +def _dump_loader_info(loader: BaseLoader) -> t.Iterator[str]: yield f"class: {type(loader).__module__}.{type(loader).__name__}" for key, value in sorted(loader.__dict__.items()): if key.startswith("_"): @@ -113,7 +121,17 @@ def _dump_loader_info(loader) -> t.Generator: yield f"{key}: {value!r}" -def explain_template_loading_attempts(app: App, template, attempts) -> None: +def explain_template_loading_attempts( + app: App, + template: str, + attempts: list[ + tuple[ + BaseLoader, + Scaffold, + tuple[str, str | None, t.Callable[[], bool] | None] | None, + ] + ], +) -> None: """This should help developers understand what failed""" info = [f"Locating template {template!r}:"] total_found = 0 diff --git a/src/flask/helpers.py b/src/flask/helpers.py index 8460891215..ae83a25b77 100644 --- a/src/flask/helpers.py +++ b/src/flask/helpers.py @@ -11,6 +11,7 @@ import werkzeug.utils from werkzeug.exceptions import abort as _wz_abort from werkzeug.utils import redirect as _wz_redirect +from werkzeug.wrappers import Response as BaseResponse from .globals import _cv_request from .globals import current_app @@ -20,8 +21,6 @@ from .signals import message_flashed if t.TYPE_CHECKING: # pragma: no cover - from werkzeug.wrappers import Response as BaseResponse - from .wrappers import Response @@ -85,16 +84,16 @@ def generate(): .. versionadded:: 0.9 """ try: - gen = iter(generator_or_function) # type: ignore + gen = iter(generator_or_function) # type: ignore[arg-type] except TypeError: def decorator(*args: t.Any, **kwargs: t.Any) -> t.Any: - gen = generator_or_function(*args, **kwargs) # type: ignore + gen = generator_or_function(*args, **kwargs) # type: ignore[operator] return stream_with_context(gen) - return update_wrapper(decorator, generator_or_function) # type: ignore + return update_wrapper(decorator, generator_or_function) # type: ignore[arg-type] - def generator() -> t.Generator: + def generator() -> t.Iterator[t.AnyStr | None]: ctx = _cv_request.get(None) if ctx is None: raise RuntimeError( @@ -122,7 +121,7 @@ def generator() -> t.Generator: # real generator is executed. wrapped_g = generator() next(wrapped_g) - return wrapped_g + return wrapped_g # type: ignore[return-value] def make_response(*args: t.Any) -> Response: @@ -171,7 +170,7 @@ def index(): return current_app.response_class() if len(args) == 1: args = args[0] - return current_app.make_response(args) # type: ignore + return current_app.make_response(args) def url_for( @@ -513,8 +512,8 @@ def send_file( def send_from_directory( - directory: os.PathLike | str, - path: os.PathLike | str, + directory: os.PathLike[str] | str, + path: os.PathLike[str] | str, **kwargs: t.Any, ) -> Response: """Send a file from within a directory using :func:`send_file`. @@ -609,7 +608,7 @@ def get_root_path(import_name: str) -> str: ) # filepath is import_name.py for a module, or __init__.py for a package. - return os.path.dirname(os.path.abspath(filepath)) + return os.path.dirname(os.path.abspath(filepath)) # type: ignore[no-any-return] @lru_cache(maxsize=None) diff --git a/src/flask/json/__init__.py b/src/flask/json/__init__.py index f15296fed0..c0941d049e 100644 --- a/src/flask/json/__init__.py +++ b/src/flask/json/__init__.py @@ -167,4 +167,4 @@ def jsonify(*args: t.Any, **kwargs: t.Any) -> Response: .. versionadded:: 0.2 """ - return current_app.json.response(*args, **kwargs) + return current_app.json.response(*args, **kwargs) # type: ignore[return-value] diff --git a/src/flask/json/provider.py b/src/flask/json/provider.py index 46d64d6a34..f9b2e8ff55 100644 --- a/src/flask/json/provider.py +++ b/src/flask/json/provider.py @@ -11,8 +11,9 @@ from werkzeug.http import http_date if t.TYPE_CHECKING: # pragma: no cover + from werkzeug.sansio.response import Response + from ..sansio.app import App - from ..wrappers import Response class JSONProvider: @@ -35,7 +36,7 @@ class and implement at least :meth:`dumps` and :meth:`loads`. All """ def __init__(self, app: App) -> None: - self._app = weakref.proxy(app) + self._app: App = weakref.proxy(app) def dumps(self, obj: t.Any, **kwargs: t.Any) -> str: """Serialize data as JSON. diff --git a/src/flask/json/tag.py b/src/flask/json/tag.py index 069739f264..2bb986bc8f 100644 --- a/src/flask/json/tag.py +++ b/src/flask/json/tag.py @@ -61,9 +61,9 @@ class JSONTag: __slots__ = ("serializer",) - #: The tag to mark the serialized object with. If ``None``, this tag is + #: The tag to mark the serialized object with. If empty, this tag is #: only used as an intermediate step during tagging. - key: str | None = None + key: str = "" def __init__(self, serializer: TaggedJSONSerializer) -> None: """Create a tagger for the given serializer.""" @@ -83,7 +83,7 @@ def to_python(self, value: t.Any) -> t.Any: will already be removed.""" raise NotImplementedError - def tag(self, value: t.Any) -> t.Any: + def tag(self, value: t.Any) -> dict[str, t.Any]: """Convert the value to a valid JSON type and add the tag structure around it.""" return {self.key: self.to_json(value)} @@ -274,7 +274,7 @@ def register( tag = tag_class(self) key = tag.key - if key is not None: + if key: if not force and key in self.tags: raise KeyError(f"Tag '{key}' is already registered.") @@ -285,7 +285,7 @@ def register( else: self.order.insert(index, tag) - def tag(self, value: t.Any) -> dict[str, t.Any]: + def tag(self, value: t.Any) -> t.Any: """Convert a value to a tagged representation if necessary.""" for tag in self.order: if tag.check(value): diff --git a/src/flask/logging.py b/src/flask/logging.py index b452f71fd5..0cb8f43746 100644 --- a/src/flask/logging.py +++ b/src/flask/logging.py @@ -22,7 +22,10 @@ def wsgi_errors_stream() -> t.TextIO: can't import this directly, you can refer to it as ``ext://flask.logging.wsgi_errors_stream``. """ - return request.environ["wsgi.errors"] if request else sys.stderr + if request: + return request.environ["wsgi.errors"] # type: ignore[no-any-return] + + return sys.stderr def has_level_handler(logger: logging.Logger) -> bool: diff --git a/src/flask/sansio/app.py b/src/flask/sansio/app.py index fdc4714e4c..21a79ba460 100644 --- a/src/flask/sansio/app.py +++ b/src/flask/sansio/app.py @@ -205,7 +205,7 @@ class App(Scaffold): #: #: This attribute can also be configured from the config with the #: ``TESTING`` configuration key. Defaults to ``False``. - testing = ConfigAttribute("TESTING") + testing = ConfigAttribute[bool]("TESTING") #: If a secret key is set, cryptographic components can use this to #: sign cookies and other things. Set this to a complex random value @@ -213,7 +213,7 @@ class App(Scaffold): #: #: This attribute can also be configured from the config with the #: :data:`SECRET_KEY` configuration key. Defaults to ``None``. - secret_key = ConfigAttribute("SECRET_KEY") + secret_key = ConfigAttribute[t.Union[str, bytes, None]]("SECRET_KEY") #: A :class:`~datetime.timedelta` which is used to set the expiration #: date of a permanent session. The default is 31 days which makes a @@ -222,8 +222,9 @@ class App(Scaffold): #: This attribute can also be configured from the config with the #: ``PERMANENT_SESSION_LIFETIME`` configuration key. Defaults to #: ``timedelta(days=31)`` - permanent_session_lifetime = ConfigAttribute( - "PERMANENT_SESSION_LIFETIME", get_converter=_make_timedelta + permanent_session_lifetime = ConfigAttribute[timedelta]( + "PERMANENT_SESSION_LIFETIME", + get_converter=_make_timedelta, # type: ignore[arg-type] ) json_provider_class: type[JSONProvider] = DefaultJSONProvider @@ -247,7 +248,7 @@ class App(Scaffold): #: This is a ``dict`` instead of an ``ImmutableDict`` to allow #: easier configuration. #: - jinja_options: dict = {} + jinja_options: dict[str, t.Any] = {} #: The rule object to use for URL rules created. This is used by #: :meth:`add_url_rule`. Defaults to :class:`werkzeug.routing.Rule`. @@ -275,18 +276,18 @@ class App(Scaffold): #: .. versionadded:: 1.0 test_cli_runner_class: type[FlaskCliRunner] | None = None - default_config: dict + default_config: dict[str, t.Any] response_class: type[Response] def __init__( self, import_name: str, static_url_path: str | None = None, - static_folder: str | os.PathLike | None = "static", + static_folder: str | os.PathLike[str] | None = "static", static_host: str | None = None, host_matching: bool = False, subdomain_matching: bool = False, - template_folder: str | os.PathLike | None = "templates", + template_folder: str | os.PathLike[str] | None = "templates", instance_path: str | None = None, instance_relative_config: bool = False, root_path: str | None = None, @@ -384,7 +385,7 @@ def __init__( #: ``'foo'``. #: #: .. versionadded:: 0.7 - self.extensions: dict = {} + self.extensions: dict[str, t.Any] = {} #: The :class:`~werkzeug.routing.Map` for this instance. You can use #: this to change the routing converters after the class was created @@ -436,7 +437,7 @@ def name(self) -> str: # type: ignore .. versionadded:: 0.8 """ if self.import_name == "__main__": - fn = getattr(sys.modules["__main__"], "__file__", None) + fn: str | None = getattr(sys.modules["__main__"], "__file__", None) if fn is None: return "__main__" return os.path.splitext(os.path.basename(fn))[0] @@ -560,7 +561,7 @@ def debug(self) -> bool: Default: ``False`` """ - return self.config["DEBUG"] + return self.config["DEBUG"] # type: ignore[no-any-return] @debug.setter def debug(self, value: bool) -> None: @@ -650,10 +651,10 @@ def add_url_rule( # Add the required methods now. methods |= required_methods - rule = self.url_rule_class(rule, methods=methods, **options) - rule.provide_automatic_options = provide_automatic_options # type: ignore + rule_obj = self.url_rule_class(rule, methods=methods, **options) + rule_obj.provide_automatic_options = provide_automatic_options # type: ignore[attr-defined] - self.url_map.add(rule) + self.url_map.add(rule_obj) if view_func is not None: old_func = self.view_functions.get(endpoint) if old_func is not None and old_func != view_func: @@ -911,7 +912,7 @@ def redirect(self, location: str, code: int = 302) -> BaseResponse: Response=self.response_class, # type: ignore[arg-type] ) - def inject_url_defaults(self, endpoint: str, values: dict) -> None: + def inject_url_defaults(self, endpoint: str, values: dict[str, t.Any]) -> None: """Injects the URL defaults for the given endpoint directly into the values dictionary passed. This is used internally and automatically called on URL building. diff --git a/src/flask/sansio/blueprints.py b/src/flask/sansio/blueprints.py index 38c92f450b..4f912cca05 100644 --- a/src/flask/sansio/blueprints.py +++ b/src/flask/sansio/blueprints.py @@ -14,8 +14,8 @@ if t.TYPE_CHECKING: # pragma: no cover from .app import App -DeferredSetupFunction = t.Callable[["BlueprintSetupState"], t.Callable] -T_after_request = t.TypeVar("T_after_request", bound=ft.AfterRequestCallable) +DeferredSetupFunction = t.Callable[["BlueprintSetupState"], None] +T_after_request = t.TypeVar("T_after_request", bound=ft.AfterRequestCallable[t.Any]) T_before_request = t.TypeVar("T_before_request", bound=ft.BeforeRequestCallable) T_error_handler = t.TypeVar("T_error_handler", bound=ft.ErrorHandlerCallable) T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable) @@ -88,7 +88,7 @@ def add_url_rule( self, rule: str, endpoint: str | None = None, - view_func: t.Callable | None = None, + view_func: ft.RouteCallable | None = None, **options: t.Any, ) -> None: """A helper method to register a rule (and optionally a view function) @@ -175,14 +175,14 @@ def __init__( self, name: str, import_name: str, - static_folder: str | os.PathLike | None = None, + static_folder: str | os.PathLike[str] | None = None, static_url_path: str | None = None, - template_folder: str | os.PathLike | None = None, + template_folder: str | os.PathLike[str] | None = None, url_prefix: str | None = None, subdomain: str | None = None, - url_defaults: dict | None = None, + url_defaults: dict[str, t.Any] | None = None, root_path: str | None = None, - cli_group: str | None = _sentinel, # type: ignore + cli_group: str | None = _sentinel, # type: ignore[assignment] ): super().__init__( import_name=import_name, @@ -208,7 +208,7 @@ def __init__( self.url_values_defaults = url_defaults self.cli_group = cli_group - self._blueprints: list[tuple[Blueprint, dict]] = [] + self._blueprints: list[tuple[Blueprint, dict[str, t.Any]]] = [] def _check_setup_finished(self, f_name: str) -> None: if self._got_registered_once: @@ -221,7 +221,7 @@ def _check_setup_finished(self, f_name: str) -> None: ) @setupmethod - def record(self, func: t.Callable) -> None: + def record(self, func: DeferredSetupFunction) -> None: """Registers a function that is called when the blueprint is registered on the application. This function is called with the state as argument as returned by the :meth:`make_setup_state` @@ -230,7 +230,7 @@ def record(self, func: t.Callable) -> None: self.deferred_functions.append(func) @setupmethod - def record_once(self, func: t.Callable) -> None: + def record_once(self, func: DeferredSetupFunction) -> None: """Works like :meth:`record` but wraps the function in another function that will ensure the function is only called once. If the blueprint is registered a second time on the application, the @@ -244,7 +244,7 @@ def wrapper(state: BlueprintSetupState) -> None: self.record(update_wrapper(wrapper, func)) def make_setup_state( - self, app: App, options: dict, first_registration: bool = False + self, app: App, options: dict[str, t.Any], first_registration: bool = False ) -> BlueprintSetupState: """Creates an instance of :meth:`~flask.blueprints.BlueprintSetupState` object that is later passed to the register callback functions. @@ -270,7 +270,7 @@ def register_blueprint(self, blueprint: Blueprint, **options: t.Any) -> None: raise ValueError("Cannot register a blueprint on itself") self._blueprints.append((blueprint, options)) - def register(self, app: App, options: dict) -> None: + def register(self, app: App, options: dict[str, t.Any]) -> None: """Called by :meth:`Flask.register_blueprint` to register all views and callbacks registered on the blueprint with the application. Creates a :class:`.BlueprintSetupState` and calls @@ -377,7 +377,10 @@ def register(self, app: App, options: dict) -> None: blueprint.register(app, bp_options) def _merge_blueprint_funcs(self, app: App, name: str) -> None: - def extend(bp_dict, parent_dict): + def extend( + bp_dict: dict[ft.AppOrBlueprintKey, list[t.Any]], + parent_dict: dict[ft.AppOrBlueprintKey, list[t.Any]], + ) -> None: for key, values in bp_dict.items(): key = name if key is None else f"{name}.{key}" parent_dict[key].extend(values) @@ -598,7 +601,10 @@ def app_errorhandler( """ def decorator(f: T_error_handler) -> T_error_handler: - self.record_once(lambda s: s.app.errorhandler(code)(f)) + def from_blueprint(state: BlueprintSetupState) -> None: + state.app.errorhandler(code)(f) + + self.record_once(from_blueprint) return f return decorator diff --git a/src/flask/sansio/scaffold.py b/src/flask/sansio/scaffold.py index a43f6fd794..40534f53c8 100644 --- a/src/flask/sansio/scaffold.py +++ b/src/flask/sansio/scaffold.py @@ -8,6 +8,7 @@ from collections import defaultdict from functools import update_wrapper +import click from jinja2 import FileSystemLoader from werkzeug.exceptions import default_exceptions from werkzeug.exceptions import HTTPException @@ -22,7 +23,7 @@ _sentinel = object() F = t.TypeVar("F", bound=t.Callable[..., t.Any]) -T_after_request = t.TypeVar("T_after_request", bound=ft.AfterRequestCallable) +T_after_request = t.TypeVar("T_after_request", bound=ft.AfterRequestCallable[t.Any]) T_before_request = t.TypeVar("T_before_request", bound=ft.BeforeRequestCallable) T_error_handler = t.TypeVar("T_error_handler", bound=ft.ErrorHandlerCallable) T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable) @@ -39,7 +40,7 @@ def setupmethod(f: F) -> F: f_name = f.__name__ - def wrapper_func(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + def wrapper_func(self: Scaffold, *args: t.Any, **kwargs: t.Any) -> t.Any: self._check_setup_finished(f_name) return f(self, *args, **kwargs) @@ -71,9 +72,9 @@ class Scaffold: def __init__( self, import_name: str, - static_folder: str | os.PathLike | None = None, + static_folder: str | os.PathLike[str] | None = None, static_url_path: str | None = None, - template_folder: str | os.PathLike | None = None, + template_folder: str | os.PathLike[str] | None = None, root_path: str | None = None, ): #: The name of the package or module that this object belongs @@ -99,7 +100,7 @@ def __init__( #: object. The commands are available from the ``flask`` command #: once the application has been discovered and blueprints have #: been registered. - self.cli = AppGroup() + self.cli: click.Group = AppGroup() #: A dictionary mapping endpoint names to view functions. #: @@ -107,7 +108,7 @@ def __init__( #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.view_functions: dict[str, t.Callable] = {} + self.view_functions: dict[str, ft.RouteCallable] = {} #: A data structure of registered error handlers, in the format #: ``{scope: {code: {class: handler}}}``. The ``scope`` key is @@ -152,7 +153,7 @@ def __init__( #: This data structure is internal. It should not be modified #: directly and its format may change at any time. self.after_request_funcs: dict[ - ft.AppOrBlueprintKey, list[ft.AfterRequestCallable] + ft.AppOrBlueprintKey, list[ft.AfterRequestCallable[t.Any]] ] = defaultdict(list) #: A data structure of functions to call at the end of each @@ -233,7 +234,7 @@ def static_folder(self) -> str | None: return None @static_folder.setter - def static_folder(self, value: str | os.PathLike | None) -> None: + def static_folder(self, value: str | os.PathLike[str] | None) -> None: if value is not None: value = os.fspath(value).rstrip(r"\/") @@ -287,7 +288,7 @@ def _method_route( self, method: str, rule: str, - options: dict, + options: dict[str, t.Any], ) -> t.Callable[[T_route], T_route]: if "methods" in options: raise TypeError("Use the 'route' decorator to use the 'methods' argument.") @@ -700,7 +701,7 @@ def _get_exc_class_and_code( return exc_class, None -def _endpoint_from_view_func(view_func: t.Callable) -> str: +def _endpoint_from_view_func(view_func: ft.RouteCallable) -> str: """Internal helper that returns the default endpoint for a given function. This always is the function name. """ @@ -717,7 +718,7 @@ def _path_is_relative_to(path: pathlib.PurePath, base: str) -> bool: return False -def _find_package_path(import_name): +def _find_package_path(import_name: str) -> str: """Find the path that contains the package or module.""" root_mod_name, _, _ = import_name.partition(".") @@ -734,34 +735,35 @@ def _find_package_path(import_name): # - we raised `ValueError` due to `root_spec` being `None` return os.getcwd() - if root_spec.origin in {"namespace", None}: - # namespace package - package_spec = importlib.util.find_spec(import_name) - - if package_spec is not None and package_spec.submodule_search_locations: - # Pick the path in the namespace that contains the submodule. - package_path = pathlib.Path( - os.path.commonpath(package_spec.submodule_search_locations) - ) - search_location = next( - location - for location in root_spec.submodule_search_locations - if _path_is_relative_to(package_path, location) - ) + if root_spec.submodule_search_locations: + if root_spec.origin is None or root_spec.origin == "namespace": + # namespace package + package_spec = importlib.util.find_spec(import_name) + + if package_spec is not None and package_spec.submodule_search_locations: + # Pick the path in the namespace that contains the submodule. + package_path = pathlib.Path( + os.path.commonpath(package_spec.submodule_search_locations) + ) + search_location = next( + location + for location in root_spec.submodule_search_locations + if _path_is_relative_to(package_path, location) + ) + else: + # Pick the first path. + search_location = root_spec.submodule_search_locations[0] + + return os.path.dirname(search_location) else: - # Pick the first path. - search_location = root_spec.submodule_search_locations[0] - - return os.path.dirname(search_location) - elif root_spec.submodule_search_locations: - # package with __init__.py - return os.path.dirname(os.path.dirname(root_spec.origin)) + # package with __init__.py + return os.path.dirname(os.path.dirname(root_spec.origin)) else: # module - return os.path.dirname(root_spec.origin) + return os.path.dirname(root_spec.origin) # type: ignore[type-var, return-value] -def find_package(import_name: str): +def find_package(import_name: str) -> tuple[str | None, str]: """Find the prefix that a package is installed under, and the path that it would be imported from. diff --git a/src/flask/sessions.py b/src/flask/sessions.py index 34f71d8c05..bb753eb814 100644 --- a/src/flask/sessions.py +++ b/src/flask/sessions.py @@ -13,12 +13,15 @@ from .json.tag import TaggedJSONSerializer if t.TYPE_CHECKING: # pragma: no cover + import typing_extensions as te + from .app import Flask from .wrappers import Request from .wrappers import Response -class SessionMixin(MutableMapping): +# TODO generic when Python > 3.8 +class SessionMixin(MutableMapping): # type: ignore[type-arg] """Expands a basic dictionary with session attributes.""" @property @@ -46,7 +49,8 @@ def permanent(self, value: bool) -> None: accessed = True -class SecureCookieSession(CallbackDict, SessionMixin): +# TODO generic when Python > 3.8 +class SecureCookieSession(CallbackDict, SessionMixin): # type: ignore[type-arg] """Base class for sessions based on signed cookies. This session backend will set the :attr:`modified` and @@ -69,7 +73,7 @@ class SecureCookieSession(CallbackDict, SessionMixin): accessed = False def __init__(self, initial: t.Any = None) -> None: - def on_update(self) -> None: + def on_update(self: te.Self) -> None: self.modified = True self.accessed = True @@ -178,7 +182,7 @@ def is_null_session(self, obj: object) -> bool: def get_cookie_name(self, app: Flask) -> str: """The name of the session cookie. Uses``app.config["SESSION_COOKIE_NAME"]``.""" - return app.config["SESSION_COOKIE_NAME"] + return app.config["SESSION_COOKIE_NAME"] # type: ignore[no-any-return] def get_cookie_domain(self, app: Flask) -> str | None: """The value of the ``Domain`` parameter on the session cookie. If not set, @@ -190,8 +194,7 @@ def get_cookie_domain(self, app: Flask) -> str | None: .. versionchanged:: 2.3 Not set by default, does not fall back to ``SERVER_NAME``. """ - rv = app.config["SESSION_COOKIE_DOMAIN"] - return rv if rv else None + return app.config["SESSION_COOKIE_DOMAIN"] # type: ignore[no-any-return] def get_cookie_path(self, app: Flask) -> str: """Returns the path for which the cookie should be valid. The @@ -199,27 +202,27 @@ def get_cookie_path(self, app: Flask) -> str: config var if it's set, and falls back to ``APPLICATION_ROOT`` or uses ``/`` if it's ``None``. """ - return app.config["SESSION_COOKIE_PATH"] or app.config["APPLICATION_ROOT"] + return app.config["SESSION_COOKIE_PATH"] or app.config["APPLICATION_ROOT"] # type: ignore[no-any-return] def get_cookie_httponly(self, app: Flask) -> bool: """Returns True if the session cookie should be httponly. This currently just returns the value of the ``SESSION_COOKIE_HTTPONLY`` config var. """ - return app.config["SESSION_COOKIE_HTTPONLY"] + return app.config["SESSION_COOKIE_HTTPONLY"] # type: ignore[no-any-return] def get_cookie_secure(self, app: Flask) -> bool: """Returns True if the cookie should be secure. This currently just returns the value of the ``SESSION_COOKIE_SECURE`` setting. """ - return app.config["SESSION_COOKIE_SECURE"] + return app.config["SESSION_COOKIE_SECURE"] # type: ignore[no-any-return] - def get_cookie_samesite(self, app: Flask) -> str: + def get_cookie_samesite(self, app: Flask) -> str | None: """Return ``'Strict'`` or ``'Lax'`` if the cookie should use the ``SameSite`` attribute. This currently just returns the value of the :data:`SESSION_COOKIE_SAMESITE` setting. """ - return app.config["SESSION_COOKIE_SAMESITE"] + return app.config["SESSION_COOKIE_SAMESITE"] # type: ignore[no-any-return] def get_expiration_time(self, app: Flask, session: SessionMixin) -> datetime | None: """A helper method that returns an expiration date for the session diff --git a/src/flask/templating.py b/src/flask/templating.py index 8dff8bacde..618a3b35d0 100644 --- a/src/flask/templating.py +++ b/src/flask/templating.py @@ -57,16 +57,16 @@ class DispatchingJinjaLoader(BaseLoader): def __init__(self, app: App) -> None: self.app = app - def get_source( # type: ignore - self, environment: Environment, template: str - ) -> tuple[str, str | None, t.Callable | None]: + def get_source( + self, environment: BaseEnvironment, template: str + ) -> tuple[str, str | None, t.Callable[[], bool] | None]: if self.app.config["EXPLAIN_TEMPLATE_LOADING"]: return self._get_source_explained(environment, template) return self._get_source_fast(environment, template) def _get_source_explained( - self, environment: Environment, template: str - ) -> tuple[str, str | None, t.Callable | None]: + self, environment: BaseEnvironment, template: str + ) -> tuple[str, str | None, t.Callable[[], bool] | None]: attempts = [] rv: tuple[str, str | None, t.Callable[[], bool] | None] | None trv: None | (tuple[str, str | None, t.Callable[[], bool] | None]) = None @@ -89,8 +89,8 @@ def _get_source_explained( raise TemplateNotFound(template) def _get_source_fast( - self, environment: Environment, template: str - ) -> tuple[str, str | None, t.Callable | None]: + self, environment: BaseEnvironment, template: str + ) -> tuple[str, str | None, t.Callable[[], bool] | None]: for _srcobj, loader in self._iter_loaders(template): try: return loader.get_source(environment, template) @@ -98,9 +98,7 @@ def _get_source_fast( continue raise TemplateNotFound(template) - def _iter_loaders( - self, template: str - ) -> t.Generator[tuple[Scaffold, BaseLoader], None, None]: + def _iter_loaders(self, template: str) -> t.Iterator[tuple[Scaffold, BaseLoader]]: loader = self.app.jinja_loader if loader is not None: yield self.app, loader diff --git a/src/flask/testing.py b/src/flask/testing.py index 69aa785188..a27b7c8fe9 100644 --- a/src/flask/testing.py +++ b/src/flask/testing.py @@ -17,6 +17,7 @@ from .sessions import SessionMixin if t.TYPE_CHECKING: # pragma: no cover + from _typeshed.wsgi import WSGIEnvironment from werkzeug.test import TestResponse from .app import Flask @@ -134,7 +135,7 @@ def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: @contextmanager def session_transaction( self, *args: t.Any, **kwargs: t.Any - ) -> t.Generator[SessionMixin, None, None]: + ) -> t.Iterator[SessionMixin]: """When used in combination with a ``with`` statement this opens a session transaction. This can be used to modify the session that the test client uses. Once the ``with`` block is left the session is @@ -181,7 +182,7 @@ def session_transaction( resp.headers.getlist("Set-Cookie"), ) - def _copy_environ(self, other): + def _copy_environ(self, other: WSGIEnvironment) -> WSGIEnvironment: out = {**self.environ_base, **other} if self.preserve_context: @@ -189,7 +190,9 @@ def _copy_environ(self, other): return out - def _request_from_builder_args(self, args, kwargs): + def _request_from_builder_args( + self, args: tuple[t.Any, ...], kwargs: dict[str, t.Any] + ) -> BaseRequest: kwargs["environ_base"] = self._copy_environ(kwargs.get("environ_base", {})) builder = EnvironBuilder(self.application, *args, **kwargs) @@ -210,7 +213,7 @@ def open( ): if isinstance(args[0], werkzeug.test.EnvironBuilder): builder = copy(args[0]) - builder.environ_base = self._copy_environ(builder.environ_base or {}) + builder.environ_base = self._copy_environ(builder.environ_base or {}) # type: ignore[arg-type] request = builder.get_request() elif isinstance(args[0], dict): request = EnvironBuilder.from_environ( @@ -287,7 +290,7 @@ def invoke( # type: ignore :return: a :class:`~click.testing.Result` object. """ if cli is None: - cli = self.app.cli # type: ignore + cli = self.app.cli if "obj" not in kwargs: kwargs["obj"] = ScriptInfo(create_app=lambda: self.app) diff --git a/src/flask/typing.py b/src/flask/typing.py index a8c9ba0425..cf6d4ae6dd 100644 --- a/src/flask/typing.py +++ b/src/flask/typing.py @@ -68,8 +68,10 @@ TemplateFilterCallable = t.Callable[..., t.Any] TemplateGlobalCallable = t.Callable[..., t.Any] TemplateTestCallable = t.Callable[..., bool] -URLDefaultCallable = t.Callable[[str, dict], None] -URLValuePreprocessorCallable = t.Callable[[t.Optional[str], t.Optional[dict]], None] +URLDefaultCallable = t.Callable[[str, t.Dict[str, t.Any]], None] +URLValuePreprocessorCallable = t.Callable[ + [t.Optional[str], t.Optional[t.Dict[str, t.Any]]], None +] # This should take Exception, but that either breaks typing the argument # with a specific exception, or decorating multiple times with different diff --git a/src/flask/views.py b/src/flask/views.py index bfc18af3b8..794fdc06cf 100644 --- a/src/flask/views.py +++ b/src/flask/views.py @@ -6,6 +6,8 @@ from .globals import current_app from .globals import request +F = t.TypeVar("F", bound=t.Callable[..., t.Any]) + http_method_funcs = frozenset( ["get", "post", "head", "options", "delete", "put", "trace", "patch"] ) @@ -59,7 +61,7 @@ def dispatch_request(self, name): #: decorator. #: #: .. versionadded:: 0.8 - decorators: t.ClassVar[list[t.Callable]] = [] + decorators: t.ClassVar[list[t.Callable[[F], F]]] = [] #: Create a new instance of this view class for every request by #: default. If a view subclass sets this to ``False``, the same @@ -105,13 +107,13 @@ def view(**kwargs: t.Any) -> ft.ResponseReturnValue: self = view.view_class( # type: ignore[attr-defined] *class_args, **class_kwargs ) - return current_app.ensure_sync(self.dispatch_request)(**kwargs) + return current_app.ensure_sync(self.dispatch_request)(**kwargs) # type: ignore[no-any-return] else: self = cls(*class_args, **class_kwargs) def view(**kwargs: t.Any) -> ft.ResponseReturnValue: - return current_app.ensure_sync(self.dispatch_request)(**kwargs) + return current_app.ensure_sync(self.dispatch_request)(**kwargs) # type: ignore[no-any-return] if cls.decorators: view.__name__ = name @@ -186,4 +188,4 @@ def dispatch_request(self, **kwargs: t.Any) -> ft.ResponseReturnValue: meth = getattr(self, "get", None) assert meth is not None, f"Unimplemented method {request.method!r}" - return current_app.ensure_sync(meth)(**kwargs) + return current_app.ensure_sync(meth)(**kwargs) # type: ignore[no-any-return] diff --git a/src/flask/wrappers.py b/src/flask/wrappers.py index ef7aa38c0e..c1eca80783 100644 --- a/src/flask/wrappers.py +++ b/src/flask/wrappers.py @@ -3,6 +3,7 @@ import typing as t from werkzeug.exceptions import BadRequest +from werkzeug.exceptions import HTTPException from werkzeug.wrappers import Request as RequestBase from werkzeug.wrappers import Response as ResponseBase @@ -49,13 +50,13 @@ class Request(RequestBase): #: raised / was raised as part of the request handling. This is #: usually a :exc:`~werkzeug.exceptions.NotFound` exception or #: something similar. - routing_exception: Exception | None = None + routing_exception: HTTPException | None = None @property - def max_content_length(self) -> int | None: # type: ignore + def max_content_length(self) -> int | None: # type: ignore[override] """Read-only view of the ``MAX_CONTENT_LENGTH`` config key.""" if current_app: - return current_app.config["MAX_CONTENT_LENGTH"] + return current_app.config["MAX_CONTENT_LENGTH"] # type: ignore[no-any-return] else: return None @@ -167,7 +168,7 @@ def max_cookie_size(self) -> int: # type: ignore Werkzeug's docs. """ if current_app: - return current_app.config["MAX_COOKIE_SIZE"] + return current_app.config["MAX_COOKIE_SIZE"] # type: ignore[no-any-return] # return Werkzeug's default when not in an app context return super().max_cookie_size