diff --git a/sanic/app.py b/sanic/app.py index c928f028ce..35ae8266e7 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1300,6 +1300,48 @@ def shutdown_tasks( self.purge_tasks() timeout -= increment + def shutdown_signal_handlers( + self, timeout: Optional[float] = None, increment: float = 0.1 + ) -> None: + """Cancel running signal handler tasks. + + Any running ``asyncio.Task`` with a name starting with "signal" will be + cancelled. If a :param:`timeout` is not provided, it will be set to the + ``GRACEFUL_SHUTDOWN_TIMEOUT`` config. + + :param timeout: the max amount of time to wait for the tasks to be + cancelled. Defaults to None. + :type timeout: Optional[float], optional + :param increment: the amount of time to wait between checking that the + tasks have been cancelled. Defaults to 0.1. + :type increment: float, optional + """ + logger.info("Cancelling signal handlers") + + if timeout is None: + timeout = self.config.GRACEFUL_SHUTDOWN_TIMEOUT + + signal_handlers = [ + task + for task in asyncio.all_tasks(self.loop) + if task.get_name().startswith("signal") + ] + + logger.debug("%d signal handlers found", len(signal_handlers)) + + for handler in signal_handlers: + logger.debug("Cancelling signal handler: %s", handler.get_name()) + handler.cancel() + + with suppress(RuntimeError): + while timeout and not all( + [handler.done() for handler in signal_handlers] + ): + self.loop.run_until_complete(asyncio.sleep(increment)) + timeout -= increment + + logger.info("Signal handlers cancelled") + @property def tasks(self): return iter(self._task_registry.values()) diff --git a/sanic/mixins/runner.py b/sanic/mixins/runner.py index 1df77e551e..383f94ff79 100644 --- a/sanic/mixins/runner.py +++ b/sanic/mixins/runner.py @@ -367,6 +367,7 @@ def stop(self): """ if self.state.stage is not ServerStage.STOPPED: self.shutdown_tasks(timeout=0) + self.shutdown_signal_handlers() for task in all_tasks(): with suppress(AttributeError): if task.get_name() == "RunServer": diff --git a/sanic/signals.py b/sanic/signals.py index d62a117c52..e0e8749ca7 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -192,7 +192,7 @@ async def dispatch( if inline: return await dispatch - task = asyncio.get_running_loop().create_task(dispatch) + task = asyncio.get_running_loop().create_task(dispatch, name="signal") await asyncio.sleep(0) return task diff --git a/tests/test_graceful_shutdown.py b/tests/test_graceful_shutdown.py index 54ba92d89a..a4466fc82f 100644 --- a/tests/test_graceful_shutdown.py +++ b/tests/test_graceful_shutdown.py @@ -46,3 +46,36 @@ def ping(): "Transport is closed." ) assert info == 11 + + +def test_no_exceptions_when_cancel_signal_handlers(app, caplog): + @app.signal("foo.bar.baz") + async def async_signal(*_): + await asyncio.sleep(5) + + @app.get("/") + async def handler(request): + request.app.dispatch("foo.bar.baz") + + def ping(): + httpx.get("http://127.0.0.1:8000") + + p = Process(target=ping) + p.start() + + with caplog.at_level(logging.INFO): + app.run() + + p.kill() + + info = 0 + for record in caplog.record_tuples: + assert record[1] != logging.ERROR + + if record[1] == logging.INFO and ( + record[2] == "Cancelling signal handlers" + or record[2] == "Signal handlers cancelled" + ): + info += 1 + + assert info == 2 diff --git a/tests/test_signals.py b/tests/test_signals.py index 9835430967..7feed7fcb9 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -402,3 +402,15 @@ def test_signal_reservation(app, event, expected): app.signal(event)(lambda: ...) else: app.signal(event)(lambda: ...) + + +@pytest.mark.asyncio +async def test_signal_handler_task_name(app): + @app.signal("foo.bar.baz") + def sync_signal(*_): + ... + + app.signal_router.finalize() + + signal_handler_task = await app.dispatch("foo.bar.baz") + assert signal_handler_task.get_name() == "signal"