diff --git a/serialx/platforms/serial_esphome.py b/serialx/platforms/serial_esphome.py index 316bf31..24f4d61 100644 --- a/serialx/platforms/serial_esphome.py +++ b/serialx/platforms/serial_esphome.py @@ -81,6 +81,7 @@ def __init__( self._read_buffer = bytearray() self._read_event = asyncio.Event() self._unsub: Callable[[], None] | None = None + self._instance_subscribed = False def _on_data(self, msg: aioesphomeapi.SerialProxyDataReceived) -> None: if msg.instance == self.instance: @@ -91,6 +92,7 @@ def open(self) -> None: """Open the serial port.""" asyncio.run(self._async_open()) assert self.api is not None + self._subscribe_instance() self._unsub = self.api.subscribe_serial_proxy_data(self._on_data) async def _async_open(self) -> None: @@ -102,6 +104,27 @@ async def _async_open(self) -> None: ) await self.api.connect(login=True) + def _subscribe_instance(self) -> None: + """Subscribe serial proxy streaming for this instance if supported.""" + if self.api is None or self._instance_subscribed: + return + subscribe = getattr(self.api, "serial_proxy_subscribe", None) + if subscribe is None: + return + subscribe(self.instance) + self._instance_subscribed = True + + def _unsubscribe_instance(self) -> None: + """Unsubscribe serial proxy streaming for this instance if supported.""" + if self.api is None or not self._instance_subscribed: + return + unsubscribe = getattr(self.api, "serial_proxy_unsubscribe", None) + if unsubscribe is None: + self._instance_subscribed = False + return + unsubscribe(self.instance) + self._instance_subscribed = False + def configure_port(self) -> None: """Configure the serial port settings.""" assert self.api is not None @@ -170,6 +193,7 @@ def close(self) -> None: self._unsub = None if self.api is not None: + self._unsubscribe_instance() asyncio.run(self.api.disconnect()) self.api = None @@ -213,6 +237,7 @@ async def _connect( # type: ignore[override] self._serial.configure_port() assert self._serial.api is not None + self._serial._subscribe_instance() self._unsub = self._serial.api.subscribe_serial_proxy_data(self._on_data) self._protocol.connection_made(self) @@ -240,6 +265,7 @@ def close(self) -> None: self._unsub = None if self._serial is not None and self._serial.api is not None: + self._serial._unsubscribe_instance() api = self._serial.api self._serial.api = None self._loop.create_task(self._async_close(api)) diff --git a/tests/test_serial_esphome.py b/tests/test_serial_esphome.py new file mode 100644 index 0000000..e02f317 --- /dev/null +++ b/tests/test_serial_esphome.py @@ -0,0 +1,136 @@ +"""Tests for ESPHome serial transport behavior.""" + +from __future__ import annotations + +import asyncio + +import pytest + +pytest.importorskip("aioesphomeapi") + +from serialx.platforms import serial_esphome + + +class _DummyAPIClient: + """Simple API client test double.""" + + def __init__(self, *_args, **_kwargs) -> None: + self.connected = False + self.disconnected = False + self.serial_proxy_configure_calls: list[dict] = [] + self.serial_proxy_subscribe_calls: list[int] = [] + self.serial_proxy_unsubscribe_calls: list[int] = [] + self.stream_subscribed = False + self.stream_unsubscribed = False + + async def connect(self, *, login: bool = True) -> None: + self.connected = login + + async def disconnect(self) -> None: + self.disconnected = True + + def serial_proxy_configure(self, **kwargs) -> None: + self.serial_proxy_configure_calls.append(kwargs) + + def serial_proxy_subscribe(self, instance: int) -> None: + self.serial_proxy_subscribe_calls.append(instance) + + def serial_proxy_unsubscribe(self, instance: int) -> None: + self.serial_proxy_unsubscribe_calls.append(instance) + + def subscribe_serial_proxy_data(self, _callback): + self.stream_subscribed = True + + def _unsub() -> None: + self.stream_unsubscribed = True + + return _unsub + + def serial_proxy_write(self, *, instance: int, data: bytes) -> None: + _ = (instance, data) + + async def serial_proxy_flush(self, *, instance: int) -> None: + _ = instance + + +class _DummyAPIClientNoInstanceSubscription: + """API client without instance subscribe APIs (older aioesphomeapi).""" + + def __init__(self, *_args, **_kwargs) -> None: + self.connected = False + self.disconnected = False + self.stream_subscribed = False + self.stream_unsubscribed = False + + async def connect(self, *, login: bool = True) -> None: + self.connected = login + + async def disconnect(self) -> None: + self.disconnected = True + + def serial_proxy_configure(self, **_kwargs) -> None: + return + + def subscribe_serial_proxy_data(self, _callback): + self.stream_subscribed = True + + def _unsub() -> None: + self.stream_unsubscribed = True + + return _unsub + + def serial_proxy_write(self, *, instance: int, data: bytes) -> None: + _ = (instance, data) + + async def serial_proxy_flush(self, *, instance: int) -> None: + _ = instance + + +@pytest.mark.asyncio +async def test_transport_subscribes_instance_and_unsubscribes_on_close( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Ensure instance subscription is managed with the transport lifecycle.""" + api = _DummyAPIClient() + monkeypatch.setattr(serial_esphome.aioesphomeapi, "APIClient", lambda *_a, **_k: api) + + loop = asyncio.get_running_loop() + protocol = asyncio.Protocol() + transport = serial_esphome.ESPHomeSerialTransport(loop=loop, protocol=protocol) + + await transport.connect(url="esphome://example-host/1", baudrate=9600) + + assert api.connected + assert api.stream_subscribed + assert api.serial_proxy_subscribe_calls == [1] + + transport.close() + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert api.stream_unsubscribed + assert api.serial_proxy_unsubscribe_calls == [1] + assert api.disconnected + + +@pytest.mark.asyncio +async def test_transport_works_without_instance_subscribe_api( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Keep compatibility when aioesphomeapi lacks instance subscribe methods.""" + api = _DummyAPIClientNoInstanceSubscription() + monkeypatch.setattr(serial_esphome.aioesphomeapi, "APIClient", lambda *_a, **_k: api) + + loop = asyncio.get_running_loop() + protocol = asyncio.Protocol() + transport = serial_esphome.ESPHomeSerialTransport(loop=loop, protocol=protocol) + + await transport.connect(url="esphome://example-host/2", baudrate=9600) + transport.close() + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert api.connected + assert api.stream_subscribed + assert api.stream_unsubscribed + assert api.disconnected