diff --git a/build_ext.py b/build_ext.py index 8c39f495..b9deeecb 100644 --- a/build_ext.py +++ b/build_ext.py @@ -32,6 +32,7 @@ def build(setup_kwargs: Any) -> None: "src/zeroconf/_handlers/answers.py", "src/zeroconf/_handlers/record_manager.py", "src/zeroconf/_handlers/query_handler.py", + "src/zeroconf/_services/info.py", "src/zeroconf/_services/registry.py", "src/zeroconf/_updates.py", "src/zeroconf/_utils/time.py", diff --git a/src/zeroconf/_services/info.pxd b/src/zeroconf/_services/info.pxd new file mode 100644 index 00000000..b06ea88d --- /dev/null +++ b/src/zeroconf/_services/info.pxd @@ -0,0 +1,87 @@ + +import cython + +from .._cache cimport DNSCache +from .._dns cimport DNSPointer, DNSRecord, DNSService, DNSText +from .._protocol.outgoing cimport DNSOutgoing +from .._updates cimport RecordUpdateListener +from .._utils.time cimport current_time_millis + + +cdef object _resolve_all_futures_to_none + +cdef object _TYPE_SRV +cdef object _TYPE_TXT +cdef object _TYPE_A +cdef object _TYPE_AAAA +cdef object _TYPE_PTR +cdef object _TYPE_NSEC +cdef object _CLASS_IN +cdef object _FLAGS_QR_QUERY + +cdef object service_type_name + +cdef object DNS_QUESTION_TYPE_QU +cdef object DNS_QUESTION_TYPE_QM + +cdef object _IPVersion_All_value +cdef object _IPVersion_V4Only_value + +cdef object TYPE_CHECKING + +cdef class ServiceInfo(RecordUpdateListener): + + cdef public cython.bytes text + cdef public str type + cdef str _name + cdef public str key + cdef public cython.list _ipv4_addresses + cdef public cython.list _ipv6_addresses + cdef public object port + cdef public object weight + cdef public object priority + cdef public str server + cdef public str server_key + cdef public cython.dict _properties + cdef public object host_ttl + cdef public object other_ttl + cdef public object interface_index + cdef public cython.set _new_records_futures + cdef public DNSPointer _dns_pointer_cache + cdef public DNSService _dns_service_cache + cdef public DNSText _dns_text_cache + cdef public cython.list _dns_address_cache + cdef public cython.set _get_address_and_nsec_records_cache + + @cython.locals( + cache=DNSCache + ) + cpdef async_update_records(self, object zc, object now, cython.list records) + + @cython.locals( + cache=DNSCache + ) + cpdef _load_from_cache(self, object zc, object now) + + cdef _unpack_text_into_properties(self) + + cdef _set_properties(self, cython.dict properties) + + cdef _set_text(self, cython.bytes text) + + cdef _get_ip_addresses_from_cache_lifo(self, object zc, object now, object type) + + cdef _process_record_threadsafe(self, object zc, DNSRecord record, object now) + + @cython.locals( + cache=DNSCache + ) + cdef cython.list _get_address_records_from_cache_by_type(self, object zc, object _type) + + cdef _set_ipv4_addresses_from_cache(self, object zc, object now) + + cdef _set_ipv6_addresses_from_cache(self, object zc, object now) + + cdef cython.list _ip_addresses_by_version_value(self, object version_value) + + cdef addresses_by_version(self, object version) diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index 14398b6a..425ad750 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -78,6 +78,12 @@ # the A/AAAA/SRV records for a host. _AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120) +float_ = float +int_ = int + +DNS_QUESTION_TYPE_QU = DNSQuestionType.QU +DNS_QUESTION_TYPE_QM = DNSQuestionType.QM + if TYPE_CHECKING: from .._core import Zeroconf @@ -281,10 +287,9 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]: """ version_value = version.value if version_value == _IPVersion_All_value: - return [ - *(addr.packed for addr in self._ipv4_addresses), - *(addr.packed for addr in self._ipv6_addresses), - ] + ip_v4_packed = [addr.packed for addr in self._ipv4_addresses] + ip_v6_packed = [addr.packed for addr in self._ipv6_addresses] + return [*ip_v4_packed, *ip_v6_packed] if version_value == _IPVersion_V4Only_value: return [addr.packed for addr in self._ipv4_addresses] return [addr.packed for addr in self._ipv6_addresses] @@ -303,7 +308,7 @@ def ip_addresses_by_version( return self._ip_addresses_by_version_value(version.value) def _ip_addresses_by_version_value( - self, version_value: int + self, version_value: int_ ) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]: """Backend for addresses_by_version that uses the raw value.""" if version_value == _IPVersion_All_value: @@ -397,7 +402,7 @@ def get_name(self) -> str: return self._name[: len(self._name) - len(self.type) - 1] def _get_ip_addresses_from_cache_lifo( - self, zc: 'Zeroconf', now: float, type: int + self, zc: 'Zeroconf', now: float_, type: int_ ) -> List[Union[IPv4Address, IPv6Address]]: """Set IPv6 addresses from the cache.""" address_list: List[Union[IPv4Address, IPv6Address]] = [] @@ -410,7 +415,7 @@ def _get_ip_addresses_from_cache_lifo( address_list.reverse() # Reverse to get LIFO order return address_list - def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: + def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None: """Set IPv6 addresses from the cache.""" if TYPE_CHECKING: self._ipv6_addresses = cast( @@ -419,7 +424,7 @@ def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: else: self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) - def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: + def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None: """Set IPv4 addresses from the cache.""" if TYPE_CHECKING: self._ipv4_addresses = cast( @@ -428,7 +433,7 @@ def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: else: self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) - def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None: + def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[RecordUpdate]) -> None: """Updates service information from a DNS record. This method will be run in the event loop. @@ -440,7 +445,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU if updated and new_records_futures: _resolve_all_futures_to_none(new_records_futures) - def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float) -> bool: + def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float_) -> bool: """Thread safe record updating. Returns True if a new record was added. @@ -624,14 +629,15 @@ def get_address_and_nsec_records(self, override_ttl: Optional[int] = None) -> Se self._get_address_and_nsec_records_cache = records return records - def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]: + def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int_) -> List[DNSAddress]: """Get the addresses from the cache.""" if self.server_key is None: return [] + cache = zc.cache if TYPE_CHECKING: - records = cast("List[DNSAddress]", zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN)) + records = cast("List[DNSAddress]", cache.get_all_by_details(self.server_key, _type, _CLASS_IN)) else: - records = zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN) + records = cache.get_all_by_details(self.server_key, _type, _CLASS_IN) return records def set_server_if_missing(self) -> None: @@ -643,28 +649,33 @@ def set_server_if_missing(self) -> None: self.server = self._name self.server_key = self.key - def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool: + def load_from_cache(self, zc: 'Zeroconf', now: Optional[float_] = None) -> bool: + """Populate the service info from the cache. + + This method is designed to be threadsafe. + """ + return self._load_from_cache(zc, now or current_time_millis()) + + def _load_from_cache(self, zc: 'Zeroconf', now: float_) -> bool: """Populate the service info from the cache. This method is designed to be threadsafe. """ - if not now: - now = current_time_millis() + cache = zc.cache original_server_key = self.server_key - cached_srv_record = zc.cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN) + cached_srv_record = cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN) if cached_srv_record: self._process_record_threadsafe(zc, cached_srv_record, now) - cached_txt_record = zc.cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN) + cached_txt_record = cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN) if cached_txt_record: self._process_record_threadsafe(zc, cached_txt_record, now) if original_server_key == self.server_key: # If there is a srv which changes the server_key, # A and AAAA will already be loaded from the cache # and we do not want to do it twice - for record in [ - *self._get_address_records_from_cache_by_type(zc, _TYPE_A), - *self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA), - ]: + for record in self._get_address_records_from_cache_by_type(zc, _TYPE_A): + self._process_record_threadsafe(zc, record, now) + for record in self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA): self._process_record_threadsafe(zc, record, now) return self._is_complete @@ -720,7 +731,7 @@ async def async_request( now = current_time_millis() - if self.load_from_cache(zc, now): + if self._load_from_cache(zc, now): return True if TYPE_CHECKING: @@ -737,11 +748,13 @@ async def async_request( return False if next_ <= now: out = self.generate_request_query( - zc, now, question_type or DNSQuestionType.QU if first_request else DNSQuestionType.QM + zc, + now, + question_type or DNS_QUESTION_TYPE_QU if first_request else DNS_QUESTION_TYPE_QM, ) first_request = False if not out.questions: - return self.load_from_cache(zc, now) + return self._load_from_cache(zc, now) zc.async_send(out, addr, port) next_ = now + delay delay *= 2 @@ -755,7 +768,7 @@ async def async_request( return True def generate_request_query( - self, zc: 'Zeroconf', now: float, question_type: Optional[DNSQuestionType] = None + self, zc: 'Zeroconf', now: float_, question_type: Optional[DNSQuestionType] = None ) -> DNSOutgoing: """Generate the request query.""" out = DNSOutgoing(_FLAGS_QR_QUERY) @@ -766,7 +779,7 @@ def generate_request_query( out.add_question_or_one_cache(cache, now, name, _TYPE_TXT, _CLASS_IN) out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_A, _CLASS_IN) out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_AAAA, _CLASS_IN) - if question_type == DNSQuestionType.QU: + if question_type == DNS_QUESTION_TYPE_QU: for question in out.questions: question.unicast = True return out