Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add support for dependencies in WebSocket routes #4534

Merged
merged 12 commits into from Jun 11, 2023
Merged
27 changes: 23 additions & 4 deletions fastapi/applications.py
Expand Up @@ -380,15 +380,34 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return decorator

def add_api_websocket_route(
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
self,
path: str,
endpoint: Callable[..., Any],
*,
dependencies: Optional[Sequence[Depends]] = None,
name: Optional[str] = None,
) -> None:
self.router.add_api_websocket_route(path, endpoint, name=name)
self.router.add_api_websocket_route(
path,
endpoint,
dependencies=dependencies,
name=name,
)

def websocket(
self, path: str, name: Optional[str] = None
self,
path: str,
*,
dependencies: Optional[Sequence[Depends]] = None,
name: Optional[str] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
self.add_api_websocket_route(
path,
func,
dependencies=dependencies,
name=name,
)
return func

return decorator
Expand Down
50 changes: 39 additions & 11 deletions fastapi/routing.py
Expand Up @@ -282,20 +282,27 @@ def __init__(
path: str,
endpoint: Callable[..., Any],
*,
dependencies: Optional[Sequence[params.Depends]] = None,
name: Optional[str] = None,
dependency_overrides_provider: Optional[Any] = None,
) -> None:
self.path = path
self.endpoint = endpoint
self.dependencies = dependencies or []
self.name = get_name(endpoint) if name is None else name
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=path, call=self.endpoint)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self.app = websocket_session(
get_websocket_app(
dependant=self.dependant,
dependency_overrides_provider=dependency_overrides_provider,
)
)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)

def matches(self, scope: Scope) -> Tuple[Match, Scope]:
match, child_scope = super().matches(scope)
Expand Down Expand Up @@ -397,10 +404,7 @@ def __init__(
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
if dependencies:
self.dependencies = list(dependencies)
else:
self.dependencies = []
self.dependencies = dependencies or []
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed"
Expand Down Expand Up @@ -491,7 +495,7 @@ def __init__(
), "A path prefix must not end with '/', as the routes will start with '/'"
self.prefix = prefix
self.tags: List[Union[str, Enum]] = tags or []
self.dependencies = list(dependencies or []) or []
self.dependencies = dependencies or []
self.deprecated = deprecated
self.include_in_schema = include_in_schema
self.responses = responses or {}
Expand Down Expand Up @@ -544,7 +548,7 @@ def add_api_route(
current_tags = self.tags.copy()
if tags:
current_tags.extend(tags)
current_dependencies = self.dependencies.copy()
current_dependencies = list(self.dependencies)
if dependencies:
current_dependencies.extend(dependencies)
current_callbacks = self.callbacks.copy()
Expand Down Expand Up @@ -646,21 +650,37 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return decorator

def add_api_websocket_route(
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
self,
path: str,
endpoint: Callable[..., Any],
*,
dependencies: Optional[Sequence[params.Depends]] = None,
name: Optional[str] = None,
paulo-raca marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
current_dependencies = list(self.dependencies)
if dependencies:
current_dependencies.extend(dependencies)

route = APIWebSocketRoute(
path,
endpoint=endpoint,
name=name,
dependency_overrides_provider=self.dependency_overrides_provider,
dependencies=current_dependencies,
)
self.routes.append(route)

def websocket(
self, path: str, name: Optional[str] = None
self,
path: str,
*,
dependencies: Optional[Sequence[params.Depends]] = None,
name: Optional[str] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
self.add_api_websocket_route(
path, func, name=name, dependencies=dependencies
)
return func

return decorator
Expand Down Expand Up @@ -766,8 +786,16 @@ def include_router(
name=route.name,
)
elif isinstance(route, APIWebSocketRoute):
current_dependencies: List[params.Depends] = []
if dependencies:
current_dependencies.extend(dependencies)
if route.dependencies:
current_dependencies.extend(route.dependencies)
self.add_api_websocket_route(
prefix + route.path, route.endpoint, name=route.name
prefix + route.path,
route.endpoint,
dependencies=current_dependencies,
name=route.name,
)
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
Expand Down
70 changes: 70 additions & 0 deletions tests/test_ws_dependencies.py
@@ -0,0 +1,70 @@
import json
from typing import Optional

from fastapi import APIRouter, Depends, FastAPI, WebSocket
from fastapi.testclient import TestClient


def dependency_list(deps: Optional[str] = None) -> list[str]:
return [deps] if deps else []


def create_dependency(name: str):
def fun(deps: list[str] = Depends(dependency_list)):
print(f"create_dependency.fun({name})")
deps.append(name)

return Depends(fun)


router = APIRouter(dependencies=[create_dependency("router")])
prefix_router = APIRouter(dependencies=[create_dependency("prefix_router")])
app = FastAPI(dependencies=[create_dependency("app")])


@app.websocket("/", dependencies=[create_dependency("index")])
async def index(websocket: WebSocket, deps=Depends(dependency_list)):
await websocket.accept()
await websocket.send_text(json.dumps(deps))
await websocket.close()


@router.websocket("/router", dependencies=[create_dependency("routerindex")])
async def routerindex(websocket: WebSocket, deps=Depends(dependency_list)):
await websocket.accept()
await websocket.send_text(json.dumps(deps))
await websocket.close()


@prefix_router.websocket("/", dependencies=[create_dependency("routerprefixindex")])
async def routerprefixindex(websocket: WebSocket, deps=Depends(dependency_list)):
await websocket.accept()
await websocket.send_text(json.dumps(deps))
await websocket.close()


app.include_router(router, dependencies=[create_dependency("router2")])
app.include_router(
prefix_router, prefix="/prefix", dependencies=[create_dependency("prefix_router2")]
)


def test_index():
client = TestClient(app)
with client.websocket_connect("/") as websocket:
data = json.loads(websocket.receive_text())
assert data == ["app", "index"]


def test_routerindex():
client = TestClient(app)
with client.websocket_connect("/router") as websocket:
data = json.loads(websocket.receive_text())
assert data == ["app", "router2", "router", "routerindex"]


def test_routerprefixindex():
client = TestClient(app)
with client.websocket_connect("/prefix/") as websocket:
data = json.loads(websocket.receive_text())
assert data == ["app", "prefix_router2", "prefix_router", "routerprefixindex"]