Skip to content

Commit

Permalink
feat: add custom route on route level
Browse files Browse the repository at this point in the history
  • Loading branch information
arkadybag committed Jan 18, 2023
1 parent 5905c3f commit 814d256
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 0 deletions.
17 changes: 17 additions & 0 deletions fastapi/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from fastapi.openapi.utils import get_openapi
from fastapi.params import Depends
from fastapi.routing import APIRoute
from fastapi.types import DecoratedCallable
from fastapi.utils import generate_unique_id
from starlette.applications import Starlette
Expand Down Expand Up @@ -454,6 +455,7 @@ def get(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -481,6 +483,7 @@ def get(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -509,6 +512,7 @@ def put(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -536,6 +540,7 @@ def put(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -564,6 +569,7 @@ def post(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -591,6 +597,7 @@ def post(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -619,6 +626,7 @@ def delete(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -646,6 +654,7 @@ def delete(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -674,6 +683,7 @@ def options(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -701,6 +711,7 @@ def options(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -729,6 +740,7 @@ def head(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -756,6 +768,7 @@ def head(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -784,6 +797,7 @@ def patch(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -811,6 +825,7 @@ def patch(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -839,6 +854,7 @@ def trace(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -866,6 +882,7 @@ def trace(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down
18 changes: 18 additions & 0 deletions fastapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def api_route(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -658,6 +659,7 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -822,6 +824,7 @@ def get(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -850,6 +853,7 @@ def get(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -878,6 +882,7 @@ def put(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -906,6 +911,7 @@ def put(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -934,6 +940,7 @@ def post(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -962,6 +969,7 @@ def post(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -990,6 +998,7 @@ def delete(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1018,6 +1027,7 @@ def delete(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1046,6 +1056,7 @@ def options(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1074,6 +1085,7 @@ def options(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1102,6 +1114,7 @@ def head(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1130,6 +1143,7 @@ def head(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1158,6 +1172,7 @@ def patch(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1186,6 +1201,7 @@ def patch(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1214,6 +1230,7 @@ def trace(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1243,6 +1260,7 @@ def trace(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down
59 changes: 59 additions & 0 deletions tests/test_custom_route_class_for_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Callable
from urllib.request import Request

import pytest
from fastapi import APIRouter, FastAPI, HTTPException, status
from fastapi.openapi.models import Response
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient

app = FastAPI()
router = APIRouter()


class CustomRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()

async def custom_route_handler(request: Request) -> Response:
if "test_header" not in request.headers:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
return await original_route_handler(request)

return custom_route_handler


@router.get("/a")
def get_a():
return {"msg": "A"}


@router.get("/b", route_class_override=CustomRoute)
def get_b():
return {"msg": "B"}


app.include_router(router=router, prefix="")


client = TestClient(app)


@pytest.mark.parametrize(
"path,expected_status,headers",
[
("/a", 200, {"test_header": "value"}),
("/a", 200, None),
("/b", 200, {"test_header": "value"}),
("/b", 400, None),
],
ids=[
"/a with test_header header",
"/a without test_header headers",
"/b with test_header headers",
"/b without test_header headers",
],
)
def test_get_path(path, expected_status, headers):
response = client.get(path, headers=headers)
assert response.status_code == expected_status

0 comments on commit 814d256

Please sign in to comment.