Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: speed up record updates #1301

Merged
merged 5 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def build(setup_kwargs: Any) -> None:
"src/zeroconf/_dns.py",
"src/zeroconf/_cache.py",
"src/zeroconf/_history.py",
"src/zeroconf/_record_update.py",
"src/zeroconf/_listener.py",
"src/zeroconf/_protocol/incoming.py",
"src/zeroconf/_protocol/outgoing.py",
Expand Down
6 changes: 3 additions & 3 deletions src/zeroconf/_cache.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ cdef class DNSCache:
)
cpdef async_all_by_details(self, str name, object type_, object class_)

cpdef async_entries_with_name(self, str name)
cpdef cython.dict async_entries_with_name(self, str name)

cpdef async_entries_with_server(self, str name)
cpdef cython.dict async_entries_with_server(self, str name)

@cython.locals(
cached_entry=DNSRecord,
Expand All @@ -57,7 +57,7 @@ cdef class DNSCache:
records=cython.dict,
entry=DNSRecord,
)
cpdef get_all_by_details(self, str name, object type_, object class_)
cpdef cython.list get_all_by_details(self, str name, object type_, object class_)

@cython.locals(
store=cython.dict,
Expand Down
14 changes: 7 additions & 7 deletions src/zeroconf/_handlers/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool:
if TYPE_CHECKING:
record = cast(_UniqueRecordsType, record)
maybe_entry = self._cache.async_get_unique(record)
return bool(maybe_entry and maybe_entry.is_recent(self._now))
return bool(maybe_entry is not None and maybe_entry.is_recent(self._now) is True)

def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:
"""Check if an answer was seen in the last second.
Expand All @@ -149,7 +149,7 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:
if TYPE_CHECKING:
record = cast(_UniqueRecordsType, record)
maybe_entry = self._cache.async_get_unique(record)
return bool(maybe_entry and self._now - maybe_entry.created < _ONE_SECOND)
return bool(maybe_entry is not None and self._now - maybe_entry.created < _ONE_SECOND)


class QueryHandler:
Expand All @@ -174,7 +174,7 @@ def _add_service_type_enumeration_query_answers(
dns_pointer = DNSPointer(
_SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, 0.0
)
if not known_answers.suppresses(dns_pointer):
if known_answers.suppresses(dns_pointer) is False:
answer_set[dns_pointer] = set()

def _add_pointer_answers(
Expand All @@ -185,7 +185,7 @@ def _add_pointer_answers(
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.1.
dns_pointer = service._dns_pointer(None)
if known_answers.suppresses(dns_pointer):
if known_answers.suppresses(dns_pointer) is True:
continue
answer_set[dns_pointer] = {
service._dns_service(None),
Expand All @@ -208,7 +208,7 @@ def _add_address_answers(
seen_types.add(dns_address.type)
if dns_address.type != type_:
additionals.add(dns_address)
elif not known_answers.suppresses(dns_address):
elif known_answers.suppresses(dns_address) is False:
answers.append(dns_address)
missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
if answers:
Expand Down Expand Up @@ -248,11 +248,11 @@ def _answer_question(
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.2.
dns_service = service._dns_service(None)
if not known_answers.suppresses(dns_service):
if known_answers.suppresses(dns_service) is False:
answer_set[dns_service] = service._get_address_and_nsec_records(None)
if type_ in (_TYPE_TXT, _TYPE_ANY):
dns_text = service._dns_text(None)
if not known_answers.suppresses(dns_text):
if known_answers.suppresses(dns_text) is False:
answer_set[dns_text] = set()

return answer_set
Expand Down
10 changes: 10 additions & 0 deletions src/zeroconf/_record_update.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

import cython

from ._dns cimport DNSRecord


cdef class RecordUpdate:

cdef public DNSRecord new
cdef public DNSRecord old
21 changes: 17 additions & 4 deletions src/zeroconf/_record_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,24 @@
USA
"""

from typing import NamedTuple, Optional
from typing import Optional

from ._dns import DNSRecord


class RecordUpdate(NamedTuple):
new: DNSRecord
old: Optional[DNSRecord]
class RecordUpdate:

__slots__ = ("new", "old")

def __init__(self, new: DNSRecord, old: Optional[DNSRecord] = None):
"""RecordUpdate represents a change in a DNS record."""
self.new = new
self.old = old

def __getitem__(self, index: int) -> Optional[DNSRecord]:
"""Get the new or old record."""
if index == 0:
return self.new
elif index == 1:
return self.old
raise IndexError(index)
7 changes: 5 additions & 2 deletions src/zeroconf/_services/browser.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import cython

from .._cache cimport DNSCache
from .._protocol.outgoing cimport DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord
from .._record_update cimport RecordUpdate
from .._updates cimport RecordUpdateListener
from .._utils.time cimport current_time_millis, millis_to_seconds
from . cimport Signal, SignalRegistrationInterface
Expand All @@ -13,6 +14,7 @@ cdef object cached_possible_types
cdef cython.uint _EXPIRE_REFRESH_TIME_PERCENT
cdef cython.uint _TYPE_PTR
cdef object SERVICE_STATE_CHANGE_ADDED, SERVICE_STATE_CHANGE_REMOVED, SERVICE_STATE_CHANGE_UPDATED
cdef cython.set _ADDRESS_RECORD_TYPES

cdef class _DNSPointerOutgoingBucket:

Expand Down Expand Up @@ -43,6 +45,7 @@ cdef class _ServiceBrowserBase(RecordUpdateListener):

cdef public cython.set types
cdef public object zc
cdef DNSCache _cache
cdef object _loop
cdef public object addr
cdef public object port
Expand All @@ -60,10 +63,10 @@ cdef class _ServiceBrowserBase(RecordUpdateListener):

cpdef _enqueue_callback(self, object state_change, object type_, object name)

@cython.locals(record=DNSRecord, cache=DNSCache, service=DNSRecord, pointer=DNSPointer)
@cython.locals(record_update=RecordUpdate, record=DNSRecord, cache=DNSCache, service=DNSRecord, pointer=DNSPointer)
cpdef async_update_records(self, object zc, cython.float now, cython.list records)

cpdef _names_matching_types(self, object types)
cpdef cython.list _names_matching_types(self, object types)

cpdef reschedule_type(self, object type_, object now, object next_time)

Expand Down
15 changes: 8 additions & 7 deletions src/zeroconf/_services/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ class _ServiceBrowserBase(RecordUpdateListener):
__slots__ = (
'types',
'zc',
'_cache',
'_loop',
'addr',
'port',
Expand Down Expand Up @@ -345,6 +346,7 @@ def __init__(
# Will generate BadTypeInNameException on a bad name
service_type_name(check_type_, strict=False)
self.zc = zc
self._cache = zc.cache
assert zc.loop is not None
self._loop = zc.loop
self.addr = addr
Expand Down Expand Up @@ -421,8 +423,8 @@ def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[Record
This method will be run in the event loop.
"""
for record_update in records:
record = record_update[0]
old_record = record_update[1]
record = record_update.new
old_record = record_update.old
record_type = record.type

if record_type is _TYPE_PTR:
Expand All @@ -440,15 +442,14 @@ def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[Record
continue

# If its expired or already exists in the cache it cannot be updated.
if old_record or record.is_expired(now) is True:
if old_record is not None or record.is_expired(now) is True:
continue

if record_type in _ADDRESS_RECORD_TYPES:
cache = self.zc.cache
cache = self._cache
names = {service.name for service in cache.async_entries_with_server(record.name)}
# Iterate through the DNSCache and callback any services that use this address
for type_, name in self._names_matching_types(
{service.name for service in cache.async_entries_with_server(record.name)}
):
for type_, name in self._names_matching_types(names):
self._enqueue_callback(SERVICE_STATE_CHANGE_UPDATED, type_, name)
continue

Expand Down
8 changes: 6 additions & 2 deletions src/zeroconf/_services/info.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import cython
from .._cache cimport DNSCache
from .._dns cimport DNSAddress, DNSNsec, DNSPointer, DNSRecord, DNSService, DNSText
from .._protocol.outgoing cimport DNSOutgoing
from .._record_update cimport RecordUpdate
from .._updates cimport RecordUpdateListener
from .._utils.time cimport current_time_millis

Expand Down Expand Up @@ -56,7 +57,7 @@ cdef class ServiceInfo(RecordUpdateListener):
cdef public cython.list _dns_address_cache
cdef public cython.set _get_address_and_nsec_records_cache

@cython.locals(cache=DNSCache)
@cython.locals(record_update=RecordUpdate, update=bint, cache=DNSCache)
cpdef async_update_records(self, object zc, cython.float now, cython.list records)

@cython.locals(cache=DNSCache)
Expand All @@ -76,7 +77,7 @@ cdef class ServiceInfo(RecordUpdateListener):
dns_text_record=DNSText,
dns_address_record=DNSAddress
)
cdef _process_record_threadsafe(self, object zc, DNSRecord record, cython.float now)
cdef bint _process_record_threadsafe(self, object zc, DNSRecord record, cython.float now)

@cython.locals(cache=DNSCache)
cdef cython.list _get_address_records_from_cache_by_type(self, object zc, object _type)
Expand Down Expand Up @@ -109,3 +110,6 @@ cdef class ServiceInfo(RecordUpdateListener):
cdef cython.set _get_address_and_nsec_records(self, object override_ttl)

cpdef async_clear_cache(self)

@cython.locals(cache=DNSCache)
cdef _generate_request_query(self, object zc, object now, object question_type)
12 changes: 6 additions & 6 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def _get_ip_addresses_from_cache_lifo(
"""Set IPv6 addresses from the cache."""
address_list: List[Union[IPv4Address, IPv6Address]] = []
for record in self._get_address_records_from_cache_by_type(zc, type):
if record.is_expired(now):
if record.is_expired(now) is True:
continue
ip_addr = _cached_ip_addresses_wrapper(record.address)
if ip_addr is not None:
Expand Down Expand Up @@ -463,7 +463,7 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo

Returns True if a new record was added.
"""
if record.is_expired(now):
if record.is_expired(now) is True:
return False

record_key = record.key
Expand Down Expand Up @@ -779,7 +779,7 @@ async def async_request(

now = current_time_millis()

if self._load_from_cache(zc, now):
if self._load_from_cache(zc, now) is True:
return True

if TYPE_CHECKING:
Expand All @@ -795,7 +795,7 @@ async def async_request(
if last <= now:
return False
if next_ <= now:
out = self.generate_request_query(
out = self._generate_request_query(
zc,
now,
question_type or DNS_QUESTION_TYPE_QU if first_request else DNS_QUESTION_TYPE_QM,
Expand All @@ -815,8 +815,8 @@ async def async_request(

return True

def generate_request_query(
self, zc: 'Zeroconf', now: float_, question_type: Optional[DNSQuestionType] = None
def _generate_request_query(
self, zc: 'Zeroconf', now: float_, question_type: DNSQuestionType
) -> DNSOutgoing:
"""Generate the request query."""
out = DNSOutgoing(_FLAGS_QR_QUERY)
Expand Down
2 changes: 1 addition & 1 deletion src/zeroconf/_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[Record
This method will be run in the event loop.
"""
for record in records:
self.update_record(zc, now, record[0])
self.update_record(zc, now, record.new)

def async_update_records_complete(self) -> None:
"""Called when a record update has completed for all handlers.
Expand Down
14 changes: 14 additions & 0 deletions tests/test_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import zeroconf as r
from zeroconf import Zeroconf, const
from zeroconf._record_update import RecordUpdate
from zeroconf._services.browser import ServiceBrowser
from zeroconf._services.info import ServiceInfo

Expand Down Expand Up @@ -87,3 +88,16 @@ def on_service_state_change(zeroconf, service_type, state_change, name):
zc.remove_listener(listener)

zc.close()


def test_record_update_compat():
"""Test a RecordUpdate can fetch by index."""
new = r.DNSPointer('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 'new')
old = r.DNSPointer('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 'old')
update = RecordUpdate(new, old)
assert update[0] == new
assert update[1] == old
with pytest.raises(IndexError):
update[2]
assert update.new == new
assert update.old == old