Skip to content

Commit

Permalink
feat: speed up processing incoming records (#1216)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Aug 13, 2023
1 parent 8a9dc0b commit aff625d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
4 changes: 4 additions & 0 deletions src/zeroconf/_dns.pxd
Expand Up @@ -44,6 +44,8 @@ cdef class DNSRecord(DNSEntry):
)
cpdef suppressed_by(self, object msg)

cpdef get_remaining_ttl(self, cython.float now)

cpdef get_expiration_time(self, cython.uint percent)

cpdef is_expired(self, cython.float now)
Expand All @@ -54,6 +56,8 @@ cdef class DNSRecord(DNSEntry):

cpdef reset_ttl(self, DNSRecord other)

cpdef set_created_ttl(self, cython.float now, cython.float ttl)

cdef class DNSAddress(DNSRecord):

cdef public cython.int _hash
Expand Down
7 changes: 4 additions & 3 deletions src/zeroconf/_dns.py
Expand Up @@ -26,7 +26,7 @@

from ._exceptions import AbstractMethodException
from ._utils.net import _is_v6_address
from ._utils.time import current_time_millis, millis_to_seconds
from ._utils.time import current_time_millis
from .const import _CLASS_MASK, _CLASS_UNIQUE, _CLASSES, _TYPE_ANY, _TYPES

_LEN_BYTE = 1
Expand Down Expand Up @@ -193,7 +193,8 @@ def get_expiration_time(self, percent: _int) -> float:
# TODO: Switch to just int here
def get_remaining_ttl(self, now: _float) -> Union[int, float]:
"""Returns the remaining TTL in seconds."""
return max(0, millis_to_seconds((self.created + (_EXPIRE_FULL_TIME_MS * self.ttl)) - now))
remain = (self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) - now) / 1000.0
return 0 if remain < 0 else remain

def is_expired(self, now: _float) -> bool:
"""Returns true if this record has expired."""
Expand All @@ -212,7 +213,7 @@ def reset_ttl(self, other) -> None: # type: ignore[no-untyped-def]
another record."""
self.set_created_ttl(other.created, other.ttl)

def set_created_ttl(self, created: float, ttl: Union[float, int]) -> None:
def set_created_ttl(self, created: _float, ttl: Union[float, int]) -> None:
"""Set the created and ttl of a record."""
self.created = created
self.ttl = ttl
Expand Down
15 changes: 9 additions & 6 deletions src/zeroconf/_handlers.py
Expand Up @@ -408,6 +408,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
removes: Set[DNSRecord] = set()
now = msg.now
unique_types: Set[Tuple[str, int, int]] = set()
cache = self.cache

for record in msg.answers:
# Protect zeroconf from records that can cause denial of service.
Expand All @@ -416,7 +417,9 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
# ServiceBrowsers generating excessive queries refresh queries.
# Apple uses a 15s minimum TTL, however we do not have the same
# level of rate limit and safe guards so we use 1/4 of the recommended value.
if record.ttl and record.type == _TYPE_PTR and record.ttl < _DNS_PTR_MIN_TTL:
record_type = record.type
record_ttl = record.ttl
if record_ttl and record_type == _TYPE_PTR and record_ttl < _DNS_PTR_MIN_TTL:
log.debug(
"Increasing effective ttl of %s to minimum of %s to protect against excessive refreshes.",
record,
Expand All @@ -425,12 +428,12 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL)

if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2
unique_types.add((record.name, record.type, record.class_))
unique_types.add((record.name, record_type, record.class_))

if TYPE_CHECKING:
record = cast(_UniqueRecordsType, record)

maybe_entry = self.cache.async_get_unique(record)
maybe_entry = cache.async_get_unique(record)
if not record.is_expired(now):
if maybe_entry is not None:
maybe_entry.reset_ttl(record)
Expand All @@ -447,7 +450,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
removes.add(record)

if unique_types:
self.cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now)
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now)

if updates:
self.async_updates(now, updates)
Expand All @@ -468,12 +471,12 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
# processsed.
new = False
if other_adds or address_adds:
new = self.cache.async_add_records(itertools.chain(address_adds, other_adds))
new = cache.async_add_records(itertools.chain(address_adds, other_adds))
# Removes are processed last since
# ServiceInfo could generate an un-needed query
# because the data was not yet populated.
if removes:
self.cache.async_remove_records(removes)
cache.async_remove_records(removes)
if updates:
self.async_updates_complete(new)

Expand Down

0 comments on commit aff625d

Please sign in to comment.