diff --git a/src/zeroconf/_protocol/incoming.pxd b/src/zeroconf/_protocol/incoming.pxd index 604b1e30..ebd09a0e 100644 --- a/src/zeroconf/_protocol/incoming.pxd +++ b/src/zeroconf/_protocol/incoming.pxd @@ -7,8 +7,6 @@ cdef cython.uint MAX_DNS_LABELS cdef cython.uint DNS_COMPRESSION_POINTER_LEN cdef cython.uint MAX_NAME_LENGTH -cdef object current_time_millis - cdef cython.uint _TYPE_A cdef cython.uint _TYPE_CNAME cdef cython.uint _TYPE_PTR @@ -43,6 +41,7 @@ from .._dns cimport ( DNSService, DNSText, ) +from .._utils.time cimport current_time_millis cdef class DNSIncoming: @@ -62,6 +61,7 @@ cdef class DNSIncoming: cdef public cython.uint num_additionals cdef public object valid cdef public object now + cdef cython.float _now_float cdef public object scope_id cdef public object source @@ -79,7 +79,9 @@ cdef class DNSIncoming: label_idx=cython.uint, length=cython.uint, link=cython.uint, - link_data=cython.uint + link_data=cython.uint, + link_py_int=object, + linked_labels=cython.list ) cdef _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers) @@ -95,9 +97,12 @@ cdef class DNSIncoming: cdef _read_questions(self) - cdef bytes _read_character_string(self) + @cython.locals( + length=cython.uint, + ) + cdef str _read_character_string(self) - cdef _read_string(self, unsigned int length) + cdef bytes _read_string(self, unsigned int length) @cython.locals( name_start=cython.uint diff --git a/src/zeroconf/_protocol/incoming.py b/src/zeroconf/_protocol/incoming.py index 352a6141..87d25816 100644 --- a/src/zeroconf/_protocol/incoming.py +++ b/src/zeroconf/_protocol/incoming.py @@ -89,6 +89,7 @@ class DNSIncoming: 'num_additionals', 'valid', 'now', + '_now_float', 'scope_id', 'source', ) @@ -116,6 +117,7 @@ def __init__( self.valid = False self._did_read_others = False self.now = now or current_time_millis() + self._now_float = self.now self.source = source self.scope_id = scope_id try: @@ -226,11 +228,13 @@ def _read_questions(self) -> None: question = DNSQuestion(name, type_, class_) self.questions.append(question) - def _read_character_string(self) -> bytes: + def _read_character_string(self) -> str: """Reads a character string from the packet""" length = self.data[self.offset] self.offset += 1 - return self._read_string(length) + info = self.data[self.offset : self.offset + length].decode('utf-8', 'replace') + self.offset += length + return info def _read_string(self, length: _int) -> bytes: """Reads a string of a given length from the packet""" @@ -273,7 +277,7 @@ def _read_record( """Read known records types and skip unknown ones.""" if type_ == _TYPE_A: dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(4)) - dns_address.created = self.now + dns_address.created = self._now_float return dns_address if type_ in (_TYPE_CNAME, _TYPE_PTR): return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now) @@ -299,13 +303,13 @@ def _read_record( type_, class_, ttl, - self._read_character_string().decode('utf-8', 'replace'), - self._read_character_string().decode('utf-8', 'replace'), + self._read_character_string(), + self._read_character_string(), self.now, ) if type_ == _TYPE_AAAA: dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(16)) - dns_address.created = self.now + dns_address.created = self._now_float dns_address.scope_id = self.scope_id return dns_address if type_ == _TYPE_NSEC: @@ -377,7 +381,7 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: # We have a DNS compression pointer link_data = self.data[off + 1] link = (length & 0x3F) * 256 + link_data - lint_int = int(link) + link_py_int = link if link > self._data_len: raise IncomingDecodeError( f"DNS compression pointer at {off} points to {link} beyond packet from {self.source}" @@ -386,16 +390,16 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: raise IncomingDecodeError( f"DNS compression pointer at {off} points to itself from {self.source}" ) - if lint_int in seen_pointers: + if link_py_int in seen_pointers: raise IncomingDecodeError( f"DNS compression pointer at {off} was seen again from {self.source}" ) - linked_labels = self.name_cache.get(lint_int) + linked_labels = self.name_cache.get(link_py_int) if not linked_labels: linked_labels = [] - seen_pointers.add(lint_int) + seen_pointers.add(link_py_int) self._decode_labels_at_offset(link, linked_labels, seen_pointers) - self.name_cache[lint_int] = linked_labels + self.name_cache[link_py_int] = linked_labels labels.extend(linked_labels) if len(labels) > MAX_DNS_LABELS: raise IncomingDecodeError(