diff --git a/bellows/ash.py b/bellows/ash.py index 48b492ca..bc98423b 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -377,8 +377,13 @@ def connection_lost(self, exc: Exception | None) -> None: self._cancel_pending_data_frames() self._ezsp_protocol.connection_lost(exc) - def eof_received(self): + def eof_received(self) -> bool: self._ezsp_protocol.eof_received() + # Return True to prevent the transport from auto-closing. For + # serial-over-TCP connections (ser2net, ESPHome stream_server, etc.) + # the remote end may signal EOF during initialization without + # intending to close, and an auto-close orphans the connection. + return True def _cancel_pending_data_frames( self, exc: BaseException = RuntimeError("Connection has been closed") diff --git a/bellows/cli/util.py b/bellows/cli/util.py index 76f83511..ebf6b755 100644 --- a/bellows/cli/util.py +++ b/bellows/cli/util.py @@ -35,8 +35,7 @@ def convert(self, value, param, ctx): def background(f): @functools.wraps(f) def inner(*args, **kwargs): - loop = asyncio.get_event_loop() - loop.run_until_complete(f(*args, **kwargs)) + asyncio.run(f(*args, **kwargs)) return inner diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index 29554c93..828a458c 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -117,6 +117,9 @@ def is_tcp_serial_port(self) -> bool: async def _startup_reset(self) -> None: """Start EZSP and reset the stack.""" + if self._gw is None: + raise EzspError("Gateway is not connected") + # `zigbeed` resets on startup if self.is_tcp_serial_port: try: @@ -220,8 +223,21 @@ async def get_xncp_features(self) -> xncp.FirmwareFeatures: async def disconnect(self): self.stop_ezsp() - if self._gw: - await self._gw.disconnect() + if self._gw is not None: + try: + await self._gw.disconnect() + except ConnectionError: + # The secondary event loop is dead. Force-close the + # underlying TCP socket so ser2net (or similar) releases + # the serial port for subsequent connection attempts. + try: + ash = self._gw._obj._transport + if ash is not None and ash._transport is not None: + sock = ash._transport.get_extra_info("socket") + if sock is not None: + sock.close() + except Exception: + pass self._gw = None async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any: diff --git a/bellows/thread.py b/bellows/thread.py index 4311768d..270402f6 100644 --- a/bellows/thread.py +++ b/bellows/thread.py @@ -1,6 +1,7 @@ import asyncio from concurrent.futures import ThreadPoolExecutor import functools +import inspect import logging LOGGER = logging.getLogger(__name__) @@ -14,7 +15,7 @@ def __init__(self): self.thread_complete = None def run_coroutine_threadsafe(self, coroutine): - current_loop = asyncio.get_event_loop() + current_loop = asyncio.get_running_loop() future = asyncio.run_coroutine_threadsafe(coroutine, self.loop) return asyncio.wrap_future(future, loop=current_loop) @@ -30,7 +31,7 @@ def _thread_main(self, init_task): self.loop = None async def start(self): - current_loop = asyncio.get_event_loop() + current_loop = asyncio.get_running_loop() if self.loop is not None and not self.loop.is_closed(): return @@ -95,11 +96,21 @@ def func_wrapper(*args, **kwargs): if loop == curr_loop: return call() if loop.is_closed(): - # Disconnected - LOGGER.warning("Attempted to use a closed event loop") - return - if asyncio.iscoroutinefunction(func): - future = asyncio.run_coroutine_threadsafe(call(), loop) + raise ConnectionError( + "Attempted to use a closed event loop, " + "the connection may have been lost" + ) + if inspect.iscoroutinefunction(func): + coro = call() + try: + future = asyncio.run_coroutine_threadsafe(coro, loop) + except RuntimeError: + # Loop closed between is_closed() check and dispatch + coro.close() + raise ConnectionError( + "Attempted to use a closed event loop, " + "the connection may have been lost" + ) return asyncio.wrap_future(future, loop=curr_loop) else: diff --git a/bellows/uart.py b/bellows/uart.py index af274dc8..f38b4155 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -52,8 +52,8 @@ def error_received(self, code: t.NcpResetCode) -> None: async def wait_for_startup_reset(self) -> None: """Wait for the first reset frame on startup.""" - assert self._startup_reset_future is None - self._startup_reset_future = asyncio.get_running_loop().create_future() + if self._startup_reset_future is None: + self._startup_reset_future = asyncio.get_running_loop().create_future() try: await self._startup_reset_future @@ -98,7 +98,7 @@ async def reset(self): return await self._reset_future self._transport.send_reset() - self._reset_future = asyncio.get_event_loop().create_future() + self._reset_future = asyncio.get_running_loop().create_future() self._reset_future.add_done_callback(self._reset_cleanup) async with asyncio_timeout(RESET_TIMEOUT): @@ -106,13 +106,18 @@ async def reset(self): async def _connect(config, api): - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() connection_done_future = loop.create_future() gateway = Gateway(api, connection_done_future) protocol = AshProtocol(gateway) + # Pre-create the startup reset future before opening the connection so that + # reset frames arriving immediately after connect are captured by + # reset_received() instead of triggering enter_failed_state(). + gateway._startup_reset_future = loop.create_future() + if config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] is None: xon_xoff, rtscts = True, False else: @@ -135,7 +140,7 @@ async def _connect(config, api): async def connect(config, api, use_thread=True): if use_thread: - api = ThreadsafeProxy(api, asyncio.get_event_loop()) + api = ThreadsafeProxy(api, asyncio.get_running_loop()) thread = EventLoopThread() await thread.start() try: diff --git a/tests/test_ezsp.py b/tests/test_ezsp.py index d548309a..45975943 100644 --- a/tests/test_ezsp.py +++ b/tests/test_ezsp.py @@ -789,6 +789,78 @@ async def wait_forever(*args, **kwargs): assert version_mock.await_count == 1 +async def test_startup_reset_gw_none(): + """Test _startup_reset raises EzspError when gateway is None.""" + ezsp = make_ezsp( + config={ + **DEVICE_CONFIG, + zigpy.config.CONF_DEVICE_PATH: "socket://localhost:1234", + } + ) + ezsp._gw = None + + with pytest.raises(EzspError, match="Gateway is not connected"): + await ezsp._startup_reset() + + +async def test_disconnect_gw_none(): + """Test disconnect doesn't raise when gateway is already None.""" + ezsp = make_ezsp() + ezsp._gw = None + + await ezsp.disconnect() # Should not raise + + assert ezsp._gw is None + + +async def test_disconnect_force_closes_socket_on_connection_error(): + """If the gateway's `disconnect()` raises ConnectionError (the secondary + event loop is dead), force-close the underlying TCP socket so ser2net + or similar serial-over-TCP bridges release the port for subsequent + connection attempts.""" + ezsp = make_ezsp() + + mock_socket = MagicMock() + + mock_asyncio_transport = MagicMock() + mock_asyncio_transport.get_extra_info.return_value = mock_socket + + mock_ash_transport = MagicMock() + mock_ash_transport._transport = mock_asyncio_transport + + mock_obj = MagicMock() + mock_obj._transport = mock_ash_transport + + mock_gw = MagicMock() + mock_gw._obj = mock_obj + mock_gw.disconnect = AsyncMock(side_effect=ConnectionError("loop closed")) + ezsp._gw = mock_gw + + await ezsp.disconnect() + + mock_asyncio_transport.get_extra_info.assert_called_once_with("socket") + mock_socket.close.assert_called_once() + assert ezsp._gw is None + + +async def test_disconnect_socket_force_close_swallows_exceptions(): + """When force-closing the underlying TCP socket after a ConnectionError + from `_gw.disconnect()`, any AttributeError or other exception walking + the proxy/transport chain must be swallowed so the integration can + still mark the gateway as None and proceed to retry.""" + ezsp = make_ezsp() + + # _obj has no `_transport` attribute, so the inner access raises. + mock_gw = MagicMock() + mock_gw._obj = object() + mock_gw.disconnect = AsyncMock(side_effect=ConnectionError("loop closed")) + ezsp._gw = mock_gw + + await ezsp.disconnect() # Should not raise + + assert ezsp._gw is None + + async def test_wait_for_stack_status(ezsp_f): assert not ezsp_f._stack_status_listeners[t.sl_Status.NETWORK_DOWN] diff --git a/tests/test_thread.py b/tests/test_thread.py index 72efa701..056e96ff 100644 --- a/tests/test_thread.py +++ b/tests/test_thread.py @@ -157,10 +157,35 @@ async def test_proxy_loop_closed(): obj = mock.MagicMock() proxy = ThreadsafeProxy(obj, loop) loop.close() - proxy.test() + with pytest.raises(ConnectionError, match="closed event loop"): + proxy.test() assert obj.test.call_count == 0 +async def test_proxy_coroutine_loop_closed_mid_dispatch(): + """If the loop closes between the `is_closed()` check and + `run_coroutine_threadsafe()`, the proxy must close the orphaned + coroutine and surface the failure as ConnectionError instead of + leaking an un-awaited coroutine warning.""" + loop = asyncio.new_event_loop() + + async def fake_coro(): # pragma: no cover - never awaited + return None + + obj = mock.MagicMock() + obj.test = fake_coro + proxy = ThreadsafeProxy(obj, loop) + + with mock.patch( + "asyncio.run_coroutine_threadsafe", + side_effect=RuntimeError("loop closed"), + ): + with pytest.raises(ConnectionError, match="closed event loop"): + proxy.test() + + loop.close() + + async def test_thread_task_cancellation_after_stop(thread): loop = asyncio.get_event_loop() obj = mock.MagicMock()