Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bellows/ash.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,13 @@ def connection_lost(self, exc: Exception | None) -> None:
self._cancel_pending_data_frames()
self._ezsp_protocol.connection_lost(exc)

def eof_received(self):
def eof_received(self) -> bool:
self._ezsp_protocol.eof_received()
# Return True to prevent the transport from auto-closing. For
# serial-over-TCP connections (ser2net, ESPHome stream_server, etc.)
# the remote end may signal EOF during initialization without
# intending to close, and an auto-close orphans the connection.
return True

def _cancel_pending_data_frames(
self, exc: BaseException = RuntimeError("Connection has been closed")
Expand Down
3 changes: 1 addition & 2 deletions bellows/cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def convert(self, value, param, ctx):
def background(f):
@functools.wraps(f)
def inner(*args, **kwargs):
loop = asyncio.get_event_loop()
loop.run_until_complete(f(*args, **kwargs))
asyncio.run(f(*args, **kwargs))

return inner

Expand Down
20 changes: 18 additions & 2 deletions bellows/ezsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def is_tcp_serial_port(self) -> bool:

async def _startup_reset(self) -> None:
"""Start EZSP and reset the stack."""
if self._gw is None:
raise EzspError("Gateway is not connected")

# `zigbeed` resets on startup
if self.is_tcp_serial_port:
try:
Expand Down Expand Up @@ -220,8 +223,21 @@ async def get_xncp_features(self) -> xncp.FirmwareFeatures:

async def disconnect(self):
self.stop_ezsp()
if self._gw:
await self._gw.disconnect()
if self._gw is not None:
try:
await self._gw.disconnect()
except ConnectionError:
# The secondary event loop is dead. Force-close the
# underlying TCP socket so ser2net (or similar) releases
# the serial port for subsequent connection attempts.
try:
ash = self._gw._obj._transport
if ash is not None and ash._transport is not None:
sock = ash._transport.get_extra_info("socket")
if sock is not None:
sock.close()
except Exception:
pass
self._gw = None

async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any:
Expand Down
25 changes: 18 additions & 7 deletions bellows/thread.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
import functools
import inspect
import logging

LOGGER = logging.getLogger(__name__)
Expand All @@ -14,7 +15,7 @@ def __init__(self):
self.thread_complete = None

def run_coroutine_threadsafe(self, coroutine):
current_loop = asyncio.get_event_loop()
current_loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe(coroutine, self.loop)
return asyncio.wrap_future(future, loop=current_loop)

Expand All @@ -30,7 +31,7 @@ def _thread_main(self, init_task):
self.loop = None

async def start(self):
current_loop = asyncio.get_event_loop()
current_loop = asyncio.get_running_loop()
if self.loop is not None and not self.loop.is_closed():
return

Expand Down Expand Up @@ -95,11 +96,21 @@ def func_wrapper(*args, **kwargs):
if loop == curr_loop:
return call()
if loop.is_closed():
# Disconnected
LOGGER.warning("Attempted to use a closed event loop")
return
if asyncio.iscoroutinefunction(func):
future = asyncio.run_coroutine_threadsafe(call(), loop)
raise ConnectionError(
"Attempted to use a closed event loop, "
"the connection may have been lost"
)
if inspect.iscoroutinefunction(func):
coro = call()
try:
future = asyncio.run_coroutine_threadsafe(coro, loop)
except RuntimeError:
# Loop closed between is_closed() check and dispatch
coro.close()
raise ConnectionError(
"Attempted to use a closed event loop, "
"the connection may have been lost"
)
return asyncio.wrap_future(future, loop=curr_loop)
else:

Expand Down
15 changes: 10 additions & 5 deletions bellows/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def error_received(self, code: t.NcpResetCode) -> None:

async def wait_for_startup_reset(self) -> None:
"""Wait for the first reset frame on startup."""
assert self._startup_reset_future is None
self._startup_reset_future = asyncio.get_running_loop().create_future()
if self._startup_reset_future is None:
self._startup_reset_future = asyncio.get_running_loop().create_future()

try:
await self._startup_reset_future
Expand Down Expand Up @@ -98,21 +98,26 @@ async def reset(self):
return await self._reset_future

self._transport.send_reset()
self._reset_future = asyncio.get_event_loop().create_future()
self._reset_future = asyncio.get_running_loop().create_future()
self._reset_future.add_done_callback(self._reset_cleanup)

async with asyncio_timeout(RESET_TIMEOUT):
return await self._reset_future


async def _connect(config, api):
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()

connection_done_future = loop.create_future()

gateway = Gateway(api, connection_done_future)
protocol = AshProtocol(gateway)

# Pre-create the startup reset future before opening the connection so that
# reset frames arriving immediately after connect are captured by
# reset_received() instead of triggering enter_failed_state().
gateway._startup_reset_future = loop.create_future()

if config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] is None:
xon_xoff, rtscts = True, False
else:
Expand All @@ -135,7 +140,7 @@ async def _connect(config, api):

async def connect(config, api, use_thread=True):
if use_thread:
api = ThreadsafeProxy(api, asyncio.get_event_loop())
api = ThreadsafeProxy(api, asyncio.get_running_loop())
thread = EventLoopThread()
await thread.start()
try:
Expand Down
72 changes: 72 additions & 0 deletions tests/test_ezsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,78 @@ async def wait_forever(*args, **kwargs):
assert version_mock.await_count == 1


async def test_startup_reset_gw_none():
"""Test _startup_reset raises EzspError when gateway is None."""
ezsp = make_ezsp(
config={
**DEVICE_CONFIG,
zigpy.config.CONF_DEVICE_PATH: "socket://localhost:1234",
}
)
ezsp._gw = None

with pytest.raises(EzspError, match="Gateway is not connected"):
await ezsp._startup_reset()


async def test_disconnect_gw_none():
"""Test disconnect doesn't raise when gateway is already None."""
ezsp = make_ezsp()
ezsp._gw = None

await ezsp.disconnect() # Should not raise

assert ezsp._gw is None


async def test_disconnect_force_closes_socket_on_connection_error():
"""If the gateway's `disconnect()` raises ConnectionError (the secondary
event loop is dead), force-close the underlying TCP socket so ser2net
or similar serial-over-TCP bridges release the port for subsequent
connection attempts."""
ezsp = make_ezsp()

mock_socket = MagicMock()

mock_asyncio_transport = MagicMock()
mock_asyncio_transport.get_extra_info.return_value = mock_socket

mock_ash_transport = MagicMock()
mock_ash_transport._transport = mock_asyncio_transport

mock_obj = MagicMock()
mock_obj._transport = mock_ash_transport

mock_gw = MagicMock()
mock_gw._obj = mock_obj
mock_gw.disconnect = AsyncMock(side_effect=ConnectionError("loop closed"))
ezsp._gw = mock_gw

await ezsp.disconnect()

mock_asyncio_transport.get_extra_info.assert_called_once_with("socket")
mock_socket.close.assert_called_once()
assert ezsp._gw is None


async def test_disconnect_socket_force_close_swallows_exceptions():
"""When force-closing the underlying TCP socket after a ConnectionError
from `_gw.disconnect()`, any AttributeError or other exception walking
the proxy/transport chain must be swallowed so the integration can
still mark the gateway as None and proceed to retry."""
ezsp = make_ezsp()

# _obj has no `_transport` attribute, so the inner access raises.
mock_gw = MagicMock()
mock_gw._obj = object()
mock_gw.disconnect = AsyncMock(side_effect=ConnectionError("loop closed"))
ezsp._gw = mock_gw

await ezsp.disconnect() # Should not raise

assert ezsp._gw is None


async def test_wait_for_stack_status(ezsp_f):
assert not ezsp_f._stack_status_listeners[t.sl_Status.NETWORK_DOWN]

Expand Down
27 changes: 26 additions & 1 deletion tests/test_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,35 @@ async def test_proxy_loop_closed():
obj = mock.MagicMock()
proxy = ThreadsafeProxy(obj, loop)
loop.close()
proxy.test()
with pytest.raises(ConnectionError, match="closed event loop"):
proxy.test()
assert obj.test.call_count == 0


async def test_proxy_coroutine_loop_closed_mid_dispatch():
"""If the loop closes between the `is_closed()` check and
`run_coroutine_threadsafe()`, the proxy must close the orphaned
coroutine and surface the failure as ConnectionError instead of
leaking an un-awaited coroutine warning."""
loop = asyncio.new_event_loop()

async def fake_coro(): # pragma: no cover - never awaited
return None

obj = mock.MagicMock()
obj.test = fake_coro
proxy = ThreadsafeProxy(obj, loop)

with mock.patch(
"asyncio.run_coroutine_threadsafe",
side_effect=RuntimeError("loop closed"),
):
with pytest.raises(ConnectionError, match="closed event loop"):
proxy.test()

loop.close()


async def test_thread_task_cancellation_after_stop(thread):
loop = asyncio.get_event_loop()
obj = mock.MagicMock()
Expand Down
Loading