From 8f265b81693fbffcb374236d08ff772d3347a96e Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 19 Mar 2023 15:40:58 +0200 Subject: [PATCH] Add convenience for dynamic changes to routing (#2704) --- sanic/app.py | 59 +++++++++++++++++++++++++++++------------ tests/test_late_adds.py | 54 +++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 17 deletions(-) create mode 100644 tests/test_late_adds.py diff --git a/sanic/app.py b/sanic/app.py index 1adc273235..9fb7e02701 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -16,7 +16,7 @@ ) from asyncio.futures import Future from collections import defaultdict, deque -from contextlib import suppress +from contextlib import contextmanager, suppress from functools import partial from inspect import isawaitable from os import environ @@ -33,6 +33,7 @@ Deque, Dict, Iterable, + Iterator, List, Optional, Set, @@ -433,14 +434,15 @@ def _apply_route(self, route: FutureRoute) -> List[Route]: ctx = params.pop("route_context") - routes = self.router.add(**params) - if isinstance(routes, Route): - routes = [routes] + with self.amend(): + routes = self.router.add(**params) + if isinstance(routes, Route): + routes = [routes] - for r in routes: - r.extra.websocket = websocket - r.extra.static = params.get("static", False) - r.ctx.__dict__.update(ctx) + for r in routes: + r.extra.websocket = websocket + r.extra.static = params.get("static", False) + r.ctx.__dict__.update(ctx) return routes @@ -449,17 +451,19 @@ def _apply_middleware( middleware: FutureMiddleware, route_names: Optional[List[str]] = None, ): - if route_names: - return self.register_named_middleware( - middleware.middleware, route_names, middleware.attach_to - ) - else: - return self.register_middleware( - middleware.middleware, middleware.attach_to - ) + with self.amend(): + if route_names: + return self.register_named_middleware( + middleware.middleware, route_names, middleware.attach_to + ) + else: + return self.register_middleware( + middleware.middleware, middleware.attach_to + ) def _apply_signal(self, signal: FutureSignal) -> Signal: - return self.signal_router.add(*signal) + with self.amend(): + return self.signal_router.add(*signal) def dispatch( self, @@ -1520,6 +1524,27 @@ def _check_uvloop_conflict(cls) -> None: # Lifecycle # -------------------------------------------------------------------- # + @contextmanager + def amend(self) -> Iterator[None]: + """ + If the application has started, this function allows changes + to be made to add routes, middleware, and signals. + """ + if not self.state.is_started: + yield + else: + do_router = self.router.finalized + do_signal_router = self.signal_router.finalized + if do_router: + self.router.reset() + if do_signal_router: + self.signal_router.reset() + yield + if do_signal_router: + self.signalize(self.config.TOUCHUP) + if do_router: + self.finalize() + def finalize(self): try: self.router.finalize() diff --git a/tests/test_late_adds.py b/tests/test_late_adds.py new file mode 100644 index 0000000000..f7281d3862 --- /dev/null +++ b/tests/test_late_adds.py @@ -0,0 +1,54 @@ +import pytest + +from sanic import Sanic, text + + +@pytest.fixture +def late_app(app: Sanic): + app.config.TOUCHUP = False + app.get("/")(lambda _: text("")) + return app + + +def test_late_route(late_app: Sanic): + @late_app.before_server_start + async def late(app: Sanic): + @app.get("/late") + def handler(_): + return text("late") + + _, response = late_app.test_client.get("/late") + assert response.status_code == 200 + assert response.text == "late" + + +def test_late_middleware(late_app: Sanic): + @late_app.get("/late") + def handler(request): + return text(request.ctx.late) + + @late_app.before_server_start + async def late(app: Sanic): + @app.on_request + def handler(request): + request.ctx.late = "late" + + _, response = late_app.test_client.get("/late") + assert response.status_code == 200 + assert response.text == "late" + + +def test_late_signal(late_app: Sanic): + @late_app.get("/late") + def handler(request): + return text(request.ctx.late) + + @late_app.before_server_start + async def late(app: Sanic): + @app.signal("http.lifecycle.request") + def handler(request): + request.ctx.late = "late" + + _, response = late_app.test_client.get("/late") + assert response.status_code == 200 + assert response.text == "late"