From f8074c72d96f73026d6fd730442797f5a5c62207 Mon Sep 17 00:00:00 2001 From: reton2 Date: Wed, 1 May 2024 10:55:30 +0100 Subject: [PATCH 1/8] fix: allow callabe in dependant for get_request_handler --- fastapi/routing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index fa1351859fb91..c34a547331c6d 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -208,7 +208,11 @@ def get_request_handler( dependency_overrides_provider: Optional[Any] = None, ) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" - is_coroutine = asyncio.iscoroutinefunction(dependant.call) + is_coroutine = ( + asyncio.iscoroutinefunction(dependant.call) + or callable(dependant.call) + and inspect.iscoroutinefunction(dependant.call.__call__) + ) is_body_form = body_field and isinstance(body_field.field_info, params.Form) if isinstance(response_class, DefaultPlaceholder): actual_response_class: Type[Response] = response_class.value From 4c5c0f60d241320f8a9ccc28a633acd30b8be949 Mon Sep 17 00:00:00 2001 From: reton2 Date: Wed, 1 May 2024 11:20:49 +0100 Subject: [PATCH 2/8] test: add test for fix/allow-callable_get_request_handler --- tests/test_endpoint_decorator.py | 59 ++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 tests/test_endpoint_decorator.py diff --git a/tests/test_endpoint_decorator.py b/tests/test_endpoint_decorator.py new file mode 100644 index 0000000000000..c420bb486becb --- /dev/null +++ b/tests/test_endpoint_decorator.py @@ -0,0 +1,59 @@ +from typing import Any, Callable +from functools import update_wrapper +from fastapi import Depends, FastAPI +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.routing import APIRoute +from fastapi.testclient import TestClient +from starlette.exceptions import HTTPException + +class EndpointWrapper(Callable[..., Any]): + def __init__(self, endpoint: Callable[..., Any]): + self.endpoint = endpoint + self.protected = False + update_wrapper(self, endpoint) + + async def __call__(self, *args, **kwargs): + return await self.endpoint(*args, **kwargs) + +def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): + if token.credentials != "fake-token": + raise HTTPException(status_code=401, detail="Unauthorized") + +def protect(endpoint: Callable[..., Any]): + if not isinstance(endpoint, EndpointWrapper): + endpoint = EndpointWrapper(endpoint) + endpoint.protected = True + return endpoint + +class CustomAPIRoute(APIRoute): + def __init__(self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs) -> None: + if dependencies is None: + dependencies = [] + if ( + isinstance(endpoint, EndpointWrapper) + and endpoint.protected + ): + dependencies.append(Depends(dummy_secruity_check)) + super().__init__(path, endpoint, dependencies=dependencies, **kwargs) + +app = FastAPI() + +app.router.route_class = CustomAPIRoute + +@app.get("/protected") +@protect +async def protected_route(): + return {"message": "This is a protected route"} + +client = TestClient(app) + +def test_protected_route(): + response = client.get("/protected") + assert response.status_code == 403 + + response = client.get("/protected", headers={"Authorization": "Bearer some-token"}) + assert response.status_code == 401 + + response = client.get("/protected", headers={"Authorization": "Bearer fake-token"}) + assert response.status_code == 200 + assert response.json() == {"message": "This is a protected route"} From f644f72306efb0cf18369fd065811a59d3437f3a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 May 2024 10:22:21 +0000 Subject: [PATCH 3/8] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_endpoint_decorator.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/test_endpoint_decorator.py b/tests/test_endpoint_decorator.py index c420bb486becb..33f03ec5490f3 100644 --- a/tests/test_endpoint_decorator.py +++ b/tests/test_endpoint_decorator.py @@ -1,11 +1,13 @@ -from typing import Any, Callable from functools import update_wrapper +from typing import Any, Callable + from fastapi import Depends, FastAPI -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.routing import APIRoute +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.testclient import TestClient from starlette.exceptions import HTTPException + class EndpointWrapper(Callable[..., Any]): def __init__(self, endpoint: Callable[..., Any]): self.endpoint = endpoint @@ -14,39 +16,45 @@ def __init__(self, endpoint: Callable[..., Any]): async def __call__(self, *args, **kwargs): return await self.endpoint(*args, **kwargs) - + + def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): if token.credentials != "fake-token": raise HTTPException(status_code=401, detail="Unauthorized") + def protect(endpoint: Callable[..., Any]): if not isinstance(endpoint, EndpointWrapper): endpoint = EndpointWrapper(endpoint) endpoint.protected = True return endpoint + class CustomAPIRoute(APIRoute): - def __init__(self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs) -> None: + def __init__( + self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs + ) -> None: if dependencies is None: dependencies = [] - if ( - isinstance(endpoint, EndpointWrapper) - and endpoint.protected - ): + if isinstance(endpoint, EndpointWrapper) and endpoint.protected: dependencies.append(Depends(dummy_secruity_check)) super().__init__(path, endpoint, dependencies=dependencies, **kwargs) + app = FastAPI() app.router.route_class = CustomAPIRoute + @app.get("/protected") @protect async def protected_route(): return {"message": "This is a protected route"} + client = TestClient(app) + def test_protected_route(): response = client.get("/protected") assert response.status_code == 403 From 033e2cff11beb1c7fa702f32e99005eed786674a Mon Sep 17 00:00:00 2001 From: reton2 Date: Wed, 1 May 2024 11:51:22 +0100 Subject: [PATCH 4/8] fix: linting issue --- fastapi/routing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastapi/routing.py b/fastapi/routing.py index c34a547331c6d..6dd7d0fe4b3ff 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -211,6 +211,7 @@ def get_request_handler( is_coroutine = ( asyncio.iscoroutinefunction(dependant.call) or callable(dependant.call) + and hasattr(dependant.call, "__call__") and inspect.iscoroutinefunction(dependant.call.__call__) ) is_body_form = body_field and isinstance(body_field.field_info, params.Form) From 22dc61c695e18d02cdf07a92528f4c0d6e4cd216 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 May 2024 10:54:02 +0000 Subject: [PATCH 5/8] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/routing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 6dd7d0fe4b3ff..188ad22d3dedf 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -211,7 +211,7 @@ def get_request_handler( is_coroutine = ( asyncio.iscoroutinefunction(dependant.call) or callable(dependant.call) - and hasattr(dependant.call, "__call__") + and callable(dependant.call) and inspect.iscoroutinefunction(dependant.call.__call__) ) is_body_form = body_field and isinstance(body_field.field_info, params.Form) From e7962cbc9fec7b913872dd8a63fc0370b82eef6a Mon Sep 17 00:00:00 2001 From: reton2 Date: Wed, 1 May 2024 11:59:57 +0100 Subject: [PATCH 6/8] chore: add ignore for line --- fastapi/routing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 188ad22d3dedf..5b5290d744340 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -211,8 +211,7 @@ def get_request_handler( is_coroutine = ( asyncio.iscoroutinefunction(dependant.call) or callable(dependant.call) - and callable(dependant.call) - and inspect.iscoroutinefunction(dependant.call.__call__) + and inspect.iscoroutinefunction(dependant.call.__call__) # type: ignore[operator] ) is_body_form = body_field and isinstance(body_field.field_info, params.Form) if isinstance(response_class, DefaultPlaceholder): From 6421bc12b22245dc44236d3beaaee844a198ba07 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 May 2024 11:00:11 +0000 Subject: [PATCH 7/8] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/routing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 5b5290d744340..d6f105446176a 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -211,7 +211,7 @@ def get_request_handler( is_coroutine = ( asyncio.iscoroutinefunction(dependant.call) or callable(dependant.call) - and inspect.iscoroutinefunction(dependant.call.__call__) # type: ignore[operator] + and inspect.iscoroutinefunction(dependant.call.__call__) # type: ignore[operator] ) is_body_form = body_field and isinstance(body_field.field_info, params.Form) if isinstance(response_class, DefaultPlaceholder): From c6f95103c2d7ac58c991402c8d4da15274ab8ec2 Mon Sep 17 00:00:00 2001 From: reton2 Date: Wed, 1 May 2024 12:07:31 +0100 Subject: [PATCH 8/8] chore: improve coverage, remove unused line --- tests/test_endpoint_decorator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_endpoint_decorator.py b/tests/test_endpoint_decorator.py index 33f03ec5490f3..ba99fb70cdd29 100644 --- a/tests/test_endpoint_decorator.py +++ b/tests/test_endpoint_decorator.py @@ -32,10 +32,8 @@ def protect(endpoint: Callable[..., Any]): class CustomAPIRoute(APIRoute): def __init__( - self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs + self, path: str, endpoint: Callable[..., Any], dependencies, **kwargs ) -> None: - if dependencies is None: - dependencies = [] if isinstance(endpoint, EndpointWrapper) and endpoint.protected: dependencies.append(Depends(dummy_secruity_check)) super().__init__(path, endpoint, dependencies=dependencies, **kwargs)