Skip to content

Commit

Permalink
feat: speed up incoming packet reader (#1314)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Nov 15, 2023
1 parent bfe4c24 commit 0d60b61
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 54 deletions.
1 change: 0 additions & 1 deletion src/zeroconf/_dns.py
Expand Up @@ -244,7 +244,6 @@ def __init__(
class_: int,
ttl: int,
address: bytes,
*,
scope_id: Optional[int] = None,
created: Optional[float] = None,
) -> None:
Expand Down
48 changes: 23 additions & 25 deletions src/zeroconf/_protocol/incoming.pxd
Expand Up @@ -21,11 +21,6 @@ cdef cython.uint _FLAGS_TC
cdef cython.uint _FLAGS_QR_QUERY
cdef cython.uint _FLAGS_QR_RESPONSE

cdef object UNPACK_3H
cdef object UNPACK_6H
cdef object UNPACK_HH
cdef object UNPACK_HHiH

cdef object DECODE_EXCEPTIONS

cdef object IncomingDecodeError
Expand Down Expand Up @@ -62,7 +57,6 @@ cdef class DNSIncoming:
cdef cython.uint _num_additionals
cdef public bint valid
cdef public object now
cdef cython.float _now_float
cdef public object scope_id
cdef public object source
cdef bint _has_qu_question
Expand All @@ -81,49 +75,53 @@ cdef class DNSIncoming:
cpdef bint is_response(self)

@cython.locals(
off=cython.uint,
label_idx=cython.uint,
length=cython.uint,
link=cython.uint,
link_data=cython.uint,
off="unsigned int",
label_idx="unsigned int",
length="unsigned int",
link="unsigned int",
link_data="unsigned int",
link_py_int=object,
linked_labels=cython.list
)
cdef cython.uint _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers)
cdef unsigned int _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers)

@cython.locals(offset="unsigned int")
cdef _read_header(self)

cdef _initial_parse(self)

@cython.locals(
end=cython.uint,
length=cython.uint
end="unsigned int",
length="unsigned int",
offset="unsigned int"
)
cdef _read_others(self)

@cython.locals(offset="unsigned int")
cdef _read_questions(self)

@cython.locals(
length=cython.uint,
length="unsigned int",
)
cdef str _read_character_string(self)

cdef bytes _read_string(self, unsigned int length)

@cython.locals(
name_start=cython.uint
name_start="unsigned int",
offset="unsigned int"
)
cdef _read_record(self, object domain, unsigned int type_, object class_, object ttl, unsigned int length)
cdef _read_record(self, object domain, unsigned int type_, unsigned int class_, unsigned int ttl, unsigned int length)

@cython.locals(
offset=cython.uint,
offset_plus_one=cython.uint,
offset_plus_two=cython.uint,
window=cython.uint,
bit=cython.uint,
byte=cython.uint,
i=cython.uint,
bitmap_length=cython.uint,
offset="unsigned int",
offset_plus_one="unsigned int",
offset_plus_two="unsigned int",
window="unsigned int",
bit="unsigned int",
byte="unsigned int",
i="unsigned int",
bitmap_length="unsigned int",
)
cdef _read_bitmap(self, unsigned int end)

Expand Down
64 changes: 36 additions & 28 deletions src/zeroconf/_protocol/incoming.py
Expand Up @@ -60,10 +60,6 @@

DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError)

UNPACK_3H = struct.Struct(b'!3H').unpack_from
UNPACK_6H = struct.Struct(b'!6H').unpack_from
UNPACK_HH = struct.Struct(b'!HH').unpack_from
UNPACK_HHiH = struct.Struct(b'!HHiH').unpack_from

_seen_logs: Dict[str, Union[int, tuple]] = {}
_str = str
Expand All @@ -90,7 +86,6 @@ class DNSIncoming:
'_num_additionals',
'valid',
'now',
'_now_float',
'scope_id',
'source',
'_has_qu_question',
Expand Down Expand Up @@ -120,7 +115,6 @@ def __init__(
self.valid = False
self._did_read_others = False
self.now = now or current_time_millis()
self._now_float = self.now
self.source = source
self.scope_id = scope_id
self._has_qu_question = False
Expand Down Expand Up @@ -230,23 +224,28 @@ def __repr__(self) -> str:

def _read_header(self) -> None:
"""Reads header portion of packet"""
(
self.id,
self.flags,
self._num_questions,
self._num_answers,
self._num_authorities,
self._num_additionals,
) = UNPACK_6H(self.data)
view = self.view
offset = self.offset
self.offset += 12
# The header has 6 unsigned shorts in network order
self.id = view[offset] << 8 | view[offset + 1]
self.flags = view[offset + 2] << 8 | view[offset + 3]
self._num_questions = view[offset + 4] << 8 | view[offset + 5]
self._num_answers = view[offset + 6] << 8 | view[offset + 7]
self._num_authorities = view[offset + 8] << 8 | view[offset + 9]
self._num_additionals = view[offset + 10] << 8 | view[offset + 11]

def _read_questions(self) -> None:
"""Reads questions section of packet"""
view = self.view
questions = self._questions
for _ in range(self._num_questions):
name = self._read_name()
type_, class_ = UNPACK_HH(self.data, self.offset)
offset = self.offset
self.offset += 4
# The question has 2 unsigned shorts in network order
type_ = view[offset] << 8 | view[offset + 1]
class_ = view[offset + 2] << 8 | view[offset + 3]
question = DNSQuestion(name, type_, class_)
if question.unique: # QU questions use the same bit as unique
self._has_qu_question = True
Expand All @@ -270,11 +269,18 @@ def _read_others(self) -> None:
"""Reads the answers, authorities and additionals section of the
packet"""
self._did_read_others = True
view = self.view
n = self._num_answers + self._num_authorities + self._num_additionals
for _ in range(n):
domain = self._read_name()
type_, class_, ttl, length = UNPACK_HHiH(self.data, self.offset)
offset = self.offset
self.offset += 10
# type_, class_ and length are unsigned shorts in network order
# ttl is an unsigned long in network order https://www.rfc-editor.org/errata/eid2130
type_ = view[offset] << 8 | view[offset + 1]
class_ = view[offset + 2] << 8 | view[offset + 3]
ttl = view[offset + 4] << 24 | view[offset + 5] << 16 | view[offset + 6] << 8 | view[offset + 7]
length = view[offset + 8] << 8 | view[offset + 9]
end = self.offset + length
rec = None
try:
Expand All @@ -300,16 +306,19 @@ def _read_record(
) -> Optional[DNSRecord]:
"""Read known records types and skip unknown ones."""
if type_ == _TYPE_A:
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(4))
dns_address.created = self._now_float
return dns_address
return DNSAddress(domain, type_, class_, ttl, self._read_string(4), None, self.now)
if type_ in (_TYPE_CNAME, _TYPE_PTR):
return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now)
if type_ == _TYPE_TXT:
return DNSText(domain, type_, class_, ttl, self._read_string(length), self.now)
if type_ == _TYPE_SRV:
priority, weight, port = UNPACK_3H(self.data, self.offset)
view = self.view
offset = self.offset
self.offset += 6
# The SRV record has 3 unsigned shorts in network order
priority = view[offset] << 8 | view[offset + 1]
weight = view[offset + 2] << 8 | view[offset + 3]
port = view[offset + 4] << 8 | view[offset + 5]
return DNSService(
domain,
type_,
Expand All @@ -332,10 +341,7 @@ def _read_record(
self.now,
)
if type_ == _TYPE_AAAA:
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(16))
dns_address.created = self._now_float
dns_address.scope_id = self.scope_id
return dns_address
return DNSAddress(domain, type_, class_, ttl, self._read_string(16), self.scope_id, self.now)
if type_ == _TYPE_NSEC:
name_start = self.offset
return DNSNsec(
Expand All @@ -356,12 +362,13 @@ def _read_record(
def _read_bitmap(self, end: _int) -> List[int]:
"""Reads an NSEC bitmap from the packet."""
rdtypes = []
view = self.view
while self.offset < end:
offset = self.offset
offset_plus_one = offset + 1
offset_plus_two = offset + 2
window = self.view[offset]
bitmap_length = self.view[offset_plus_one]
window = view[offset]
bitmap_length = view[offset_plus_one]
bitmap_end = offset_plus_two + bitmap_length
for i, byte in enumerate(self.data[offset_plus_two:bitmap_end]):
for bit in range(0, 8):
Expand All @@ -386,8 +393,9 @@ def _read_name(self) -> str:

def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: Set[int]) -> int:
# This is a tight loop that is called frequently, small optimizations can make a difference.
view = self.view
while off < self._data_len:
length = self.view[off]
length = view[off]
if length == 0:
return off + DNS_COMPRESSION_HEADER_LEN

Expand All @@ -403,7 +411,7 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers:
)

# We have a DNS compression pointer
link_data = self.view[off + 1]
link_data = view[off + 1]
link = (length & 0x3F) * 256 + link_data
link_py_int = link
if link > self._data_len:
Expand Down

0 comments on commit 0d60b61

Please sign in to comment.