Skip to content

Commit

Permalink
fix: improve performance of ServiceInfo.async_request (#1205)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Aug 2, 2023
1 parent d92aad2 commit 8019a73
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions src/zeroconf/_services/info.py
Expand Up @@ -39,11 +39,7 @@
from .._logger import log
from .._protocol.outgoing import DNSOutgoing
from .._updates import RecordUpdate, RecordUpdateListener
from .._utils.asyncio import (
get_running_loop,
run_coro_with_timeout,
wait_event_or_timeout,
)
from .._utils.asyncio import get_running_loop, run_coro_with_timeout
from .._utils.name import service_type_name
from .._utils.net import IPVersion, _encode_address
from .._utils.time import current_time_millis, millis_to_seconds
Expand Down Expand Up @@ -131,6 +127,7 @@ class ServiceInfo(RecordUpdateListener):
"host_ttl",
"other_ttl",
"interface_index",
"_new_records_futures",
)

def __init__(
Expand Down Expand Up @@ -177,7 +174,7 @@ def __init__(
self.host_ttl = host_ttl
self.other_ttl = other_ttl
self.interface_index = interface_index
self._notify_event: Optional[asyncio.Event] = None
self._new_records_futures: List[asyncio.Future] = []

@property
def name(self) -> str:
Expand Down Expand Up @@ -235,9 +232,14 @@ def properties(self) -> Dict:

async def async_wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
if self._notify_event is None:
self._notify_event = asyncio.Event()
await wait_event_or_timeout(self._notify_event, timeout=millis_to_seconds(timeout))
loop = asyncio.get_running_loop()
future = loop.create_future()
self._new_records_futures.append(future)
handle = loop.call_later(millis_to_seconds(timeout), future.set_result, None)
try:
await future
finally:
handle.cancel()

def addresses_by_version(self, version: IPVersion) -> List[bytes]:
"""List addresses matching IP version.
Expand Down Expand Up @@ -409,9 +411,11 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
This method will be run in the event loop.
"""
if self._process_records_threadsafe(zc, now, records) and self._notify_event:
self._notify_event.set()
self._notify_event.clear()
if self._process_records_threadsafe(zc, now, records) and self._new_records_futures:
for future in self._new_records_futures:
if not future.done():
future.set_result(None)
self._new_records_futures.clear()

def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:
"""Thread safe record updating.
Expand Down Expand Up @@ -591,12 +595,13 @@ def set_server_if_missing(self) -> None:
self.server = self.name
self.server_key = self.server.lower()

def load_from_cache(self, zc: 'Zeroconf') -> bool:
def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool:
"""Populate the service info from the cache.
This method is designed to be threadsafe.
"""
now = current_time_millis()
if not now:
now = current_time_millis()
original_server_key = self.server_key
cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN)
if cached_srv_record:
Expand Down Expand Up @@ -664,11 +669,13 @@ async def async_request(
"""
if not zc.started:
await zc.async_wait_for_start()
if self.load_from_cache(zc):

now = current_time_millis()

if self.load_from_cache(zc, now):
return True

first_request = True
now = current_time_millis()
delay = _LISTENER_TIME
next_ = now
last = now + timeout
Expand All @@ -683,7 +690,7 @@ async def async_request(
)
first_request = False
if not out.questions:
return self.load_from_cache(zc)
return self.load_from_cache(zc, now)
zc.async_send(out, addr, port)
next_ = now + delay
delay *= 2
Expand Down

0 comments on commit 8019a73

Please sign in to comment.