From 69b33be3b2f9d4a27ef5154cae94afca048efffa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Aug 2023 20:40:46 -0500 Subject: [PATCH] feat: improve performance responding to queries (#1217) --- src/zeroconf/_services/info.py | 73 +++++++++++++++++++--------------- src/zeroconf/const.py | 1 + 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index 29ddb9a0..2f4ae59e 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -46,7 +46,7 @@ from ..const import ( _ADDRESS_RECORD_TYPES, _CLASS_IN, - _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, _DNS_HOST_TTL, _DNS_OTHER_TTL, _FLAGS_QR_QUERY, @@ -388,7 +388,7 @@ def _unpack_text_into_properties(self) -> None: def get_name(self) -> str: """Name accessor""" - return self.name[: len(self.name) - len(self.type) - 1] + return self._name[: len(self._name) - len(self.type) - 1] def _get_ip_addresses_from_cache_lifo( self, zc: 'Zeroconf', now: float, type: int @@ -409,15 +409,21 @@ def _get_ip_addresses_from_cache_lifo( def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: """Set IPv6 addresses from the cache.""" - self._ipv6_addresses = cast( - "List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) - ) + if TYPE_CHECKING: + self._ipv6_addresses = cast( + "List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) + ) + else: + self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: """Set IPv4 addresses from the cache.""" - self._ipv4_addresses = cast( - "List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) - ) + if TYPE_CHECKING: + self._ipv4_addresses = cast( + "List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) + ) + else: + self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: """Updates service information from a DNS record. @@ -523,9 +529,9 @@ def dns_addresses( created: Optional[float] = None, ) -> List[DNSAddress]: """Return matching DNSAddress from ServiceInfo.""" - name = self.server or self.name + name = self.server or self._name ttl = override_ttl if override_ttl is not None else self.host_ttl - class_ = _CLASS_IN | _CLASS_UNIQUE + class_ = _CLASS_IN_UNIQUE version_value = version.value return [ DNSAddress( @@ -546,30 +552,33 @@ def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[floa _TYPE_PTR, _CLASS_IN, override_ttl if override_ttl is not None else self.other_ttl, - self.name, + self._name, created, ) def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService: """Return DNSService from ServiceInfo.""" + port = self.port + if TYPE_CHECKING: + assert isinstance(port, int) return DNSService( - self.name, + self._name, _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, override_ttl if override_ttl is not None else self.host_ttl, self.priority, self.weight, - cast(int, self.port), - self.server or self.name, + port, + self.server or self._name, created, ) def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText: """Return DNSText from ServiceInfo.""" return DNSText( - self.name, + self._name, _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, override_ttl if override_ttl is not None else self.other_ttl, self.text, created, @@ -580,11 +589,11 @@ def dns_nsec( ) -> DNSNsec: """Return DNSNsec from ServiceInfo.""" return DNSNsec( - self.name, + self._name, _TYPE_NSEC, - _CLASS_IN | _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, override_ttl if override_ttl is not None else self.host_ttl, - self.name, + self._name, missing_types, created, ) @@ -593,12 +602,11 @@ def get_address_and_nsec_records( self, override_ttl: Optional[int] = None, created: Optional[float] = None ) -> Set[DNSRecord]: """Build a set of address records and NSEC records for non-present record types.""" - seen_types: Set[int] = set() + missing_types: Set[int] = _ADDRESS_RECORD_TYPES.copy() records: Set[DNSRecord] = set() for dns_address in self.dns_addresses(override_ttl, IPVersion.All, created): - seen_types.add(dns_address.type) + missing_types.discard(dns_address.type) records.add(dns_address) - missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types if missing_types: assert self.server is not None, "Service server must be set for NSEC record." records.add(self.dns_nsec(list(missing_types), override_ttl, created)) @@ -616,7 +624,7 @@ def set_server_if_missing(self) -> None: This function is for backwards compatibility. """ if self.server is None: - self.server = self.name + self.server = self._name self.server_key = self.server.lower() def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool: @@ -627,10 +635,10 @@ def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool: if not now: now = current_time_millis() original_server_key = self.server_key - cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) + cached_srv_record = zc.cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN) if cached_srv_record: self._process_record_threadsafe(zc, cached_srv_record, now) - cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) + cached_txt_record = zc.cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN) if cached_txt_record: self._process_record_threadsafe(zc, cached_txt_record, now) if original_server_key == self.server_key: @@ -732,10 +740,13 @@ def generate_request_query( ) -> DNSOutgoing: """Generate the request query.""" out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) - out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_A, _CLASS_IN) - out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_AAAA, _CLASS_IN) + name = self._name + server_or_name = self.server or name + cache = zc.cache + out.add_question_or_one_cache(cache, now, name, _TYPE_SRV, _CLASS_IN) + out.add_question_or_one_cache(cache, now, name, _TYPE_TXT, _CLASS_IN) + out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_AAAA, _CLASS_IN) if question_type == DNSQuestionType.QU: for question in out.questions: question.unicast = True @@ -743,7 +754,7 @@ def generate_request_query( def __eq__(self, other: object) -> bool: """Tests equality of service name""" - return isinstance(other, ServiceInfo) and other.name == self.name + return isinstance(other, ServiceInfo) and other._name == self._name def __repr__(self) -> str: """String representation""" diff --git a/src/zeroconf/const.py b/src/zeroconf/const.py index f87c1336..ca199df5 100644 --- a/src/zeroconf/const.py +++ b/src/zeroconf/const.py @@ -84,6 +84,7 @@ _CLASS_ANY = 255 _CLASS_MASK = 0x7FFF _CLASS_UNIQUE = 0x8000 +_CLASS_IN_UNIQUE = _CLASS_IN | _CLASS_UNIQUE _TYPE_A = 1 _TYPE_NS = 2