Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Update internal
AsyncExitStack
to fix context for dependencies wi…
…th `yield` (#4575)
- Loading branch information
Showing
7 changed files
with
272 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Optional | ||
|
||
from fastapi.concurrency import AsyncExitStack | ||
from starlette.types import ASGIApp, Receive, Scope, Send | ||
|
||
|
||
class AsyncExitStackMiddleware: | ||
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None: | ||
self.app = app | ||
self.context_name = context_name | ||
|
||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
if AsyncExitStack: | ||
dependency_exception: Optional[Exception] = None | ||
async with AsyncExitStack() as stack: | ||
scope[self.context_name] = stack | ||
try: | ||
await self.app(scope, receive, send) | ||
except Exception as e: | ||
dependency_exception = e | ||
raise e | ||
if dependency_exception: | ||
# This exception was possibly handled by the dependency but it should | ||
# still bubble up so that the ServerErrorMiddleware can return a 500 | ||
# or the ExceptionMiddleware can catch and handle any other exceptions | ||
raise dependency_exception | ||
else: | ||
await self.app(scope, receive, send) # pragma: no cover |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from contextvars import ContextVar | ||
from typing import Any, Awaitable, Callable, Dict, Optional | ||
|
||
from fastapi import Depends, FastAPI, Request, Response | ||
from fastapi.testclient import TestClient | ||
|
||
legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar( | ||
"legacy_request_state_context_var", default=None | ||
) | ||
|
||
app = FastAPI() | ||
|
||
|
||
async def set_up_request_state_dependency(): | ||
request_state = {"user": "deadpond"} | ||
contextvar_token = legacy_request_state_context_var.set(request_state) | ||
yield request_state | ||
legacy_request_state_context_var.reset(contextvar_token) | ||
|
||
|
||
@app.middleware("http") | ||
async def custom_middleware( | ||
request: Request, call_next: Callable[[Request], Awaitable[Response]] | ||
): | ||
response = await call_next(request) | ||
response.headers["custom"] = "foo" | ||
return response | ||
|
||
|
||
@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)]) | ||
def get_user(): | ||
request_state = legacy_request_state_context_var.get() | ||
assert request_state | ||
return request_state["user"] | ||
|
||
|
||
client = TestClient(app) | ||
|
||
|
||
def test_dependency_contextvars(): | ||
""" | ||
Check that custom middlewares don't affect the contextvar context for dependencies. | ||
The code before yield and the code after yield should be run in the same contextvar | ||
context, so that request_state_context_var.reset(contextvar_token). | ||
If they are run in a different context, that raises an error. | ||
""" | ||
response = client.get("/user") | ||
assert response.json() == "deadpond" | ||
assert response.headers["custom"] == "foo" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import pytest | ||
from fastapi import Body, Depends, FastAPI, HTTPException | ||
from fastapi.testclient import TestClient | ||
|
||
initial_fake_database = {"rick": "Rick Sanchez"} | ||
|
||
fake_database = initial_fake_database.copy() | ||
|
||
initial_state = {"except": False, "finally": False} | ||
|
||
state = initial_state.copy() | ||
|
||
app = FastAPI() | ||
|
||
|
||
async def get_database(): | ||
temp_database = fake_database.copy() | ||
try: | ||
yield temp_database | ||
fake_database.update(temp_database) | ||
except HTTPException: | ||
state["except"] = True | ||
finally: | ||
state["finally"] = True | ||
|
||
|
||
@app.put("/invalid-user/{user_id}") | ||
def put_invalid_user( | ||
user_id: str, name: str = Body(...), db: dict = Depends(get_database) | ||
): | ||
db[user_id] = name | ||
raise HTTPException(status_code=400, detail="Invalid user") | ||
|
||
|
||
@app.put("/user/{user_id}") | ||
def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)): | ||
db[user_id] = name | ||
return {"message": "OK"} | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def reset_state_and_db(): | ||
global fake_database | ||
global state | ||
fake_database = initial_fake_database.copy() | ||
state = initial_state.copy() | ||
|
||
|
||
client = TestClient(app) | ||
|
||
|
||
def test_dependency_gets_exception(): | ||
assert state["except"] is False | ||
assert state["finally"] is False | ||
response = client.put("/invalid-user/rick", json="Morty") | ||
assert response.status_code == 400, response.text | ||
assert response.json() == {"detail": "Invalid user"} | ||
assert state["except"] is True | ||
assert state["finally"] is True | ||
assert fake_database["rick"] == "Rick Sanchez" | ||
|
||
|
||
def test_dependency_no_exception(): | ||
assert state["except"] is False | ||
assert state["finally"] is False | ||
response = client.put("/user/rick", json="Morty") | ||
assert response.status_code == 200, response.text | ||
assert response.json() == {"message": "OK"} | ||
assert state["except"] is False | ||
assert state["finally"] is True | ||
assert fake_database["rick"] == "Morty" |
Oops, something went wrong.