diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index e4fe5cdd..fd7a9619 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -20,6 +20,7 @@ USA """ +import asyncio import ipaddress import random from functools import lru_cache @@ -37,10 +38,14 @@ from .._logger import log from .._protocol.outgoing import DNSOutgoing from .._updates import RecordUpdate, RecordUpdateListener -from .._utils.asyncio import get_running_loop, run_coro_with_timeout +from .._utils.asyncio import ( + get_running_loop, + run_coro_with_timeout, + wait_event_or_timeout, +) from .._utils.name import service_type_name from .._utils.net import IPVersion, _encode_address -from .._utils.time import current_time_millis +from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( _CLASS_IN, _CLASS_UNIQUE, @@ -166,6 +171,7 @@ def __init__( self.host_ttl = host_ttl self.other_ttl = other_ttl self.interface_index = interface_index + self._notify_event: Optional[asyncio.Event] = None @property def name(self) -> str: @@ -221,6 +227,12 @@ def properties(self) -> Dict: """ return self._properties + async def async_wait(self, timeout: float) -> None: + """Calling task waits for a given number of milliseconds or until notified.""" + if self._notify_event is None: + self._notify_event = asyncio.Event() + await wait_event_or_timeout(self._notify_event, timeout=millis_to_seconds(timeout)) + def addresses_by_version(self, version: IPVersion) -> List[bytes]: """List addresses matching IP version. @@ -384,7 +396,9 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU This method will be run in the event loop. """ - self._process_records_threadsafe(zc, now, records) + if self._process_records_threadsafe(zc, now, records) and self._notify_event: + self._notify_event.set() + self._notify_event.clear() def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool: """Thread safe record updating. @@ -605,7 +619,7 @@ async def async_request( delay *= 2 next_ += random.randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL) - await zc.async_wait(min(next_, last) - now) + await self.async_wait(min(next_, last) - now) now = current_time_millis() finally: zc.async_remove_listener(self)