Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from zha.zigbee.device import (
ClusterBinding,
DeviceFirmwareInfoUpdatedEvent,
DeviceStatus,
ZHAEvent,
get_device_automation_triggers,
)
Expand Down Expand Up @@ -613,6 +614,59 @@ async def test_issue_cluster_command(

assert cluster.request.await_count == 1

cluster.request.reset_mock()

# Issue being validated:
# issue_cluster_command() accepts a manufacturer argument but does not forward
# it to the underlying cluster command invocation.
#
# Why this is a problem:
# manufacturer-specific commands can be encoded incorrectly (or treated as
# non-manufacturer-specific), leading to silent command failures on devices
# that require manufacturer framing.
await zha_device.issue_cluster_command(
3,
general.OnOff.cluster_id,
general.OnOff.ServerCommandDefs.on.id,
CLUSTER_COMMAND_SERVER,
None,
{},
manufacturer=0x1234,
)
assert cluster.request.await_count == 1
assert cluster.request.await_args.kwargs["manufacturer"] == 0x1234


async def test_issue_cluster_command_args_path_forwards_manufacturer(
zha_gateway: Gateway,
) -> None:
"""Test issue_cluster_command forwards manufacturer when using args."""
zigpy_dev = zigpy_device(zha_gateway, with_basic_cluster_handler=True)
zha_device = await join_zigpy_device(zha_gateway, zigpy_dev)
cluster = zigpy_dev.endpoints[3].on_off

with patch("zigpy.zcl.Cluster.request", return_value=[0x5, Status.SUCCESS]):
# Issue being validated:
# The deprecated `args` code path in issue_cluster_command() can bypass
# manufacturer forwarding even when a manufacturer code is provided.
#
# Why this is a problem:
# Integrations that still use the args path for compatibility may emit
# manufacturer-specific commands without manufacturer framing, causing
# hard-to-diagnose command failures on affected devices.
await zha_device.issue_cluster_command(
3,
general.OnOff.cluster_id,
general.OnOff.ServerCommandDefs.on.id,
CLUSTER_COMMAND_SERVER,
[],
None,
manufacturer=0x1234,
)

assert cluster.request.await_count == 1
assert cluster.request.await_args.kwargs["manufacturer"] == 0x1234


async def test_async_add_to_group_remove_from_group(
zha_gateway: Gateway,
Expand Down Expand Up @@ -1358,3 +1412,123 @@ async def test_device_on_remove_pending_entity_failure(

assert "Failed to remove pending entity" in caplog.text
assert "Pending entity removal failed" in caplog.text


async def test_async_initialize_does_not_grow_pending_entities_between_passes(
zha_gateway: Gateway,
) -> None:
"""Test repeated initialize passes do not accumulate pending entities."""
zigpy_dev = zigpy_device(zha_gateway, with_basic_cluster_handler=True)
zha_device = await join_zigpy_device(zha_gateway, zigpy_dev)
initial_pending_count = len(zha_device._pending_entities)

# Issue being validated:
# repeated async_initialize() calls append entities into _pending_entities
# without clearing completed entries from prior passes.
#
# Why this is a problem:
# pending state should represent only the current discovery pass; growth across
# passes leaks lifecycle state and causes unnecessary entity churn over time.
await zha_device.async_initialize(from_cache=False)

assert len(zha_device._pending_entities) == initial_pending_count


async def test_async_initialize_does_not_mark_initialized_if_endpoint_init_fails(
zha_gateway: Gateway,
) -> None:
"""Test endpoint init failure prevents initialized status."""
zigpy_dev = zigpy_device(zha_gateway, with_basic_cluster_handler=True)
zha_device = zha_gateway.get_or_create_device(zigpy_dev)

try:
assert zha_device.status is DeviceStatus.CREATED

endpoint = next(iter(zha_device.endpoints.values()))
with patch.object(
endpoint,
"async_initialize",
side_effect=RuntimeError("endpoint init failed"),
):
# Issue being validated:
# endpoint initialization exceptions are swallowed during async_initialize(),
# then status is still moved to INITIALIZED.
#
# Why this is a problem:
# gateway rejoin logic keys off INITIALIZED status; a false-positive status
# transition can skip full initialization despite endpoint failure.
await zha_device.async_initialize(from_cache=False)

assert zha_device.status is DeviceStatus.CREATED
finally:
await zha_device.on_remove()


async def test_async_initialize_logs_stale_pending_entity_cleanup_failure(
zha_gateway: Gateway,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test stale pending-entity cleanup failures are logged and tolerated."""
zigpy_dev = zigpy_device(zha_gateway, with_basic_cluster_handler=True)
zha_device = await join_zigpy_device(zha_gateway, zigpy_dev)

stale_pending_entity = mock.Mock()
stale_pending_entity.on_remove = mock.AsyncMock(
side_effect=RuntimeError("stale pending cleanup failed")
)
zha_device._pending_entities.append(stale_pending_entity)

# Issue being validated:
# async_initialize() first drains stale pending entities from a previous pass.
# If stale entity cleanup raises, the failure must be logged instead of
# aborting initialization.
#
# Why this is a problem:
# A single stale entity teardown error can otherwise prevent all future
# initialization work for the device, leaving discovery and entity lifecycle
# in a partially initialized state.
await zha_device.async_initialize(from_cache=False)

assert "Failed to remove stale pending entity" in caplog.text
assert "stale pending cleanup failed" in caplog.text


async def test_platform_entity_on_remove_callback_failure_does_not_abort_cleanup(
zha_gateway: Gateway,
) -> None:
"""Test entity on_remove callback failures do not prevent task cleanup."""
zigpy_dev = zigpy_device(zha_gateway, with_basic_cluster_handler=True)
zha_device = await join_zigpy_device(zha_gateway, zigpy_dev)
entity = get_entity(zha_device, platform=Platform.SWITCH)

blocked: asyncio.Future[None] = asyncio.get_running_loop().create_future()

async def _blocked_task() -> None:
await blocked

tracked_task: asyncio.Task[None] = asyncio.create_task(_blocked_task())
entity._tracked_tasks.append(tracked_task)

def failing_on_remove_callback() -> None:
raise RuntimeError("entity callback failure")

entity._on_remove_callbacks.append(failing_on_remove_callback)

# Issue being validated:
# BaseEntity.on_remove() executes callbacks without per-callback exception handling.
# A single callback failure aborts the rest of teardown immediately.
#
# Why this is a problem:
# entity-owned tasks/handles may remain active after partial teardown, leaking
# background work and causing unpredictable behavior during remove/shutdown flows.
try:
await entity.on_remove()

assert tracked_task.cancelled()
assert tracked_task not in entity._tracked_tasks
finally:
if not tracked_task.done():
tracked_task.cancel()
if tracked_task in entity._tracked_tasks:
entity._tracked_tasks.remove(tracked_task)
await asyncio.gather(tracked_task, return_exceptions=True)
11 changes: 10 additions & 1 deletion zha/application/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,14 @@ async def on_remove(self) -> None:
while self._on_remove_callbacks:
callback = self._on_remove_callbacks.pop()
self.debug("Running remove callback: %s", callback)
callback()
try:
callback()
except Exception: # pylint: disable=broad-exception-caught
self.warning(
"Failed to execute on_remove callback %s",
callback,
exc_info=True,
)

for handle in self._tracked_handles:
self.debug("Cancelling handle: %s", handle)
Expand All @@ -416,6 +423,8 @@ async def on_remove(self) -> None:
for task in tasks:
self.debug("Cancelling task: %s", task)
task.cancel()
with suppress(ValueError):
self._tracked_tasks.remove(task)
with suppress(asyncio.CancelledError):
await asyncio.gather(*tasks, return_exceptions=True)

Expand Down
45 changes: 43 additions & 2 deletions zha/zigbee/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,19 @@ async def async_initialize(self, from_cache: bool = False) -> None:
"""Initialize cluster handlers."""
self.debug("started initialization")

# Clean up entities from any previous unfinished discovery pass before
# replacing the pending list for this initialization cycle.
while self._pending_entities:
pending_entity = self._pending_entities.pop()
try:
await pending_entity.on_remove()
except Exception: # pylint: disable=broad-exception-caught
_LOGGER.warning(
"Failed to remove stale pending entity %s for device %s",
pending_entity,
self,
exc_info=True,
)
self._discover_new_entities()

await self._zdo_handler.async_initialize(from_cache)
Expand All @@ -989,10 +1002,12 @@ async def async_initialize(self, from_cache: bool = False) -> None:
# three `device.async_initialize()`s are spawned, only three concurrent requests
# will ever be in flight at once. Startup concurrency is managed at the device
# level.
endpoint_init_failed = False
for endpoint in self._endpoints.values():
try:
await endpoint.async_initialize(from_cache)
except Exception: # pylint: disable=broad-exception-caught
endpoint_init_failed = True
self.debug("Failed to initialize endpoint", exc_info=True)

# Compute the final entities
Expand All @@ -1015,12 +1030,23 @@ async def async_initialize(self, from_cache: bool = False) -> None:
await entity.on_remove()
continue

# Keep existing entity instances stable across re-initialization
# passes. Newly rediscovered duplicates must be cleaned up.
if key in self._platform_entities:
await entity.on_remove()
continue

new_entities[key] = entity

if new_entities:
_LOGGER.debug("Discovered new entities %r", new_entities)
self._platform_entities.update(new_entities)

# Discovery for this initialization pass has been fully reconciled.
# Keep _pending_entities transient so the next pass only contains
# entities staged by async_configure or a fresh discovery cycle.
self._pending_entities.clear()

# At this point we can compute a primary entity
self._compute_primary_entity()

Expand All @@ -1045,6 +1071,12 @@ def entity_update_listener(event: EntityStateChangedEvent) -> None:
break

self.debug("power source: %s", self.power_source)
if endpoint_init_failed:
self.debug(
"completed initialization with endpoint failures; status unchanged"
)
return

self.status = DeviceStatus.INITIALIZED
self.debug("completed initialization")

Expand Down Expand Up @@ -1217,11 +1249,20 @@ async def issue_cluster_command(
args,
[field.name for field in commands[command].schema.fields],
)
response = await getattr(cluster, commands[command].name)(*args)
command_kwargs: dict[str, Any] = {}
if manufacturer is not None:
command_kwargs["manufacturer"] = manufacturer
response = await getattr(cluster, commands[command].name)(
*args, **command_kwargs
)
else:
assert params is not None
command_kwargs = {}
if manufacturer is not None:
command_kwargs["manufacturer"] = manufacturer
response = await getattr(cluster, commands[command].name)(
**convert_to_zcl_values(params, commands[command].schema)
**convert_to_zcl_values(params, commands[command].schema),
**command_kwargs,
)
self.debug(
"Issued cluster command: %s %s %s %s %s %s %s %s",
Expand Down
Loading