From f405c6f19e002ae13708cb33f6d48257cc1ea37a Mon Sep 17 00:00:00 2001 From: pgjones Date: Fri, 23 Apr 2021 09:44:15 +0100 Subject: [PATCH 1/2] Initial typing support This enables type checking in CI and marks the project as typed. --- .github/workflows/tests.yaml | 1 + MANIFEST.in | 1 + requirements/typing.in | 1 + requirements/typing.txt | 14 ++++++++++++++ setup.cfg | 27 +++++++++++++++++++++++++++ src/flask/py.typed | 0 tox.ini | 5 +++++ 7 files changed, 49 insertions(+) create mode 100644 requirements/typing.in create mode 100644 requirements/typing.txt create mode 100644 src/flask/py.typed diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 7d34bf78c2..e656dcf840 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -31,6 +31,7 @@ jobs: - {name: '3.7', python: '3.7', os: ubuntu-latest, tox: py37} - {name: '3.6', python: '3.6', os: ubuntu-latest, tox: py36} - {name: 'PyPy', python: pypy3, os: ubuntu-latest, tox: pypy3} + - {name: Typing, python: '3.9', os: ubuntu-latest, tox: typing} steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 diff --git a/MANIFEST.in b/MANIFEST.in index ddf882369f..65a9774968 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -7,4 +7,5 @@ graft docs prune docs/_build graft examples graft tests +include src/flask/py.typed global-exclude *.pyc diff --git a/requirements/typing.in b/requirements/typing.in new file mode 100644 index 0000000000..f0aa93ac8e --- /dev/null +++ b/requirements/typing.in @@ -0,0 +1 @@ +mypy diff --git a/requirements/typing.txt b/requirements/typing.txt new file mode 100644 index 0000000000..29e12e5e85 --- /dev/null +++ b/requirements/typing.txt @@ -0,0 +1,14 @@ +# +# This file is autogenerated by pip-compile +# To update, run: +# +# pip-compile requirements/typing.in +# +mypy-extensions==0.4.3 + # via mypy +mypy==0.812 + # via -r requirements/typing.in +typed-ast==1.4.3 + # via mypy +typing-extensions==3.7.4.3 + # via mypy diff --git a/setup.cfg b/setup.cfg index 9dee3575ca..7761686623 100644 --- a/setup.cfg +++ b/setup.cfg @@ -85,3 +85,30 @@ max-line-length = 80 per-file-ignores = # __init__ module exports names src/flask/__init__.py: F401 + +[mypy] +files = src/flask +python_version = 3.6 +allow_redefinition = True +disallow_subclassing_any = True +# disallow_untyped_calls = True +# disallow_untyped_defs = True +# disallow_incomplete_defs = True +no_implicit_optional = True +local_partial_types = True +# no_implicit_reexport = True +strict_equality = True +warn_redundant_casts = True +warn_unused_configs = True +warn_unused_ignores = True +# warn_return_any = True +# warn_unreachable = True + +[mypy-asgiref.*] +ignore_missing_imports = True + +[mypy-blinker.*] +ignore_missing_imports = True + +[mypy-dotenv.*] +ignore_missing_imports = True diff --git a/src/flask/py.typed b/src/flask/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tox.ini b/tox.ini index cf12c0ebcc..d4558ad03d 100644 --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,7 @@ envlist = py{39,38,37,36,py3} style + typing docs skip_missing_interpreters = true @@ -24,6 +25,10 @@ deps = pre-commit skip_install = true commands = pre-commit run --all-files --show-diff-on-failure +[testenv:typing] +deps = -r requirements/typing.txt +commands = mypy + [testenv:docs] deps = -r requirements/docs.txt From 77237093da25438c88b0a74c374a397d4bf823eb Mon Sep 17 00:00:00 2001 From: pgjones Date: Sat, 24 Apr 2021 12:22:26 +0100 Subject: [PATCH 2/2] Add initial type hints This should make it easier for users to correctly use Flask. The hints are from Quart. --- CHANGES.rst | 1 + src/flask/app.py | 292 +++++++++++++++++++++++-------------- src/flask/blueprints.py | 138 ++++++++++++------ src/flask/cli.py | 17 ++- src/flask/config.py | 54 +++---- src/flask/ctx.py | 80 ++++++---- src/flask/debughelpers.py | 7 +- src/flask/globals.py | 14 +- src/flask/helpers.py | 74 ++++++---- src/flask/json/__init__.py | 57 +++++--- src/flask/json/tag.py | 78 +++++----- src/flask/logging.py | 14 +- src/flask/scaffold.py | 175 ++++++++++++++-------- src/flask/sessions.py | 70 +++++---- src/flask/signals.py | 12 +- src/flask/templating.py | 39 +++-- src/flask/testing.py | 57 +++++--- src/flask/typing.py | 46 ++++++ src/flask/views.py | 27 ++-- src/flask/wrappers.py | 29 ++-- 20 files changed, 820 insertions(+), 461 deletions(-) create mode 100644 src/flask/typing.py diff --git a/CHANGES.rst b/CHANGES.rst index d62d038335..ce169742a9 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -74,6 +74,7 @@ Unreleased ``python`` shell if ``readline`` is installed. :issue:`3941` - ``helpers.total_seconds()`` is deprecated. Use ``timedelta.total_seconds()`` instead. :pr:`3962` +- Add type hinting. :pr:`3973`. Version 1.1.2 diff --git a/src/flask/app.py b/src/flask/app.py index 98437cba8b..7afb0a1e67 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -1,11 +1,14 @@ import functools import inspect +import logging import os import sys +import typing as t import weakref from datetime import timedelta from itertools import chain from threading import Lock +from types import TracebackType from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableDict @@ -15,6 +18,7 @@ from werkzeug.exceptions import InternalServerError from werkzeug.routing import BuildError from werkzeug.routing import Map +from werkzeug.routing import MapAdapter from werkzeug.routing import RequestRedirect from werkzeug.routing import RoutingException from werkzeug.routing import Rule @@ -53,15 +57,30 @@ from .signals import request_tearing_down from .templating import DispatchingJinjaLoader from .templating import Environment +from .typing import AfterRequestCallable +from .typing import BeforeRequestCallable +from .typing import ErrorHandlerCallable +from .typing import ResponseReturnValue +from .typing import TeardownCallable +from .typing import TemplateContextProcessorCallable +from .typing import TemplateFilterCallable +from .typing import TemplateGlobalCallable +from .typing import TemplateTestCallable +from .typing import URLDefaultCallable +from .typing import URLValuePreprocessorCallable from .wrappers import Request from .wrappers import Response +if t.TYPE_CHECKING: + from .blueprints import Blueprint + from .testing import FlaskClient + from .testing import FlaskCliRunner if sys.version_info >= (3, 8): iscoroutinefunction = inspect.iscoroutinefunction else: - def iscoroutinefunction(func): + def iscoroutinefunction(func: t.Any) -> bool: while inspect.ismethod(func): func = func.__func__ @@ -71,7 +90,7 @@ def iscoroutinefunction(func): return inspect.iscoroutinefunction(func) -def _make_timedelta(value): +def _make_timedelta(value: t.Optional[timedelta]) -> t.Optional[timedelta]: if value is None or isinstance(value, timedelta): return value @@ -295,7 +314,7 @@ class Flask(Scaffold): #: This is a ``dict`` instead of an ``ImmutableDict`` to allow #: easier configuration. #: - jinja_options = {} + jinja_options: dict = {} #: Default configuration parameters. default_config = ImmutableDict( @@ -347,7 +366,7 @@ class Flask(Scaffold): #: the test client that is used with when `test_client` is used. #: #: .. versionadded:: 0.7 - test_client_class = None + test_client_class: t.Optional[t.Type["FlaskClient"]] = None #: The :class:`~click.testing.CliRunner` subclass, by default #: :class:`~flask.testing.FlaskCliRunner` that is used by @@ -355,7 +374,7 @@ class Flask(Scaffold): #: Flask app object as the first argument. #: #: .. versionadded:: 1.0 - test_cli_runner_class = None + test_cli_runner_class: t.Optional[t.Type["FlaskCliRunner"]] = None #: the session interface to use. By default an instance of #: :class:`~flask.sessions.SecureCookieSessionInterface` is used here. @@ -365,16 +384,16 @@ class Flask(Scaffold): def __init__( self, - import_name, - static_url_path=None, - static_folder="static", - static_host=None, - host_matching=False, - subdomain_matching=False, - template_folder="templates", - instance_path=None, - instance_relative_config=False, - root_path=None, + import_name: str, + static_url_path: t.Optional[str] = None, + static_folder: t.Optional[str] = "static", + static_host: t.Optional[str] = None, + host_matching: bool = False, + subdomain_matching: bool = False, + template_folder: t.Optional[str] = "templates", + instance_path: t.Optional[str] = None, + instance_relative_config: bool = False, + root_path: t.Optional[str] = None, ): super().__init__( import_name=import_name, @@ -409,14 +428,16 @@ def __init__( #: tried. #: #: .. versionadded:: 0.9 - self.url_build_error_handlers = [] + self.url_build_error_handlers: t.List[ + t.Callable[[Exception, str, dict], str] + ] = [] #: A list of functions that will be called at the beginning of the #: first request to this instance. To register a function, use the #: :meth:`before_first_request` decorator. #: #: .. versionadded:: 0.8 - self.before_first_request_funcs = [] + self.before_first_request_funcs: t.List[BeforeRequestCallable] = [] #: A list of functions that are called when the application context #: is destroyed. Since the application context is also torn down @@ -424,13 +445,13 @@ def __init__( #: from databases. #: #: .. versionadded:: 0.9 - self.teardown_appcontext_funcs = [] + self.teardown_appcontext_funcs: t.List[TeardownCallable] = [] #: A list of shell context processor functions that should be run #: when a shell context is created. #: #: .. versionadded:: 0.11 - self.shell_context_processors = [] + self.shell_context_processors: t.List[t.Callable[[], t.Dict[str, t.Any]]] = [] #: Maps registered blueprint names to blueprint objects. The #: dict retains the order the blueprints were registered in. @@ -438,7 +459,7 @@ def __init__( #: not track how often they were attached. #: #: .. versionadded:: 0.7 - self.blueprints = {} + self.blueprints: t.Dict[str, "Blueprint"] = {} #: a place where extensions can store application specific state. For #: example this is where an extension could store database engines and @@ -449,7 +470,7 @@ def __init__( #: ``'foo'``. #: #: .. versionadded:: 0.7 - self.extensions = {} + self.extensions: dict = {} #: The :class:`~werkzeug.routing.Map` for this instance. You can use #: this to change the routing converters after the class was created @@ -492,18 +513,18 @@ def __init__( f"{self.static_url_path}/", endpoint="static", host=static_host, - view_func=lambda **kw: self_ref().send_static_file(**kw), + view_func=lambda **kw: self_ref().send_static_file(**kw), # type: ignore # noqa: B950 ) # Set the name of the Click group in case someone wants to add # the app's commands to another CLI tool. self.cli.name = self.name - def _is_setup_finished(self): + def _is_setup_finished(self) -> bool: return self.debug and self._got_first_request @locked_cached_property - def name(self): + def name(self) -> str: # type: ignore """The name of the application. This is usually the import name with the difference that it's guessed from the run file if the import name is main. This name is used as a display name when @@ -520,7 +541,7 @@ def name(self): return self.import_name @property - def propagate_exceptions(self): + def propagate_exceptions(self) -> bool: """Returns the value of the ``PROPAGATE_EXCEPTIONS`` configuration value in case it's set, otherwise a sensible default is returned. @@ -532,7 +553,7 @@ def propagate_exceptions(self): return self.testing or self.debug @property - def preserve_context_on_exception(self): + def preserve_context_on_exception(self) -> bool: """Returns the value of the ``PRESERVE_CONTEXT_ON_EXCEPTION`` configuration value in case it's set, otherwise a sensible default is returned. @@ -545,7 +566,7 @@ def preserve_context_on_exception(self): return self.debug @locked_cached_property - def logger(self): + def logger(self) -> logging.Logger: """A standard Python :class:`~logging.Logger` for the app, with the same name as :attr:`name`. @@ -572,7 +593,7 @@ def logger(self): return create_logger(self) @locked_cached_property - def jinja_env(self): + def jinja_env(self) -> Environment: """The Jinja environment used to load templates. The environment is created the first time this property is @@ -582,7 +603,7 @@ def jinja_env(self): return self.create_jinja_environment() @property - def got_first_request(self): + def got_first_request(self) -> bool: """This attribute is set to ``True`` if the application started handling the first request. @@ -590,7 +611,7 @@ def got_first_request(self): """ return self._got_first_request - def make_config(self, instance_relative=False): + def make_config(self, instance_relative: bool = False) -> Config: """Used to create the config attribute by the Flask constructor. The `instance_relative` parameter is passed in from the constructor of Flask (there named `instance_relative_config`) and indicates if @@ -607,7 +628,7 @@ def make_config(self, instance_relative=False): defaults["DEBUG"] = get_debug_flag() return self.config_class(root_path, defaults) - def auto_find_instance_path(self): + def auto_find_instance_path(self) -> str: """Tries to locate the instance path if it was not provided to the constructor of the application class. It will basically calculate the path to a folder named ``instance`` next to your main file or @@ -620,7 +641,7 @@ def auto_find_instance_path(self): return os.path.join(package_path, "instance") return os.path.join(prefix, "var", f"{self.name}-instance") - def open_instance_resource(self, resource, mode="rb"): + def open_instance_resource(self, resource: str, mode: str = "rb") -> t.IO[t.AnyStr]: """Opens a resource from the application's instance folder (:attr:`instance_path`). Otherwise works like :meth:`open_resource`. Instance resources can also be opened for @@ -633,7 +654,7 @@ def open_instance_resource(self, resource, mode="rb"): return open(os.path.join(self.instance_path, resource), mode) @property - def templates_auto_reload(self): + def templates_auto_reload(self) -> bool: """Reload templates when they are changed. Used by :meth:`create_jinja_environment`. @@ -648,10 +669,10 @@ def templates_auto_reload(self): return rv if rv is not None else self.debug @templates_auto_reload.setter - def templates_auto_reload(self, value): + def templates_auto_reload(self, value: bool) -> None: self.config["TEMPLATES_AUTO_RELOAD"] = value - def create_jinja_environment(self): + def create_jinja_environment(self) -> Environment: """Create the Jinja environment based on :attr:`jinja_options` and the various Jinja-related methods of the app. Changing :attr:`jinja_options` after this will have no effect. Also adds @@ -683,10 +704,10 @@ def create_jinja_environment(self): session=session, g=g, ) - rv.policies["json.dumps_function"] = json.dumps + rv.policies["json.dumps_function"] = json.dumps # type: ignore return rv - def create_global_jinja_loader(self): + def create_global_jinja_loader(self) -> DispatchingJinjaLoader: """Creates the loader for the Jinja2 environment. Can be used to override just the loader and keeping the rest unchanged. It's discouraged to override this function. Instead one should override @@ -699,7 +720,7 @@ def create_global_jinja_loader(self): """ return DispatchingJinjaLoader(self) - def select_jinja_autoescape(self, filename): + def select_jinja_autoescape(self, filename: str) -> bool: """Returns ``True`` if autoescaping should be active for the given template name. If no template name is given, returns `True`. @@ -709,7 +730,7 @@ def select_jinja_autoescape(self, filename): return True return filename.endswith((".html", ".htm", ".xml", ".xhtml")) - def update_template_context(self, context): + def update_template_context(self, context: dict) -> None: """Update the template context with some commonly used variables. This injects request, session, config and g into the template context as well as everything template context processors want @@ -720,7 +741,9 @@ def update_template_context(self, context): :param context: the context as a dictionary that is updated in place to add extra variables. """ - funcs = self.template_context_processors[None] + funcs: t.Iterable[ + TemplateContextProcessorCallable + ] = self.template_context_processors[None] reqctx = _request_ctx_stack.top if reqctx is not None: for bp in self._request_blueprints(): @@ -734,7 +757,7 @@ def update_template_context(self, context): # existing views. context.update(orig_ctx) - def make_shell_context(self): + def make_shell_context(self) -> dict: """Returns the shell context for an interactive shell for this application. This runs all the registered shell context processors. @@ -758,7 +781,7 @@ def make_shell_context(self): env = ConfigAttribute("ENV") @property - def debug(self): + def debug(self) -> bool: """Whether debug mode is enabled. When using ``flask run`` to start the development server, an interactive debugger will be shown for unhandled exceptions, and the server will be reloaded when code @@ -775,11 +798,18 @@ def debug(self): return self.config["DEBUG"] @debug.setter - def debug(self, value): + def debug(self, value: bool) -> None: self.config["DEBUG"] = value self.jinja_env.auto_reload = self.templates_auto_reload - def run(self, host=None, port=None, debug=None, load_dotenv=True, **options): + def run( + self, + host: t.Optional[str] = None, + port: t.Optional[int] = None, + debug: t.Optional[bool] = None, + load_dotenv: bool = True, + **options: t.Any, + ) -> None: """Runs the application on a local development server. Do not use ``run()`` in a production setting. It is not intended to @@ -887,14 +917,14 @@ def run(self, host=None, port=None, debug=None, load_dotenv=True, **options): from werkzeug.serving import run_simple try: - run_simple(host, port, self, **options) + run_simple(t.cast(str, host), port, self, **options) finally: # reset the first request information if the development server # reset normally. This makes it possible to restart the server # without reloader and that stuff from an interactive shell. self._got_first_request = False - def test_client(self, use_cookies=True, **kwargs): + def test_client(self, use_cookies: bool = True, **kwargs: t.Any) -> "FlaskClient": """Creates a test client for this application. For information about unit testing head over to :doc:`/testing`. @@ -947,10 +977,12 @@ def __init__(self, *args, **kwargs): """ cls = self.test_client_class if cls is None: - from .testing import FlaskClient as cls - return cls(self, self.response_class, use_cookies=use_cookies, **kwargs) + from .testing import FlaskClient as cls # type: ignore + return cls( # type: ignore + self, self.response_class, use_cookies=use_cookies, **kwargs + ) - def test_cli_runner(self, **kwargs): + def test_cli_runner(self, **kwargs: t.Any) -> "FlaskCliRunner": """Create a CLI runner for testing CLI commands. See :ref:`testing-cli`. @@ -963,12 +995,12 @@ def test_cli_runner(self, **kwargs): cls = self.test_cli_runner_class if cls is None: - from .testing import FlaskCliRunner as cls + from .testing import FlaskCliRunner as cls # type: ignore - return cls(self, **kwargs) + return cls(self, **kwargs) # type: ignore @setupmethod - def register_blueprint(self, blueprint, **options): + def register_blueprint(self, blueprint: "Blueprint", **options: t.Any) -> None: """Register a :class:`~flask.Blueprint` on the application. Keyword arguments passed to this method will override the defaults set on the blueprint. @@ -989,7 +1021,7 @@ def register_blueprint(self, blueprint, **options): """ blueprint.register(self, options) - def iter_blueprints(self): + def iter_blueprints(self) -> t.ValuesView["Blueprint"]: """Iterates over all blueprints by the order they were registered. .. versionadded:: 0.11 @@ -999,14 +1031,14 @@ def iter_blueprints(self): @setupmethod def add_url_rule( self, - rule, - endpoint=None, - view_func=None, - provide_automatic_options=None, - **options, - ): + rule: str, + endpoint: t.Optional[str] = None, + view_func: t.Optional[t.Callable] = None, + provide_automatic_options: t.Optional[bool] = None, + **options: t.Any, + ) -> None: if endpoint is None: - endpoint = _endpoint_from_view_func(view_func) + endpoint = _endpoint_from_view_func(view_func) # type: ignore options["endpoint"] = endpoint methods = options.pop("methods", None) @@ -1043,13 +1075,13 @@ def add_url_rule( methods |= required_methods rule = self.url_rule_class(rule, methods=methods, **options) - rule.provide_automatic_options = provide_automatic_options + rule.provide_automatic_options = provide_automatic_options # type: ignore self.url_map.add(rule) if view_func is not None: old_func = self.view_functions.get(endpoint) if getattr(old_func, "_flask_sync_wrapper", False): - old_func = old_func.__wrapped__ + old_func = old_func.__wrapped__ # type: ignore if old_func is not None and old_func != view_func: raise AssertionError( "View function mapping is overwriting an existing" @@ -1058,7 +1090,7 @@ def add_url_rule( self.view_functions[endpoint] = self.ensure_sync(view_func) @setupmethod - def template_filter(self, name=None): + def template_filter(self, name: t.Optional[str] = None) -> t.Callable: """A decorator that is used to register custom template filter. You can specify a name for the filter, otherwise the function name will be used. Example:: @@ -1071,14 +1103,16 @@ def reverse(s): function name will be used. """ - def decorator(f): + def decorator(f: TemplateFilterCallable) -> TemplateFilterCallable: self.add_template_filter(f, name=name) return f return decorator @setupmethod - def add_template_filter(self, f, name=None): + def add_template_filter( + self, f: TemplateFilterCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template filter. Works exactly like the :meth:`template_filter` decorator. @@ -1088,7 +1122,7 @@ def add_template_filter(self, f, name=None): self.jinja_env.filters[name or f.__name__] = f @setupmethod - def template_test(self, name=None): + def template_test(self, name: t.Optional[str] = None) -> t.Callable: """A decorator that is used to register custom template test. You can specify a name for the test, otherwise the function name will be used. Example:: @@ -1108,14 +1142,16 @@ def is_prime(n): function name will be used. """ - def decorator(f): + def decorator(f: TemplateTestCallable) -> TemplateTestCallable: self.add_template_test(f, name=name) return f return decorator @setupmethod - def add_template_test(self, f, name=None): + def add_template_test( + self, f: TemplateTestCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template test. Works exactly like the :meth:`template_test` decorator. @@ -1127,7 +1163,7 @@ def add_template_test(self, f, name=None): self.jinja_env.tests[name or f.__name__] = f @setupmethod - def template_global(self, name=None): + def template_global(self, name: t.Optional[str] = None) -> t.Callable: """A decorator that is used to register a custom template global function. You can specify a name for the global function, otherwise the function name will be used. Example:: @@ -1142,14 +1178,16 @@ def double(n): function name will be used. """ - def decorator(f): + def decorator(f: TemplateGlobalCallable) -> TemplateGlobalCallable: self.add_template_global(f, name=name) return f return decorator @setupmethod - def add_template_global(self, f, name=None): + def add_template_global( + self, f: TemplateGlobalCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template global function. Works exactly like the :meth:`template_global` decorator. @@ -1161,7 +1199,7 @@ def add_template_global(self, f, name=None): self.jinja_env.globals[name or f.__name__] = f @setupmethod - def before_first_request(self, f): + def before_first_request(self, f: BeforeRequestCallable) -> BeforeRequestCallable: """Registers a function to be run before the first request to this instance of the application. @@ -1174,7 +1212,7 @@ def before_first_request(self, f): return f @setupmethod - def teardown_appcontext(self, f): + def teardown_appcontext(self, f: TeardownCallable) -> TeardownCallable: """Registers a function to be called when the application context ends. These functions are typically also called when the request context is popped. @@ -1207,7 +1245,7 @@ def teardown_appcontext(self, f): return f @setupmethod - def shell_context_processor(self, f): + def shell_context_processor(self, f: t.Callable) -> t.Callable: """Registers a shell context processor function. .. versionadded:: 0.11 @@ -1215,7 +1253,7 @@ def shell_context_processor(self, f): self.shell_context_processors.append(f) return f - def _find_error_handler(self, e): + def _find_error_handler(self, e: Exception) -> t.Optional[ErrorHandlerCallable]: """Return a registered error handler for an exception in this order: blueprint handler for a specific code, app handler for a specific code, blueprint handler for an exception class, app handler for an exception @@ -1235,8 +1273,11 @@ def _find_error_handler(self, e): if handler is not None: return handler + return None - def handle_http_exception(self, e): + def handle_http_exception( + self, e: HTTPException + ) -> t.Union[HTTPException, ResponseReturnValue]: """Handles an HTTP exception. By default this will invoke the registered error handlers and fall back to returning the exception as response. @@ -1269,7 +1310,7 @@ def handle_http_exception(self, e): return e return handler(e) - def trap_http_exception(self, e): + def trap_http_exception(self, e: Exception) -> bool: """Checks if an HTTP exception should be trapped or not. By default this will return ``False`` for all exceptions except for a bad request key error if ``TRAP_BAD_REQUEST_ERRORS`` is set to ``True``. It @@ -1304,7 +1345,9 @@ def trap_http_exception(self, e): return False - def handle_user_exception(self, e): + def handle_user_exception( + self, e: Exception + ) -> t.Union[HTTPException, ResponseReturnValue]: """This method is called whenever an exception occurs that should be handled. A special case is :class:`~werkzeug .exceptions.HTTPException` which is forwarded to the @@ -1334,7 +1377,7 @@ def handle_user_exception(self, e): return handler(e) - def handle_exception(self, e): + def handle_exception(self, e: Exception) -> Response: """Handle an exception that did not have an error handler associated with it, or that was raised from an error handler. This always causes a 500 ``InternalServerError``. @@ -1374,6 +1417,7 @@ def handle_exception(self, e): raise e self.log_exception(exc_info) + server_error: t.Union[InternalServerError, ResponseReturnValue] server_error = InternalServerError(original_exception=e) handler = self._find_error_handler(server_error) @@ -1382,7 +1426,12 @@ def handle_exception(self, e): return self.finalize_request(server_error, from_error_handler=True) - def log_exception(self, exc_info): + def log_exception( + self, + exc_info: t.Union[ + t.Tuple[type, BaseException, TracebackType], t.Tuple[None, None, None] + ], + ) -> None: """Logs an exception. This is called by :meth:`handle_exception` if debugging is disabled and right before the handler is called. The default implementation logs the exception as error on the @@ -1394,7 +1443,7 @@ def log_exception(self, exc_info): f"Exception on {request.path} [{request.method}]", exc_info=exc_info ) - def raise_routing_exception(self, request): + def raise_routing_exception(self, request: Request) -> t.NoReturn: """Exceptions that are recording during routing are reraised with this method. During debug we are not reraising redirect requests for non ``GET``, ``HEAD``, or ``OPTIONS`` requests and we're raising @@ -1407,13 +1456,13 @@ def raise_routing_exception(self, request): or not isinstance(request.routing_exception, RequestRedirect) or request.method in ("GET", "HEAD", "OPTIONS") ): - raise request.routing_exception + raise request.routing_exception # type: ignore from .debughelpers import FormDataRoutingRedirect raise FormDataRoutingRedirect(request) - def dispatch_request(self): + def dispatch_request(self) -> ResponseReturnValue: """Does the request dispatching. Matches the URL and returns the return value of the view or error handler. This does not have to be a response object. In order to convert the return value to a @@ -1437,7 +1486,7 @@ def dispatch_request(self): # otherwise dispatch to the handler for that endpoint return self.view_functions[rule.endpoint](**req.view_args) - def full_dispatch_request(self): + def full_dispatch_request(self) -> Response: """Dispatches the request and on top of that performs request pre and postprocessing as well as HTTP exception catching and error handling. @@ -1454,7 +1503,11 @@ def full_dispatch_request(self): rv = self.handle_user_exception(e) return self.finalize_request(rv) - def finalize_request(self, rv, from_error_handler=False): + def finalize_request( + self, + rv: t.Union[ResponseReturnValue, HTTPException], + from_error_handler: bool = False, + ) -> Response: """Given the return value from a view function this finalizes the request by converting it into a response and invoking the postprocessing functions. This is invoked for both normal @@ -1479,7 +1532,7 @@ def finalize_request(self, rv, from_error_handler=False): ) return response - def try_trigger_before_first_request_functions(self): + def try_trigger_before_first_request_functions(self) -> None: """Called before each request and will ensure that it triggers the :attr:`before_first_request_funcs` and only exactly once per application instance (which means process usually). @@ -1495,7 +1548,7 @@ def try_trigger_before_first_request_functions(self): func() self._got_first_request = True - def make_default_options_response(self): + def make_default_options_response(self) -> Response: """This method is called to create the default ``OPTIONS`` response. This can be changed through subclassing to change the default behavior of ``OPTIONS`` responses. @@ -1508,7 +1561,7 @@ def make_default_options_response(self): rv.allow.update(methods) return rv - def should_ignore_error(self, error): + def should_ignore_error(self, error: t.Optional[BaseException]) -> bool: """This is called to figure out if an error should be ignored or not as far as the teardown system is concerned. If this function returns ``True`` then the teardown handlers will not be @@ -1518,7 +1571,7 @@ def should_ignore_error(self, error): """ return False - def ensure_sync(self, func): + def ensure_sync(self, func: t.Callable) -> t.Callable: """Ensure that the function is synchronous for WSGI workers. Plain ``def`` functions are returned as-is. ``async def`` functions are wrapped to run and wait for the response. @@ -1532,7 +1585,7 @@ def ensure_sync(self, func): return func - def make_response(self, rv): + def make_response(self, rv: ResponseReturnValue) -> Response: """Convert the return value from a view function to an instance of :attr:`response_class`. @@ -1620,7 +1673,7 @@ def make_response(self, rv): # evaluate a WSGI callable, or coerce a different response # class to the correct type try: - rv = self.response_class.force_type(rv, request.environ) + rv = self.response_class.force_type(rv, request.environ) # type: ignore # noqa: B950 except TypeError as e: raise TypeError( f"{e}\nThe view function did not return a valid" @@ -1636,10 +1689,11 @@ def make_response(self, rv): f" callable, but it was a {type(rv).__name__}." ) + rv = t.cast(Response, rv) # prefer the status if it was provided if status is not None: if isinstance(status, (str, bytes, bytearray)): - rv.status = status + rv.status = status # type: ignore else: rv.status_code = status @@ -1649,7 +1703,9 @@ def make_response(self, rv): return rv - def create_url_adapter(self, request): + def create_url_adapter( + self, request: t.Optional[Request] + ) -> t.Optional[MapAdapter]: """Creates a URL adapter for the given request. The URL adapter is created at a point where the request context is not yet set up so the request is passed explicitly. @@ -1687,21 +1743,25 @@ def create_url_adapter(self, request): url_scheme=self.config["PREFERRED_URL_SCHEME"], ) - def inject_url_defaults(self, endpoint, values): + return None + + def inject_url_defaults(self, endpoint: str, values: dict) -> None: """Injects the URL defaults for the given endpoint directly into the values dictionary passed. This is used internally and automatically called on URL building. .. versionadded:: 0.7 """ - funcs = self.url_default_functions[None] + funcs: t.Iterable[URLDefaultCallable] = self.url_default_functions[None] if "." in endpoint: bp = endpoint.rsplit(".", 1)[0] funcs = chain(funcs, self.url_default_functions[bp]) for func in funcs: func(endpoint, values) - def handle_url_build_error(self, error, endpoint, values): + def handle_url_build_error( + self, error: Exception, endpoint: str, values: dict + ) -> str: """Handle :class:`~werkzeug.routing.BuildError` on :meth:`url_for`. """ @@ -1722,7 +1782,7 @@ def handle_url_build_error(self, error, endpoint, values): raise error - def preprocess_request(self): + def preprocess_request(self) -> t.Optional[ResponseReturnValue]: """Called before the request is dispatched. Calls :attr:`url_value_preprocessors` registered with the app and the current blueprint (if any). Then calls :attr:`before_request_funcs` @@ -1733,14 +1793,16 @@ def preprocess_request(self): further request handling is stopped. """ - funcs = self.url_value_preprocessors[None] + funcs: t.Iterable[URLValuePreprocessorCallable] = self.url_value_preprocessors[ + None + ] for bp in self._request_blueprints(): if bp in self.url_value_preprocessors: funcs = chain(funcs, self.url_value_preprocessors[bp]) for func in funcs: func(request.endpoint, request.view_args) - funcs = self.before_request_funcs[None] + funcs: t.Iterable[BeforeRequestCallable] = self.before_request_funcs[None] for bp in self._request_blueprints(): if bp in self.before_request_funcs: funcs = chain(funcs, self.before_request_funcs[bp]) @@ -1749,7 +1811,9 @@ def preprocess_request(self): if rv is not None: return rv - def process_response(self, response): + return None + + def process_response(self, response: Response) -> Response: """Can be overridden in order to modify the response object before it's sent to the WSGI server. By default this will call all the :meth:`after_request` decorated functions. @@ -1763,7 +1827,7 @@ def process_response(self, response): instance of :attr:`response_class`. """ ctx = _request_ctx_stack.top - funcs = ctx._after_request_functions + funcs: t.Iterable[AfterRequestCallable] = ctx._after_request_functions for bp in self._request_blueprints(): if bp in self.after_request_funcs: funcs = chain(funcs, reversed(self.after_request_funcs[bp])) @@ -1775,7 +1839,9 @@ def process_response(self, response): self.session_interface.save_session(self, ctx.session, response) return response - def do_teardown_request(self, exc=_sentinel): + def do_teardown_request( + self, exc: t.Optional[BaseException] = _sentinel # type: ignore + ) -> None: """Called after the request is dispatched and the response is returned, right before the request context is popped. @@ -1798,7 +1864,9 @@ def do_teardown_request(self, exc=_sentinel): """ if exc is _sentinel: exc = sys.exc_info()[1] - funcs = reversed(self.teardown_request_funcs[None]) + funcs: t.Iterable[TeardownCallable] = reversed( + self.teardown_request_funcs[None] + ) for bp in self._request_blueprints(): if bp in self.teardown_request_funcs: funcs = chain(funcs, reversed(self.teardown_request_funcs[bp])) @@ -1806,7 +1874,9 @@ def do_teardown_request(self, exc=_sentinel): func(exc) request_tearing_down.send(self, exc=exc) - def do_teardown_appcontext(self, exc=_sentinel): + def do_teardown_appcontext( + self, exc: t.Optional[BaseException] = _sentinel # type: ignore + ) -> None: """Called right before the application context is popped. When handling a request, the application context is popped @@ -1827,7 +1897,7 @@ def do_teardown_appcontext(self, exc=_sentinel): func(exc) appcontext_tearing_down.send(self, exc=exc) - def app_context(self): + def app_context(self) -> AppContext: """Create an :class:`~flask.ctx.AppContext`. Use as a ``with`` block to push the context, which will make :data:`current_app` point at this application. @@ -1848,7 +1918,7 @@ def app_context(self): """ return AppContext(self) - def request_context(self, environ): + def request_context(self, environ: dict) -> RequestContext: """Create a :class:`~flask.ctx.RequestContext` representing a WSGI environment. Use a ``with`` block to push the context, which will make :data:`request` point at this request. @@ -1864,7 +1934,7 @@ def request_context(self, environ): """ return RequestContext(self, environ) - def test_request_context(self, *args, **kwargs): + def test_request_context(self, *args: t.Any, **kwargs: t.Any) -> RequestContext: """Create a :class:`~flask.ctx.RequestContext` for a WSGI environment created from the given values. This is mostly useful during testing, where you may want to run a function that uses @@ -1920,7 +1990,7 @@ def test_request_context(self, *args, **kwargs): finally: builder.close() - def wsgi_app(self, environ, start_response): + def wsgi_app(self, environ: dict, start_response: t.Callable) -> t.Any: """The actual WSGI application. This is not implemented in :meth:`__call__` so that middlewares can be applied without losing a reference to the app object. Instead of doing this:: @@ -1946,7 +2016,7 @@ def wsgi_app(self, environ, start_response): start the response. """ ctx = self.request_context(environ) - error = None + error: t.Optional[BaseException] = None try: try: ctx.push() @@ -1963,14 +2033,14 @@ def wsgi_app(self, environ, start_response): error = None ctx.auto_pop(error) - def __call__(self, environ, start_response): + def __call__(self, environ: dict, start_response: t.Callable) -> t.Any: """The WSGI server calls the Flask application object as the WSGI application. This calls :meth:`wsgi_app`, which can be wrapped to apply middleware. """ return self.wsgi_app(environ, start_response) - def _request_blueprints(self): + def _request_blueprints(self) -> t.Iterable[str]: if _request_ctx_stack.top.request.blueprint is None: return [] else: diff --git a/src/flask/blueprints.py b/src/flask/blueprints.py index 92345cf2ba..a2b6c0f50f 100644 --- a/src/flask/blueprints.py +++ b/src/flask/blueprints.py @@ -1,9 +1,25 @@ +import typing as t from collections import defaultdict from functools import update_wrapper from .scaffold import _endpoint_from_view_func from .scaffold import _sentinel from .scaffold import Scaffold +from .typing import AfterRequestCallable +from .typing import BeforeRequestCallable +from .typing import ErrorHandlerCallable +from .typing import TeardownCallable +from .typing import TemplateContextProcessorCallable +from .typing import TemplateFilterCallable +from .typing import TemplateGlobalCallable +from .typing import TemplateTestCallable +from .typing import URLDefaultCallable +from .typing import URLValuePreprocessorCallable + +if t.TYPE_CHECKING: + from .app import Flask + +DeferredSetupFunction = t.Callable[["BlueprintSetupState"], t.Callable] class BlueprintSetupState: @@ -13,7 +29,13 @@ class BlueprintSetupState: to all register callback functions. """ - def __init__(self, blueprint, app, options, first_registration): + def __init__( + self, + blueprint: "Blueprint", + app: "Flask", + options: t.Any, + first_registration: bool, + ) -> None: #: a reference to the current application self.app = app @@ -52,7 +74,13 @@ def __init__(self, blueprint, app, options, first_registration): self.url_defaults = dict(self.blueprint.url_values_defaults) self.url_defaults.update(self.options.get("url_defaults", ())) - def add_url_rule(self, rule, endpoint=None, view_func=None, **options): + def add_url_rule( + self, + rule: str, + endpoint: t.Optional[str] = None, + view_func: t.Optional[t.Callable] = None, + **options: t.Any, + ) -> None: """A helper method to register a rule (and optionally a view function) to the application. The endpoint is automatically prefixed with the blueprint's name. @@ -64,7 +92,7 @@ def add_url_rule(self, rule, endpoint=None, view_func=None, **options): rule = self.url_prefix options.setdefault("subdomain", self.subdomain) if endpoint is None: - endpoint = _endpoint_from_view_func(view_func) + endpoint = _endpoint_from_view_func(view_func) # type: ignore defaults = self.url_defaults if "defaults" in options: defaults = dict(defaults, **options.pop("defaults")) @@ -142,16 +170,16 @@ class Blueprint(Scaffold): def __init__( self, - name, - import_name, - static_folder=None, - static_url_path=None, - template_folder=None, - url_prefix=None, - subdomain=None, - url_defaults=None, - root_path=None, - cli_group=_sentinel, + name: str, + import_name: str, + static_folder: t.Optional[str] = None, + static_url_path: t.Optional[str] = None, + template_folder: t.Optional[str] = None, + url_prefix: t.Optional[str] = None, + subdomain: t.Optional[str] = None, + url_defaults: t.Optional[dict] = None, + root_path: t.Optional[str] = None, + cli_group: t.Optional[str] = _sentinel, # type: ignore ): super().__init__( import_name=import_name, @@ -163,19 +191,19 @@ def __init__( self.name = name self.url_prefix = url_prefix self.subdomain = subdomain - self.deferred_functions = [] + self.deferred_functions: t.List[DeferredSetupFunction] = [] if url_defaults is None: url_defaults = {} self.url_values_defaults = url_defaults self.cli_group = cli_group - self._blueprints = [] + self._blueprints: t.List[t.Tuple["Blueprint", dict]] = [] - def _is_setup_finished(self): + def _is_setup_finished(self) -> bool: return self.warn_on_modifications and self._got_registered_once - def record(self, func): + def record(self, func: t.Callable) -> None: """Registers a function that is called when the blueprint is registered on the application. This function is called with the state as argument as returned by the :meth:`make_setup_state` @@ -193,27 +221,29 @@ def record(self, func): ) self.deferred_functions.append(func) - def record_once(self, func): + def record_once(self, func: t.Callable) -> None: """Works like :meth:`record` but wraps the function in another function that will ensure the function is only called once. If the blueprint is registered a second time on the application, the function passed is not called. """ - def wrapper(state): + def wrapper(state: BlueprintSetupState) -> None: if state.first_registration: func(state) return self.record(update_wrapper(wrapper, func)) - def make_setup_state(self, app, options, first_registration=False): + def make_setup_state( + self, app: "Flask", options: dict, first_registration: bool = False + ) -> BlueprintSetupState: """Creates an instance of :meth:`~flask.blueprints.BlueprintSetupState` object that is later passed to the register callback functions. Subclasses can override this to return a subclass of the setup state. """ return BlueprintSetupState(self, app, options, first_registration) - def register_blueprint(self, blueprint, **options): + def register_blueprint(self, blueprint: "Blueprint", **options: t.Any) -> None: """Register a :class:`~flask.Blueprint` on this blueprint. Keyword arguments passed to this method will override the defaults set on the blueprint. @@ -222,7 +252,7 @@ def register_blueprint(self, blueprint, **options): """ self._blueprints.append((blueprint, options)) - def register(self, app, options): + def register(self, app: "Flask", options: dict) -> None: """Called by :meth:`Flask.register_blueprint` to register all views and callbacks registered on the blueprint with the application. Creates a :class:`.BlueprintSetupState` and calls @@ -327,7 +357,13 @@ def extend(bp_dict, parent_dict, ensure_sync=False): bp_options["name_prefix"] = options.get("name_prefix", "") + self.name + "." blueprint.register(app, bp_options) - def add_url_rule(self, rule, endpoint=None, view_func=None, **options): + def add_url_rule( + self, + rule: str, + endpoint: t.Optional[str] = None, + view_func: t.Optional[t.Callable] = None, + **options: t.Any, + ) -> None: """Like :meth:`Flask.add_url_rule` but for a blueprint. The endpoint for the :func:`url_for` function is prefixed with the name of the blueprint. """ @@ -339,7 +375,7 @@ def add_url_rule(self, rule, endpoint=None, view_func=None, **options): ), "Blueprint view function name should not contain dots" self.record(lambda s: s.add_url_rule(rule, endpoint, view_func, **options)) - def app_template_filter(self, name=None): + def app_template_filter(self, name: t.Optional[str] = None) -> t.Callable: """Register a custom template filter, available application wide. Like :meth:`Flask.template_filter` but for a blueprint. @@ -347,13 +383,15 @@ def app_template_filter(self, name=None): function name will be used. """ - def decorator(f): + def decorator(f: TemplateFilterCallable) -> TemplateFilterCallable: self.add_app_template_filter(f, name=name) return f return decorator - def add_app_template_filter(self, f, name=None): + def add_app_template_filter( + self, f: TemplateFilterCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template filter, available application wide. Like :meth:`Flask.add_template_filter` but for a blueprint. Works exactly like the :meth:`app_template_filter` decorator. @@ -362,12 +400,12 @@ def add_app_template_filter(self, f, name=None): function name will be used. """ - def register_template(state): + def register_template(state: BlueprintSetupState) -> None: state.app.jinja_env.filters[name or f.__name__] = f self.record_once(register_template) - def app_template_test(self, name=None): + def app_template_test(self, name: t.Optional[str] = None) -> t.Callable: """Register a custom template test, available application wide. Like :meth:`Flask.template_test` but for a blueprint. @@ -377,13 +415,15 @@ def app_template_test(self, name=None): function name will be used. """ - def decorator(f): + def decorator(f: TemplateTestCallable) -> TemplateTestCallable: self.add_app_template_test(f, name=name) return f return decorator - def add_app_template_test(self, f, name=None): + def add_app_template_test( + self, f: TemplateTestCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template test, available application wide. Like :meth:`Flask.add_template_test` but for a blueprint. Works exactly like the :meth:`app_template_test` decorator. @@ -394,12 +434,12 @@ def add_app_template_test(self, f, name=None): function name will be used. """ - def register_template(state): + def register_template(state: BlueprintSetupState) -> None: state.app.jinja_env.tests[name or f.__name__] = f self.record_once(register_template) - def app_template_global(self, name=None): + def app_template_global(self, name: t.Optional[str] = None) -> t.Callable: """Register a custom template global, available application wide. Like :meth:`Flask.template_global` but for a blueprint. @@ -409,13 +449,15 @@ def app_template_global(self, name=None): function name will be used. """ - def decorator(f): + def decorator(f: TemplateGlobalCallable) -> TemplateGlobalCallable: self.add_app_template_global(f, name=name) return f return decorator - def add_app_template_global(self, f, name=None): + def add_app_template_global( + self, f: TemplateGlobalCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template global, available application wide. Like :meth:`Flask.add_template_global` but for a blueprint. Works exactly like the :meth:`app_template_global` decorator. @@ -426,12 +468,12 @@ def add_app_template_global(self, f, name=None): function name will be used. """ - def register_template(state): + def register_template(state: BlueprintSetupState) -> None: state.app.jinja_env.globals[name or f.__name__] = f self.record_once(register_template) - def before_app_request(self, f): + def before_app_request(self, f: BeforeRequestCallable) -> BeforeRequestCallable: """Like :meth:`Flask.before_request`. Such a function is executed before each request, even if outside of a blueprint. """ @@ -442,7 +484,9 @@ def before_app_request(self, f): ) return f - def before_app_first_request(self, f): + def before_app_first_request( + self, f: BeforeRequestCallable + ) -> BeforeRequestCallable: """Like :meth:`Flask.before_first_request`. Such a function is executed before the first request to the application. """ @@ -451,7 +495,7 @@ def before_app_first_request(self, f): ) return f - def after_app_request(self, f): + def after_app_request(self, f: AfterRequestCallable) -> AfterRequestCallable: """Like :meth:`Flask.after_request` but for a blueprint. Such a function is executed after each request, even if outside of the blueprint. """ @@ -462,7 +506,7 @@ def after_app_request(self, f): ) return f - def teardown_app_request(self, f): + def teardown_app_request(self, f: TeardownCallable) -> TeardownCallable: """Like :meth:`Flask.teardown_request` but for a blueprint. Such a function is executed when tearing down each request, even if outside of the blueprint. @@ -472,7 +516,9 @@ def teardown_app_request(self, f): ) return f - def app_context_processor(self, f): + def app_context_processor( + self, f: TemplateContextProcessorCallable + ) -> TemplateContextProcessorCallable: """Like :meth:`Flask.context_processor` but for a blueprint. Such a function is executed each request, even if outside of the blueprint. """ @@ -481,32 +527,34 @@ def app_context_processor(self, f): ) return f - def app_errorhandler(self, code): + def app_errorhandler(self, code: t.Union[t.Type[Exception], int]) -> t.Callable: """Like :meth:`Flask.errorhandler` but for a blueprint. This handler is used for all requests, even if outside of the blueprint. """ - def decorator(f): + def decorator(f: ErrorHandlerCallable) -> ErrorHandlerCallable: self.record_once(lambda s: s.app.errorhandler(code)(f)) return f return decorator - def app_url_value_preprocessor(self, f): + def app_url_value_preprocessor( + self, f: URLValuePreprocessorCallable + ) -> URLValuePreprocessorCallable: """Same as :meth:`url_value_preprocessor` but application wide.""" self.record_once( lambda s: s.app.url_value_preprocessors.setdefault(None, []).append(f) ) return f - def app_url_defaults(self, f): + def app_url_defaults(self, f: URLDefaultCallable) -> URLDefaultCallable: """Same as :meth:`url_defaults` but application wide.""" self.record_once( lambda s: s.app.url_default_functions.setdefault(None, []).append(f) ) return f - def ensure_sync(self, f): + def ensure_sync(self, f: t.Callable) -> t.Callable: """Ensure the function is synchronous. Override if you would like custom async to sync behaviour in diff --git a/src/flask/cli.py b/src/flask/cli.py index 79a9a7c4d2..987c95cdd7 100644 --- a/src/flask/cli.py +++ b/src/flask/cli.py @@ -27,7 +27,7 @@ try: import ssl except ImportError: - ssl = None + ssl = None # type: ignore class NoAppException(click.UsageError): @@ -860,7 +860,7 @@ def run_command( @click.command("shell", short_help="Run a shell in the app context.") @with_appcontext -def shell_command(): +def shell_command() -> None: """Run an interactive Python shell in the context of a given Flask application. The application will populate the default namespace of this shell according to its configuration. @@ -877,7 +877,7 @@ def shell_command(): f"App: {app.import_name} [{app.env}]\n" f"Instance: {app.instance_path}" ) - ctx = {} + ctx: dict = {} # Support the regular Python interpreter startup script if someone # is using it. @@ -922,7 +922,7 @@ def shell_command(): ) @click.option("--all-methods", is_flag=True, help="Show HEAD and OPTIONS methods.") @with_appcontext -def routes_command(sort, all_methods): +def routes_command(sort: str, all_methods: bool) -> None: """Show all registered routes with endpoints and methods.""" rules = list(current_app.url_map.iter_rules()) @@ -935,9 +935,12 @@ def routes_command(sort, all_methods): if sort in ("endpoint", "rule"): rules = sorted(rules, key=attrgetter(sort)) elif sort == "methods": - rules = sorted(rules, key=lambda rule: sorted(rule.methods)) + rules = sorted(rules, key=lambda rule: sorted(rule.methods)) # type: ignore - rule_methods = [", ".join(sorted(rule.methods - ignored_methods)) for rule in rules] + rule_methods = [ + ", ".join(sorted(rule.methods - ignored_methods)) # type: ignore + for rule in rules + ] headers = ("Endpoint", "Methods", "Rule") widths = ( @@ -975,7 +978,7 @@ def routes_command(sort, all_methods): ) -def main(): +def main() -> None: # TODO omit sys.argv once https://github.com/pallets/click/issues/536 is fixed cli.main(args=sys.argv[1:]) diff --git a/src/flask/config.py b/src/flask/config.py index d2dfec2b7f..86f21dc8ac 100644 --- a/src/flask/config.py +++ b/src/flask/config.py @@ -1,6 +1,7 @@ import errno import os import types +import typing as t from werkzeug.utils import import_string @@ -8,11 +9,11 @@ class ConfigAttribute: """Makes an attribute forward to the config""" - def __init__(self, name, get_converter=None): + def __init__(self, name: str, get_converter: t.Optional[t.Callable] = None) -> None: self.__name__ = name self.get_converter = get_converter - def __get__(self, obj, type=None): + def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: if obj is None: return self rv = obj.config[self.__name__] @@ -20,7 +21,7 @@ def __get__(self, obj, type=None): rv = self.get_converter(rv) return rv - def __set__(self, obj, value): + def __set__(self, obj: t.Any, value: t.Any) -> None: obj.config[self.__name__] = value @@ -68,11 +69,11 @@ class Config(dict): :param defaults: an optional dictionary of default values """ - def __init__(self, root_path, defaults=None): + def __init__(self, root_path: str, defaults: t.Optional[dict] = None) -> None: dict.__init__(self, defaults or {}) self.root_path = root_path - def from_envvar(self, variable_name, silent=False): + def from_envvar(self, variable_name: str, silent: bool = False) -> bool: """Loads a configuration from an environment variable pointing to a configuration file. This is basically just a shortcut with nicer error messages for this line of code:: @@ -96,7 +97,7 @@ def from_envvar(self, variable_name, silent=False): ) return self.from_pyfile(rv, silent=silent) - def from_pyfile(self, filename, silent=False): + def from_pyfile(self, filename: str, silent: bool = False) -> bool: """Updates the values in the config from a Python file. This function behaves as if the file was imported as module with the :meth:`from_object` function. @@ -124,7 +125,7 @@ def from_pyfile(self, filename, silent=False): self.from_object(d) return True - def from_object(self, obj): + def from_object(self, obj: t.Union[object, str]) -> None: """Updates the values from the given object. An object can be of one of the following two types: @@ -162,7 +163,12 @@ class and has ``@property`` attributes, it needs to be if key.isupper(): self[key] = getattr(obj, key) - def from_file(self, filename, load, silent=False): + def from_file( + self, + filename: str, + load: t.Callable[[t.IO[t.Any]], t.Mapping], + silent: bool = False, + ) -> bool: """Update the values in the config from a file that is loaded using the ``load`` parameter. The loaded data is passed to the :meth:`from_mapping` method. @@ -196,30 +202,26 @@ def from_file(self, filename, load, silent=False): return self.from_mapping(obj) - def from_mapping(self, *mapping, **kwargs): + def from_mapping( + self, mapping: t.Optional[t.Mapping[str, t.Any]] = None, **kwargs: t.Any + ) -> bool: """Updates the config like :meth:`update` ignoring items with non-upper keys. .. versionadded:: 0.11 """ - mappings = [] - if len(mapping) == 1: - if hasattr(mapping[0], "items"): - mappings.append(mapping[0].items()) - else: - mappings.append(mapping[0]) - elif len(mapping) > 1: - raise TypeError( - f"expected at most 1 positional argument, got {len(mapping)}" - ) - mappings.append(kwargs.items()) - for mapping in mappings: - for (key, value) in mapping: - if key.isupper(): - self[key] = value + mappings: t.Dict[str, t.Any] = {} + if mapping is not None: + mappings.update(mapping) + mappings.update(kwargs) + for key, value in mappings.items(): + if key.isupper(): + self[key] = value return True - def get_namespace(self, namespace, lowercase=True, trim_namespace=True): + def get_namespace( + self, namespace: str, lowercase: bool = True, trim_namespace: bool = True + ) -> t.Dict[str, t.Any]: """Returns a dictionary containing a subset of configuration options that match the specified namespace/prefix. Example usage:: @@ -260,5 +262,5 @@ def get_namespace(self, namespace, lowercase=True, trim_namespace=True): rv[key] = v return rv - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} {dict.__repr__(self)}>" diff --git a/src/flask/ctx.py b/src/flask/ctx.py index f9cb87d223..70de8cad6a 100644 --- a/src/flask/ctx.py +++ b/src/flask/ctx.py @@ -1,5 +1,7 @@ import sys +import typing as t from functools import update_wrapper +from types import TracebackType from werkzeug.exceptions import HTTPException @@ -7,6 +9,12 @@ from .globals import _request_ctx_stack from .signals import appcontext_popped from .signals import appcontext_pushed +from .typing import AfterRequestCallable + +if t.TYPE_CHECKING: + from .app import Flask + from .sessions import SessionMixin + from .wrappers import Request # a singleton sentinel value for parameter defaults @@ -33,7 +41,7 @@ class _AppCtxGlobals: .. versionadded:: 0.10 """ - def get(self, name, default=None): + def get(self, name: str, default: t.Optional[t.Any] = None) -> t.Any: """Get an attribute by name, or a default value. Like :meth:`dict.get`. @@ -44,7 +52,7 @@ def get(self, name, default=None): """ return self.__dict__.get(name, default) - def pop(self, name, default=_sentinel): + def pop(self, name: str, default: t.Any = _sentinel) -> t.Any: """Get and remove an attribute by name. Like :meth:`dict.pop`. :param name: Name of attribute to pop. @@ -58,7 +66,7 @@ def pop(self, name, default=_sentinel): else: return self.__dict__.pop(name, default) - def setdefault(self, name, default=None): + def setdefault(self, name: str, default: t.Any = None) -> t.Any: """Get the value of an attribute if it is present, otherwise set and return a default value. Like :meth:`dict.setdefault`. @@ -70,20 +78,20 @@ def setdefault(self, name, default=None): """ return self.__dict__.setdefault(name, default) - def __contains__(self, item): + def __contains__(self, item: t.Any) -> bool: return item in self.__dict__ - def __iter__(self): + def __iter__(self) -> t.Iterator: return iter(self.__dict__) - def __repr__(self): + def __repr__(self) -> str: top = _app_ctx_stack.top if top is not None: return f"" return object.__repr__(self) -def after_this_request(f): +def after_this_request(f: AfterRequestCallable) -> AfterRequestCallable: """Executes a function after this request. This is useful to modify response objects. The function is passed the response object and has to return the same or a new one. @@ -108,7 +116,7 @@ def add_header(response): return f -def copy_current_request_context(f): +def copy_current_request_context(f: t.Callable) -> t.Callable: """A helper function that decorates a function to retain the current request context. This is useful when working with greenlets. The moment the function is decorated a copy of the request context is created and @@ -148,7 +156,7 @@ def wrapper(*args, **kwargs): return update_wrapper(wrapper, f) -def has_request_context(): +def has_request_context() -> bool: """If you have code that wants to test if a request context is there or not this function can be used. For instance, you may want to take advantage of request information if the request object is available, but fail @@ -180,7 +188,7 @@ def __init__(self, username, remote_addr=None): return _request_ctx_stack.top is not None -def has_app_context(): +def has_app_context() -> bool: """Works like :func:`has_request_context` but for the application context. You can also just do a boolean check on the :data:`current_app` object instead. @@ -199,7 +207,7 @@ class AppContext: context. """ - def __init__(self, app): + def __init__(self, app: "Flask") -> None: self.app = app self.url_adapter = app.create_url_adapter(None) self.g = app.app_ctx_globals_class() @@ -208,13 +216,13 @@ def __init__(self, app): # but there a basic "refcount" is enough to track them. self._refcnt = 0 - def push(self): + def push(self) -> None: """Binds the app context to the current context.""" self._refcnt += 1 _app_ctx_stack.push(self) appcontext_pushed.send(self.app) - def pop(self, exc=_sentinel): + def pop(self, exc: t.Optional[BaseException] = _sentinel) -> None: # type: ignore """Pops the app context.""" try: self._refcnt -= 1 @@ -227,11 +235,13 @@ def pop(self, exc=_sentinel): assert rv is self, f"Popped wrong app context. ({rv!r} instead of {self!r})" appcontext_popped.send(self.app) - def __enter__(self): + def __enter__(self) -> "AppContext": self.push() return self - def __exit__(self, exc_type, exc_value, tb): + def __exit__( + self, exc_type: type, exc_value: BaseException, tb: TracebackType + ) -> None: self.pop(exc_value) @@ -265,7 +275,13 @@ class RequestContext: that situation, otherwise your unittests will leak memory. """ - def __init__(self, app, environ, request=None, session=None): + def __init__( + self, + app: "Flask", + environ: dict, + request: t.Optional["Request"] = None, + session: t.Optional["SessionMixin"] = None, + ) -> None: self.app = app if request is None: request = app.request_class(environ) @@ -282,7 +298,7 @@ def __init__(self, app, environ, request=None, session=None): # other request contexts. Now only if the last level is popped we # get rid of them. Additionally if an application context is missing # one is created implicitly so for each level we add this information - self._implicit_app_ctx_stack = [] + self._implicit_app_ctx_stack: t.List[t.Optional["AppContext"]] = [] # indicator if the context was preserved. Next time another context # is pushed the preserved context is popped. @@ -295,17 +311,17 @@ def __init__(self, app, environ, request=None, session=None): # Functions that should be executed after the request on the response # object. These will be called before the regular "after_request" # functions. - self._after_request_functions = [] + self._after_request_functions: t.List[AfterRequestCallable] = [] @property - def g(self): + def g(self) -> AppContext: return _app_ctx_stack.top.g @g.setter - def g(self, value): + def g(self, value: AppContext) -> None: _app_ctx_stack.top.g = value - def copy(self): + def copy(self) -> "RequestContext": """Creates a copy of this request context with the same request object. This can be used to move a request context to a different greenlet. Because the actual request object is the same this cannot be used to @@ -325,17 +341,17 @@ def copy(self): session=self.session, ) - def match_request(self): + def match_request(self) -> None: """Can be overridden by a subclass to hook into the matching of the request. """ try: - result = self.url_adapter.match(return_rule=True) - self.request.url_rule, self.request.view_args = result + result = self.url_adapter.match(return_rule=True) # type: ignore + self.request.url_rule, self.request.view_args = result # type: ignore except HTTPException as e: self.request.routing_exception = e - def push(self): + def push(self) -> None: """Binds the request context to the current context.""" # If an exception occurs in debug mode or if context preservation is # activated under exception situations exactly one context stays @@ -375,7 +391,7 @@ def push(self): if self.session is None: self.session = session_interface.make_null_session(self.app) - def pop(self, exc=_sentinel): + def pop(self, exc: t.Optional[BaseException] = _sentinel) -> None: # type: ignore """Pops the request context and unbinds it by doing that. This will also trigger the execution of functions registered by the :meth:`~flask.Flask.teardown_request` decorator. @@ -414,20 +430,22 @@ def pop(self, exc=_sentinel): rv is self ), f"Popped wrong request context. ({rv!r} instead of {self!r})" - def auto_pop(self, exc): + def auto_pop(self, exc: t.Optional[BaseException]) -> None: if self.request.environ.get("flask._preserve_context") or ( exc is not None and self.app.preserve_context_on_exception ): self.preserved = True - self._preserved_exc = exc + self._preserved_exc = exc # type: ignore else: self.pop(exc) - def __enter__(self): + def __enter__(self) -> "RequestContext": self.push() return self - def __exit__(self, exc_type, exc_value, tb): + def __exit__( + self, exc_type: type, exc_value: BaseException, tb: TracebackType + ) -> None: # do not pop the request stack if we are in debug mode and an # exception happened. This will allow the debugger to still # access the request object in the interactive shell. Furthermore @@ -435,7 +453,7 @@ def __exit__(self, exc_type, exc_value, tb): # See flask.testing for how this works. self.auto_pop(exc_value) - def __repr__(self): + def __repr__(self) -> str: return ( f"<{type(self).__name__} {self.request.url!r}" f" [{self.request.method}] of {self.app.name}>" diff --git a/src/flask/debughelpers.py b/src/flask/debughelpers.py index 4bd85bc50b..ce65c487c8 100644 --- a/src/flask/debughelpers.py +++ b/src/flask/debughelpers.py @@ -1,4 +1,5 @@ import os +import typing as t from warnings import warn from .app import Flask @@ -92,7 +93,7 @@ def __getitem__(self, key): request.files.__class__ = newcls -def _dump_loader_info(loader): +def _dump_loader_info(loader) -> t.Generator: yield f"class: {type(loader).__module__}.{type(loader).__name__}" for key, value in sorted(loader.__dict__.items()): if key.startswith("_"): @@ -109,7 +110,7 @@ def _dump_loader_info(loader): yield f"{key}: {value!r}" -def explain_template_loading_attempts(app, template, attempts): +def explain_template_loading_attempts(app: Flask, template, attempts) -> None: """This should help developers understand what failed""" info = [f"Locating template {template!r}:"] total_found = 0 @@ -157,7 +158,7 @@ def explain_template_loading_attempts(app, template, attempts): app.logger.info("\n".join(info)) -def explain_ignored_app_run(): +def explain_ignored_app_run() -> None: if os.environ.get("WERKZEUG_RUN_MAIN") != "true": warn( Warning( diff --git a/src/flask/globals.py b/src/flask/globals.py index d46ccb4160..5e6e8c751d 100644 --- a/src/flask/globals.py +++ b/src/flask/globals.py @@ -1,8 +1,14 @@ +import typing as t from functools import partial from werkzeug.local import LocalProxy from werkzeug.local import LocalStack +if t.TYPE_CHECKING: + from .app import Flask + from .ctx import AppContext + from .sessions import SessionMixin + from .wrappers import Request _request_ctx_err_msg = """\ Working outside of request context. @@ -45,7 +51,7 @@ def _find_app(): # context locals _request_ctx_stack = LocalStack() _app_ctx_stack = LocalStack() -current_app = LocalProxy(_find_app) -request = LocalProxy(partial(_lookup_req_object, "request")) -session = LocalProxy(partial(_lookup_req_object, "session")) -g = LocalProxy(partial(_lookup_app_object, "g")) +current_app: "Flask" = LocalProxy(_find_app) # type: ignore +request: "Request" = LocalProxy(partial(_lookup_req_object, "request")) # type: ignore +session: "SessionMixin" = LocalProxy(partial(_lookup_req_object, "session")) # type: ignore # noqa: B950 +g: "AppContext" = LocalProxy(partial(_lookup_app_object, "g")) # type: ignore diff --git a/src/flask/helpers.py b/src/flask/helpers.py index 6a6bbcf118..99594fcef1 100644 --- a/src/flask/helpers.py +++ b/src/flask/helpers.py @@ -1,6 +1,8 @@ import os import socket +import typing as t import warnings +from datetime import timedelta from functools import update_wrapper from functools import wraps from threading import RLock @@ -18,8 +20,11 @@ from .globals import session from .signals import message_flashed +if t.TYPE_CHECKING: + from .wrappers import Response -def get_env(): + +def get_env() -> str: """Get the environment the app is running in, indicated by the :envvar:`FLASK_ENV` environment variable. The default is ``'production'``. @@ -27,7 +32,7 @@ def get_env(): return os.environ.get("FLASK_ENV") or "production" -def get_debug_flag(): +def get_debug_flag() -> bool: """Get whether debug mode should be enabled for the app, indicated by the :envvar:`FLASK_DEBUG` environment variable. The default is ``True`` if :func:`.get_env` returns ``'development'``, or ``False`` @@ -41,7 +46,7 @@ def get_debug_flag(): return val.lower() not in ("0", "false", "no") -def get_load_dotenv(default=True): +def get_load_dotenv(default: bool = True) -> bool: """Get whether the user has disabled loading dotenv files by setting :envvar:`FLASK_SKIP_DOTENV`. The default is ``True``, load the files. @@ -56,7 +61,9 @@ def get_load_dotenv(default=True): return val.lower() in ("0", "false", "no") -def stream_with_context(generator_or_function): +def stream_with_context( + generator_or_function: t.Union[t.Generator, t.Callable] +) -> t.Generator: """Request contexts disappear when the response is started on the server. This is done for efficiency reasons and to make it less likely to encounter memory leaks with badly written WSGI middlewares. The downside is that if @@ -91,16 +98,16 @@ def generate(): .. versionadded:: 0.9 """ try: - gen = iter(generator_or_function) + gen = iter(generator_or_function) # type: ignore except TypeError: - def decorator(*args, **kwargs): - gen = generator_or_function(*args, **kwargs) + def decorator(*args: t.Any, **kwargs: t.Any) -> t.Any: + gen = generator_or_function(*args, **kwargs) # type: ignore return stream_with_context(gen) - return update_wrapper(decorator, generator_or_function) + return update_wrapper(decorator, generator_or_function) # type: ignore - def generator(): + def generator() -> t.Generator: ctx = _request_ctx_stack.top if ctx is None: raise RuntimeError( @@ -120,7 +127,7 @@ def generator(): yield from gen finally: if hasattr(gen, "close"): - gen.close() + gen.close() # type: ignore # The trick is to start the generator. Then the code execution runs until # the first dummy None is yielded at which point the context was already @@ -131,7 +138,7 @@ def generator(): return wrapped_g -def make_response(*args): +def make_response(*args: t.Any) -> "Response": """Sometimes it is necessary to set additional headers in a view. Because views do not have to return response objects but can return a value that is converted into a response object by Flask itself, it becomes tricky to @@ -180,7 +187,7 @@ def index(): return current_app.make_response(args) -def url_for(endpoint, **values): +def url_for(endpoint: str, **values: t.Any) -> str: """Generates a URL to the given endpoint with the method provided. Variable arguments that are unknown to the target endpoint are appended @@ -331,7 +338,7 @@ def external_url_handler(error, endpoint, values): return rv -def get_template_attribute(template_name, attribute): +def get_template_attribute(template_name: str, attribute: str) -> t.Any: """Loads a macro (or variable) a template exports. This can be used to invoke a macro from within Python code. If you for example have a template named :file:`_cider.html` with the following contents: @@ -353,7 +360,7 @@ def get_template_attribute(template_name, attribute): return getattr(current_app.jinja_env.get_template(template_name).module, attribute) -def flash(message, category="message"): +def flash(message: str, category: str = "message") -> None: """Flashes a message to the next request. In order to remove the flashed message from the session and to display it to the user, the template has to call :func:`get_flashed_messages`. @@ -379,11 +386,15 @@ def flash(message, category="message"): flashes.append((category, message)) session["_flashes"] = flashes message_flashed.send( - current_app._get_current_object(), message=message, category=category + current_app._get_current_object(), # type: ignore + message=message, + category=category, ) -def get_flashed_messages(with_categories=False, category_filter=()): +def get_flashed_messages( + with_categories: bool = False, category_filter: t.Iterable[str] = () +) -> t.Union[t.List[str], t.List[t.Tuple[str, str]]]: """Pulls all flashed messages from the session and returns them. Further calls in the same request to the function will return the same messages. By default just the messages are returned, @@ -608,7 +619,7 @@ def send_file( ) -def safe_join(directory, *pathnames): +def safe_join(directory: str, *pathnames: str) -> str: """Safely join zero or more untrusted path components to a base directory to avoid escaping the base directory. @@ -631,7 +642,7 @@ def safe_join(directory, *pathnames): return path -def send_from_directory(directory, path, **kwargs): +def send_from_directory(directory: str, path: str, **kwargs: t.Any) -> "Response": """Send a file from within a directory using :func:`send_file`. .. code-block:: python @@ -661,7 +672,7 @@ def download_file(name): .. versionadded:: 0.5 """ - return werkzeug.utils.send_from_directory( + return werkzeug.utils.send_from_directory( # type: ignore directory, path, **_prepare_send_file_kwargs(**kwargs) ) @@ -675,27 +686,32 @@ class locked_cached_property(werkzeug.utils.cached_property): Inherits from Werkzeug's ``cached_property`` (and ``property``). """ - def __init__(self, fget, name=None, doc=None): + def __init__( + self, + fget: t.Callable[[t.Any], t.Any], + name: t.Optional[str] = None, + doc: t.Optional[str] = None, + ) -> None: super().__init__(fget, name=name, doc=doc) self.lock = RLock() - def __get__(self, obj, type=None): + def __get__(self, obj: object, type: type = None) -> t.Any: # type: ignore if obj is None: return self with self.lock: return super().__get__(obj, type=type) - def __set__(self, obj, value): + def __set__(self, obj: object, value: t.Any) -> None: with self.lock: super().__set__(obj, value) - def __delete__(self, obj): + def __delete__(self, obj: object) -> None: with self.lock: super().__delete__(obj) -def total_seconds(td): +def total_seconds(td: timedelta) -> int: """Returns the total seconds from a timedelta object. :param timedelta td: the timedelta to be converted in seconds @@ -716,7 +732,7 @@ def total_seconds(td): return td.days * 60 * 60 * 24 + td.seconds -def is_ip(value): +def is_ip(value: str) -> bool: """Determine if the given string is an IP address. :param value: value to check @@ -736,7 +752,7 @@ def is_ip(value): return False -def run_async(func): +def run_async(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]: """Return a sync function that will run the coroutine function *func*.""" try: from asgiref.sync import async_to_sync @@ -752,7 +768,7 @@ def run_async(func): ) @wraps(func) - def outer(*args, **kwargs): + def outer(*args: t.Any, **kwargs: t.Any) -> t.Any: """This function grabs the current context for the inner function. This is similar to the copy_current_xxx_context functions in the @@ -764,7 +780,7 @@ def outer(*args, **kwargs): ctx = _request_ctx_stack.top.copy() @wraps(func) - async def inner(*a, **k): + async def inner(*a: t.Any, **k: t.Any) -> t.Any: """This restores the context before awaiting the func. This is required as the function must be awaited within the @@ -780,5 +796,5 @@ async def inner(*a, **k): return async_to_sync(inner)(*args, **kwargs) - outer._flask_sync_wrapper = True + outer._flask_sync_wrapper = True # type: ignore return outer diff --git a/src/flask/json/__init__.py b/src/flask/json/__init__.py index 7ca0db90b0..5a6e4942fe 100644 --- a/src/flask/json/__init__.py +++ b/src/flask/json/__init__.py @@ -1,20 +1,25 @@ import io import json as _json +import typing as t import uuid import warnings from datetime import date -from jinja2.utils import htmlsafe_json_dumps as _jinja_htmlsafe_dumps +from jinja2.utils import htmlsafe_json_dumps as _jinja_htmlsafe_dumps # type: ignore from werkzeug.http import http_date from ..globals import current_app from ..globals import request +if t.TYPE_CHECKING: + from ..app import Flask + from ..wrappers import Response + try: import dataclasses except ImportError: # Python < 3.7 - dataclasses = None + dataclasses = None # type: ignore class JSONEncoder(_json.JSONEncoder): @@ -34,7 +39,7 @@ class JSONEncoder(_json.JSONEncoder): :attr:`flask.Blueprint.json_encoder` to override the default. """ - def default(self, o): + def default(self, o: t.Any) -> t.Any: """Convert ``o`` to a JSON serializable type. See :meth:`json.JSONEncoder.default`. Python does not support overriding how basic types like ``str`` or ``list`` are @@ -48,7 +53,7 @@ def default(self, o): return dataclasses.asdict(o) if hasattr(o, "__html__"): return str(o.__html__()) - return super().default(self, o) + return super().default(o) class JSONDecoder(_json.JSONDecoder): @@ -62,14 +67,19 @@ class JSONDecoder(_json.JSONDecoder): """ -def _dump_arg_defaults(kwargs, app=None): +def _dump_arg_defaults( + kwargs: t.Dict[str, t.Any], app: t.Optional["Flask"] = None +) -> None: """Inject default arguments for dump functions.""" if app is None: app = current_app if app: - bp = app.blueprints.get(request.blueprint) if request else None - cls = bp.json_encoder if bp and bp.json_encoder else app.json_encoder + cls = app.json_encoder + bp = app.blueprints.get(request.blueprint) if request else None # type: ignore + if bp is not None and bp.json_encoder is not None: + cls = bp.json_encoder + kwargs.setdefault("cls", cls) kwargs.setdefault("ensure_ascii", app.config["JSON_AS_ASCII"]) kwargs.setdefault("sort_keys", app.config["JSON_SORT_KEYS"]) @@ -78,20 +88,25 @@ def _dump_arg_defaults(kwargs, app=None): kwargs.setdefault("cls", JSONEncoder) -def _load_arg_defaults(kwargs, app=None): +def _load_arg_defaults( + kwargs: t.Dict[str, t.Any], app: t.Optional["Flask"] = None +) -> None: """Inject default arguments for load functions.""" if app is None: app = current_app if app: - bp = app.blueprints.get(request.blueprint) if request else None - cls = bp.json_decoder if bp and bp.json_decoder else app.json_decoder + cls = app.json_decoder + bp = app.blueprints.get(request.blueprint) if request else None # type: ignore + if bp is not None and bp.json_decoder is not None: + cls = bp.json_decoder + kwargs.setdefault("cls", cls) else: kwargs.setdefault("cls", JSONDecoder) -def dumps(obj, app=None, **kwargs): +def dumps(obj: t.Any, app: t.Optional["Flask"] = None, **kwargs: t.Any) -> str: """Serialize an object to a string of JSON. Takes the same arguments as the built-in :func:`json.dumps`, with @@ -121,12 +136,14 @@ def dumps(obj, app=None, **kwargs): ) if isinstance(rv, str): - return rv.encode(encoding) + return rv.encode(encoding) # type: ignore return rv -def dump(obj, fp, app=None, **kwargs): +def dump( + obj: t.Any, fp: t.IO[str], app: t.Optional["Flask"] = None, **kwargs: t.Any +) -> None: """Serialize an object to JSON written to a file object. Takes the same arguments as the built-in :func:`json.dump`, with @@ -150,7 +167,7 @@ def dump(obj, fp, app=None, **kwargs): fp.write("") except TypeError: show_warning = True - fp = io.TextIOWrapper(fp, encoding or "utf-8") + fp = io.TextIOWrapper(fp, encoding or "utf-8") # type: ignore if show_warning: warnings.warn( @@ -163,7 +180,7 @@ def dump(obj, fp, app=None, **kwargs): _json.dump(obj, fp, **kwargs) -def loads(s, app=None, **kwargs): +def loads(s: str, app: t.Optional["Flask"] = None, **kwargs: t.Any) -> t.Any: """Deserialize an object from a string of JSON. Takes the same arguments as the built-in :func:`json.loads`, with @@ -199,7 +216,7 @@ def loads(s, app=None, **kwargs): return _json.loads(s, **kwargs) -def load(fp, app=None, **kwargs): +def load(fp: t.IO[str], app: t.Optional["Flask"] = None, **kwargs: t.Any) -> t.Any: """Deserialize an object from JSON read from a file object. Takes the same arguments as the built-in :func:`json.load`, with @@ -227,12 +244,12 @@ def load(fp, app=None, **kwargs): ) if isinstance(fp.read(0), bytes): - fp = io.TextIOWrapper(fp, encoding) + fp = io.TextIOWrapper(fp, encoding) # type: ignore return _json.load(fp, **kwargs) -def htmlsafe_dumps(obj, **kwargs): +def htmlsafe_dumps(obj: t.Any, **kwargs: t.Any) -> str: """Serialize an object to a string of JSON with :func:`dumps`, then replace HTML-unsafe characters with Unicode escapes and mark the result safe with :class:`~markupsafe.Markup`. @@ -256,7 +273,7 @@ def htmlsafe_dumps(obj, **kwargs): return _jinja_htmlsafe_dumps(obj, dumps=dumps, **kwargs) -def htmlsafe_dump(obj, fp, **kwargs): +def htmlsafe_dump(obj: t.Any, fp: t.IO[str], **kwargs: t.Any) -> None: """Serialize an object to JSON written to a file object, replacing HTML-unsafe characters with Unicode escapes. See :func:`htmlsafe_dumps` and :func:`dumps`. @@ -264,7 +281,7 @@ def htmlsafe_dump(obj, fp, **kwargs): fp.write(htmlsafe_dumps(obj, **kwargs)) -def jsonify(*args, **kwargs): +def jsonify(*args: t.Any, **kwargs: t.Any) -> "Response": """Serialize data to JSON and wrap it in a :class:`~flask.Response` with the :mimetype:`application/json` mimetype. diff --git a/src/flask/json/tag.py b/src/flask/json/tag.py index d3c29adbf9..97f365a9b0 100644 --- a/src/flask/json/tag.py +++ b/src/flask/json/tag.py @@ -40,6 +40,7 @@ def to_python(self, value): app.session_interface.serializer.register(TagOrderedDict, index=0) """ +import typing as t from base64 import b64decode from base64 import b64encode from datetime import datetime @@ -60,27 +61,27 @@ class JSONTag: #: The tag to mark the serialized object with. If ``None``, this tag is #: only used as an intermediate step during tagging. - key = None + key: t.Optional[str] = None - def __init__(self, serializer): + def __init__(self, serializer: "TaggedJSONSerializer") -> None: """Create a tagger for the given serializer.""" self.serializer = serializer - def check(self, value): + def check(self, value: t.Any) -> bool: """Check if the given value should be tagged by this tag.""" raise NotImplementedError - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: """Convert the Python object to an object that is a valid JSON type. The tag will be added later.""" raise NotImplementedError - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: """Convert the JSON representation back to the correct type. The tag will already be removed.""" raise NotImplementedError - def tag(self, value): + def tag(self, value: t.Any) -> t.Any: """Convert the value to a valid JSON type and add the tag structure around it.""" return {self.key: self.to_json(value)} @@ -96,18 +97,18 @@ class TagDict(JSONTag): __slots__ = () key = " di" - def check(self, value): + def check(self, value: t.Any) -> bool: return ( isinstance(value, dict) and len(value) == 1 and next(iter(value)) in self.serializer.tags ) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: key = next(iter(value)) return {f"{key}__": self.serializer.tag(value[key])} - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: key = next(iter(value)) return {key[:-2]: value[key]} @@ -115,10 +116,10 @@ def to_python(self, value): class PassDict(JSONTag): __slots__ = () - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, dict) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: # JSON objects may only have string keys, so don't bother tagging the # key here. return {k: self.serializer.tag(v) for k, v in value.items()} @@ -130,23 +131,23 @@ class TagTuple(JSONTag): __slots__ = () key = " t" - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, tuple) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return [self.serializer.tag(item) for item in value] - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: return tuple(value) class PassList(JSONTag): __slots__ = () - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, list) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return [self.serializer.tag(item) for item in value] tag = to_json @@ -156,13 +157,13 @@ class TagBytes(JSONTag): __slots__ = () key = " b" - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, bytes) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return b64encode(value).decode("ascii") - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: return b64decode(value) @@ -174,13 +175,13 @@ class TagMarkup(JSONTag): __slots__ = () key = " m" - def check(self, value): + def check(self, value: t.Any) -> bool: return callable(getattr(value, "__html__", None)) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return str(value.__html__()) - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: return Markup(value) @@ -188,13 +189,13 @@ class TagUUID(JSONTag): __slots__ = () key = " u" - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, UUID) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return value.hex - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: return UUID(value) @@ -202,13 +203,13 @@ class TagDateTime(JSONTag): __slots__ = () key = " d" - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, datetime) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return http_date(value) - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: return parse_date(value) @@ -242,14 +243,19 @@ class TaggedJSONSerializer: TagDateTime, ] - def __init__(self): - self.tags = {} - self.order = [] + def __init__(self) -> None: + self.tags: t.Dict[str, JSONTag] = {} + self.order: t.List[JSONTag] = [] for cls in self.default_tags: self.register(cls) - def register(self, tag_class, force=False, index=None): + def register( + self, + tag_class: t.Type[JSONTag], + force: bool = False, + index: t.Optional[int] = None, + ) -> None: """Register a new tag with this serializer. :param tag_class: tag class to register. Will be instantiated with this @@ -277,7 +283,7 @@ def register(self, tag_class, force=False, index=None): else: self.order.insert(index, tag) - def tag(self, value): + def tag(self, value: t.Any) -> t.Dict[str, t.Any]: """Convert a value to a tagged representation if necessary.""" for tag in self.order: if tag.check(value): @@ -285,7 +291,7 @@ def tag(self, value): return value - def untag(self, value): + def untag(self, value: t.Dict[str, t.Any]) -> t.Any: """Convert a tagged representation back to the original type.""" if len(value) != 1: return value @@ -297,10 +303,10 @@ def untag(self, value): return self.tags[key].to_python(value[key]) - def dumps(self, value): + def dumps(self, value: t.Any) -> str: """Tag the value and dump it to a compact JSON string.""" return dumps(self.tag(value), separators=(",", ":")) - def loads(self, value): + def loads(self, value: str) -> t.Any: """Load data from a JSON string and deserialized any tagged objects.""" return loads(value, object_hook=self.untag) diff --git a/src/flask/logging.py b/src/flask/logging.py index fe6809b226..48a5b7ff4c 100644 --- a/src/flask/logging.py +++ b/src/flask/logging.py @@ -1,13 +1,17 @@ import logging import sys +import typing as t from werkzeug.local import LocalProxy from .globals import request +if t.TYPE_CHECKING: + from .app import Flask + @LocalProxy -def wsgi_errors_stream(): +def wsgi_errors_stream() -> t.TextIO: """Find the most appropriate error stream for the application. If a request is active, log to ``wsgi.errors``, otherwise use ``sys.stderr``. @@ -19,7 +23,7 @@ def wsgi_errors_stream(): return request.environ["wsgi.errors"] if request else sys.stderr -def has_level_handler(logger): +def has_level_handler(logger: logging.Logger) -> bool: """Check if there is a handler in the logging chain that will handle the given logger's :meth:`effective level <~logging.Logger.getEffectiveLevel>`. """ @@ -33,20 +37,20 @@ def has_level_handler(logger): if not current.propagate: break - current = current.parent + current = current.parent # type: ignore return False #: Log messages to :func:`~flask.logging.wsgi_errors_stream` with the format #: ``[%(asctime)s] %(levelname)s in %(module)s: %(message)s``. -default_handler = logging.StreamHandler(wsgi_errors_stream) +default_handler = logging.StreamHandler(wsgi_errors_stream) # type: ignore default_handler.setFormatter( logging.Formatter("[%(asctime)s] %(levelname)s in %(module)s: %(message)s") ) -def create_logger(app): +def create_logger(app: "Flask") -> logging.Logger: """Get the Flask app's logger and configure it if needed. The logger name will be the same as diff --git a/src/flask/scaffold.py b/src/flask/scaffold.py index 44745b7dbe..445ac50007 100644 --- a/src/flask/scaffold.py +++ b/src/flask/scaffold.py @@ -2,8 +2,11 @@ import os import pkgutil import sys +import typing as t from collections import defaultdict from functools import update_wrapper +from json import JSONDecoder +from json import JSONEncoder from jinja2 import FileSystemLoader from werkzeug.exceptions import default_exceptions @@ -14,17 +17,28 @@ from .helpers import locked_cached_property from .helpers import send_from_directory from .templating import _default_template_ctx_processor +from .typing import AfterRequestCallable +from .typing import AppOrBlueprintKey +from .typing import BeforeRequestCallable +from .typing import ErrorHandlerCallable +from .typing import TeardownCallable +from .typing import TemplateContextProcessorCallable +from .typing import URLDefaultCallable +from .typing import URLValuePreprocessorCallable + +if t.TYPE_CHECKING: + from .wrappers import Response # a singleton sentinel value for parameter defaults _sentinel = object() -def setupmethod(f): +def setupmethod(f: t.Callable) -> t.Callable: """Wraps a method so that it performs a check in debug mode if the first request was already handled. """ - def wrapper_func(self, *args, **kwargs): + def wrapper_func(self, *args: t.Any, **kwargs: t.Any) -> t.Any: if self._is_setup_finished(): raise AssertionError( "A setup function was called after the first request " @@ -60,24 +74,24 @@ class Scaffold: """ name: str - _static_folder = None - _static_url_path = None + _static_folder: t.Optional[str] = None + _static_url_path: t.Optional[str] = None #: JSON encoder class used by :func:`flask.json.dumps`. If a #: blueprint sets this, it will be used instead of the app's value. - json_encoder = None + json_encoder: t.Optional[t.Type[JSONEncoder]] = None #: JSON decoder class used by :func:`flask.json.loads`. If a #: blueprint sets this, it will be used instead of the app's value. - json_decoder = None + json_decoder: t.Optional[t.Type[JSONDecoder]] = None def __init__( self, - import_name, - static_folder=None, - static_url_path=None, - template_folder=None, - root_path=None, + import_name: str, + static_folder: t.Optional[str] = None, + static_url_path: t.Optional[str] = None, + template_folder: t.Optional[str] = None, + root_path: t.Optional[str] = None, ): #: The name of the package or module that this object belongs #: to. Do not change this once it is set by the constructor. @@ -110,7 +124,7 @@ def __init__( #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.view_functions = {} + self.view_functions: t.Dict[str, t.Callable] = {} #: A data structure of registered error handlers, in the format #: ``{scope: {code: {class: handler}}}```. The ``scope`` key is @@ -125,7 +139,10 @@ def __init__( #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.error_handler_spec = defaultdict(lambda: defaultdict(dict)) + self.error_handler_spec: t.Dict[ + AppOrBlueprintKey, + t.Dict[t.Optional[int], t.Dict[t.Type[Exception], ErrorHandlerCallable]], + ] = defaultdict(lambda: defaultdict(dict)) #: A data structure of functions to call at the beginning of #: each request, in the format ``{scope: [functions]}``. The @@ -137,7 +154,9 @@ def __init__( #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.before_request_funcs = defaultdict(list) + self.before_request_funcs: t.Dict[ + AppOrBlueprintKey, t.List[BeforeRequestCallable] + ] = defaultdict(list) #: A data structure of functions to call at the end of each #: request, in the format ``{scope: [functions]}``. The @@ -149,7 +168,9 @@ def __init__( #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.after_request_funcs = defaultdict(list) + self.after_request_funcs: t.Dict[ + AppOrBlueprintKey, t.List[AfterRequestCallable] + ] = defaultdict(list) #: A data structure of functions to call at the end of each #: request even if an exception is raised, in the format @@ -162,7 +183,9 @@ def __init__( #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.teardown_request_funcs = defaultdict(list) + self.teardown_request_funcs: t.Dict[ + AppOrBlueprintKey, t.List[TeardownCallable] + ] = defaultdict(list) #: A data structure of functions to call to pass extra context #: values when rendering templates, in the format @@ -175,9 +198,9 @@ def __init__( #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.template_context_processors = defaultdict( - list, {None: [_default_template_ctx_processor]} - ) + self.template_context_processors: t.Dict[ + AppOrBlueprintKey, t.List[TemplateContextProcessorCallable] + ] = defaultdict(list, {None: [_default_template_ctx_processor]}) #: A data structure of functions to call to modify the keyword #: arguments passed to the view function, in the format @@ -190,7 +213,10 @@ def __init__( #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.url_value_preprocessors = defaultdict(list) + self.url_value_preprocessors: t.Dict[ + AppOrBlueprintKey, + t.List[URLValuePreprocessorCallable], + ] = defaultdict(list) #: A data structure of functions to call to modify the keyword #: arguments when generating URLs, in the format @@ -203,31 +229,35 @@ def __init__( #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.url_default_functions = defaultdict(list) + self.url_default_functions: t.Dict[ + AppOrBlueprintKey, t.List[URLDefaultCallable] + ] = defaultdict(list) - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} {self.name!r}>" - def _is_setup_finished(self): + def _is_setup_finished(self) -> bool: raise NotImplementedError @property - def static_folder(self): + def static_folder(self) -> t.Optional[str]: """The absolute path to the configured static folder. ``None`` if no static folder is set. """ if self._static_folder is not None: return os.path.join(self.root_path, self._static_folder) + else: + return None @static_folder.setter - def static_folder(self, value): + def static_folder(self, value: t.Optional[str]) -> None: if value is not None: value = os.fspath(value).rstrip(r"\/") self._static_folder = value @property - def has_static_folder(self): + def has_static_folder(self) -> bool: """``True`` if :attr:`static_folder` is set. .. versionadded:: 0.5 @@ -235,7 +265,7 @@ def has_static_folder(self): return self.static_folder is not None @property - def static_url_path(self): + def static_url_path(self) -> t.Optional[str]: """The URL prefix that the static route will be accessible from. If it was not configured during init, it is derived from @@ -248,14 +278,16 @@ def static_url_path(self): basename = os.path.basename(self.static_folder) return f"/{basename}".rstrip("/") + return None + @static_url_path.setter - def static_url_path(self, value): + def static_url_path(self, value: t.Optional[str]) -> None: if value is not None: value = value.rstrip("/") self._static_url_path = value - def get_send_file_max_age(self, filename): + def get_send_file_max_age(self, filename: str) -> t.Optional[int]: """Used by :func:`send_file` to determine the ``max_age`` cache value for a given file path if it wasn't passed. @@ -276,7 +308,7 @@ def get_send_file_max_age(self, filename): return int(value.total_seconds()) - def send_static_file(self, filename): + def send_static_file(self, filename: str) -> "Response": """The view function used to serve files from :attr:`static_folder`. A route is automatically registered for this view at :attr:`static_url_path` if :attr:`static_folder` is @@ -290,10 +322,12 @@ def send_static_file(self, filename): # send_file only knows to call get_send_file_max_age on the app, # call it here so it works for blueprints too. max_age = self.get_send_file_max_age(filename) - return send_from_directory(self.static_folder, filename, max_age=max_age) + return send_from_directory( + t.cast(str, self.static_folder), filename, max_age=max_age + ) @locked_cached_property - def jinja_loader(self): + def jinja_loader(self) -> t.Optional[FileSystemLoader]: """The Jinja loader for this object's templates. By default this is a class :class:`jinja2.loaders.FileSystemLoader` to :attr:`template_folder` if it is set. @@ -302,8 +336,10 @@ def jinja_loader(self): """ if self.template_folder is not None: return FileSystemLoader(os.path.join(self.root_path, self.template_folder)) + else: + return None - def open_resource(self, resource, mode="rb"): + def open_resource(self, resource: str, mode: str = "rb") -> t.IO[t.AnyStr]: """Open a resource file relative to :attr:`root_path` for reading. @@ -326,48 +362,48 @@ def open_resource(self, resource, mode="rb"): return open(os.path.join(self.root_path, resource), mode) - def _method_route(self, method, rule, options): + def _method_route(self, method: str, rule: str, options: dict) -> t.Callable: if "methods" in options: raise TypeError("Use the 'route' decorator to use the 'methods' argument.") return self.route(rule, methods=[method], **options) - def get(self, rule, **options): + def get(self, rule: str, **options: t.Any) -> t.Callable: """Shortcut for :meth:`route` with ``methods=["GET"]``. .. versionadded:: 2.0 """ return self._method_route("GET", rule, options) - def post(self, rule, **options): + def post(self, rule: str, **options: t.Any) -> t.Callable: """Shortcut for :meth:`route` with ``methods=["POST"]``. .. versionadded:: 2.0 """ return self._method_route("POST", rule, options) - def put(self, rule, **options): + def put(self, rule: str, **options: t.Any) -> t.Callable: """Shortcut for :meth:`route` with ``methods=["PUT"]``. .. versionadded:: 2.0 """ return self._method_route("PUT", rule, options) - def delete(self, rule, **options): + def delete(self, rule: str, **options: t.Any) -> t.Callable: """Shortcut for :meth:`route` with ``methods=["DELETE"]``. .. versionadded:: 2.0 """ return self._method_route("DELETE", rule, options) - def patch(self, rule, **options): + def patch(self, rule: str, **options: t.Any) -> t.Callable: """Shortcut for :meth:`route` with ``methods=["PATCH"]``. .. versionadded:: 2.0 """ return self._method_route("PATCH", rule, options) - def route(self, rule, **options): + def route(self, rule: str, **options: t.Any) -> t.Callable: """Decorate a view function to register it with the given URL rule and options. Calls :meth:`add_url_rule`, which has more details about the implementation. @@ -391,7 +427,7 @@ def index(): :class:`~werkzeug.routing.Rule` object. """ - def decorator(f): + def decorator(f: t.Callable) -> t.Callable: endpoint = options.pop("endpoint", None) self.add_url_rule(rule, endpoint, f, **options) return f @@ -401,12 +437,12 @@ def decorator(f): @setupmethod def add_url_rule( self, - rule, - endpoint=None, - view_func=None, - provide_automatic_options=None, - **options, - ): + rule: str, + endpoint: t.Optional[str] = None, + view_func: t.Optional[t.Callable] = None, + provide_automatic_options: t.Optional[bool] = None, + **options: t.Any, + ) -> t.Callable: """Register a rule for routing incoming requests and building URLs. The :meth:`route` decorator is a shortcut to call this with the ``view_func`` argument. These are equivalent: @@ -466,7 +502,7 @@ def index(): """ raise NotImplementedError - def endpoint(self, endpoint): + def endpoint(self, endpoint: str) -> t.Callable: """Decorate a view function to register it for the given endpoint. Used if a rule is added without a ``view_func`` with :meth:`add_url_rule`. @@ -490,7 +526,7 @@ def decorator(f): return decorator @setupmethod - def before_request(self, f): + def before_request(self, f: BeforeRequestCallable) -> BeforeRequestCallable: """Register a function to run before each request. For example, this can be used to open a database connection, or @@ -512,7 +548,7 @@ def load_user(): return f @setupmethod - def after_request(self, f): + def after_request(self, f: AfterRequestCallable) -> AfterRequestCallable: """Register a function to run after each request to this object. The function is called with the response object, and must return @@ -528,7 +564,7 @@ def after_request(self, f): return f @setupmethod - def teardown_request(self, f): + def teardown_request(self, f: TeardownCallable) -> TeardownCallable: """Register a function to be run at the end of each request, regardless of whether there was an exception or not. These functions are executed when the request context is popped, even if not an @@ -567,13 +603,17 @@ def teardown_request(self, f): return f @setupmethod - def context_processor(self, f): + def context_processor( + self, f: TemplateContextProcessorCallable + ) -> TemplateContextProcessorCallable: """Registers a template context processor function.""" self.template_context_processors[None].append(f) return f @setupmethod - def url_value_preprocessor(self, f): + def url_value_preprocessor( + self, f: URLValuePreprocessorCallable + ) -> URLValuePreprocessorCallable: """Register a URL value preprocessor function for all view functions in the application. These functions will be called before the :meth:`before_request` functions. @@ -590,7 +630,7 @@ def url_value_preprocessor(self, f): return f @setupmethod - def url_defaults(self, f): + def url_defaults(self, f: URLDefaultCallable) -> URLDefaultCallable: """Callback function for URL defaults for all view functions of the application. It's called with the endpoint and values and should update the values passed in place. @@ -599,7 +639,9 @@ def url_defaults(self, f): return f @setupmethod - def errorhandler(self, code_or_exception): + def errorhandler( + self, code_or_exception: t.Union[t.Type[Exception], int] + ) -> t.Callable: """Register a function to handle errors by code or exception class. A decorator that is used to register a function given an @@ -629,14 +671,18 @@ def special_exception_handler(error): an arbitrary exception """ - def decorator(f): + def decorator(f: ErrorHandlerCallable) -> ErrorHandlerCallable: self.register_error_handler(code_or_exception, f) return f return decorator @setupmethod - def register_error_handler(self, code_or_exception, f): + def register_error_handler( + self, + code_or_exception: t.Union[t.Type[Exception], int], + f: ErrorHandlerCallable, + ) -> None: """Alternative error attach function to the :meth:`errorhandler` decorator that is more straightforward to use for non decorator usage. @@ -662,7 +708,9 @@ def register_error_handler(self, code_or_exception, f): self.error_handler_spec[None][code][exc_class] = self.ensure_sync(f) @staticmethod - def _get_exc_class_and_code(exc_class_or_code): + def _get_exc_class_and_code( + exc_class_or_code: t.Union[t.Type[Exception], int] + ) -> t.Tuple[t.Type[Exception], t.Optional[int]]: """Get the exception class being handled. For HTTP status codes or ``HTTPException`` subclasses, return both the exception and status code. @@ -670,6 +718,7 @@ def _get_exc_class_and_code(exc_class_or_code): :param exc_class_or_code: Any exception class, or an HTTP status code as an integer. """ + exc_class: t.Type[Exception] if isinstance(exc_class_or_code, int): exc_class = default_exceptions[exc_class_or_code] else: @@ -684,11 +733,11 @@ def _get_exc_class_and_code(exc_class_or_code): else: return exc_class, None - def ensure_sync(self, func): + def ensure_sync(self, func: t.Callable) -> t.Callable: raise NotImplementedError() -def _endpoint_from_view_func(view_func): +def _endpoint_from_view_func(view_func: t.Callable) -> str: """Internal helper that returns the default endpoint for a given function. This always is the function name. """ @@ -696,7 +745,7 @@ def _endpoint_from_view_func(view_func): return view_func.__name__ -def get_root_path(import_name): +def get_root_path(import_name: str) -> str: """Find the root path of a package, or the path that contains a module. If it cannot be found, returns the current working directory. @@ -721,7 +770,7 @@ def get_root_path(import_name): return os.getcwd() if hasattr(loader, "get_filename"): - filepath = loader.get_filename(import_name) + filepath = loader.get_filename(import_name) # type: ignore else: # Fall back to imports. __import__(import_name) @@ -822,7 +871,7 @@ def _find_package_path(root_mod_name): return package_path -def find_package(import_name): +def find_package(import_name: str): """Find the prefix that a package is installed under, and the path that it would be imported from. diff --git a/src/flask/sessions.py b/src/flask/sessions.py index 795a922c6c..0e68e884f1 100644 --- a/src/flask/sessions.py +++ b/src/flask/sessions.py @@ -1,4 +1,5 @@ import hashlib +import typing as t import warnings from collections.abc import MutableMapping from datetime import datetime @@ -10,17 +11,21 @@ from .helpers import is_ip from .json.tag import TaggedJSONSerializer +if t.TYPE_CHECKING: + from .app import Flask + from .wrappers import Request, Response + class SessionMixin(MutableMapping): """Expands a basic dictionary with session attributes.""" @property - def permanent(self): + def permanent(self) -> bool: """This reflects the ``'_permanent'`` key in the dict.""" return self.get("_permanent", False) @permanent.setter - def permanent(self, value): + def permanent(self, value: bool) -> None: self["_permanent"] = bool(value) #: Some implementations can detect whether a session is newly @@ -61,22 +66,22 @@ class SecureCookieSession(CallbackDict, SessionMixin): #: different users. accessed = False - def __init__(self, initial=None): - def on_update(self): + def __init__(self, initial: t.Any = None) -> None: + def on_update(self) -> None: self.modified = True self.accessed = True super().__init__(initial, on_update) - def __getitem__(self, key): + def __getitem__(self, key: str) -> t.Any: self.accessed = True return super().__getitem__(key) - def get(self, key, default=None): + def get(self, key: str, default: t.Any = None) -> t.Any: self.accessed = True return super().get(key, default) - def setdefault(self, key, default=None): + def setdefault(self, key: str, default: t.Any = None) -> t.Any: self.accessed = True return super().setdefault(key, default) @@ -87,14 +92,14 @@ class NullSession(SecureCookieSession): but fail on setting. """ - def _fail(self, *args, **kwargs): + def _fail(self, *args: t.Any, **kwargs: t.Any) -> t.NoReturn: raise RuntimeError( "The session is unavailable because no secret " "key was set. Set the secret_key on the " "application to something unique and secret." ) - __setitem__ = __delitem__ = clear = pop = popitem = update = setdefault = _fail + __setitem__ = __delitem__ = clear = pop = popitem = update = setdefault = _fail # type: ignore # noqa: B950 del _fail @@ -141,7 +146,7 @@ class Session(dict, SessionMixin): #: .. versionadded:: 0.10 pickle_based = False - def make_null_session(self, app): + def make_null_session(self, app: "Flask") -> NullSession: """Creates a null session which acts as a replacement object if the real session support could not be loaded due to a configuration error. This mainly aids the user experience because the job of the @@ -153,7 +158,7 @@ def make_null_session(self, app): """ return self.null_session_class() - def is_null_session(self, obj): + def is_null_session(self, obj: object) -> bool: """Checks if a given object is a null session. Null sessions are not asked to be saved. @@ -162,14 +167,14 @@ def is_null_session(self, obj): """ return isinstance(obj, self.null_session_class) - def get_cookie_name(self, app): + def get_cookie_name(self, app: "Flask") -> str: """Returns the name of the session cookie. Uses ``app.session_cookie_name`` which is set to ``SESSION_COOKIE_NAME`` """ return app.session_cookie_name - def get_cookie_domain(self, app): + def get_cookie_domain(self, app: "Flask") -> t.Optional[str]: """Returns the domain that should be set for the session cookie. Uses ``SESSION_COOKIE_DOMAIN`` if it is configured, otherwise @@ -227,7 +232,7 @@ def get_cookie_domain(self, app): app.config["SESSION_COOKIE_DOMAIN"] = rv return rv - def get_cookie_path(self, app): + def get_cookie_path(self, app: "Flask") -> str: """Returns the path for which the cookie should be valid. The default implementation uses the value from the ``SESSION_COOKIE_PATH`` config var if it's set, and falls back to ``APPLICATION_ROOT`` or @@ -235,27 +240,29 @@ def get_cookie_path(self, app): """ return app.config["SESSION_COOKIE_PATH"] or app.config["APPLICATION_ROOT"] - def get_cookie_httponly(self, app): + def get_cookie_httponly(self, app: "Flask") -> bool: """Returns True if the session cookie should be httponly. This currently just returns the value of the ``SESSION_COOKIE_HTTPONLY`` config var. """ return app.config["SESSION_COOKIE_HTTPONLY"] - def get_cookie_secure(self, app): + def get_cookie_secure(self, app: "Flask") -> bool: """Returns True if the cookie should be secure. This currently just returns the value of the ``SESSION_COOKIE_SECURE`` setting. """ return app.config["SESSION_COOKIE_SECURE"] - def get_cookie_samesite(self, app): + def get_cookie_samesite(self, app: "Flask") -> str: """Return ``'Strict'`` or ``'Lax'`` if the cookie should use the ``SameSite`` attribute. This currently just returns the value of the :data:`SESSION_COOKIE_SAMESITE` setting. """ return app.config["SESSION_COOKIE_SAMESITE"] - def get_expiration_time(self, app, session): + def get_expiration_time( + self, app: "Flask", session: SessionMixin + ) -> t.Optional[datetime]: """A helper method that returns an expiration date for the session or ``None`` if the session is linked to the browser session. The default implementation returns now + the permanent session @@ -263,8 +270,9 @@ def get_expiration_time(self, app, session): """ if session.permanent: return datetime.utcnow() + app.permanent_session_lifetime + return None - def should_set_cookie(self, app, session): + def should_set_cookie(self, app: "Flask", session: SessionMixin) -> bool: """Used by session backends to determine if a ``Set-Cookie`` header should be set for this session cookie for this response. If the session has been modified, the cookie is set. If the session is permanent and @@ -280,7 +288,9 @@ def should_set_cookie(self, app, session): session.permanent and app.config["SESSION_REFRESH_EACH_REQUEST"] ) - def open_session(self, app, request): + def open_session( + self, app: "Flask", request: "Request" + ) -> t.Optional[SessionMixin]: """This method has to be implemented and must either return ``None`` in case the loading failed because of a configuration error or an instance of a session object which implements a dictionary like @@ -288,7 +298,9 @@ def open_session(self, app, request): """ raise NotImplementedError() - def save_session(self, app, session, response): + def save_session( + self, app: "Flask", session: SessionMixin, response: "Response" + ) -> None: """This is called for actual sessions returned by :meth:`open_session` at the end of the request. This is still called during a request context so if you absolutely need access to the request you can do @@ -319,7 +331,9 @@ class SecureCookieSessionInterface(SessionInterface): serializer = session_json_serializer session_class = SecureCookieSession - def get_signing_serializer(self, app): + def get_signing_serializer( + self, app: "Flask" + ) -> t.Optional[URLSafeTimedSerializer]: if not app.secret_key: return None signer_kwargs = dict( @@ -332,7 +346,9 @@ def get_signing_serializer(self, app): signer_kwargs=signer_kwargs, ) - def open_session(self, app, request): + def open_session( + self, app: "Flask", request: "Request" + ) -> t.Optional[SecureCookieSession]: s = self.get_signing_serializer(app) if s is None: return None @@ -346,7 +362,9 @@ def open_session(self, app, request): except BadSignature: return self.session_class() - def save_session(self, app, session, response): + def save_session( + self, app: "Flask", session: SessionMixin, response: "Response" + ) -> None: name = self.get_cookie_name(app) domain = self.get_cookie_domain(app) path = self.get_cookie_path(app) @@ -372,10 +390,10 @@ def save_session(self, app, session, response): httponly = self.get_cookie_httponly(app) expires = self.get_expiration_time(app, session) - val = self.get_signing_serializer(app).dumps(dict(session)) + val = self.get_signing_serializer(app).dumps(dict(session)) # type: ignore response.set_cookie( name, - val, + val, # type: ignore expires=expires, httponly=httponly, domain=domain, diff --git a/src/flask/signals.py b/src/flask/signals.py index d2179c65da..63667bdb76 100644 --- a/src/flask/signals.py +++ b/src/flask/signals.py @@ -1,3 +1,5 @@ +import typing as t + try: from blinker import Namespace @@ -5,8 +7,8 @@ except ImportError: signals_available = False - class Namespace: - def signal(self, name, doc=None): + class Namespace: # type: ignore + def signal(self, name: str, doc: t.Optional[str] = None) -> "_FakeSignal": return _FakeSignal(name, doc) class _FakeSignal: @@ -16,14 +18,14 @@ class _FakeSignal: will just ignore the arguments and do nothing instead. """ - def __init__(self, name, doc=None): + def __init__(self, name: str, doc: t.Optional[str] = None) -> None: self.name = name self.__doc__ = doc - def send(self, *args, **kwargs): + def send(self, *args: t.Any, **kwargs: t.Any) -> t.Any: pass - def _fail(self, *args, **kwargs): + def _fail(self, *args: t.Any, **kwargs: t.Any) -> t.Any: raise RuntimeError( "Signalling support is unavailable because the blinker" " library is not installed." diff --git a/src/flask/templating.py b/src/flask/templating.py index 6eebb13d61..1987d9e96e 100644 --- a/src/flask/templating.py +++ b/src/flask/templating.py @@ -1,5 +1,8 @@ +import typing as t + from jinja2 import BaseLoader from jinja2 import Environment as BaseEnvironment +from jinja2 import Template from jinja2 import TemplateNotFound from .globals import _app_ctx_stack @@ -7,8 +10,12 @@ from .signals import before_render_template from .signals import template_rendered +if t.TYPE_CHECKING: + from .app import Flask + from .scaffold import Scaffold + -def _default_template_ctx_processor(): +def _default_template_ctx_processor() -> t.Dict[str, t.Any]: """Default template context processor. Injects `request`, `session` and `g`. """ @@ -29,7 +36,7 @@ class Environment(BaseEnvironment): name of the blueprint to referenced templates if necessary. """ - def __init__(self, app, **options): + def __init__(self, app: "Flask", **options: t.Any) -> None: if "loader" not in options: options["loader"] = app.create_global_jinja_loader() BaseEnvironment.__init__(self, **options) @@ -41,15 +48,19 @@ class DispatchingJinjaLoader(BaseLoader): the blueprint folders. """ - def __init__(self, app): + def __init__(self, app: "Flask") -> None: self.app = app - def get_source(self, environment, template): + def get_source( + self, environment: Environment, template: str + ) -> t.Tuple[str, t.Optional[str], t.Callable]: if self.app.config["EXPLAIN_TEMPLATE_LOADING"]: return self._get_source_explained(environment, template) return self._get_source_fast(environment, template) - def _get_source_explained(self, environment, template): + def _get_source_explained( + self, environment: Environment, template: str + ) -> t.Tuple[str, t.Optional[str], t.Callable]: attempts = [] trv = None @@ -70,7 +81,9 @@ def _get_source_explained(self, environment, template): return trv raise TemplateNotFound(template) - def _get_source_fast(self, environment, template): + def _get_source_fast( + self, environment: Environment, template: str + ) -> t.Tuple[str, t.Optional[str], t.Callable]: for _srcobj, loader in self._iter_loaders(template): try: return loader.get_source(environment, template) @@ -78,7 +91,9 @@ def _get_source_fast(self, environment, template): continue raise TemplateNotFound(template) - def _iter_loaders(self, template): + def _iter_loaders( + self, template: str + ) -> t.Generator[t.Tuple["Scaffold", BaseLoader], None, None]: loader = self.app.jinja_loader if loader is not None: yield self.app, loader @@ -88,7 +103,7 @@ def _iter_loaders(self, template): if loader is not None: yield blueprint, loader - def list_templates(self): + def list_templates(self) -> t.List[str]: result = set() loader = self.app.jinja_loader if loader is not None: @@ -103,7 +118,7 @@ def list_templates(self): return list(result) -def _render(template, context, app): +def _render(template: Template, context: dict, app: "Flask") -> str: """Renders the template and fires the signal""" before_render_template.send(app, template=template, context=context) @@ -112,7 +127,9 @@ def _render(template, context, app): return rv -def render_template(template_name_or_list, **context): +def render_template( + template_name_or_list: t.Union[str, t.List[str]], **context: t.Any +) -> str: """Renders a template from the template folder with the given context. @@ -131,7 +148,7 @@ def render_template(template_name_or_list, **context): ) -def render_template_string(source, **context): +def render_template_string(source: str, **context: t.Any) -> str: """Renders a template from the given template source string with the given context. Template variables will be autoescaped. diff --git a/src/flask/testing.py b/src/flask/testing.py index 247e660598..fe3b846a17 100644 --- a/src/flask/testing.py +++ b/src/flask/testing.py @@ -1,5 +1,7 @@ +import typing as t from contextlib import contextmanager from copy import copy +from types import TracebackType import werkzeug.test from click.testing import CliRunner @@ -10,6 +12,11 @@ from . import _request_ctx_stack from .cli import ScriptInfo from .json import dumps as json_dumps +from .sessions import SessionMixin + +if t.TYPE_CHECKING: + from .app import Flask + from .wrappers import Response class EnvironBuilder(werkzeug.test.EnvironBuilder): @@ -36,14 +43,14 @@ class EnvironBuilder(werkzeug.test.EnvironBuilder): def __init__( self, - app, - path="/", - base_url=None, - subdomain=None, - url_scheme=None, - *args, - **kwargs, - ): + app: "Flask", + path: str = "/", + base_url: t.Optional[str] = None, + subdomain: t.Optional[str] = None, + url_scheme: t.Optional[str] = None, + *args: t.Any, + **kwargs: t.Any, + ) -> None: assert not (base_url or subdomain or url_scheme) or ( base_url is not None ) != bool( @@ -74,7 +81,7 @@ def __init__( self.app = app super().__init__(path, base_url, *args, **kwargs) - def json_dumps(self, obj, **kwargs): + def json_dumps(self, obj: t.Any, **kwargs: t.Any) -> str: # type: ignore """Serialize ``obj`` to a JSON-formatted string. The serialization will be configured according to the config associated @@ -99,9 +106,10 @@ class FlaskClient(Client): Basic usage is outlined in the :doc:`/testing` chapter. """ + application: "Flask" preserve_context = False - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) self.environ_base = { "REMOTE_ADDR": "127.0.0.1", @@ -109,7 +117,9 @@ def __init__(self, *args, **kwargs): } @contextmanager - def session_transaction(self, *args, **kwargs): + def session_transaction( + self, *args: t.Any, **kwargs: t.Any + ) -> t.Generator[SessionMixin, None, None]: """When used in combination with a ``with`` statement this opens a session transaction. This can be used to modify the session that the test client uses. Once the ``with`` block is left the session is @@ -161,9 +171,14 @@ def session_transaction(self, *args, **kwargs): headers = resp.get_wsgi_headers(c.request.environ) self.cookie_jar.extract_wsgi(c.request.environ, headers) - def open( - self, *args, as_tuple=False, buffered=False, follow_redirects=False, **kwargs - ): + def open( # type: ignore + self, + *args: t.Any, + as_tuple: bool = False, + buffered: bool = False, + follow_redirects: bool = False, + **kwargs: t.Any, + ) -> "Response": # Same logic as super.open, but apply environ_base and preserve_context. request = None @@ -198,20 +213,22 @@ def copy_environ(other): finally: builder.close() - return super().open( + return super().open( # type: ignore request, as_tuple=as_tuple, buffered=buffered, follow_redirects=follow_redirects, ) - def __enter__(self): + def __enter__(self) -> "FlaskClient": if self.preserve_context: raise RuntimeError("Cannot nest client invocations") self.preserve_context = True return self - def __exit__(self, exc_type, exc_value, tb): + def __exit__( + self, exc_type: type, exc_value: BaseException, tb: TracebackType + ) -> None: self.preserve_context = False # Normally the request context is preserved until the next @@ -233,11 +250,13 @@ class FlaskCliRunner(CliRunner): :meth:`~flask.Flask.test_cli_runner`. See :ref:`testing-cli`. """ - def __init__(self, app, **kwargs): + def __init__(self, app: "Flask", **kwargs: t.Any) -> None: self.app = app super().__init__(**kwargs) - def invoke(self, cli=None, args=None, **kwargs): + def invoke( # type: ignore + self, cli: t.Any = None, args: t.Any = None, **kwargs: t.Any + ) -> t.Any: """Invokes a CLI command in an isolated environment. See :meth:`CliRunner.invoke ` for full method documentation. See :ref:`testing-cli` for examples. diff --git a/src/flask/typing.py b/src/flask/typing.py new file mode 100644 index 0000000000..9a664e4141 --- /dev/null +++ b/src/flask/typing.py @@ -0,0 +1,46 @@ +import typing as t + + +if t.TYPE_CHECKING: + from werkzeug.datastructures import Headers # noqa: F401 + from wsgiref.types import WSGIApplication # noqa: F401 + from .wrappers import Response # noqa: F401 + +# The possible types that are directly convertible or are a Response object. +ResponseValue = t.Union[ + "Response", + t.AnyStr, + t.Dict[str, t.Any], # any jsonify-able dict + t.Generator[t.AnyStr, None, None], +] +StatusCode = int + +# the possible types for an individual HTTP header +HeaderName = str +HeaderValue = t.Union[str, t.List[str], t.Tuple[str, ...]] + +# the possible types for HTTP headers +HeadersValue = t.Union[ + "Headers", t.Dict[HeaderName, HeaderValue], t.List[t.Tuple[HeaderName, HeaderValue]] +] + +# The possible types returned by a route function. +ResponseReturnValue = t.Union[ + ResponseValue, + t.Tuple[ResponseValue, HeadersValue], + t.Tuple[ResponseValue, StatusCode], + t.Tuple[ResponseValue, StatusCode, HeadersValue], + "WSGIApplication", +] + +AppOrBlueprintKey = t.Optional[str] # The App key is None, whereas blueprints are named +AfterRequestCallable = t.Callable[["Response"], "Response"] +BeforeRequestCallable = t.Callable[[], None] +ErrorHandlerCallable = t.Callable[[Exception], ResponseReturnValue] +TeardownCallable = t.Callable[[t.Optional[BaseException]], "Response"] +TemplateContextProcessorCallable = t.Callable[[], t.Dict[str, t.Any]] +TemplateFilterCallable = t.Callable[[t.Any], str] +TemplateGlobalCallable = t.Callable[[], t.Any] +TemplateTestCallable = t.Callable[[t.Any], bool] +URLDefaultCallable = t.Callable[[str, dict], None] +URLValuePreprocessorCallable = t.Callable[[t.Optional[str], t.Optional[dict]], None] diff --git a/src/flask/views.py b/src/flask/views.py index 323e6118e6..339ffa18f2 100644 --- a/src/flask/views.py +++ b/src/flask/views.py @@ -1,4 +1,7 @@ +import typing as t + from .globals import request +from .typing import ResponseReturnValue http_method_funcs = frozenset( @@ -39,10 +42,10 @@ def dispatch_request(self): """ #: A list of methods this view can handle. - methods = None + methods: t.Optional[t.List[str]] = None #: Setting this disables or force-enables the automatic options handling. - provide_automatic_options = None + provide_automatic_options: t.Optional[bool] = None #: The canonical way to decorate class-based views is to decorate the #: return value of as_view(). However since this moves parts of the @@ -53,9 +56,9 @@ def dispatch_request(self): #: view function is created the result is automatically decorated. #: #: .. versionadded:: 0.8 - decorators = () + decorators: t.List[t.Callable] = [] - def dispatch_request(self): + def dispatch_request(self) -> ResponseReturnValue: """Subclasses have to override this method to implement the actual view function code. This method is called with all the arguments from the URL rule. @@ -63,7 +66,9 @@ def dispatch_request(self): raise NotImplementedError() @classmethod - def as_view(cls, name, *class_args, **class_kwargs): + def as_view( + cls, name: str, *class_args: t.Any, **class_kwargs: t.Any + ) -> t.Callable: """Converts the class into an actual view function that can be used with the routing system. Internally this generates a function on the fly which will instantiate the :class:`View` on each request and call @@ -73,8 +78,8 @@ def as_view(cls, name, *class_args, **class_kwargs): constructor of the class. """ - def view(*args, **kwargs): - self = view.view_class(*class_args, **class_kwargs) + def view(*args: t.Any, **kwargs: t.Any) -> ResponseReturnValue: + self = view.view_class(*class_args, **class_kwargs) # type: ignore return self.dispatch_request(*args, **kwargs) if cls.decorators: @@ -88,12 +93,12 @@ def view(*args, **kwargs): # view this thing came from, secondly it's also used for instantiating # the view class so you can actually replace it with something else # for testing purposes and debugging. - view.view_class = cls + view.view_class = cls # type: ignore view.__name__ = name view.__doc__ = cls.__doc__ view.__module__ = cls.__module__ - view.methods = cls.methods - view.provide_automatic_options = cls.provide_automatic_options + view.methods = cls.methods # type: ignore + view.provide_automatic_options = cls.provide_automatic_options # type: ignore return view @@ -140,7 +145,7 @@ def post(self): app.add_url_rule('/counter', view_func=CounterAPI.as_view('counter')) """ - def dispatch_request(self, *args, **kwargs): + def dispatch_request(self, *args: t.Any, **kwargs: t.Any) -> ResponseReturnValue: meth = getattr(self, request.method.lower(), None) # If the request method is HEAD and we don't have a handler for it diff --git a/src/flask/wrappers.py b/src/flask/wrappers.py index 1d8f17d7b1..48fcc34b7d 100644 --- a/src/flask/wrappers.py +++ b/src/flask/wrappers.py @@ -1,3 +1,5 @@ +import typing as t + from werkzeug.exceptions import BadRequest from werkzeug.wrappers import Request as RequestBase from werkzeug.wrappers import Response as ResponseBase @@ -5,6 +7,9 @@ from . import json from .globals import current_app +if t.TYPE_CHECKING: + from werkzeug.routing import Rule + class Request(RequestBase): """The request object used by default in Flask. Remembers the @@ -31,26 +36,28 @@ class Request(RequestBase): #: because the request was never internally bound. #: #: .. versionadded:: 0.6 - url_rule = None + url_rule: t.Optional["Rule"] = None #: A dict of view arguments that matched the request. If an exception #: happened when matching, this will be ``None``. - view_args = None + view_args: t.Optional[t.Dict[str, t.Any]] = None #: If matching the URL failed, this is the exception that will be #: raised / was raised as part of the request handling. This is #: usually a :exc:`~werkzeug.exceptions.NotFound` exception or #: something similar. - routing_exception = None + routing_exception: t.Optional[Exception] = None @property - def max_content_length(self): + def max_content_length(self) -> t.Optional[int]: # type: ignore """Read-only view of the ``MAX_CONTENT_LENGTH`` config key.""" if current_app: return current_app.config["MAX_CONTENT_LENGTH"] + else: + return None @property - def endpoint(self): + def endpoint(self) -> t.Optional[str]: """The endpoint that matched the request. This in combination with :attr:`view_args` can be used to reconstruct the same or a modified URL. If an exception happened when matching, this will @@ -58,14 +65,18 @@ def endpoint(self): """ if self.url_rule is not None: return self.url_rule.endpoint + else: + return None @property - def blueprint(self): + def blueprint(self) -> t.Optional[str]: """The name of the current blueprint""" if self.url_rule and "." in self.url_rule.endpoint: return self.url_rule.endpoint.rsplit(".", 1)[0] + else: + return None - def _load_form_data(self): + def _load_form_data(self) -> None: RequestBase._load_form_data(self) # In debug mode we're replacing the files multidict with an ad-hoc @@ -80,7 +91,7 @@ def _load_form_data(self): attach_enctype_error_multidict(self) - def on_json_loading_failed(self, e): + def on_json_loading_failed(self, e: Exception) -> t.NoReturn: if current_app and current_app.debug: raise BadRequest(f"Failed to decode JSON object: {e}") @@ -110,7 +121,7 @@ class Response(ResponseBase): json_module = json @property - def max_cookie_size(self): + def max_cookie_size(self) -> int: # type: ignore """Read-only view of the :data:`MAX_COOKIE_SIZE` config key. See :attr:`~werkzeug.wrappers.Response.max_cookie_size` in