Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve performance responding to queries #1217

Merged
merged 2 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
73 changes: 42 additions & 31 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ..const import (
_ADDRESS_RECORD_TYPES,
_CLASS_IN,
_CLASS_UNIQUE,
_CLASS_IN_UNIQUE,
_DNS_HOST_TTL,
_DNS_OTHER_TTL,
_FLAGS_QR_QUERY,
Expand Down Expand Up @@ -388,7 +388,7 @@ def _unpack_text_into_properties(self) -> None:

def get_name(self) -> str:
"""Name accessor"""
return self.name[: len(self.name) - len(self.type) - 1]
return self._name[: len(self._name) - len(self.type) - 1]

def _get_ip_addresses_from_cache_lifo(
self, zc: 'Zeroconf', now: float, type: int
Expand All @@ -409,15 +409,21 @@ def _get_ip_addresses_from_cache_lifo(

def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
"""Set IPv6 addresses from the cache."""
self._ipv6_addresses = cast(
"List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)
)
if TYPE_CHECKING:
self._ipv6_addresses = cast(
"List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)
)
else:
self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)

def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
"""Set IPv4 addresses from the cache."""
self._ipv4_addresses = cast(
"List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)
)
if TYPE_CHECKING:
self._ipv4_addresses = cast(
"List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)
)
else:
self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)

def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None:
"""Updates service information from a DNS record.
Expand Down Expand Up @@ -523,9 +529,9 @@ def dns_addresses(
created: Optional[float] = None,
) -> List[DNSAddress]:
"""Return matching DNSAddress from ServiceInfo."""
name = self.server or self.name
name = self.server or self._name
ttl = override_ttl if override_ttl is not None else self.host_ttl
class_ = _CLASS_IN | _CLASS_UNIQUE
class_ = _CLASS_IN_UNIQUE
version_value = version.value
return [
DNSAddress(
Expand All @@ -546,30 +552,33 @@ def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[floa
_TYPE_PTR,
_CLASS_IN,
override_ttl if override_ttl is not None else self.other_ttl,
self.name,
self._name,
created,
)

def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService:
"""Return DNSService from ServiceInfo."""
port = self.port
if TYPE_CHECKING:
assert isinstance(port, int)
return DNSService(
self.name,
self._name,
_TYPE_SRV,
_CLASS_IN | _CLASS_UNIQUE,
_CLASS_IN_UNIQUE,
override_ttl if override_ttl is not None else self.host_ttl,
self.priority,
self.weight,
cast(int, self.port),
self.server or self.name,
port,
self.server or self._name,
created,
)

def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText:
"""Return DNSText from ServiceInfo."""
return DNSText(
self.name,
self._name,
_TYPE_TXT,
_CLASS_IN | _CLASS_UNIQUE,
_CLASS_IN_UNIQUE,
override_ttl if override_ttl is not None else self.other_ttl,
self.text,
created,
Expand All @@ -580,11 +589,11 @@ def dns_nsec(
) -> DNSNsec:
"""Return DNSNsec from ServiceInfo."""
return DNSNsec(
self.name,
self._name,
_TYPE_NSEC,
_CLASS_IN | _CLASS_UNIQUE,
_CLASS_IN_UNIQUE,
override_ttl if override_ttl is not None else self.host_ttl,
self.name,
self._name,
missing_types,
created,
)
Expand All @@ -593,12 +602,11 @@ def get_address_and_nsec_records(
self, override_ttl: Optional[int] = None, created: Optional[float] = None
) -> Set[DNSRecord]:
"""Build a set of address records and NSEC records for non-present record types."""
seen_types: Set[int] = set()
missing_types: Set[int] = _ADDRESS_RECORD_TYPES.copy()
records: Set[DNSRecord] = set()
for dns_address in self.dns_addresses(override_ttl, IPVersion.All, created):
seen_types.add(dns_address.type)
missing_types.discard(dns_address.type)
records.add(dns_address)
missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
if missing_types:
assert self.server is not None, "Service server must be set for NSEC record."
records.add(self.dns_nsec(list(missing_types), override_ttl, created))
Expand All @@ -616,7 +624,7 @@ def set_server_if_missing(self) -> None:
This function is for backwards compatibility.
"""
if self.server is None:
self.server = self.name
self.server = self._name
self.server_key = self.server.lower()

def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool:
Expand All @@ -627,10 +635,10 @@ def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool:
if not now:
now = current_time_millis()
original_server_key = self.server_key
cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN)
cached_srv_record = zc.cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN)
if cached_srv_record:
self._process_record_threadsafe(zc, cached_srv_record, now)
cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN)
cached_txt_record = zc.cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN)
if cached_txt_record:
self._process_record_threadsafe(zc, cached_txt_record, now)
if original_server_key == self.server_key:
Expand Down Expand Up @@ -732,18 +740,21 @@ def generate_request_query(
) -> DNSOutgoing:
"""Generate the request query."""
out = DNSOutgoing(_FLAGS_QR_QUERY)
out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN)
out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN)
out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_A, _CLASS_IN)
out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_AAAA, _CLASS_IN)
name = self._name
server_or_name = self.server or name
cache = zc.cache
out.add_question_or_one_cache(cache, now, name, _TYPE_SRV, _CLASS_IN)
out.add_question_or_one_cache(cache, now, name, _TYPE_TXT, _CLASS_IN)
out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_A, _CLASS_IN)
out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_AAAA, _CLASS_IN)
if question_type == DNSQuestionType.QU:
for question in out.questions:
question.unicast = True
return out

def __eq__(self, other: object) -> bool:
"""Tests equality of service name"""
return isinstance(other, ServiceInfo) and other.name == self.name
return isinstance(other, ServiceInfo) and other._name == self._name

def __repr__(self) -> str:
"""String representation"""
Expand Down
1 change: 1 addition & 0 deletions src/zeroconf/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
_CLASS_ANY = 255
_CLASS_MASK = 0x7FFF
_CLASS_UNIQUE = 0x8000
_CLASS_IN_UNIQUE = _CLASS_IN | _CLASS_UNIQUE

_TYPE_A = 1
_TYPE_NS = 2
Expand Down