Skip to content

Commit

Permalink
feat: speed up decoding incoming packets (#1256)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Sep 6, 2023
1 parent aebabd9 commit ac081cf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
15 changes: 10 additions & 5 deletions src/zeroconf/_protocol/incoming.pxd
Expand Up @@ -7,8 +7,6 @@ cdef cython.uint MAX_DNS_LABELS
cdef cython.uint DNS_COMPRESSION_POINTER_LEN
cdef cython.uint MAX_NAME_LENGTH

cdef object current_time_millis

cdef cython.uint _TYPE_A
cdef cython.uint _TYPE_CNAME
cdef cython.uint _TYPE_PTR
Expand Down Expand Up @@ -43,6 +41,7 @@ from .._dns cimport (
DNSService,
DNSText,
)
from .._utils.time cimport current_time_millis


cdef class DNSIncoming:
Expand All @@ -62,6 +61,7 @@ cdef class DNSIncoming:
cdef public cython.uint num_additionals
cdef public object valid
cdef public object now
cdef cython.float _now_float
cdef public object scope_id
cdef public object source

Expand All @@ -79,7 +79,9 @@ cdef class DNSIncoming:
label_idx=cython.uint,
length=cython.uint,
link=cython.uint,
link_data=cython.uint
link_data=cython.uint,
link_py_int=object,
linked_labels=cython.list
)
cdef _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers)

Expand All @@ -95,9 +97,12 @@ cdef class DNSIncoming:

cdef _read_questions(self)

cdef bytes _read_character_string(self)
@cython.locals(
length=cython.uint,
)
cdef str _read_character_string(self)

cdef _read_string(self, unsigned int length)
cdef bytes _read_string(self, unsigned int length)

@cython.locals(
name_start=cython.uint
Expand Down
26 changes: 15 additions & 11 deletions src/zeroconf/_protocol/incoming.py
Expand Up @@ -89,6 +89,7 @@ class DNSIncoming:
'num_additionals',
'valid',
'now',
'_now_float',
'scope_id',
'source',
)
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
self.valid = False
self._did_read_others = False
self.now = now or current_time_millis()
self._now_float = self.now
self.source = source
self.scope_id = scope_id
try:
Expand Down Expand Up @@ -226,11 +228,13 @@ def _read_questions(self) -> None:
question = DNSQuestion(name, type_, class_)
self.questions.append(question)

def _read_character_string(self) -> bytes:
def _read_character_string(self) -> str:
"""Reads a character string from the packet"""
length = self.data[self.offset]
self.offset += 1
return self._read_string(length)
info = self.data[self.offset : self.offset + length].decode('utf-8', 'replace')
self.offset += length
return info

def _read_string(self, length: _int) -> bytes:
"""Reads a string of a given length from the packet"""
Expand Down Expand Up @@ -273,7 +277,7 @@ def _read_record(
"""Read known records types and skip unknown ones."""
if type_ == _TYPE_A:
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(4))
dns_address.created = self.now
dns_address.created = self._now_float
return dns_address
if type_ in (_TYPE_CNAME, _TYPE_PTR):
return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now)
Expand All @@ -299,13 +303,13 @@ def _read_record(
type_,
class_,
ttl,
self._read_character_string().decode('utf-8', 'replace'),
self._read_character_string().decode('utf-8', 'replace'),
self._read_character_string(),
self._read_character_string(),
self.now,
)
if type_ == _TYPE_AAAA:
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(16))
dns_address.created = self.now
dns_address.created = self._now_float
dns_address.scope_id = self.scope_id
return dns_address
if type_ == _TYPE_NSEC:
Expand Down Expand Up @@ -377,7 +381,7 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers:
# We have a DNS compression pointer
link_data = self.data[off + 1]
link = (length & 0x3F) * 256 + link_data
lint_int = int(link)
link_py_int = link
if link > self._data_len:
raise IncomingDecodeError(
f"DNS compression pointer at {off} points to {link} beyond packet from {self.source}"
Expand All @@ -386,16 +390,16 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers:
raise IncomingDecodeError(
f"DNS compression pointer at {off} points to itself from {self.source}"
)
if lint_int in seen_pointers:
if link_py_int in seen_pointers:
raise IncomingDecodeError(
f"DNS compression pointer at {off} was seen again from {self.source}"
)
linked_labels = self.name_cache.get(lint_int)
linked_labels = self.name_cache.get(link_py_int)
if not linked_labels:
linked_labels = []
seen_pointers.add(lint_int)
seen_pointers.add(link_py_int)
self._decode_labels_at_offset(link, linked_labels, seen_pointers)
self.name_cache[lint_int] = linked_labels
self.name_cache[link_py_int] = linked_labels
labels.extend(linked_labels)
if len(labels) > MAX_DNS_LABELS:
raise IncomingDecodeError(
Expand Down

0 comments on commit ac081cf

Please sign in to comment.