Skip to content

Commit

Permalink
feat: speed up ServiceInfo with a cython pxd (#1264)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Sep 10, 2023
1 parent 6bf5d95 commit 7ca690a
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 27 deletions.
1 change: 1 addition & 0 deletions build_ext.py
Expand Up @@ -32,6 +32,7 @@ def build(setup_kwargs: Any) -> None:
"src/zeroconf/_handlers/answers.py",
"src/zeroconf/_handlers/record_manager.py",
"src/zeroconf/_handlers/query_handler.py",
"src/zeroconf/_services/info.py",
"src/zeroconf/_services/registry.py",
"src/zeroconf/_updates.py",
"src/zeroconf/_utils/time.py",
Expand Down
87 changes: 87 additions & 0 deletions src/zeroconf/_services/info.pxd
@@ -0,0 +1,87 @@

import cython

from .._cache cimport DNSCache
from .._dns cimport DNSPointer, DNSRecord, DNSService, DNSText
from .._protocol.outgoing cimport DNSOutgoing
from .._updates cimport RecordUpdateListener
from .._utils.time cimport current_time_millis


cdef object _resolve_all_futures_to_none

cdef object _TYPE_SRV
cdef object _TYPE_TXT
cdef object _TYPE_A
cdef object _TYPE_AAAA
cdef object _TYPE_PTR
cdef object _TYPE_NSEC
cdef object _CLASS_IN
cdef object _FLAGS_QR_QUERY

cdef object service_type_name

cdef object DNS_QUESTION_TYPE_QU
cdef object DNS_QUESTION_TYPE_QM

cdef object _IPVersion_All_value
cdef object _IPVersion_V4Only_value

cdef object TYPE_CHECKING

cdef class ServiceInfo(RecordUpdateListener):

cdef public cython.bytes text
cdef public str type
cdef str _name
cdef public str key
cdef public cython.list _ipv4_addresses
cdef public cython.list _ipv6_addresses
cdef public object port
cdef public object weight
cdef public object priority
cdef public str server
cdef public str server_key
cdef public cython.dict _properties
cdef public object host_ttl
cdef public object other_ttl
cdef public object interface_index
cdef public cython.set _new_records_futures
cdef public DNSPointer _dns_pointer_cache
cdef public DNSService _dns_service_cache
cdef public DNSText _dns_text_cache
cdef public cython.list _dns_address_cache
cdef public cython.set _get_address_and_nsec_records_cache

@cython.locals(
cache=DNSCache
)
cpdef async_update_records(self, object zc, object now, cython.list records)

@cython.locals(
cache=DNSCache
)
cpdef _load_from_cache(self, object zc, object now)

cdef _unpack_text_into_properties(self)

cdef _set_properties(self, cython.dict properties)

cdef _set_text(self, cython.bytes text)

cdef _get_ip_addresses_from_cache_lifo(self, object zc, object now, object type)

cdef _process_record_threadsafe(self, object zc, DNSRecord record, object now)

@cython.locals(
cache=DNSCache
)
cdef cython.list _get_address_records_from_cache_by_type(self, object zc, object _type)

cdef _set_ipv4_addresses_from_cache(self, object zc, object now)

cdef _set_ipv6_addresses_from_cache(self, object zc, object now)

cdef cython.list _ip_addresses_by_version_value(self, object version_value)

cdef addresses_by_version(self, object version)
67 changes: 40 additions & 27 deletions src/zeroconf/_services/info.py
Expand Up @@ -78,6 +78,12 @@
# the A/AAAA/SRV records for a host.
_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)

float_ = float
int_ = int

DNS_QUESTION_TYPE_QU = DNSQuestionType.QU
DNS_QUESTION_TYPE_QM = DNSQuestionType.QM

if TYPE_CHECKING:
from .._core import Zeroconf

Expand Down Expand Up @@ -281,10 +287,9 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]:
"""
version_value = version.value
if version_value == _IPVersion_All_value:
return [
*(addr.packed for addr in self._ipv4_addresses),
*(addr.packed for addr in self._ipv6_addresses),
]
ip_v4_packed = [addr.packed for addr in self._ipv4_addresses]
ip_v6_packed = [addr.packed for addr in self._ipv6_addresses]
return [*ip_v4_packed, *ip_v6_packed]
if version_value == _IPVersion_V4Only_value:
return [addr.packed for addr in self._ipv4_addresses]
return [addr.packed for addr in self._ipv6_addresses]
Expand All @@ -303,7 +308,7 @@ def ip_addresses_by_version(
return self._ip_addresses_by_version_value(version.value)

def _ip_addresses_by_version_value(
self, version_value: int
self, version_value: int_
) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]:
"""Backend for addresses_by_version that uses the raw value."""
if version_value == _IPVersion_All_value:
Expand Down Expand Up @@ -397,7 +402,7 @@ def get_name(self) -> str:
return self._name[: len(self._name) - len(self.type) - 1]

def _get_ip_addresses_from_cache_lifo(
self, zc: 'Zeroconf', now: float, type: int
self, zc: 'Zeroconf', now: float_, type: int_
) -> List[Union[IPv4Address, IPv6Address]]:
"""Set IPv6 addresses from the cache."""
address_list: List[Union[IPv4Address, IPv6Address]] = []
Expand All @@ -410,7 +415,7 @@ def _get_ip_addresses_from_cache_lifo(
address_list.reverse() # Reverse to get LIFO order
return address_list

def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None:
"""Set IPv6 addresses from the cache."""
if TYPE_CHECKING:
self._ipv6_addresses = cast(
Expand All @@ -419,7 +424,7 @@ def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
else:
self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)

def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None:
"""Set IPv4 addresses from the cache."""
if TYPE_CHECKING:
self._ipv4_addresses = cast(
Expand All @@ -428,7 +433,7 @@ def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
else:
self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)

def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[RecordUpdate]) -> None:
"""Updates service information from a DNS record.
This method will be run in the event loop.
Expand All @@ -440,7 +445,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
if updated and new_records_futures:
_resolve_all_futures_to_none(new_records_futures)

def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float) -> bool:
def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float_) -> bool:
"""Thread safe record updating.
Returns True if a new record was added.
Expand Down Expand Up @@ -624,14 +629,15 @@ def get_address_and_nsec_records(self, override_ttl: Optional[int] = None) -> Se
self._get_address_and_nsec_records_cache = records
return records

def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]:
def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int_) -> List[DNSAddress]:
"""Get the addresses from the cache."""
if self.server_key is None:
return []
cache = zc.cache
if TYPE_CHECKING:
records = cast("List[DNSAddress]", zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN))
records = cast("List[DNSAddress]", cache.get_all_by_details(self.server_key, _type, _CLASS_IN))
else:
records = zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN)
records = cache.get_all_by_details(self.server_key, _type, _CLASS_IN)
return records

def set_server_if_missing(self) -> None:
Expand All @@ -643,28 +649,33 @@ def set_server_if_missing(self) -> None:
self.server = self._name
self.server_key = self.key

def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool:
def load_from_cache(self, zc: 'Zeroconf', now: Optional[float_] = None) -> bool:
"""Populate the service info from the cache.
This method is designed to be threadsafe.
"""
return self._load_from_cache(zc, now or current_time_millis())

def _load_from_cache(self, zc: 'Zeroconf', now: float_) -> bool:
"""Populate the service info from the cache.
This method is designed to be threadsafe.
"""
if not now:
now = current_time_millis()
cache = zc.cache
original_server_key = self.server_key
cached_srv_record = zc.cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN)
cached_srv_record = cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN)
if cached_srv_record:
self._process_record_threadsafe(zc, cached_srv_record, now)
cached_txt_record = zc.cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN)
cached_txt_record = cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN)
if cached_txt_record:
self._process_record_threadsafe(zc, cached_txt_record, now)
if original_server_key == self.server_key:
# If there is a srv which changes the server_key,
# A and AAAA will already be loaded from the cache
# and we do not want to do it twice
for record in [
*self._get_address_records_from_cache_by_type(zc, _TYPE_A),
*self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA),
]:
for record in self._get_address_records_from_cache_by_type(zc, _TYPE_A):
self._process_record_threadsafe(zc, record, now)
for record in self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA):
self._process_record_threadsafe(zc, record, now)
return self._is_complete

Expand Down Expand Up @@ -720,7 +731,7 @@ async def async_request(

now = current_time_millis()

if self.load_from_cache(zc, now):
if self._load_from_cache(zc, now):
return True

if TYPE_CHECKING:
Expand All @@ -737,11 +748,13 @@ async def async_request(
return False
if next_ <= now:
out = self.generate_request_query(
zc, now, question_type or DNSQuestionType.QU if first_request else DNSQuestionType.QM
zc,
now,
question_type or DNS_QUESTION_TYPE_QU if first_request else DNS_QUESTION_TYPE_QM,
)
first_request = False
if not out.questions:
return self.load_from_cache(zc, now)
return self._load_from_cache(zc, now)
zc.async_send(out, addr, port)
next_ = now + delay
delay *= 2
Expand All @@ -755,7 +768,7 @@ async def async_request(
return True

def generate_request_query(
self, zc: 'Zeroconf', now: float, question_type: Optional[DNSQuestionType] = None
self, zc: 'Zeroconf', now: float_, question_type: Optional[DNSQuestionType] = None
) -> DNSOutgoing:
"""Generate the request query."""
out = DNSOutgoing(_FLAGS_QR_QUERY)
Expand All @@ -766,7 +779,7 @@ def generate_request_query(
out.add_question_or_one_cache(cache, now, name, _TYPE_TXT, _CLASS_IN)
out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_A, _CLASS_IN)
out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_AAAA, _CLASS_IN)
if question_type == DNSQuestionType.QU:
if question_type == DNS_QUESTION_TYPE_QU:
for question in out.questions:
question.unicast = True
return out
Expand Down

0 comments on commit 7ca690a

Please sign in to comment.