Skip to content

Commit

Permalink
feat: speed up processing incoming records (#1206)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Aug 2, 2023
1 parent 1310f12 commit 126849c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
26 changes: 21 additions & 5 deletions src/zeroconf/_dns.pxd
@@ -1,6 +1,8 @@

import cython

from ._protocol.incoming cimport DNSIncoming


cdef object _LEN_BYTE
cdef object _LEN_SHORT
Expand All @@ -9,9 +11,9 @@ cdef object _LEN_INT
cdef object _NAME_COMPRESSION_MIN_SIZE
cdef object _BASE_MAX_SIZE

cdef object _EXPIRE_FULL_TIME_MS
cdef object _EXPIRE_STALE_TIME_MS
cdef object _RECENT_TIME_MS
cdef cython.uint _EXPIRE_FULL_TIME_MS
cdef cython.uint _EXPIRE_STALE_TIME_MS
cdef cython.uint _RECENT_TIME_MS

cdef object _CLASS_UNIQUE
cdef object _CLASS_MASK
Expand All @@ -34,11 +36,25 @@ cdef class DNSQuestion(DNSEntry):

cdef class DNSRecord(DNSEntry):

cdef public object ttl
cdef public object created
cdef public cython.float ttl
cdef public cython.float created

cdef _suppressed_by_answer(self, DNSRecord answer)

@cython.locals(
answers=cython.list,
)
cpdef suppressed_by(self, DNSIncoming msg)

cpdef get_expiration_time(self, cython.uint percent)

cpdef is_expired(self, cython.float now)

cpdef is_stale(self, cython.float now)

cpdef is_recent(self, cython.float now)

cpdef reset_ttl(self, DNSRecord other)

cdef class DNSAddress(DNSRecord):

Expand Down
20 changes: 13 additions & 7 deletions src/zeroconf/_dns.py
Expand Up @@ -40,6 +40,8 @@
_EXPIRE_STALE_TIME_MS = 500
_RECENT_TIME_MS = 250

_float = float
_int = int

if TYPE_CHECKING:
from ._protocol.incoming import DNSIncoming
Expand Down Expand Up @@ -172,32 +174,36 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use
def suppressed_by(self, msg: 'DNSIncoming') -> bool:
"""Returns true if any answer in a message can suffice for the
information held in this record."""
return any(self._suppressed_by_answer(record) for record in msg.answers)
answers = msg.answers
for record in answers:
if self._suppressed_by_answer(record):
return True
return False

def _suppressed_by_answer(self, other) -> bool: # type: ignore[no-untyped-def]
def _suppressed_by_answer(self, other: 'DNSRecord') -> bool:
"""Returns true if another record has same name, type and class,
and if its TTL is at least half of this record's."""
return self == other and other.ttl > (self.ttl / 2)

def get_expiration_time(self, percent: int) -> float:
def get_expiration_time(self, percent: _int) -> float:
"""Returns the time at which this record will have expired
by a certain percentage."""
return self.created + (percent * self.ttl * 10)

# TODO: Switch to just int here
def get_remaining_ttl(self, now: float) -> Union[int, float]:
def get_remaining_ttl(self, now: _float) -> Union[int, float]:
"""Returns the remaining TTL in seconds."""
return max(0, millis_to_seconds((self.created + (_EXPIRE_FULL_TIME_MS * self.ttl)) - now))

def is_expired(self, now: float) -> bool:
def is_expired(self, now: _float) -> bool:
"""Returns true if this record has expired."""
return self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) <= now

def is_stale(self, now: float) -> bool:
def is_stale(self, now: _float) -> bool:
"""Returns true if this record is at least half way expired."""
return self.created + (_EXPIRE_STALE_TIME_MS * self.ttl) <= now

def is_recent(self, now: float) -> bool:
def is_recent(self, now: _float) -> bool:
"""Returns true if the record more than one quarter of its TTL remaining."""
return self.created + (_RECENT_TIME_MS * self.ttl) > now

Expand Down

0 comments on commit 126849c

Please sign in to comment.