Skip to content

Commit

Permalink
feat: refactor notify implementation to reduce overhead of adding and…
Browse files Browse the repository at this point in the history
… removing listeners (#1224)
  • Loading branch information
bdraco committed Aug 14, 2023
1 parent 0e96220 commit ceb92cf
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 31 deletions.
19 changes: 10 additions & 9 deletions src/zeroconf/_core.py
Expand Up @@ -25,7 +25,7 @@
import sys
import threading
from types import TracebackType
from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union
from typing import Awaitable, Dict, List, Optional, Set, Tuple, Type, Union

from ._cache import DNSCache
from ._dns import DNSQuestion, DNSQuestionType
Expand All @@ -49,11 +49,13 @@
from ._transport import _WrappedTransport
from ._updates import RecordUpdateListener
from ._utils.asyncio import (
_resolve_all_futures_to_none,
await_awaitable,
get_running_loop,
run_coro_with_timeout,
shutdown_loop,
wait_event_or_timeout,
wait_for_future_set_or_timeout,
)
from ._utils.name import service_type_name
from ._utils.net import (
Expand Down Expand Up @@ -188,7 +190,7 @@ def __init__(
self.query_handler = QueryHandler(self.registry, self.cache, self.question_history)
self.record_manager = RecordManager(self)

self.notify_event: Optional[asyncio.Event] = None
self._notify_futures: Set[asyncio.Future] = set()
self.loop: Optional[asyncio.AbstractEventLoop] = None
self._loop_thread: Optional[threading.Thread] = None

Expand All @@ -206,7 +208,6 @@ def start(self) -> None:
"""Start Zeroconf."""
self.loop = get_running_loop()
if self.loop:
self.notify_event = asyncio.Event()
self.engine.setup(self.loop, None)
return
self._start_thread()
Expand All @@ -218,7 +219,6 @@ def _start_thread(self) -> None:
def _run_loop() -> None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.notify_event = asyncio.Event()
self.engine.setup(self.loop, loop_thread_ready)
self.loop.run_forever()

Expand All @@ -245,8 +245,9 @@ def listeners(self) -> List[RecordUpdateListener]:

async def async_wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
assert self.notify_event is not None
await wait_event_or_timeout(self.notify_event, timeout=millis_to_seconds(timeout))
loop = self.loop
assert loop is not None
await wait_for_future_set_or_timeout(loop, self._notify_futures, timeout)

def notify_all(self) -> None:
"""Notifies all waiting threads and notify listeners."""
Expand All @@ -255,9 +256,9 @@ def notify_all(self) -> None:

def async_notify_all(self) -> None:
"""Schedule an async_notify_all."""
assert self.notify_event is not None
self.notify_event.set()
self.notify_event.clear()
notify_futures = self._notify_futures
if notify_futures:
_resolve_all_futures_to_none(notify_futures)

def get_service_info(
self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None
Expand Down
36 changes: 14 additions & 22 deletions src/zeroconf/_services/info.py
Expand Up @@ -39,10 +39,15 @@
from .._logger import log
from .._protocol.outgoing import DNSOutgoing
from .._updates import RecordUpdate, RecordUpdateListener
from .._utils.asyncio import get_running_loop, run_coro_with_timeout
from .._utils.asyncio import (
_resolve_all_futures_to_none,
get_running_loop,
run_coro_with_timeout,
wait_for_future_set_or_timeout,
)
from .._utils.name import service_type_name
from .._utils.net import IPVersion, _encode_address
from .._utils.time import current_time_millis, millis_to_seconds
from .._utils.time import current_time_millis
from ..const import (
_ADDRESS_RECORD_TYPES,
_CLASS_IN,
Expand Down Expand Up @@ -89,12 +94,6 @@ def instance_name_from_service_info(info: "ServiceInfo", strict: bool = True) ->
_cached_ip_addresses = lru_cache(maxsize=256)(ip_address)


def _set_future_none_if_not_done(fut: asyncio.Future) -> None:
"""Set a future to None if it is not done."""
if not fut.done(): # pragma: no branch
fut.set_result(None)


class ServiceInfo(RecordUpdateListener):
"""Service information.
Expand Down Expand Up @@ -180,7 +179,7 @@ def __init__(
self.host_ttl = host_ttl
self.other_ttl = other_ttl
self.interface_index = interface_index
self._new_records_futures: List[asyncio.Future] = []
self._new_records_futures: Set[asyncio.Future] = set()

@property
def name(self) -> str:
Expand Down Expand Up @@ -242,14 +241,9 @@ def properties(self) -> Dict[Union[str, bytes], Optional[Union[str, bytes]]]:

async def async_wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
loop = asyncio.get_running_loop()
future = loop.create_future()
self._new_records_futures.append(future)
handle = loop.call_later(millis_to_seconds(timeout), _set_future_none_if_not_done, future)
try:
await future
finally:
handle.cancel()
loop = get_running_loop()
assert loop is not None
await wait_for_future_set_or_timeout(loop, self._new_records_futures, timeout)

def addresses_by_version(self, version: IPVersion) -> List[bytes]:
"""List addresses matching IP version.
Expand Down Expand Up @@ -441,11 +435,9 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
This method will be run in the event loop.
"""
if self._process_records_threadsafe(zc, now, records) and self._new_records_futures:
for future in self._new_records_futures:
if not future.done():
future.set_result(None)
self._new_records_futures.clear()
new_records_futures = self._new_records_futures
if self._process_records_threadsafe(zc, now, records) and new_records_futures:
_resolve_all_futures_to_none(new_records_futures)

def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:
"""Thread safe record updating.
Expand Down
27 changes: 27 additions & 0 deletions src/zeroconf/_utils/asyncio.py
Expand Up @@ -41,6 +41,33 @@
_WAIT_FOR_LOOP_TASKS_TIMEOUT = 3 # Must be larger than _TASK_AWAIT_TIMEOUT


def _set_future_none_if_not_done(fut: asyncio.Future) -> None:
"""Set a future to None if it is not done."""
if not fut.done(): # pragma: no branch
fut.set_result(None)


def _resolve_all_futures_to_none(futures: Set[asyncio.Future]) -> None:
"""Resolve all futures to None."""
for fut in futures:
_set_future_none_if_not_done(fut)
futures.clear()


async def wait_for_future_set_or_timeout(
loop: asyncio.AbstractEventLoop, future_set: Set[asyncio.Future], timeout: float
) -> None:
"""Wait for a future or timeout (in milliseconds)."""
future = loop.create_future()
future_set.add(future)
handle = loop.call_later(millis_to_seconds(timeout), _set_future_none_if_not_done, future)
try:
await future
finally:
handle.cancel()
future_set.discard(future)


async def wait_event_or_timeout(event: asyncio.Event, timeout: float) -> None:
"""Wait for an event or timeout."""
with contextlib.suppress(asyncio.TimeoutError):
Expand Down

0 comments on commit ceb92cf

Please sign in to comment.