From edcae93e0a9a2e7981581ea14f8739fe5beee4e7 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 18 Dec 2023 12:19:34 +0200 Subject: [PATCH 1/4] Add Hub Mode for inspector --- sanic/config.py | 14 ++--- sanic/mixins/startup.py | 57 +++++++------------- sanic/worker/inspector.py | 110 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 133 insertions(+), 48 deletions(-) diff --git a/sanic/config.py b/sanic/config.py index 305949fb65..e46fd3c9aa 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -50,6 +50,9 @@ "INSPECTOR_TLS_KEY": _default, "INSPECTOR_TLS_CERT": _default, "INSPECTOR_API_KEY": "", + "INSPECTOR_HUB_MODE": _default, + "INSPECTOR_HUB_HOST": "", + "INSPECTOR_HUB_PORT": 0, "KEEP_ALIVE_TIMEOUT": 120, "KEEP_ALIVE": True, "LOCAL_CERT_CREATOR": LocalCertCreator.AUTO, @@ -108,6 +111,9 @@ class Config(dict, metaclass=DescriptorMeta): INSPECTOR_TLS_KEY: Union[Path, str, Default] INSPECTOR_TLS_CERT: Union[Path, str, Default] INSPECTOR_API_KEY: str + INSPECTOR_HUB_MODE: Union[bool, Default] + INSPECTOR_HUB_HOST: str + INSPECTOR_HUB_PORT: int KEEP_ALIVE_TIMEOUT: int KEEP_ALIVE: bool LOCAL_CERT_CREATOR: Union[str, LocalCertCreator] @@ -135,9 +141,7 @@ class Config(dict, metaclass=DescriptorMeta): def __init__( self, - defaults: Optional[ - Dict[str, Union[str, bool, int, float, None]] - ] = None, + defaults: Optional[Dict[str, Union[str, bool, int, float, None]]] = None, env_prefix: Optional[str] = SANIC_PREFIX, keep_alive: Optional[bool] = None, *, @@ -237,9 +241,7 @@ def _post_set(self, attr, value) -> None: if attr == "LOCAL_CERT_CREATOR" and not isinstance( self.LOCAL_CERT_CREATOR, LocalCertCreator ): - self.LOCAL_CERT_CREATOR = LocalCertCreator[ - self.LOCAL_CERT_CREATOR.upper() - ] + self.LOCAL_CERT_CREATOR = LocalCertCreator[self.LOCAL_CERT_CREATOR.upper()] elif attr == "DEPRECATION_FILTER": self._configure_warnings() diff --git a/sanic/mixins/startup.py b/sanic/mixins/startup.py index 5eb42eca2d..faa2c722cc 100644 --- a/sanic/mixins/startup.py +++ b/sanic/mixins/startup.py @@ -114,8 +114,7 @@ def setup_loop(self) -> None: """ if not self.asgi: if self.config.USE_UVLOOP is True or ( - isinstance(self.config.USE_UVLOOP, Default) - and not OS_IS_WINDOWS + isinstance(self.config.USE_UVLOOP, Default) and not OS_IS_WINDOWS ): try_use_uvloop() elif OS_IS_WINDOWS: @@ -387,8 +386,7 @@ def prepare( if single_process and (fast or (workers > 1) or auto_reload): raise RuntimeError( - "Single process cannot be run with multiple workers " - "or auto-reload" + "Single process cannot be run with multiple workers " "or auto-reload" ) if register_sys_signals is False and not single_process: @@ -407,9 +405,7 @@ def prepare( for directory in reload_dir: direc = Path(directory) if not direc.is_dir(): - logger.warning( - f"Directory {directory} could not be located" - ) + logger.warning(f"Directory {directory} could not be located") self.state.reload_dirs.add(Path(directory)) if loop is not None: @@ -424,9 +420,7 @@ def prepare( host, port = self.get_address(host, port, version, auto_tls) if protocol is None: - protocol = ( - WebSocketProtocol if self.websocket_enabled else HttpProtocol - ) + protocol = WebSocketProtocol if self.websocket_enabled else HttpProtocol # Set explicitly passed configuration values for attribute, value in { @@ -462,9 +456,7 @@ def prepare( register_sys_signals=register_sys_signals, auto_tls=auto_tls, ) - self.state.server_info.append( - ApplicationServerInfo(settings=server_settings) - ) + self.state.server_info.append(ApplicationServerInfo(settings=server_settings)) # if self.config.USE_UVLOOP is True or ( # self.config.USE_UVLOOP is _default and not OS_IS_WINDOWS @@ -560,9 +552,7 @@ async def main(): host, port = host, port = self.get_address(host, port) if protocol is None: - protocol = ( - WebSocketProtocol if self.websocket_enabled else HttpProtocol - ) + protocol = WebSocketProtocol if self.websocket_enabled else HttpProtocol # Set explicitly passed configuration values for attribute, value in { @@ -805,10 +795,7 @@ def get_motd_data( reload_display += ", ".join( [ "", - *( - str(path.absolute()) - for path in self.state.reload_dirs - ), + *(str(path.absolute()) for path in self.state.reload_dirs), ] ) display["auto-reload"] = reload_display @@ -914,9 +901,7 @@ def should_auto_reload(cls) -> bool: @classmethod def _get_startup_method(cls) -> str: return ( - cls.start_method - if not isinstance(cls.start_method, Default) - else "spawn" + cls.start_method if not isinstance(cls.start_method, Default) else "spawn" ) @classmethod @@ -1010,9 +995,7 @@ def serve( try: primary = apps[0] except IndexError: - raise RuntimeError( - "Did not find any applications." - ) from None + raise RuntimeError("Did not find any applications.") from None # This exists primarily for unit testing if not primary.state.server_info: # no cov @@ -1115,9 +1098,7 @@ def serve( inspector = None if primary.config.INSPECTOR: display, extra = primary.get_motd_data() - packages = [ - pkg.strip() for pkg in display["packages"].split(",") - ] + packages = [pkg.strip() for pkg in display["packages"].split(",")] module = import_module("sanic") sanic_version = f"sanic=={module.__version__}" # type: ignore app_info = { @@ -1134,8 +1115,12 @@ def serve( primary.config.INSPECTOR_API_KEY, primary.config.INSPECTOR_TLS_KEY, primary.config.INSPECTOR_TLS_CERT, + primary.config.INSPECTOR_HUB_MODE, + primary.config.INSPECTOR_HUB_HOST, + primary.config.INSPECTOR_HUB_PORT, ) - manager.manage("Inspector", inspector, {}, transient=False) + # TODO: Change back to false + manager.manage("Inspector", inspector, {}, transient=True) primary._inspector = inspector primary._manager = manager @@ -1148,9 +1133,7 @@ def serve( exit_code = 1 except BaseException: kwargs = primary_server_info.settings - error_logger.exception( - "Experienced exception while trying to serve" - ) + error_logger.exception("Experienced exception while trying to serve") raise finally: logger.info("Server Stopped") @@ -1191,9 +1174,7 @@ def serve( @staticmethod def _get_process_states(worker_state) -> List[str]: - return [ - state for s in worker_state.values() if (state := s.get("state")) - ] + return [state for s in worker_state.values() if (state := s.get("state"))] @classmethod def serve_single(cls, primary: Optional[Sanic] = None) -> None: @@ -1274,9 +1255,7 @@ def serve_single(cls, primary: Optional[Sanic] = None) -> None: try: worker_serve(monitor_publisher=None, **kwargs) except BaseException: - error_logger.exception( - "Experienced exception while trying to serve" - ) + error_logger.exception("Experienced exception while trying to serve") raise finally: logger.info("Server Stopped") diff --git a/sanic/worker/inspector.py b/sanic/worker/inspector.py index 524c7bbe40..32ca7c200d 100644 --- a/sanic/worker/inspector.py +++ b/sanic/worker/inspector.py @@ -2,16 +2,59 @@ from datetime import datetime from inspect import isawaitable +from logging import debug from multiprocessing.connection import Connection from os import environ from pathlib import Path -from typing import Any, Dict, Mapping, Union - +from typing import TYPE_CHECKING, Any, Dict, Mapping, Union, Tuple +from asyncio import sleep from sanic.exceptions import Unauthorized -from sanic.helpers import Default +from sanic.helpers import Default, _default from sanic.log import logger from sanic.request import Request from sanic.response import json +from dataclasses import dataclass +from websockets import connection, connect, ConnectionClosed +from sanic.server.websockets.impl import WebsocketImplProtocol + +if TYPE_CHECKING: + from sanic import Sanic + + +@dataclass +class NodeState: + ... + + +@dataclass +class HubState: + nodes: Dict[str, NodeState] + + +class NodeClient: + def __init__(self, hub_host: str, hub_port: int) -> None: + self.hub_host = hub_host + self.hub_port = hub_port + + async def run(self, state_getter) -> None: + try: + async for ws in connect(f"ws://{self.hub_host}:{self.hub_port}/hub"): + try: + await self._run_node(ws, state_getter) + except ConnectionClosed: + continue + except BaseException: + ... + finally: + print("Node out") + + def _setup_ws_client(self, hub_host: str, hub_port: int) -> connection: + return connect(f"ws://{hub_host}:{hub_port}/hub") + + async def _run_node(self, ws: connection, state_getter) -> None: + while True: + await ws.send(str(state_getter())) + await sleep(3) class Inspector: @@ -46,7 +89,13 @@ def __init__( api_key: str, tls_key: Union[Path, str, Default], tls_cert: Union[Path, str, Default], + hub_mode: Union[bool, Default] = _default, + hub_host: str = "", + hub_port: int = 0, ): + hub_mode, node_mode = self._detect_modes( + hub_mode, host, port, hub_host, hub_port + ) self._publisher = publisher self.app_info = app_info self.worker_state = worker_state @@ -55,6 +104,10 @@ def __init__( self.api_key = api_key self.tls_key = tls_key self.tls_cert = tls_cert + self.hub_mode = hub_mode + self.node_mode = node_mode + self.hub_host = hub_host + self.hub_port = hub_port def __call__(self, run=True, **_) -> Inspector: from sanic import Sanic @@ -70,14 +123,44 @@ def __call__(self, run=True, **_) -> Inspector: if not isinstance(self.tls_key, Default) and not isinstance(self.tls_cert, Default) else None, + debug=True, ) return self + def _detect_modes( + self, + hub_mode: Union[bool, Default], + host: str, + port: int, + hub_host: str, + hub_port: int, + ) -> Tuple[bool, bool]: + print(hub_mode, host, port, hub_host, hub_port) + if hub_host == host and hub_port == port: + if not hub_mode: + raise ValueError( + "Hub mode must be enabled when using the same host and port" + ) + hub_mode = True + if (hub_host and not hub_port) or (hub_port and not hub_host): + raise ValueError("Both hub host and hub port must be specified") + if hub_mode is True: + return True, False + elif hub_host and hub_port: + return False, True + else: + return False, False + def _setup(self): self.app.get("/")(self._info) self.app.post("/")(self._action) if self.api_key: self.app.on_request(self._authentication) + if self.hub_mode: + self.app.before_server_start(self._setup_hub) + self.app.websocket("/hub")(self._hub) + if self.node_mode: + self.app.before_server_start(self._run_node) environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true" def _authentication(self, request: Request) -> None: @@ -154,3 +237,24 @@ def shutdown(self) -> None: """Shutdown the workers""" message = "__TERMINATE__" self._publisher.send(message) + + def _setup_hub(self, app: Sanic) -> None: + app.ctx.hub_state = HubState(nodes={}) + + @staticmethod + async def _hub( + request: Request, + websocket: WebsocketImplProtocol, + ) -> None: + hub_state = request.app.ctx.hub_state + hub_state.nodes[request.id] = NodeState() + while True: + message = await websocket.recv() + if message == "ping": + await websocket.send("pong") + else: + logger.info("Hub received message: %s", message) + + async def _run_node(self, app: Sanic) -> None: + client = NodeClient(self.hub_host, self.hub_port) + app.add_task(client.run(self._state_to_json)) From 2f3ae12fbb1ad79ba15097de8733bcaed2cf6e8a Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 18 Dec 2023 12:24:28 +0200 Subject: [PATCH 2/4] Fix error --- sanic/config.py | 8 ++++-- sanic/mixins/startup.py | 51 +++++++++++++++++++++++++++++---------- sanic/worker/inspector.py | 19 +++++++++------ 3 files changed, 56 insertions(+), 22 deletions(-) diff --git a/sanic/config.py b/sanic/config.py index e46fd3c9aa..f0bc7c2731 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -141,7 +141,9 @@ class Config(dict, metaclass=DescriptorMeta): def __init__( self, - defaults: Optional[Dict[str, Union[str, bool, int, float, None]]] = None, + defaults: Optional[ + Dict[str, Union[str, bool, int, float, None]] + ] = None, env_prefix: Optional[str] = SANIC_PREFIX, keep_alive: Optional[bool] = None, *, @@ -241,7 +243,9 @@ def _post_set(self, attr, value) -> None: if attr == "LOCAL_CERT_CREATOR" and not isinstance( self.LOCAL_CERT_CREATOR, LocalCertCreator ): - self.LOCAL_CERT_CREATOR = LocalCertCreator[self.LOCAL_CERT_CREATOR.upper()] + self.LOCAL_CERT_CREATOR = LocalCertCreator[ + self.LOCAL_CERT_CREATOR.upper() + ] elif attr == "DEPRECATION_FILTER": self._configure_warnings() diff --git a/sanic/mixins/startup.py b/sanic/mixins/startup.py index faa2c722cc..01cdf9a030 100644 --- a/sanic/mixins/startup.py +++ b/sanic/mixins/startup.py @@ -114,7 +114,8 @@ def setup_loop(self) -> None: """ if not self.asgi: if self.config.USE_UVLOOP is True or ( - isinstance(self.config.USE_UVLOOP, Default) and not OS_IS_WINDOWS + isinstance(self.config.USE_UVLOOP, Default) + and not OS_IS_WINDOWS ): try_use_uvloop() elif OS_IS_WINDOWS: @@ -386,7 +387,8 @@ def prepare( if single_process and (fast or (workers > 1) or auto_reload): raise RuntimeError( - "Single process cannot be run with multiple workers " "or auto-reload" + "Single process cannot be run with multiple workers " + "or auto-reload" ) if register_sys_signals is False and not single_process: @@ -405,7 +407,9 @@ def prepare( for directory in reload_dir: direc = Path(directory) if not direc.is_dir(): - logger.warning(f"Directory {directory} could not be located") + logger.warning( + f"Directory {directory} could not be located" + ) self.state.reload_dirs.add(Path(directory)) if loop is not None: @@ -420,7 +424,9 @@ def prepare( host, port = self.get_address(host, port, version, auto_tls) if protocol is None: - protocol = WebSocketProtocol if self.websocket_enabled else HttpProtocol + protocol = ( + WebSocketProtocol if self.websocket_enabled else HttpProtocol + ) # Set explicitly passed configuration values for attribute, value in { @@ -456,7 +462,9 @@ def prepare( register_sys_signals=register_sys_signals, auto_tls=auto_tls, ) - self.state.server_info.append(ApplicationServerInfo(settings=server_settings)) + self.state.server_info.append( + ApplicationServerInfo(settings=server_settings) + ) # if self.config.USE_UVLOOP is True or ( # self.config.USE_UVLOOP is _default and not OS_IS_WINDOWS @@ -552,7 +560,9 @@ async def main(): host, port = host, port = self.get_address(host, port) if protocol is None: - protocol = WebSocketProtocol if self.websocket_enabled else HttpProtocol + protocol = ( + WebSocketProtocol if self.websocket_enabled else HttpProtocol + ) # Set explicitly passed configuration values for attribute, value in { @@ -795,7 +805,10 @@ def get_motd_data( reload_display += ", ".join( [ "", - *(str(path.absolute()) for path in self.state.reload_dirs), + *( + str(path.absolute()) + for path in self.state.reload_dirs + ), ] ) display["auto-reload"] = reload_display @@ -901,7 +914,9 @@ def should_auto_reload(cls) -> bool: @classmethod def _get_startup_method(cls) -> str: return ( - cls.start_method if not isinstance(cls.start_method, Default) else "spawn" + cls.start_method + if not isinstance(cls.start_method, Default) + else "spawn" ) @classmethod @@ -995,7 +1010,9 @@ def serve( try: primary = apps[0] except IndexError: - raise RuntimeError("Did not find any applications.") from None + raise RuntimeError( + "Did not find any applications." + ) from None # This exists primarily for unit testing if not primary.state.server_info: # no cov @@ -1098,7 +1115,9 @@ def serve( inspector = None if primary.config.INSPECTOR: display, extra = primary.get_motd_data() - packages = [pkg.strip() for pkg in display["packages"].split(",")] + packages = [ + pkg.strip() for pkg in display["packages"].split(",") + ] module = import_module("sanic") sanic_version = f"sanic=={module.__version__}" # type: ignore app_info = { @@ -1133,7 +1152,9 @@ def serve( exit_code = 1 except BaseException: kwargs = primary_server_info.settings - error_logger.exception("Experienced exception while trying to serve") + error_logger.exception( + "Experienced exception while trying to serve" + ) raise finally: logger.info("Server Stopped") @@ -1174,7 +1195,9 @@ def serve( @staticmethod def _get_process_states(worker_state) -> List[str]: - return [state for s in worker_state.values() if (state := s.get("state"))] + return [ + state for s in worker_state.values() if (state := s.get("state")) + ] @classmethod def serve_single(cls, primary: Optional[Sanic] = None) -> None: @@ -1255,7 +1278,9 @@ def serve_single(cls, primary: Optional[Sanic] = None) -> None: try: worker_serve(monitor_publisher=None, **kwargs) except BaseException: - error_logger.exception("Experienced exception while trying to serve") + error_logger.exception( + "Experienced exception while trying to serve" + ) raise finally: logger.info("Server Stopped") diff --git a/sanic/worker/inspector.py b/sanic/worker/inspector.py index 32ca7c200d..0958752ac3 100644 --- a/sanic/worker/inspector.py +++ b/sanic/worker/inspector.py @@ -1,22 +1,24 @@ from __future__ import annotations +from asyncio import sleep +from dataclasses import dataclass from datetime import datetime from inspect import isawaitable -from logging import debug from multiprocessing.connection import Connection from os import environ from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Mapping, Union, Tuple -from asyncio import sleep +from typing import TYPE_CHECKING, Any, Dict, Mapping, Tuple, Union + +from websockets import ConnectionClosed, connect, connection + from sanic.exceptions import Unauthorized from sanic.helpers import Default, _default from sanic.log import logger from sanic.request import Request from sanic.response import json -from dataclasses import dataclass -from websockets import connection, connect, ConnectionClosed from sanic.server.websockets.impl import WebsocketImplProtocol + if TYPE_CHECKING: from sanic import Sanic @@ -38,7 +40,9 @@ def __init__(self, hub_host: str, hub_port: int) -> None: async def run(self, state_getter) -> None: try: - async for ws in connect(f"ws://{self.hub_host}:{self.hub_port}/hub"): + async for ws in connect( + f"ws://{self.hub_host}:{self.hub_port}/hub" + ): try: await self._run_node(ws, state_getter) except ConnectionClosed: @@ -139,7 +143,8 @@ def _detect_modes( if hub_host == host and hub_port == port: if not hub_mode: raise ValueError( - "Hub mode must be enabled when using the same host and port" + "Hub mode must be enabled when using the same " + "host and port for the hub and the inspector" ) hub_mode = True if (hub_host and not hub_port) or (hub_port and not hub_host): From 69ed169f00d16a8493f18512ee91ca8b1aeb6b6e Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 26 Dec 2023 08:46:37 +0200 Subject: [PATCH 3/4] Change ws connection close --- sanic/worker/inspector.py | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/sanic/worker/inspector.py b/sanic/worker/inspector.py index 0958752ac3..04e9b24d86 100644 --- a/sanic/worker/inspector.py +++ b/sanic/worker/inspector.py @@ -1,6 +1,6 @@ from __future__ import annotations -from asyncio import sleep +from asyncio import sleep, run as run_async from dataclasses import dataclass from datetime import datetime from inspect import isawaitable @@ -8,6 +8,7 @@ from os import environ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Mapping, Tuple, Union +from signal import SIGTERM, SIGINT, signal from websockets import ConnectionClosed, connect, connection @@ -37,16 +38,18 @@ class NodeClient: def __init__(self, hub_host: str, hub_port: int) -> None: self.hub_host = hub_host self.hub_port = hub_port + self._run = True async def run(self, state_getter) -> None: try: - async for ws in connect( - f"ws://{self.hub_host}:{self.hub_port}/hub" - ): + async for ws in connect(f"ws://{self.hub_host}:{self.hub_port}/hub"): + # async with connect(f"ws://{self.hub_host}:{self.hub_port}/hub") as ws: try: - await self._run_node(ws, state_getter) + close = await self._run_node(ws, state_getter) + if close: + ... except ConnectionClosed: - continue + ... except BaseException: ... finally: @@ -56,9 +59,13 @@ def _setup_ws_client(self, hub_host: str, hub_port: int) -> connection: return connect(f"ws://{hub_host}:{hub_port}/hub") async def _run_node(self, ws: connection, state_getter) -> None: - while True: + while self._run: await ws.send(str(state_getter())) await sleep(3) + return True + + def close(self, *args): + self._run = False class Inspector: @@ -118,7 +125,9 @@ def __call__(self, run=True, **_) -> Inspector: self.app = Sanic("Inspector") self._setup() - if run: + if self.node_mode: + run_async(self._run_node()) + elif run: self.app.run( host=self.host, port=self.port, @@ -260,6 +269,13 @@ async def _hub( else: logger.info("Hub received message: %s", message) - async def _run_node(self, app: Sanic) -> None: + async def _run_node(self) -> None: client = NodeClient(self.hub_host, self.hub_port) - app.add_task(client.run(self._state_to_json)) + + def signal_close(*args, **kwargs): + client.close() + + signal(SIGTERM, signal_close) + signal(SIGINT, signal_close) + + await client.run(self._state_to_json) From 44166ad45430341b2c814c1a4eb1b0796e431d23 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 23 Jun 2024 10:54:23 +0300 Subject: [PATCH 4/4] Pass Node state --- sanic/cli/inspector_client.py | 25 ++++- sanic/worker/inspector.py | 170 ++++++++++++++++++++++++++++------ 2 files changed, 165 insertions(+), 30 deletions(-) diff --git a/sanic/cli/inspector_client.py b/sanic/cli/inspector_client.py index fd22bbd869..f2c1e2bbb4 100644 --- a/sanic/cli/inspector_client.py +++ b/sanic/cli/inspector_client.py @@ -55,20 +55,39 @@ def do(self, action: str, **kwargs: Any) -> None: sys.stdout.write(out + "\n") def info(self) -> None: - out = sys.stdout.write response = self.request("", "GET") if self.raw or not response: return data = response["result"] display = data.pop("info") + nodes = data.pop("nodes", {}) + self._display_info(display) + self._display_workers(data["workers"], None if not nodes else "Hub") + if nodes: + for name, node in nodes.items(): + # info = node.pop("info") + workers = node.pop("workers") + # self._display_info(info) + self._display_workers(workers, name) + + def _display_info(self, display: Dict[str, Any]) -> None: extra = display.pop("extra", {}) + out = sys.stdout.write display["packages"] = ", ".join(display["packages"]) MOTDTTY(get_logo(), self.base_url, display, extra).display( version=False, action="Inspecting", out=out, ) - for name, info in data["workers"].items(): + + def _display_workers( + self, workers: Dict[str, Dict[str, Any]], node: Optional[str] = None + ) -> None: + out = sys.stdout.write + for name, info in workers.items(): + name = f"{Colors.BOLD}{Colors.SANIC}{name}{Colors.END}" + if node: + name += f" {Colors.BOLD}{Colors.YELLOW}({node}){Colors.END}" info = "\n".join( f"\t{key}: {Colors.BLUE}{value}{Colors.END}" for key, value in info.items() @@ -78,7 +97,7 @@ def info(self) -> None: + indent( "\n".join( [ - f"{Colors.BOLD}{Colors.SANIC}{name}{Colors.END}", + name, info, ] ), diff --git a/sanic/worker/inspector.py b/sanic/worker/inspector.py index 04e9b24d86..b143314fb9 100644 --- a/sanic/worker/inspector.py +++ b/sanic/worker/inspector.py @@ -1,16 +1,36 @@ from __future__ import annotations -from asyncio import sleep, run as run_async -from dataclasses import dataclass +import random + +from asyncio import ( + Task, + get_running_loop, + sleep, +) +from asyncio import ( + run as run_async, +) +from dataclasses import asdict, dataclass from datetime import datetime from inspect import isawaitable from multiprocessing.connection import Connection from os import environ from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Mapping, Tuple, Union -from signal import SIGTERM, SIGINT, signal - -from websockets import ConnectionClosed, connect, connection +from signal import SIGINT, SIGTERM, signal +from string import ascii_lowercase, ascii_uppercase +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Mapping, + Optional, + Tuple, + Union, +) + +from websockets import WebSocketException, connection +from websockets.legacy.client import Connect, WebSocketClientProtocol from sanic.exceptions import Unauthorized from sanic.helpers import Default, _default @@ -20,13 +40,21 @@ from sanic.server.websockets.impl import WebsocketImplProtocol +try: + from ujson import dumps as dump_json + from ujson import loads as load_json +except ImportError: + from json import dumps as dump_json + from json import loads as load_json + if TYPE_CHECKING: from sanic import Sanic @dataclass class NodeState: - ... + info: Dict[str, Any] + workers: Dict[str, Any] @dataclass @@ -34,38 +62,107 @@ class HubState: nodes: Dict[str, NodeState] +class HubConnection(Connect): + MAX_RETRIES = 6 + BACKOFF_MAX = 15 + + async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: + backoff_delay = self.BACKOFF_MIN + failures = 0 + while True: + if failures >= self.MAX_RETRIES: + raise RuntimeError( + "Could not connect to bridge " + f"after {self.MAX_RETRIES} retries" + ) + try: + async with self as protocol: + if failures > 0: + self.logger.info( + "! connect succeeded after %d failures", failures + ) + failures = 0 + yield protocol + except Exception: + # Add a random initial delay between 0 and 5 seconds. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. + if backoff_delay == self.BACKOFF_MIN: + initial_delay = random.random() * self.BACKOFF_INITIAL + self.logger.info( + "! connect failed; reconnecting in %.1f seconds", + initial_delay, + ) + self.logger.debug("Exception", exc_info=True) + await sleep(initial_delay) + else: + self.logger.info( + "! connect failed again; retrying in %d seconds", + int(backoff_delay), + ) + self.logger.debug("Exception", exc_info=True) + await sleep(int(backoff_delay)) + # Increase delay with truncated exponential backoff. + backoff_delay = backoff_delay * self.BACKOFF_FACTOR + backoff_delay = min(backoff_delay, self.BACKOFF_MAX) + failures += 1 + continue + else: + # Connection succeeded - reset backoff delay + backoff_delay = self.BACKOFF_MIN + + class NodeClient: def __init__(self, hub_host: str, hub_port: int) -> None: self.hub_host = hub_host self.hub_port = hub_port self._run = True + self._heartbeat_task: Optional[Task] = None + self._command_task: Optional[Task] = None async def run(self, state_getter) -> None: + loop = get_running_loop() try: - async for ws in connect(f"ws://{self.hub_host}:{self.hub_port}/hub"): - # async with connect(f"ws://{self.hub_host}:{self.hub_port}/hub") as ws: + async for ws in HubConnection( + f"ws://{self.hub_host}:{self.hub_port}/hub" + ): try: - close = await self._run_node(ws, state_getter) - if close: - ... - except ConnectionClosed: - ... - except BaseException: - ... + self._cancel_tasks() + self._heartbeat_task = loop.create_task( + self._heartbeat(ws, state_getter) + ) + self._command_task = loop.create_task(self._command(ws)) + while self._run: + await sleep(1) + except WebSocketException: + logger.debug("Connection to hub dropped") + finally: + if not self._run: + break finally: - print("Node out") - - def _setup_ws_client(self, hub_host: str, hub_port: int) -> connection: - return connect(f"ws://{hub_host}:{hub_port}/hub") + self._cancel_tasks() + logger.debug("Node client shutting down") - async def _run_node(self, ws: connection, state_getter) -> None: + async def _heartbeat(self, ws: connection, state_getter) -> None: while self._run: - await ws.send(str(state_getter())) + await ws.send(dump_json(state_getter())) await sleep(3) - return True + + async def _command(self, ws: connection) -> None: + while self._run: + message = await ws.recv() + logger.info("Node received message: %s", message) + + def _cancel_tasks(self) -> None: + if self._heartbeat_task: + self._heartbeat_task.cancel() + self._heartbeat_task = None + if self._command_task: + self._command_task.cancel() + self._command_task = None def close(self, *args): self._run = False + self._cancel_tasks() class Inspector: @@ -148,7 +245,6 @@ def _detect_modes( hub_host: str, hub_port: int, ) -> Tuple[bool, bool]: - print(hub_mode, host, port, hub_host, hub_port) if hub_host == host and hub_port == port: if not hub_mode: raise ValueError( @@ -206,6 +302,11 @@ async def _respond(self, request: Request, output: Any): def _state_to_json(self) -> Dict[str, Any]: output = {"info": self.app_info} output["workers"] = self._make_safe(dict(self.worker_state)) + if self.hub_mode: + output["nodes"] = { + ident: self._make_safe(asdict(node)) + for ident, node in self.app.ctx.hub_state.nodes.items() + } return output @staticmethod @@ -253,21 +354,29 @@ def shutdown(self) -> None: self._publisher.send(message) def _setup_hub(self, app: Sanic) -> None: + logger.info( + f"Sanic Inspector running in hub mode on {self.host}:{self.port}" + ) app.ctx.hub_state = HubState(nodes={}) - @staticmethod async def _hub( + self, request: Request, websocket: WebsocketImplProtocol, ) -> None: hub_state = request.app.ctx.hub_state - hub_state.nodes[request.id] = NodeState() + ident = self._generate_ident() + hub_state.nodes[ident] = NodeState({}, {}) while True: message = await websocket.recv() if message == "ping": await websocket.send("pong") + elif not message: + break else: - logger.info("Hub received message: %s", message) + raw = load_json(message) + node_state = NodeState(**raw) + hub_state.nodes[ident] = node_state async def _run_node(self) -> None: client = NodeClient(self.hub_host, self.hub_port) @@ -278,4 +387,11 @@ def signal_close(*args, **kwargs): signal(SIGTERM, signal_close) signal(SIGINT, signal_close) + logger.info( + f"Sanic Inspector running in node mode on {self.host}:{self.port}" + ) await client.run(self._state_to_json) + + def _generate_ident(self, length: int = 8) -> str: + base = ascii_lowercase + ascii_uppercase + return "".join(random.choices(base, k=length))