From 50c667b7ccb66c3184d8f7677877dfc2d3efa5b5 Mon Sep 17 00:00:00 2001 From: Paulo Costa Date: Mon, 7 Feb 2022 22:50:09 -0300 Subject: [PATCH 1/5] Support `dependencies` in websocket routes I've been using dependencies to handle authentication. But imagine my surprise when I realized my websocket endpoint wasn't authenticated at all? This commit cherry-picks the `dependencies` chunks from `APIRoute` into `APIWebSocketRoute` I also made a few minor style nit-picks --- fastapi/applications.py | 27 ++++++++++++++++++---- fastapi/routing.py | 50 ++++++++++++++++++++++++++++++++--------- 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 132a94c9a1c7c..976b9cdd409d2 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -380,15 +380,34 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable: return decorator def add_api_websocket_route( - self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None + self, + path: str, + endpoint: Callable[..., Any], + *, + dependencies: Optional[Sequence[Depends]] = None, + name: Optional[str] = None, ) -> None: - self.router.add_api_websocket_route(path, endpoint, name=name) + self.router.add_api_websocket_route( + path, + endpoint, + dependencies=dependencies, + name=name, + ) def websocket( - self, path: str, name: Optional[str] = None + self, + path: str, + *, + dependencies: Optional[Sequence[Depends]] = None, + name: Optional[str] = None, ) -> Callable[[DecoratedCallable], DecoratedCallable]: def decorator(func: DecoratedCallable) -> DecoratedCallable: - self.add_api_websocket_route(path, func, name=name) + self.add_api_websocket_route( + path, + func, + dependencies=dependencies, + name=name, + ) return func return decorator diff --git a/fastapi/routing.py b/fastapi/routing.py index 0f416ac42e1df..5e01dd6a53467 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -282,20 +282,27 @@ def __init__( path: str, endpoint: Callable[..., Any], *, + dependencies: Optional[Sequence[params.Depends]] = None, name: Optional[str] = None, dependency_overrides_provider: Optional[Any] = None, ) -> None: self.path = path self.endpoint = endpoint + self.dependencies = dependencies or [] self.name = get_name(endpoint) if name is None else name + self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.dependant = get_dependant(path=path, call=self.endpoint) + for depends in self.dependencies[::-1]: + self.dependant.dependencies.insert( + 0, + get_parameterless_sub_dependant(depends=depends, path=self.path_format), + ) self.app = websocket_session( get_websocket_app( dependant=self.dependant, dependency_overrides_provider=dependency_overrides_provider, ) ) - self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> Tuple[Match, Scope]: match, child_scope = super().matches(scope) @@ -397,10 +404,7 @@ def __init__( else: self.response_field = None # type: ignore self.secure_cloned_response_field = None - if dependencies: - self.dependencies = list(dependencies) - else: - self.dependencies = [] + self.dependencies = dependencies or [] self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") # if a "form feed" character (page break) is found in the description text, # truncate description text to the content preceding the first "form feed" @@ -491,7 +495,7 @@ def __init__( ), "A path prefix must not end with '/', as the routes will start with '/'" self.prefix = prefix self.tags: List[Union[str, Enum]] = tags or [] - self.dependencies = list(dependencies or []) or [] + self.dependencies = dependencies or [] self.deprecated = deprecated self.include_in_schema = include_in_schema self.responses = responses or {} @@ -544,7 +548,7 @@ def add_api_route( current_tags = self.tags.copy() if tags: current_tags.extend(tags) - current_dependencies = self.dependencies.copy() + current_dependencies = list(self.dependencies) if dependencies: current_dependencies.extend(dependencies) current_callbacks = self.callbacks.copy() @@ -646,21 +650,37 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable: return decorator def add_api_websocket_route( - self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None + self, + path: str, + endpoint: Callable[..., Any], + *, + dependencies: Optional[Sequence[params.Depends]] = None, + name: Optional[str] = None, ) -> None: + current_dependencies = list(self.dependencies) + if dependencies: + current_dependencies.extend(dependencies) + route = APIWebSocketRoute( path, endpoint=endpoint, name=name, dependency_overrides_provider=self.dependency_overrides_provider, + dependencies=current_dependencies, ) self.routes.append(route) def websocket( - self, path: str, name: Optional[str] = None + self, + path: str, + *, + dependencies: Optional[Sequence[params.Depends]] = None, + name: Optional[str] = None, ) -> Callable[[DecoratedCallable], DecoratedCallable]: def decorator(func: DecoratedCallable) -> DecoratedCallable: - self.add_api_websocket_route(path, func, name=name) + self.add_api_websocket_route( + path, func, name=name, dependencies=dependencies + ) return func return decorator @@ -766,8 +786,16 @@ def include_router( name=route.name, ) elif isinstance(route, APIWebSocketRoute): + current_dependencies: List[params.Depends] = [] + if dependencies: + current_dependencies.extend(dependencies) + if route.dependencies: + current_dependencies.extend(route.dependencies) self.add_api_websocket_route( - prefix + route.path, route.endpoint, name=route.name + prefix + route.path, + route.endpoint, + dependencies=current_dependencies, + name=route.name, ) elif isinstance(route, routing.WebSocketRoute): self.add_websocket_route( From 0fc264d861074afc58411c4eb7c86cb2e89025a2 Mon Sep 17 00:00:00 2001 From: Paulo Costa Date: Wed, 2 Mar 2022 21:49:27 -0300 Subject: [PATCH 2/5] Add tests for dependencies on websocket endpoints/routes --- tests/test_ws_dependencies.py | 70 +++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/test_ws_dependencies.py diff --git a/tests/test_ws_dependencies.py b/tests/test_ws_dependencies.py new file mode 100644 index 0000000000000..8cb0e669b8e3b --- /dev/null +++ b/tests/test_ws_dependencies.py @@ -0,0 +1,70 @@ +import json +from typing import Optional + +from fastapi import APIRouter, Depends, FastAPI, WebSocket +from fastapi.testclient import TestClient + + +def dependency_list(deps: Optional[str] = None) -> list[str]: + return [deps] if deps else [] + + +def create_dependency(name: str): + def fun(deps: list[str] = Depends(dependency_list)): + print(f"create_dependency.fun({name})") + deps.append(name) + + return Depends(fun) + + +router = APIRouter(dependencies=[create_dependency("router")]) +prefix_router = APIRouter(dependencies=[create_dependency("prefix_router")]) +app = FastAPI(dependencies=[create_dependency("app")]) + + +@app.websocket("/", dependencies=[create_dependency("index")]) +async def index(websocket: WebSocket, deps=Depends(dependency_list)): + await websocket.accept() + await websocket.send_text(json.dumps(deps)) + await websocket.close() + + +@router.websocket("/router", dependencies=[create_dependency("routerindex")]) +async def routerindex(websocket: WebSocket, deps=Depends(dependency_list)): + await websocket.accept() + await websocket.send_text(json.dumps(deps)) + await websocket.close() + + +@prefix_router.websocket("/", dependencies=[create_dependency("routerprefixindex")]) +async def routerprefixindex(websocket: WebSocket, deps=Depends(dependency_list)): + await websocket.accept() + await websocket.send_text(json.dumps(deps)) + await websocket.close() + + +app.include_router(router, dependencies=[create_dependency("router2")]) +app.include_router( + prefix_router, prefix="/prefix", dependencies=[create_dependency("prefix_router2")] +) + + +def test_index(): + client = TestClient(app) + with client.websocket_connect("/") as websocket: + data = json.loads(websocket.receive_text()) + assert data == ["app", "index"] + + +def test_routerindex(): + client = TestClient(app) + with client.websocket_connect("/router") as websocket: + data = json.loads(websocket.receive_text()) + assert data == ["app", "router2", "router", "routerindex"] + + +def test_routerprefixindex(): + client = TestClient(app) + with client.websocket_connect("/prefix/") as websocket: + data = json.loads(websocket.receive_text()) + assert data == ["app", "prefix_router2", "prefix_router", "routerprefixindex"] From e4ddcb8348a92e7c8e4a62db7052e98d67331449 Mon Sep 17 00:00:00 2001 From: Paulo Costa Date: Mon, 3 Oct 2022 22:26:55 -0300 Subject: [PATCH 3/5] Preserve websocket's `name` as a positional argument --- fastapi/applications.py | 8 ++++---- fastapi/routing.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index bb3da2e8400da..a756b12a3c0ca 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -392,30 +392,30 @@ def add_api_websocket_route( self, path: str, endpoint: Callable[..., Any], + name: Optional[str] = None, *, dependencies: Optional[Sequence[Depends]] = None, - name: Optional[str] = None, ) -> None: self.router.add_api_websocket_route( path, endpoint, - dependencies=dependencies, name=name, + dependencies=dependencies, ) def websocket( self, path: str, + name: Optional[str] = None, *, dependencies: Optional[Sequence[Depends]] = None, - name: Optional[str] = None, ) -> Callable[[DecoratedCallable], DecoratedCallable]: def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_api_websocket_route( path, func, - dependencies=dependencies, name=name, + dependencies=dependencies, ) return func diff --git a/fastapi/routing.py b/fastapi/routing.py index 38d86f17dde52..31d6b30a8b1a8 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -291,14 +291,14 @@ def __init__( path: str, endpoint: Callable[..., Any], *, - dependencies: Optional[Sequence[params.Depends]] = None, name: Optional[str] = None, + dependencies: Optional[Sequence[params.Depends]] = None, dependency_overrides_provider: Optional[Any] = None, ) -> None: self.path = path self.endpoint = endpoint - self.dependencies = dependencies or [] self.name = get_name(endpoint) if name is None else name + self.dependencies = dependencies or [] self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.dependant = get_dependant(path=self.path_format, call=self.endpoint) for depends in self.dependencies[::-1]: @@ -663,9 +663,9 @@ def add_api_websocket_route( self, path: str, endpoint: Callable[..., Any], + name: Optional[str] = None, *, dependencies: Optional[Sequence[params.Depends]] = None, - name: Optional[str] = None, ) -> None: current_dependencies = list(self.dependencies) if dependencies: @@ -675,17 +675,17 @@ def add_api_websocket_route( self.prefix + path, endpoint=endpoint, name=name, - dependency_overrides_provider=self.dependency_overrides_provider, dependencies=current_dependencies, + dependency_overrides_provider=self.dependency_overrides_provider, ) self.routes.append(route) def websocket( self, path: str, + name: Optional[str] = None, *, dependencies: Optional[Sequence[params.Depends]] = None, - name: Optional[str] = None, ) -> Callable[[DecoratedCallable], DecoratedCallable]: def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_api_websocket_route( From 230f291153165c220e54ae5eb2175c93fb74627f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 11 Jun 2023 22:29:31 +0200 Subject: [PATCH 4/5] =?UTF-8?q?=E2=8F=AA=EF=B8=8F=20Revert=20changes=20to?= =?UTF-8?q?=20dependencies=20expected=20types?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/routing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 41512669639dc..af628f32d7d13 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -302,7 +302,7 @@ def __init__( self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name - self.dependencies = dependencies or [] + self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.dependant = get_dependant(path=self.path_format, call=self.endpoint) for depends in self.dependencies[::-1]: @@ -424,7 +424,7 @@ def __init__( else: self.response_field = None # type: ignore self.secure_cloned_response_field = None - self.dependencies = dependencies or [] + self.dependencies = list(dependencies or []) self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") # if a "form feed" character (page break) is found in the description text, # truncate description text to the content preceding the first "form feed" @@ -519,7 +519,7 @@ def __init__( ), "A path prefix must not end with '/', as the routes will start with '/'" self.prefix = prefix self.tags: List[Union[str, Enum]] = tags or [] - self.dependencies = dependencies or [] + self.dependencies = list(dependencies or []) self.deprecated = deprecated self.include_in_schema = include_in_schema self.responses = responses or {} @@ -591,7 +591,7 @@ def add_api_route( current_tags = self.tags.copy() if tags: current_tags.extend(tags) - current_dependencies = list(self.dependencies) + current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) current_callbacks = self.callbacks.copy() @@ -700,7 +700,7 @@ def add_api_websocket_route( *, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: - current_dependencies = list(self.dependencies) + current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) From 931b492013a86be3545aa17b59aa2606ebbd9984 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 11 Jun 2023 22:30:18 +0200 Subject: [PATCH 5/5] =?UTF-8?q?=E2=9C=85=20Tweak=20tests=20for=20WebSocket?= =?UTF-8?q?s=20with=20dependencies?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_ws_dependencies.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/test_ws_dependencies.py b/tests/test_ws_dependencies.py index e3e0579369c87..ccb1c4b7da2e8 100644 --- a/tests/test_ws_dependencies.py +++ b/tests/test_ws_dependencies.py @@ -1,17 +1,20 @@ import json -from typing import List, Optional +from typing import List from fastapi import APIRouter, Depends, FastAPI, WebSocket from fastapi.testclient import TestClient +from typing_extensions import Annotated -def dependency_list(deps: Optional[str] = None) -> List[str]: - return [deps] if deps else [] +def dependency_list() -> List[str]: + return [] + + +DepList = Annotated[List[str], Depends(dependency_list)] def create_dependency(name: str): - def fun(deps: List[str] = Depends(dependency_list)): - print(f"create_dependency.fun({name})") + def fun(deps: DepList): deps.append(name) return Depends(fun) @@ -23,21 +26,21 @@ def fun(deps: List[str] = Depends(dependency_list)): @app.websocket("/", dependencies=[create_dependency("index")]) -async def index(websocket: WebSocket, deps=Depends(dependency_list)): +async def index(websocket: WebSocket, deps: DepList): await websocket.accept() await websocket.send_text(json.dumps(deps)) await websocket.close() @router.websocket("/router", dependencies=[create_dependency("routerindex")]) -async def routerindex(websocket: WebSocket, deps=Depends(dependency_list)): +async def routerindex(websocket: WebSocket, deps: DepList): await websocket.accept() await websocket.send_text(json.dumps(deps)) await websocket.close() @prefix_router.websocket("/", dependencies=[create_dependency("routerprefixindex")]) -async def routerprefixindex(websocket: WebSocket, deps=Depends(dependency_list)): +async def routerprefixindex(websocket: WebSocket, deps: DepList): await websocket.accept() await websocket.send_text(json.dumps(deps)) await websocket.close()