diff --git a/fastapi/routing.py b/fastapi/routing.py index fa1351859fb91..d6f105446176a 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -208,7 +208,11 @@ def get_request_handler( dependency_overrides_provider: Optional[Any] = None, ) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" - is_coroutine = asyncio.iscoroutinefunction(dependant.call) + is_coroutine = ( + asyncio.iscoroutinefunction(dependant.call) + or callable(dependant.call) + and inspect.iscoroutinefunction(dependant.call.__call__) # type: ignore[operator] + ) is_body_form = body_field and isinstance(body_field.field_info, params.Form) if isinstance(response_class, DefaultPlaceholder): actual_response_class: Type[Response] = response_class.value diff --git a/tests/test_endpoint_decorator.py b/tests/test_endpoint_decorator.py new file mode 100644 index 0000000000000..ba99fb70cdd29 --- /dev/null +++ b/tests/test_endpoint_decorator.py @@ -0,0 +1,65 @@ +from functools import update_wrapper +from typing import Any, Callable + +from fastapi import Depends, FastAPI +from fastapi.routing import APIRoute +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.testclient import TestClient +from starlette.exceptions import HTTPException + + +class EndpointWrapper(Callable[..., Any]): + def __init__(self, endpoint: Callable[..., Any]): + self.endpoint = endpoint + self.protected = False + update_wrapper(self, endpoint) + + async def __call__(self, *args, **kwargs): + return await self.endpoint(*args, **kwargs) + + +def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): + if token.credentials != "fake-token": + raise HTTPException(status_code=401, detail="Unauthorized") + + +def protect(endpoint: Callable[..., Any]): + if not isinstance(endpoint, EndpointWrapper): + endpoint = EndpointWrapper(endpoint) + endpoint.protected = True + return endpoint + + +class CustomAPIRoute(APIRoute): + def __init__( + self, path: str, endpoint: Callable[..., Any], dependencies, **kwargs + ) -> None: + if isinstance(endpoint, EndpointWrapper) and endpoint.protected: + dependencies.append(Depends(dummy_secruity_check)) + super().__init__(path, endpoint, dependencies=dependencies, **kwargs) + + +app = FastAPI() + +app.router.route_class = CustomAPIRoute + + +@app.get("/protected") +@protect +async def protected_route(): + return {"message": "This is a protected route"} + + +client = TestClient(app) + + +def test_protected_route(): + response = client.get("/protected") + assert response.status_code == 403 + + response = client.get("/protected", headers={"Authorization": "Bearer some-token"}) + assert response.status_code == 401 + + response = client.get("/protected", headers={"Authorization": "Bearer fake-token"}) + assert response.status_code == 200 + assert response.json() == {"message": "This is a protected route"}