diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index 3e9f074a..4ca429a8 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -244,7 +244,6 @@ def __init__( class_: int, ttl: int, address: bytes, - *, scope_id: Optional[int] = None, created: Optional[float] = None, ) -> None: diff --git a/src/zeroconf/_protocol/incoming.pxd b/src/zeroconf/_protocol/incoming.pxd index 3bfc57f2..07ae6e78 100644 --- a/src/zeroconf/_protocol/incoming.pxd +++ b/src/zeroconf/_protocol/incoming.pxd @@ -21,11 +21,6 @@ cdef cython.uint _FLAGS_TC cdef cython.uint _FLAGS_QR_QUERY cdef cython.uint _FLAGS_QR_RESPONSE -cdef object UNPACK_3H -cdef object UNPACK_6H -cdef object UNPACK_HH -cdef object UNPACK_HHiH - cdef object DECODE_EXCEPTIONS cdef object IncomingDecodeError @@ -62,7 +57,6 @@ cdef class DNSIncoming: cdef cython.uint _num_additionals cdef public bint valid cdef public object now - cdef cython.float _now_float cdef public object scope_id cdef public object source cdef bint _has_qu_question @@ -81,49 +75,53 @@ cdef class DNSIncoming: cpdef bint is_response(self) @cython.locals( - off=cython.uint, - label_idx=cython.uint, - length=cython.uint, - link=cython.uint, - link_data=cython.uint, + off="unsigned int", + label_idx="unsigned int", + length="unsigned int", + link="unsigned int", + link_data="unsigned int", link_py_int=object, linked_labels=cython.list ) - cdef cython.uint _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers) + cdef unsigned int _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers) + @cython.locals(offset="unsigned int") cdef _read_header(self) cdef _initial_parse(self) @cython.locals( - end=cython.uint, - length=cython.uint + end="unsigned int", + length="unsigned int", + offset="unsigned int" ) cdef _read_others(self) + @cython.locals(offset="unsigned int") cdef _read_questions(self) @cython.locals( - length=cython.uint, + length="unsigned int", ) cdef str _read_character_string(self) cdef bytes _read_string(self, unsigned int length) @cython.locals( - name_start=cython.uint + name_start="unsigned int", + offset="unsigned int" ) - cdef _read_record(self, object domain, unsigned int type_, object class_, object ttl, unsigned int length) + cdef _read_record(self, object domain, unsigned int type_, unsigned int class_, unsigned int ttl, unsigned int length) @cython.locals( - offset=cython.uint, - offset_plus_one=cython.uint, - offset_plus_two=cython.uint, - window=cython.uint, - bit=cython.uint, - byte=cython.uint, - i=cython.uint, - bitmap_length=cython.uint, + offset="unsigned int", + offset_plus_one="unsigned int", + offset_plus_two="unsigned int", + window="unsigned int", + bit="unsigned int", + byte="unsigned int", + i="unsigned int", + bitmap_length="unsigned int", ) cdef _read_bitmap(self, unsigned int end) diff --git a/src/zeroconf/_protocol/incoming.py b/src/zeroconf/_protocol/incoming.py index fd5fafb6..9e208b63 100644 --- a/src/zeroconf/_protocol/incoming.py +++ b/src/zeroconf/_protocol/incoming.py @@ -60,10 +60,6 @@ DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError) -UNPACK_3H = struct.Struct(b'!3H').unpack_from -UNPACK_6H = struct.Struct(b'!6H').unpack_from -UNPACK_HH = struct.Struct(b'!HH').unpack_from -UNPACK_HHiH = struct.Struct(b'!HHiH').unpack_from _seen_logs: Dict[str, Union[int, tuple]] = {} _str = str @@ -90,7 +86,6 @@ class DNSIncoming: '_num_additionals', 'valid', 'now', - '_now_float', 'scope_id', 'source', '_has_qu_question', @@ -120,7 +115,6 @@ def __init__( self.valid = False self._did_read_others = False self.now = now or current_time_millis() - self._now_float = self.now self.source = source self.scope_id = scope_id self._has_qu_question = False @@ -230,23 +224,28 @@ def __repr__(self) -> str: def _read_header(self) -> None: """Reads header portion of packet""" - ( - self.id, - self.flags, - self._num_questions, - self._num_answers, - self._num_authorities, - self._num_additionals, - ) = UNPACK_6H(self.data) + view = self.view + offset = self.offset self.offset += 12 + # The header has 6 unsigned shorts in network order + self.id = view[offset] << 8 | view[offset + 1] + self.flags = view[offset + 2] << 8 | view[offset + 3] + self._num_questions = view[offset + 4] << 8 | view[offset + 5] + self._num_answers = view[offset + 6] << 8 | view[offset + 7] + self._num_authorities = view[offset + 8] << 8 | view[offset + 9] + self._num_additionals = view[offset + 10] << 8 | view[offset + 11] def _read_questions(self) -> None: """Reads questions section of packet""" + view = self.view questions = self._questions for _ in range(self._num_questions): name = self._read_name() - type_, class_ = UNPACK_HH(self.data, self.offset) + offset = self.offset self.offset += 4 + # The question has 2 unsigned shorts in network order + type_ = view[offset] << 8 | view[offset + 1] + class_ = view[offset + 2] << 8 | view[offset + 3] question = DNSQuestion(name, type_, class_) if question.unique: # QU questions use the same bit as unique self._has_qu_question = True @@ -270,11 +269,18 @@ def _read_others(self) -> None: """Reads the answers, authorities and additionals section of the packet""" self._did_read_others = True + view = self.view n = self._num_answers + self._num_authorities + self._num_additionals for _ in range(n): domain = self._read_name() - type_, class_, ttl, length = UNPACK_HHiH(self.data, self.offset) + offset = self.offset self.offset += 10 + # type_, class_ and length are unsigned shorts in network order + # ttl is an unsigned long in network order https://www.rfc-editor.org/errata/eid2130 + type_ = view[offset] << 8 | view[offset + 1] + class_ = view[offset + 2] << 8 | view[offset + 3] + ttl = view[offset + 4] << 24 | view[offset + 5] << 16 | view[offset + 6] << 8 | view[offset + 7] + length = view[offset + 8] << 8 | view[offset + 9] end = self.offset + length rec = None try: @@ -300,16 +306,19 @@ def _read_record( ) -> Optional[DNSRecord]: """Read known records types and skip unknown ones.""" if type_ == _TYPE_A: - dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(4)) - dns_address.created = self._now_float - return dns_address + return DNSAddress(domain, type_, class_, ttl, self._read_string(4), None, self.now) if type_ in (_TYPE_CNAME, _TYPE_PTR): return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now) if type_ == _TYPE_TXT: return DNSText(domain, type_, class_, ttl, self._read_string(length), self.now) if type_ == _TYPE_SRV: - priority, weight, port = UNPACK_3H(self.data, self.offset) + view = self.view + offset = self.offset self.offset += 6 + # The SRV record has 3 unsigned shorts in network order + priority = view[offset] << 8 | view[offset + 1] + weight = view[offset + 2] << 8 | view[offset + 3] + port = view[offset + 4] << 8 | view[offset + 5] return DNSService( domain, type_, @@ -332,10 +341,7 @@ def _read_record( self.now, ) if type_ == _TYPE_AAAA: - dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(16)) - dns_address.created = self._now_float - dns_address.scope_id = self.scope_id - return dns_address + return DNSAddress(domain, type_, class_, ttl, self._read_string(16), self.scope_id, self.now) if type_ == _TYPE_NSEC: name_start = self.offset return DNSNsec( @@ -356,12 +362,13 @@ def _read_record( def _read_bitmap(self, end: _int) -> List[int]: """Reads an NSEC bitmap from the packet.""" rdtypes = [] + view = self.view while self.offset < end: offset = self.offset offset_plus_one = offset + 1 offset_plus_two = offset + 2 - window = self.view[offset] - bitmap_length = self.view[offset_plus_one] + window = view[offset] + bitmap_length = view[offset_plus_one] bitmap_end = offset_plus_two + bitmap_length for i, byte in enumerate(self.data[offset_plus_two:bitmap_end]): for bit in range(0, 8): @@ -386,8 +393,9 @@ def _read_name(self) -> str: def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: Set[int]) -> int: # This is a tight loop that is called frequently, small optimizations can make a difference. + view = self.view while off < self._data_len: - length = self.view[off] + length = view[off] if length == 0: return off + DNS_COMPRESSION_HEADER_LEN @@ -403,7 +411,7 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: ) # We have a DNS compression pointer - link_data = self.view[off + 1] + link_data = view[off + 1] link = (length & 0x3F) * 256 + link_data link_py_int = link if link > self._data_len: