diff --git a/tests/test_zigbee_util.py b/tests/test_zigbee_util.py index 1600ba00e..2dd197f00 100644 --- a/tests/test_zigbee_util.py +++ b/tests/test_zigbee_util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import logging import sys @@ -467,3 +469,16 @@ def test_singleton(): obj = {} obj[singleton] = 5 assert obj[singleton] == 5 + + +@pytest.mark.parametrize( + "input_relays, expected_relays", + [ + ([0x0000, 0x0000, 0x0001, 0x0001, 0x0002], [0x0001, 0x0002]), + ([0x0001, 0x0002], [0x0001, 0x0002]), + ([], []), + ([0x0000], []), + ], +) +def test_relay_filtering(input_relays: list[int], expected_relays: list[int]): + assert util.filter_relays(input_relays) == expected_relays diff --git a/zigpy/appdb.py b/zigpy/appdb.py index ffe59a125..02d8a26e3 100644 --- a/zigpy/appdb.py +++ b/zigpy/appdb.py @@ -745,7 +745,8 @@ async def _load_relays(self) -> None: async with self.execute(f"SELECT * FROM relays{DB_V}") as cursor: async for (ieee, value) in cursor: dev = self._application.get_device(ieee) - dev.relays, _ = t.Relays.deserialize(value) + relays, _ = t.Relays.deserialize(value) + dev.relays = zigpy.util.filter_relays(relays) async def _load_neighbors(self) -> None: async with self.execute(f"SELECT * FROM neighbors{DB_V}") as cursor: diff --git a/zigpy/application.py b/zigpy/application.py index 71c4cfd2b..098200905 100644 --- a/zigpy/application.py +++ b/zigpy/application.py @@ -598,8 +598,7 @@ def handle_relays(self, nwk: t.NWK, relays: list[t.NWK]) -> None: f"discover_unknown_device_from_relays-nwk={nwk!r}", ) else: - # `relays` is a property with a setter that emits an event - device.relays = relays + device.relays = zigpy.util.filter_relays(relays) @classmethod async def probe(cls, device_config: dict[str, Any]) -> bool | dict[str, Any]: diff --git a/zigpy/util.py b/zigpy/util.py index c9817fecc..59ca6654d 100644 --- a/zigpy/util.py +++ b/zigpy/util.py @@ -450,3 +450,15 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash(self.name) + + +def filter_relays(relays: list[int]) -> list[int]: + """Filter out invalid relays.""" + filtered_relays = [] + + # BUG: relays sometimes include 0x0000 or duplicate entries + for relay in relays: + if relay != 0x0000 and relay not in filtered_relays: + filtered_relays.append(relay) + + return filtered_relays