diff --git a/backend/middleware.py b/backend/middleware.py index f74091b..3a932fd 100644 --- a/backend/middleware.py +++ b/backend/middleware.py @@ -1,31 +1,37 @@ -import typing as t - import ssl from motor.motor_asyncio import AsyncIOMotorClient -from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request -from starlette.responses import JSONResponse, Response +from starlette.responses import JSONResponse +from starlette.types import ASGIApp, Scope, Receive, Send from backend.constants import DATABASE_URL, DOCS_PASSWORD, MONGO_DATABASE -class DatabaseMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: t.Callable) -> Response: +class DatabaseMiddleware: + + def __init__(self, app: ASGIApp) -> None: + self._app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: client: AsyncIOMotorClient = AsyncIOMotorClient( DATABASE_URL, ssl_cert_reqs=ssl.CERT_NONE ) db = client[MONGO_DATABASE] - request.state.db = db - response = await call_next(request) - return response + Request(scope).state.db = db + await self._app(scope, receive, send) -class ProtectedDocsMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: t.Callable) -> Response: +class ProtectedDocsMiddleware: + + def __init__(self, app: ASGIApp) -> None: + self._app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + request = Request(scope) if DOCS_PASSWORD and request.url.path.startswith("/docs"): if request.cookies.get("docs_password") != DOCS_PASSWORD: - return JSONResponse({"status": "unauthorized"}, status_code=403) - - resp = await call_next(request) - return resp + resp = JSONResponse({"status": "unauthorized"}, status_code=403) + await resp(scope, receive, send) + return + await self._app(scope, receive, send)