Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions serialx/platforms/serial_esphome.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
self._read_buffer = bytearray()
self._read_event = asyncio.Event()
self._unsub: Callable[[], None] | None = None
self._instance_subscribed = False

def _on_data(self, msg: aioesphomeapi.SerialProxyDataReceived) -> None:
if msg.instance == self.instance:
Expand All @@ -91,6 +92,7 @@ def open(self) -> None:
"""Open the serial port."""
asyncio.run(self._async_open())
assert self.api is not None
self._subscribe_instance()
self._unsub = self.api.subscribe_serial_proxy_data(self._on_data)

async def _async_open(self) -> None:
Expand All @@ -102,6 +104,27 @@ async def _async_open(self) -> None:
)
await self.api.connect(login=True)

def _subscribe_instance(self) -> None:
"""Subscribe serial proxy streaming for this instance if supported."""
if self.api is None or self._instance_subscribed:
return
subscribe = getattr(self.api, "serial_proxy_subscribe", None)
if subscribe is None:
return
subscribe(self.instance)
self._instance_subscribed = True

def _unsubscribe_instance(self) -> None:
"""Unsubscribe serial proxy streaming for this instance if supported."""
if self.api is None or not self._instance_subscribed:
return
unsubscribe = getattr(self.api, "serial_proxy_unsubscribe", None)
if unsubscribe is None:
self._instance_subscribed = False
return
unsubscribe(self.instance)
self._instance_subscribed = False

def configure_port(self) -> None:
"""Configure the serial port settings."""
assert self.api is not None
Expand Down Expand Up @@ -170,6 +193,7 @@ def close(self) -> None:
self._unsub = None

if self.api is not None:
self._unsubscribe_instance()
asyncio.run(self.api.disconnect())
self.api = None

Expand Down Expand Up @@ -213,6 +237,7 @@ async def _connect( # type: ignore[override]
self._serial.configure_port()

assert self._serial.api is not None
self._serial._subscribe_instance()
self._unsub = self._serial.api.subscribe_serial_proxy_data(self._on_data)

self._protocol.connection_made(self)
Expand Down Expand Up @@ -240,6 +265,7 @@ def close(self) -> None:
self._unsub = None

if self._serial is not None and self._serial.api is not None:
self._serial._unsubscribe_instance()
api = self._serial.api
self._serial.api = None
self._loop.create_task(self._async_close(api))
Expand Down
136 changes: 136 additions & 0 deletions tests/test_serial_esphome.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Tests for ESPHome serial transport behavior."""

from __future__ import annotations

import asyncio

import pytest

pytest.importorskip("aioesphomeapi")

from serialx.platforms import serial_esphome


class _DummyAPIClient:
"""Simple API client test double."""

def __init__(self, *_args, **_kwargs) -> None:
self.connected = False
self.disconnected = False
self.serial_proxy_configure_calls: list[dict] = []
self.serial_proxy_subscribe_calls: list[int] = []
self.serial_proxy_unsubscribe_calls: list[int] = []
self.stream_subscribed = False
self.stream_unsubscribed = False

async def connect(self, *, login: bool = True) -> None:
self.connected = login

async def disconnect(self) -> None:
self.disconnected = True

def serial_proxy_configure(self, **kwargs) -> None:
self.serial_proxy_configure_calls.append(kwargs)

def serial_proxy_subscribe(self, instance: int) -> None:
self.serial_proxy_subscribe_calls.append(instance)

def serial_proxy_unsubscribe(self, instance: int) -> None:
self.serial_proxy_unsubscribe_calls.append(instance)

def subscribe_serial_proxy_data(self, _callback):
self.stream_subscribed = True

def _unsub() -> None:
self.stream_unsubscribed = True

return _unsub

def serial_proxy_write(self, *, instance: int, data: bytes) -> None:
_ = (instance, data)

async def serial_proxy_flush(self, *, instance: int) -> None:
_ = instance


class _DummyAPIClientNoInstanceSubscription:
"""API client without instance subscribe APIs (older aioesphomeapi)."""

def __init__(self, *_args, **_kwargs) -> None:
self.connected = False
self.disconnected = False
self.stream_subscribed = False
self.stream_unsubscribed = False

async def connect(self, *, login: bool = True) -> None:
self.connected = login

async def disconnect(self) -> None:
self.disconnected = True

def serial_proxy_configure(self, **_kwargs) -> None:
return

def subscribe_serial_proxy_data(self, _callback):
self.stream_subscribed = True

def _unsub() -> None:
self.stream_unsubscribed = True

return _unsub

def serial_proxy_write(self, *, instance: int, data: bytes) -> None:
_ = (instance, data)

async def serial_proxy_flush(self, *, instance: int) -> None:
_ = instance


@pytest.mark.asyncio
async def test_transport_subscribes_instance_and_unsubscribes_on_close(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Ensure instance subscription is managed with the transport lifecycle."""
api = _DummyAPIClient()
monkeypatch.setattr(serial_esphome.aioesphomeapi, "APIClient", lambda *_a, **_k: api)

loop = asyncio.get_running_loop()
protocol = asyncio.Protocol()
transport = serial_esphome.ESPHomeSerialTransport(loop=loop, protocol=protocol)

await transport.connect(url="esphome://example-host/1", baudrate=9600)

assert api.connected
assert api.stream_subscribed
assert api.serial_proxy_subscribe_calls == [1]

transport.close()
await asyncio.sleep(0)
await asyncio.sleep(0)

assert api.stream_unsubscribed
assert api.serial_proxy_unsubscribe_calls == [1]
assert api.disconnected


@pytest.mark.asyncio
async def test_transport_works_without_instance_subscribe_api(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Keep compatibility when aioesphomeapi lacks instance subscribe methods."""
api = _DummyAPIClientNoInstanceSubscription()
monkeypatch.setattr(serial_esphome.aioesphomeapi, "APIClient", lambda *_a, **_k: api)

loop = asyncio.get_running_loop()
protocol = asyncio.Protocol()
transport = serial_esphome.ESPHomeSerialTransport(loop=loop, protocol=protocol)

await transport.connect(url="esphome://example-host/2", baudrate=9600)
transport.close()
await asyncio.sleep(0)
await asyncio.sleep(0)

assert api.connected
assert api.stream_subscribed
assert api.stream_unsubscribed
assert api.disconnected
Loading