Skip to content

Commit

Permalink
feat: optimize processing of records in RecordUpdateListener
Browse files Browse the repository at this point in the history
These classes used to process a single record at a time, but since
we now process them all in batches, we can remove all the breakout
functions
  • Loading branch information
bdraco committed Aug 22, 2023
1 parent 0c5e5cf commit 85d6c2b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 57 deletions.
70 changes: 33 additions & 37 deletions src/zeroconf/_services/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
cast,
)

from .._dns import DNSPointer, DNSQuestion, DNSQuestionType, DNSRecord
from .._dns import DNSPointer, DNSQuestion, DNSQuestionType
from .._logger import log
from .._protocol.outgoing import DNSOutgoing
from .._services import (
Expand Down Expand Up @@ -383,50 +383,46 @@ def _enqueue_callback(
):
self._pending_handlers[key] = state_change

def _async_process_record_update(
self, now: float, record: DNSRecord, old_record: Optional[DNSRecord]
) -> None:
"""Process a single record update from a batch of updates."""
record_type = record.type

if record_type is _TYPE_PTR:
if TYPE_CHECKING:
record = cast(DNSPointer, record)
for type_ in self.types.intersection(cached_possible_types(record.name)):
if old_record is None:
self._enqueue_callback(ServiceStateChange.Added, type_, record.alias)
elif record.is_expired(now):
self._enqueue_callback(ServiceStateChange.Removed, type_, record.alias)
else:
self.reschedule_type(type_, now, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT))
return

# If its expired or already exists in the cache it cannot be updated.
if old_record or record.is_expired(now):
return

if record_type in _ADDRESS_RECORD_TYPES:
# Iterate through the DNSCache and callback any services that use this address
for type_, name in self._names_matching_types(
{service.name for service in self.zc.cache.async_entries_with_server(record.name)}
):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)
return

for type_, name in self._names_matching_types((record.name,)):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)

def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
"""Callback invoked by Zeroconf when new information arrives.
Updates information required by browser in the Zeroconf cache.
Ensures that there is are no unecessary duplicates in the list.
Ensures that there is are no unnecessary duplicates in the list.
This method will be run in the event loop.
"""
for record in records:
self._async_process_record_update(now, record[0], record[1])
for record_update in records:
record, old_record = record_update
record_type = record.type

if record_type is _TYPE_PTR:
if TYPE_CHECKING:
record = cast(DNSPointer, record)
for type_ in self.types.intersection(cached_possible_types(record.name)):
if old_record is None:
self._enqueue_callback(ServiceStateChange.Added, type_, record.alias)
elif record.is_expired(now):
self._enqueue_callback(ServiceStateChange.Removed, type_, record.alias)
else:
expire_time = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)
self.reschedule_type(type_, now, expire_time)
continue

# If its expired or already exists in the cache it cannot be updated.
if old_record or record.is_expired(now):
continue

if record_type in _ADDRESS_RECORD_TYPES:
# Iterate through the DNSCache and callback any services that use this address
for type_, name in self._names_matching_types(
{service.name for service in self.zc.cache.async_entries_with_server(record.name)}
):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)
continue

for type_, name in self._names_matching_types((record.name,)):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)

@abstractmethod
def async_update_records_complete(self) -> None:
Expand Down
22 changes: 2 additions & 20 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,35 +410,17 @@ def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
else:
self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)

def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None:
"""Updates service information from a DNS record.
This method is deprecated and will be removed in a future version.
update_records should be implemented instead.
This method will be run in the event loop.
"""
if record is not None:
self._process_record_threadsafe(zc, record, now)

def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
"""Updates service information from a DNS record.
This method will be run in the event loop.
"""
new_records_futures = self._new_records_futures
if self._process_records_threadsafe(zc, now, records) and new_records_futures:
_resolve_all_futures_to_none(new_records_futures)

def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:
"""Thread safe record updating.
Returns True if new records were added.
"""
updated: bool = False
for record_update in records:
updated |= self._process_record_threadsafe(zc, record_update.new, now)
return updated
if updated and new_records_futures:
_resolve_all_futures_to_none(new_records_futures)

def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float) -> bool:
"""Thread safe record updating.
Expand Down

0 comments on commit 85d6c2b

Please sign in to comment.