Skip to content

Commit

Permalink
✨ Improve type annotations, add support for mypy --strict, internally…
Browse files Browse the repository at this point in the history
… and for external packages (#2547)
  • Loading branch information
tiangolo committed Dec 20, 2020
1 parent 4fdcdf3 commit fdb6c9c
Show file tree
Hide file tree
Showing 43 changed files with 314 additions and 244 deletions.
2 changes: 1 addition & 1 deletion docs_src/openapi_callbacks/tutorial001.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class InvoiceEventReceived(BaseModel):


@invoices_callback_router.post(
"{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived,
"{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived
)
def invoice_notification(body: InvoiceEvent):
pass
Expand Down
39 changes: 19 additions & 20 deletions fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,23 @@

__version__ = "0.62.0"

from starlette import status
from starlette import status as status

from .applications import FastAPI
from .background import BackgroundTasks
from .datastructures import UploadFile
from .exceptions import HTTPException
from .param_functions import (
Body,
Cookie,
Depends,
File,
Form,
Header,
Path,
Query,
Security,
)
from .requests import Request
from .responses import Response
from .routing import APIRouter
from .websockets import WebSocket, WebSocketDisconnect
from .applications import FastAPI as FastAPI
from .background import BackgroundTasks as BackgroundTasks
from .datastructures import UploadFile as UploadFile
from .exceptions import HTTPException as HTTPException
from .param_functions import Body as Body
from .param_functions import Cookie as Cookie
from .param_functions import Depends as Depends
from .param_functions import File as File
from .param_functions import Form as Form
from .param_functions import Header as Header
from .param_functions import Path as Path
from .param_functions import Query as Query
from .param_functions import Security as Security
from .requests import Request as Request
from .responses import Response as Response
from .routing import APIRouter as APIRouter
from .websockets import WebSocket as WebSocket
from .websockets import WebSocketDisconnect as WebSocketDisconnect
89 changes: 50 additions & 39 deletions fastapi/applications.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union

from fastapi import routing
from fastapi.concurrency import AsyncExitStack
Expand All @@ -17,14 +17,15 @@
)
from fastapi.openapi.utils import get_openapi
from fastapi.params import Depends
from fastapi.types import DecoratedCallable
from starlette.applications import Starlette
from starlette.datastructures import State
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, Response
from starlette.routing import BaseRoute
from starlette.types import Receive, Scope, Send
from starlette.types import ASGIApp, Receive, Scope, Send


class FastAPI(Starlette):
Expand All @@ -44,24 +45,27 @@ def __init__(
docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc",
swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
swagger_ui_init_oauth: Optional[dict] = None,
swagger_ui_init_oauth: Optional[Dict[str, Any]] = None,
middleware: Optional[Sequence[Middleware]] = None,
exception_handlers: Optional[
Dict[Union[int, Type[Exception]], Callable]
Dict[
Union[int, Type[Exception]],
Callable[[Request, Any], Coroutine[Any, Any, Response]],
]
] = None,
on_startup: Optional[Sequence[Callable]] = None,
on_shutdown: Optional[Sequence[Callable]] = None,
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
openapi_prefix: str = "",
root_path: str = "",
root_path_in_servers: bool = True,
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
deprecated: bool = None,
callbacks: Optional[List[BaseRoute]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
**extra: Any,
) -> None:
self._debug = debug
self.state = State()
self._debug: bool = debug
self.state: State = State()
self.router: routing.APIRouter = routing.APIRouter(
routes=routes,
dependency_overrides_provider=self,
Expand All @@ -74,16 +78,21 @@ def __init__(
include_in_schema=include_in_schema,
responses=responses,
)
self.exception_handlers = (
self.exception_handlers: Dict[
Union[int, Type[Exception]],
Callable[[Request, Any], Coroutine[Any, Any, Response]],
] = (
{} if exception_handlers is None else dict(exception_handlers)
)
self.exception_handlers.setdefault(HTTPException, http_exception_handler)
self.exception_handlers.setdefault(
RequestValidationError, request_validation_exception_handler
)

self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack = self.build_middleware_stack()
self.user_middleware: List[Middleware] = (
[] if middleware is None else list(middleware)
)
self.middleware_stack: ASGIApp = self.build_middleware_stack()

self.title = title
self.description = description
Expand All @@ -106,7 +115,7 @@ def __init__(
self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url
self.swagger_ui_init_oauth = swagger_ui_init_oauth
self.extra = extra
self.dependency_overrides: Dict[Callable, Callable] = {}
self.dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}

self.openapi_version = "3.0.2"

Expand All @@ -116,7 +125,7 @@ def __init__(
self.openapi_schema: Optional[Dict[str, Any]] = None
self.setup()

def openapi(self) -> Dict:
def openapi(self) -> Dict[str, Any]:
if not self.openapi_schema:
self.openapi_schema = get_openapi(
title=self.title,
Expand Down Expand Up @@ -194,7 +203,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
def add_api_route(
self,
path: str,
endpoint: Callable,
endpoint: Callable[..., Coroutine[Any, Any, Response]],
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
Expand Down Expand Up @@ -268,8 +277,8 @@ def api_route(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.router.add_api_route(
path,
func,
Expand Down Expand Up @@ -299,12 +308,14 @@ def decorator(func: Callable) -> Callable:
return decorator

def add_api_websocket_route(
self, path: str, endpoint: Callable, name: Optional[str] = None
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
) -> None:
self.router.add_api_websocket_route(path, endpoint, name=name)

def websocket(self, path: str, name: Optional[str] = None) -> Callable:
def decorator(func: Callable) -> Callable:
def websocket(
self, path: str, name: Optional[str] = None
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
return func

Expand All @@ -318,10 +329,10 @@ def include_router(
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: bool = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
default_response_class: Type[Response] = Default(JSONResponse),
callbacks: Optional[List[routing.APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> None:
self.router.include_router(
router,
Expand Down Expand Up @@ -358,8 +369,8 @@ def get(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.get(
path,
response_model=response_model,
Expand Down Expand Up @@ -407,8 +418,8 @@ def put(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.put(
path,
response_model=response_model,
Expand Down Expand Up @@ -456,8 +467,8 @@ def post(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.post(
path,
response_model=response_model,
Expand Down Expand Up @@ -505,8 +516,8 @@ def delete(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.delete(
path,
response_model=response_model,
Expand Down Expand Up @@ -554,8 +565,8 @@ def options(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.options(
path,
response_model=response_model,
Expand Down Expand Up @@ -603,8 +614,8 @@ def head(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.head(
path,
response_model=response_model,
Expand Down Expand Up @@ -652,8 +663,8 @@ def patch(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.patch(
path,
response_model=response_model,
Expand Down Expand Up @@ -701,8 +712,8 @@ def trace(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
return self.router.trace(
path,
response_model=response_model,
Expand Down
2 changes: 1 addition & 1 deletion fastapi/background.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from starlette.background import BackgroundTasks # noqa
from starlette.background import BackgroundTasks as BackgroundTasks # noqa
22 changes: 13 additions & 9 deletions fastapi/concurrency.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Callable

from starlette.concurrency import iterate_in_threadpool # noqa
from starlette.concurrency import run_in_threadpool # noqa
from starlette.concurrency import run_until_first_complete # noqa
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
from starlette.concurrency import ( # noqa
run_until_first_complete as run_until_first_complete,
)

asynccontextmanager_error_message = """
FastAPI's contextmanager_in_threadpool require Python 3.7 or above,
Expand All @@ -11,31 +13,33 @@
"""


def _fake_asynccontextmanager(func: Callable) -> Callable:
def _fake_asynccontextmanager(func: Callable[..., Any]) -> Callable[..., Any]:
def raiser(*args: Any, **kwargs: Any) -> Any:
raise RuntimeError(asynccontextmanager_error_message)

return raiser


try:
from contextlib import asynccontextmanager # type: ignore
from contextlib import asynccontextmanager as asynccontextmanager # type: ignore
except ImportError:
try:
from async_generator import asynccontextmanager # type: ignore
from async_generator import ( # type: ignore # isort: skip
asynccontextmanager as asynccontextmanager,
)
except ImportError: # pragma: no cover
asynccontextmanager = _fake_asynccontextmanager

try:
from contextlib import AsyncExitStack # type: ignore
from contextlib import AsyncExitStack as AsyncExitStack # type: ignore
except ImportError:
try:
from async_exit_stack import AsyncExitStack # type: ignore
from async_exit_stack import AsyncExitStack as AsyncExitStack # type: ignore
except ImportError: # pragma: no cover
AsyncExitStack = None # type: ignore


@asynccontextmanager
@asynccontextmanager # type: ignore
async def contextmanager_in_threadpool(cm: Any) -> Any:
try:
yield await run_in_threadpool(cm.__enter__)
Expand Down
3 changes: 2 additions & 1 deletion fastapi/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Callable, Iterable, Type, TypeVar

from starlette.datastructures import State as State # noqa: F401
from starlette.datastructures import UploadFile as StarletteUploadFile


class UploadFile(StarletteUploadFile):
@classmethod
def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable]:
def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable[..., Any]]:
yield cls.validate

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions fastapi/dependencies/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence

from fastapi.security.base import SecurityBase
from pydantic.fields import ModelField
Expand All @@ -24,7 +24,7 @@ def __init__(
dependencies: Optional[List["Dependant"]] = None,
security_schemes: Optional[List[SecurityRequirement]] = None,
name: Optional[str] = None,
call: Optional[Callable] = None,
call: Optional[Callable[..., Any]] = None,
request_param_name: Optional[str] = None,
websocket_param_name: Optional[str] = None,
http_connection_param_name: Optional[str] = None,
Expand Down
Loading

0 comments on commit fdb6c9c

Please sign in to comment.