diff --git a/src/zeroconf/_protocol/incoming.pxd b/src/zeroconf/_protocol/incoming.pxd index d5620692..348cc667 100644 --- a/src/zeroconf/_protocol/incoming.pxd +++ b/src/zeroconf/_protocol/incoming.pxd @@ -54,7 +54,7 @@ cdef class DNSIncoming: cdef unsigned int _data_len cdef public cython.dict name_cache cdef public cython.list questions - cdef object _answers + cdef cython.list _answers cdef public object id cdef public cython.uint num_questions cdef public cython.uint num_answers @@ -78,8 +78,6 @@ cdef class DNSIncoming: cdef _initial_parse(self) - cdef _unpack(self, object unpacker, object length) - @cython.locals( end=cython.uint, length=cython.uint @@ -88,9 +86,6 @@ cdef class DNSIncoming: cdef _read_questions(self) - @cython.locals( - length=cython.uint - ) cdef bytes _read_character_string(self) cdef _read_string(self, unsigned int length) diff --git a/src/zeroconf/_protocol/incoming.py b/src/zeroconf/_protocol/incoming.py index 9505f466..32fdc47f 100644 --- a/src/zeroconf/_protocol/incoming.py +++ b/src/zeroconf/_protocol/incoming.py @@ -22,7 +22,7 @@ import struct import sys -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Set, Tuple, Union from .._dns import ( DNSAddress, @@ -194,10 +194,6 @@ def __repr__(self) -> str: ] ) - def _unpack(self, unpacker: Callable[[bytes, int], tuple], length: int) -> tuple: - self.offset += length - return unpacker(self.data, self.offset - length) - def _read_header(self) -> None: """Reads header portion of packet""" ( @@ -207,7 +203,8 @@ def _read_header(self) -> None: self.num_answers, self.num_authorities, self.num_additionals, - ) = self._unpack(UNPACK_6H, 12) + ) = UNPACK_6H(self.data) + self.offset += 12 def _read_questions(self) -> None: """Reads questions section of packet""" @@ -264,18 +261,24 @@ def _read_record( ) -> Optional[DNSRecord]: """Read known records types and skip unknown ones.""" if type_ == _TYPE_A: - return DNSAddress(domain, type_, class_, ttl, self._read_string(4), created=self.now) + dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(4)) + dns_address.created = self.now + return dns_address if type_ in (_TYPE_CNAME, _TYPE_PTR): return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now) if type_ == _TYPE_TXT: return DNSText(domain, type_, class_, ttl, self._read_string(length), self.now) if type_ == _TYPE_SRV: + priority, weight, port = UNPACK_3H(self.data, self.offset) + self.offset += 6 return DNSService( domain, type_, class_, ttl, - *cast(Tuple[int, int, int], self._unpack(UNPACK_3H, 6)), + priority, + weight, + port, self._read_name(), self.now, ) @@ -285,14 +288,15 @@ def _read_record( type_, class_, ttl, - self._read_character_string().decode('utf-8'), - self._read_character_string().decode('utf-8'), + self._read_character_string().decode('utf-8', 'replace'), + self._read_character_string().decode('utf-8', 'replace'), self.now, ) if type_ == _TYPE_AAAA: - return DNSAddress( - domain, type_, class_, ttl, self._read_string(16), created=self.now, scope_id=self.scope_id - ) + dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(16)) + dns_address.created = self.now + dns_address.scope_id = self.scope_id + return dns_address if type_ == _TYPE_NSEC: name_start = self.offset return DNSNsec( @@ -384,4 +388,4 @@ def _decode_labels_at_offset(self, off: int, labels: List[str], seen_pointers: S ) return off + DNS_COMPRESSION_POINTER_LEN - raise IncomingDecodeError("Corrupt packet received while decoding name from {self.source}") + raise IncomingDecodeError(f"Corrupt packet received while decoding name from {self.source}")