Skip to content

Commit

Permalink
feat: improve performance of loading records from cache in ServiceInfo (
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Sep 24, 2023
1 parent 248b506 commit 6257d49
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 31 deletions.
4 changes: 2 additions & 2 deletions src/zeroconf/_dns.pxd
Expand Up @@ -106,8 +106,8 @@ cdef class DNSService(DNSRecord):
cdef public object priority
cdef public object weight
cdef public object port
cdef public object server
cdef public object server_key
cdef public str server
cdef public str server_key

cdef _eq(self, DNSService other)

Expand Down
25 changes: 13 additions & 12 deletions src/zeroconf/_services/info.pxd
Expand Up @@ -2,13 +2,14 @@
import cython

from .._cache cimport DNSCache
from .._dns cimport DNSNsec, DNSPointer, DNSRecord, DNSService, DNSText
from .._dns cimport DNSAddress, DNSNsec, DNSPointer, DNSRecord, DNSService, DNSText
from .._protocol.outgoing cimport DNSOutgoing
from .._updates cimport RecordUpdateListener
from .._utils.time cimport current_time_millis


cdef object _resolve_all_futures_to_none
cdef object _cached_ip_addresses_wrapper

cdef object _TYPE_SRV
cdef object _TYPE_TXT
Expand Down Expand Up @@ -55,29 +56,29 @@ cdef class ServiceInfo(RecordUpdateListener):
cdef public cython.list _dns_address_cache
cdef public cython.set _get_address_and_nsec_records_cache

@cython.locals(
cache=DNSCache,
)
@cython.locals(cache=DNSCache)
cpdef async_update_records(self, object zc, cython.float now, cython.list records)

@cython.locals(
cache=DNSCache
)
cpdef _load_from_cache(self, object zc, object now)
@cython.locals(cache=DNSCache)
cpdef _load_from_cache(self, object zc, cython.float now)

cdef _unpack_text_into_properties(self)

cdef _set_properties(self, cython.dict properties)

cdef _set_text(self, cython.bytes text)

cdef _get_ip_addresses_from_cache_lifo(self, object zc, object now, object type)

cdef _process_record_threadsafe(self, object zc, DNSRecord record, cython.float now)
@cython.locals(record=DNSAddress)
cdef _get_ip_addresses_from_cache_lifo(self, object zc, cython.float now, object type)

@cython.locals(
cache=DNSCache
dns_service_record=DNSService,
dns_text_record=DNSText,
dns_address_record=DNSAddress
)
cdef _process_record_threadsafe(self, object zc, DNSRecord record, cython.float now)

@cython.locals(cache=DNSCache)
cdef cython.list _get_address_records_from_cache_by_type(self, object zc, object _type)

cdef _set_ipv4_addresses_from_cache(self, object zc, object now)
Expand Down
46 changes: 29 additions & 17 deletions src/zeroconf/_services/info.py
Expand Up @@ -107,6 +107,9 @@ def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4
return None


_cached_ip_addresses_wrapper = _cached_ip_addresses


class ServiceInfo(RecordUpdateListener):
"""Service information.
Expand Down Expand Up @@ -197,7 +200,7 @@ def __init__(
self.host_ttl = host_ttl
self.other_ttl = other_ttl
self.interface_index = interface_index
self._new_records_futures: Set[asyncio.Future] = set()
self._new_records_futures: Optional[Set[asyncio.Future]] = None
self._dns_address_cache: Optional[List[DNSAddress]] = None
self._dns_pointer_cache: Optional[DNSPointer] = None
self._dns_service_cache: Optional[DNSService] = None
Expand Down Expand Up @@ -240,7 +243,7 @@ def addresses(self, value: List[bytes]) -> None:
self._get_address_and_nsec_records_cache = None

for address in value:
addr = _cached_ip_addresses(address)
addr = _cached_ip_addresses_wrapper(address)
if addr is None:
raise TypeError(
"Addresses must either be IPv4 or IPv6 strings, bytes, or integers;"
Expand Down Expand Up @@ -272,6 +275,8 @@ def properties(self) -> Dict[Union[str, bytes], Optional[Union[str, bytes]]]:

async def async_wait(self, timeout: float, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
if not self._new_records_futures:
self._new_records_futures = set()
await wait_for_future_set_or_timeout(
loop or asyncio.get_running_loop(), self._new_records_futures, timeout
)
Expand Down Expand Up @@ -409,7 +414,7 @@ def _get_ip_addresses_from_cache_lifo(
for record in self._get_address_records_from_cache_by_type(zc, type):
if record.is_expired(now):
continue
ip_addr = _cached_ip_addresses(record.address)
ip_addr = _cached_ip_addresses_wrapper(record.address)
if ip_addr is not None:
address_list.append(ip_addr)
address_list.reverse() # Reverse to get LIFO order
Expand Down Expand Up @@ -455,12 +460,17 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo

record_key = record.key
record_type = type(record)
if record_key == self.server_key and record_type is DNSAddress:
if record_type is DNSAddress and record_key == self.server_key:
dns_address_record = record
if TYPE_CHECKING:
assert isinstance(record, DNSAddress)
ip_addr = _cached_ip_addresses(record.address)
assert isinstance(dns_address_record, DNSAddress)
ip_addr = _cached_ip_addresses_wrapper(dns_address_record.address)
if ip_addr is None:
log.warning("Encountered invalid address while processing %s: %s", record, record.address)
log.warning(
"Encountered invalid address while processing %s: %s",
dns_address_record,
dns_address_record.address,
)
return False

if ip_addr.version == 4:
Expand Down Expand Up @@ -492,22 +502,24 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo
return False

if record_type is DNSText:
dns_text_record = record
if TYPE_CHECKING:
assert isinstance(record, DNSText)
self._set_text(record.text)
assert isinstance(dns_text_record, DNSText)
self._set_text(dns_text_record.text)
return True

if record_type is DNSService:
dns_service_record = record
if TYPE_CHECKING:
assert isinstance(record, DNSService)
assert isinstance(dns_service_record, DNSService)
old_server_key = self.server_key
self._name = record.name
self.key = record.key
self.server = record.server
self.server_key = record.server_key
self.port = record.port
self.weight = record.weight
self.priority = record.priority
self._name = dns_service_record.name
self.key = dns_service_record.key
self.server = dns_service_record.server
self.server_key = dns_service_record.server_key
self.port = dns_service_record.port
self.weight = dns_service_record.weight
self.priority = dns_service_record.priority
if old_server_key != self.server_key:
self._set_ipv4_addresses_from_cache(zc, now)
self._set_ipv6_addresses_from_cache(zc, now)
Expand Down

0 comments on commit 6257d49

Please sign in to comment.