Skip to content

Commit

Permalink
feat: improve performance responding to queries (#1217)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Aug 14, 2023
1 parent 844c554 commit 69b33be
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
73 changes: 42 additions & 31 deletions src/zeroconf/_services/info.py
Expand Up @@ -46,7 +46,7 @@
from ..const import (
_ADDRESS_RECORD_TYPES,
_CLASS_IN,
_CLASS_UNIQUE,
_CLASS_IN_UNIQUE,
_DNS_HOST_TTL,
_DNS_OTHER_TTL,
_FLAGS_QR_QUERY,
Expand Down Expand Up @@ -388,7 +388,7 @@ def _unpack_text_into_properties(self) -> None:

def get_name(self) -> str:
"""Name accessor"""
return self.name[: len(self.name) - len(self.type) - 1]
return self._name[: len(self._name) - len(self.type) - 1]

def _get_ip_addresses_from_cache_lifo(
self, zc: 'Zeroconf', now: float, type: int
Expand All @@ -409,15 +409,21 @@ def _get_ip_addresses_from_cache_lifo(

def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
"""Set IPv6 addresses from the cache."""
self._ipv6_addresses = cast(
"List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)
)
if TYPE_CHECKING:
self._ipv6_addresses = cast(
"List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)
)
else:
self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)

def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
"""Set IPv4 addresses from the cache."""
self._ipv4_addresses = cast(
"List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)
)
if TYPE_CHECKING:
self._ipv4_addresses = cast(
"List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)
)
else:
self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)

def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None:
"""Updates service information from a DNS record.
Expand Down Expand Up @@ -523,9 +529,9 @@ def dns_addresses(
created: Optional[float] = None,
) -> List[DNSAddress]:
"""Return matching DNSAddress from ServiceInfo."""
name = self.server or self.name
name = self.server or self._name
ttl = override_ttl if override_ttl is not None else self.host_ttl
class_ = _CLASS_IN | _CLASS_UNIQUE
class_ = _CLASS_IN_UNIQUE
version_value = version.value
return [
DNSAddress(
Expand All @@ -546,30 +552,33 @@ def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[floa
_TYPE_PTR,
_CLASS_IN,
override_ttl if override_ttl is not None else self.other_ttl,
self.name,
self._name,
created,
)

def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService:
"""Return DNSService from ServiceInfo."""
port = self.port
if TYPE_CHECKING:
assert isinstance(port, int)
return DNSService(
self.name,
self._name,
_TYPE_SRV,
_CLASS_IN | _CLASS_UNIQUE,
_CLASS_IN_UNIQUE,
override_ttl if override_ttl is not None else self.host_ttl,
self.priority,
self.weight,
cast(int, self.port),
self.server or self.name,
port,
self.server or self._name,
created,
)

def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText:
"""Return DNSText from ServiceInfo."""
return DNSText(
self.name,
self._name,
_TYPE_TXT,
_CLASS_IN | _CLASS_UNIQUE,
_CLASS_IN_UNIQUE,
override_ttl if override_ttl is not None else self.other_ttl,
self.text,
created,
Expand All @@ -580,11 +589,11 @@ def dns_nsec(
) -> DNSNsec:
"""Return DNSNsec from ServiceInfo."""
return DNSNsec(
self.name,
self._name,
_TYPE_NSEC,
_CLASS_IN | _CLASS_UNIQUE,
_CLASS_IN_UNIQUE,
override_ttl if override_ttl is not None else self.host_ttl,
self.name,
self._name,
missing_types,
created,
)
Expand All @@ -593,12 +602,11 @@ def get_address_and_nsec_records(
self, override_ttl: Optional[int] = None, created: Optional[float] = None
) -> Set[DNSRecord]:
"""Build a set of address records and NSEC records for non-present record types."""
seen_types: Set[int] = set()
missing_types: Set[int] = _ADDRESS_RECORD_TYPES.copy()
records: Set[DNSRecord] = set()
for dns_address in self.dns_addresses(override_ttl, IPVersion.All, created):
seen_types.add(dns_address.type)
missing_types.discard(dns_address.type)
records.add(dns_address)
missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
if missing_types:
assert self.server is not None, "Service server must be set for NSEC record."
records.add(self.dns_nsec(list(missing_types), override_ttl, created))
Expand All @@ -616,7 +624,7 @@ def set_server_if_missing(self) -> None:
This function is for backwards compatibility.
"""
if self.server is None:
self.server = self.name
self.server = self._name
self.server_key = self.server.lower()

def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool:
Expand All @@ -627,10 +635,10 @@ def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool:
if not now:
now = current_time_millis()
original_server_key = self.server_key
cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN)
cached_srv_record = zc.cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN)
if cached_srv_record:
self._process_record_threadsafe(zc, cached_srv_record, now)
cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN)
cached_txt_record = zc.cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN)
if cached_txt_record:
self._process_record_threadsafe(zc, cached_txt_record, now)
if original_server_key == self.server_key:
Expand Down Expand Up @@ -732,18 +740,21 @@ def generate_request_query(
) -> DNSOutgoing:
"""Generate the request query."""
out = DNSOutgoing(_FLAGS_QR_QUERY)
out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN)
out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN)
out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_A, _CLASS_IN)
out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_AAAA, _CLASS_IN)
name = self._name
server_or_name = self.server or name
cache = zc.cache
out.add_question_or_one_cache(cache, now, name, _TYPE_SRV, _CLASS_IN)
out.add_question_or_one_cache(cache, now, name, _TYPE_TXT, _CLASS_IN)
out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_A, _CLASS_IN)
out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_AAAA, _CLASS_IN)
if question_type == DNSQuestionType.QU:
for question in out.questions:
question.unicast = True
return out

def __eq__(self, other: object) -> bool:
"""Tests equality of service name"""
return isinstance(other, ServiceInfo) and other.name == self.name
return isinstance(other, ServiceInfo) and other._name == self._name

def __repr__(self) -> str:
"""String representation"""
Expand Down
1 change: 1 addition & 0 deletions src/zeroconf/const.py
Expand Up @@ -84,6 +84,7 @@
_CLASS_ANY = 255
_CLASS_MASK = 0x7FFF
_CLASS_UNIQUE = 0x8000
_CLASS_IN_UNIQUE = _CLASS_IN | _CLASS_UNIQUE

_TYPE_A = 1
_TYPE_NS = 2
Expand Down

0 comments on commit 69b33be

Please sign in to comment.