diff --git a/platformio/home/rpc/server.py b/platformio/home/rpc/server.py index 942961643d..2437e40ec0 100644 --- a/platformio/home/rpc/server.py +++ b/platformio/home/rpc/server.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from urllib.parse import parse_qs + import click -from ajsonrpc.core import JSONRPC20Error +from ajsonrpc.core import JSONRPC20Error, JSONRPC20Request from ajsonrpc.dispatcher import Dispatcher from ajsonrpc.manager import AsyncJSONRPCResponseManager, JSONRPC20Response from starlette.endpoints import WebSocketEndpoint @@ -32,6 +34,7 @@ def __init__(self, shutdown_timeout=0): self.manager = AsyncJSONRPCResponseManager( Dispatcher(), is_server_error_verbose=True ) + self._clients = {} def __call__(self, *args, **kwargs): raise NotImplementedError @@ -40,13 +43,16 @@ def add_object_handler(self, handler, namespace): handler.factory = self self.manager.dispatcher.add_object(handler, prefix="%s." % namespace) - def on_client_connect(self): + def on_client_connect(self, connection, actor=None): + self._clients[connection] = {"actor": actor} self.connection_nums += 1 if self.shutdown_timer: self.shutdown_timer.cancel() self.shutdown_timer = None - def on_client_disconnect(self): + def on_client_disconnect(self, connection): + if connection in self._clients: + del self._clients[connection] self.connection_nums -= 1 if self.connection_nums < 1: self.connection_nums = 0 @@ -69,6 +75,14 @@ def _auto_shutdown_server(): self.shutdown_timeout, _auto_shutdown_server ) + async def notify_clients(self, method, params=None, actor=None): + for client, options in self._clients.items(): + if actor and options["actor"] != actor: + continue + request = JSONRPC20Request(method, params, is_notification=True) + await client.send_text(self.manager.serialize(request.body)) + return True + class WebSocketJSONRPCServerFactory(JSONRPCServerFactoryBase): def __call__(self, *args, **kwargs): @@ -83,13 +97,17 @@ class WebSocketJSONRPCServer(WebSocketEndpoint): async def on_connect(self, websocket): await websocket.accept() - self.factory.on_client_connect() # pylint: disable=no-member + qs = parse_qs(self.scope.get("query_string", b"")) + actors = qs.get(b"actor") + self.factory.on_client_connect( # pylint: disable=no-member + websocket, actor=actors[0].decode() if actors else None + ) async def on_receive(self, websocket, data): aio_create_task(self._handle_rpc(websocket, data)) async def on_disconnect(self, websocket, close_code): - self.factory.on_client_disconnect() # pylint: disable=no-member + self.factory.on_client_disconnect(websocket) # pylint: disable=no-member async def _handle_rpc(self, websocket, data): # pylint: disable=no-member