Skip to content

Commit

Permalink
feat: speed up outgoing packet writer (#1313)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Nov 13, 2023
1 parent 9caeabb commit 55cf4cc
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 36 deletions.
3 changes: 2 additions & 1 deletion bench/outgoing.py
Expand Up @@ -158,9 +158,10 @@ def generate_packets() -> DNSOutgoing:


def make_outgoing_message() -> None:
out.packets()
out.state = State.init.value
out.finished = False
out.packets()
out._reset_for_next_packet()


count = 100000
Expand Down
47 changes: 30 additions & 17 deletions src/zeroconf/_protocol/outgoing.pxd
Expand Up @@ -15,8 +15,11 @@ cdef cython.uint _FLAGS_TC
cdef cython.uint _MAX_MSG_ABSOLUTE
cdef cython.uint _MAX_MSG_TYPICAL


cdef bint TYPE_CHECKING

cdef unsigned int SHORT_CACHE_MAX

cdef object PACK_BYTE
cdef object PACK_SHORT
cdef object PACK_LONG
Expand All @@ -28,6 +31,7 @@ cdef object LOGGING_IS_ENABLED_FOR
cdef object LOGGING_DEBUG

cdef cython.tuple BYTE_TABLE
cdef cython.tuple SHORT_LOOKUP

cdef class DNSOutgoing:

Expand All @@ -46,13 +50,15 @@ cdef class DNSOutgoing:
cdef public cython.list authorities
cdef public cython.list additionals

cdef _reset_for_next_packet(self)
cpdef _reset_for_next_packet(self)

cdef _write_byte(self, object value)
cdef _write_byte(self, cython.uint value)

cdef _insert_short_at_start(self, object value)
cdef void _insert_short_at_start(self, unsigned int value)

cdef _replace_short(self, object index, object value)
cdef _replace_short(self, cython.uint index, cython.uint value)

cdef _get_short(self, cython.uint value)

cdef _write_int(self, object value)

Expand All @@ -61,24 +67,29 @@ cdef class DNSOutgoing:
@cython.locals(
d=cython.bytes,
data_view=cython.list,
index=cython.uint,
length=cython.uint
)
cdef cython.bint _write_record(self, DNSRecord record, object now)

@cython.locals(class_=cython.uint)
cdef _write_record_class(self, DNSEntry record)

@cython.locals(
start_size_int=object
)
cdef cython.bint _check_data_limit_or_rollback(self, cython.uint start_data_length, cython.uint start_size)

cdef _write_questions_from_offset(self, object questions_offset)
@cython.locals(questions_written=cython.uint)
cdef cython.uint _write_questions_from_offset(self, unsigned int questions_offset)

cdef _write_answers_from_offset(self, object answer_offset)
@cython.locals(answers_written=cython.uint)
cdef cython.uint _write_answers_from_offset(self, unsigned int answer_offset)

cdef _write_records_from_offset(self, cython.list records, object offset)
@cython.locals(records_written=cython.uint)
cdef cython.uint _write_records_from_offset(self, cython.list records, unsigned int offset)

cdef _has_more_to_add(self, object questions_offset, object answer_offset, object authority_offset, object additional_offset)
cdef bint _has_more_to_add(self, unsigned int questions_offset, unsigned int answer_offset, unsigned int authority_offset, unsigned int additional_offset)

cdef _write_ttl(self, DNSRecord record, object now)

Expand All @@ -93,23 +104,25 @@ cdef class DNSOutgoing:

cdef _write_link_to_name(self, unsigned int index)

cpdef write_short(self, object value)
cpdef write_short(self, cython.uint value)

cpdef write_string(self, cython.bytes value)

@cython.locals(utfstr=bytes)
cpdef _write_utf(self, cython.str value)

@cython.locals(
debug_enable=bint,
made_progress=bint,
questions_offset=object,
answer_offset=object,
authority_offset=object,
additional_offset=object,
questions_written=object,
answers_written=object,
authorities_written=object,
additionals_written=object,
has_more_to_add=bint,
questions_offset="unsigned int",
answer_offset="unsigned int",
authority_offset="unsigned int",
additional_offset="unsigned int",
questions_written="unsigned int",
answers_written="unsigned int",
authorities_written="unsigned int",
additionals_written="unsigned int",
)
cpdef packets(self)

Expand Down
46 changes: 28 additions & 18 deletions src/zeroconf/_protocol/outgoing.py
Expand Up @@ -53,7 +53,10 @@
PACK_SHORT = Struct('>H').pack
PACK_LONG = Struct('>L').pack

SHORT_CACHE_MAX = 128

BYTE_TABLE = tuple(PACK_BYTE(i) for i in range(256))
SHORT_LOOKUP = tuple(PACK_SHORT(i) for i in range(SHORT_CACHE_MAX))


class State(enum.Enum):
Expand Down Expand Up @@ -220,17 +223,21 @@ def _write_byte(self, value: int_) -> None:
self.data.append(BYTE_TABLE[value])
self.size += 1

def _get_short(self, value: int_) -> bytes:
"""Convert an unsigned short to 2 bytes."""
return SHORT_LOOKUP[value] if value < SHORT_CACHE_MAX else PACK_SHORT(value)

def _insert_short_at_start(self, value: int_) -> None:
"""Inserts an unsigned short at the start of the packet"""
self.data.insert(0, PACK_SHORT(value))
self.data.insert(0, self._get_short(value))

def _replace_short(self, index: int_, value: int_) -> None:
"""Replaces an unsigned short in a certain position in the packet"""
self.data[index] = PACK_SHORT(value)
self.data[index] = self._get_short(value)

def write_short(self, value: int_) -> None:
"""Writes an unsigned short to the packet"""
self.data.append(PACK_SHORT(value))
self.data.append(self._get_short(value))
self.size += 2

def _write_int(self, value: Union[float, int]) -> None:
Expand Down Expand Up @@ -323,10 +330,11 @@ def _write_question(self, question: DNSQuestion_) -> bool:

def _write_record_class(self, record: Union[DNSQuestion_, DNSRecord_]) -> None:
"""Write out the record class including the unique/unicast (QU) bit."""
if record.unique and self.multicast:
self.write_short(record.class_ | _CLASS_UNIQUE)
class_ = record.class_
if record.unique is True and self.multicast is True:
self.write_short(class_ | _CLASS_UNIQUE)
else:
self.write_short(record.class_)
self.write_short(class_)

def _write_ttl(self, record: DNSRecord_, now: float_) -> None:
"""Write out the record ttl."""
Expand Down Expand Up @@ -417,21 +425,20 @@ def packets(self) -> List[bytes]:
will be written out to a single oversized packet no more than
_MAX_MSG_ABSOLUTE in length (and hence will be subject to IP
fragmentation potentially)."""
packets_data = self.packets_data

if self.state == STATE_FINISHED:
return self.packets_data
return packets_data

questions_offset = 0
answer_offset = 0
authority_offset = 0
additional_offset = 0
# we have to at least write out the question
first_time = True
debug_enable = LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG)
debug_enable = LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG) is True
has_more_to_add = True

while first_time or self._has_more_to_add(
questions_offset, answer_offset, authority_offset, additional_offset
):
first_time = False
while has_more_to_add:
if debug_enable:
log.debug(
"offsets = questions=%d, answers=%d, authorities=%d, additionals=%d",
Expand Down Expand Up @@ -473,9 +480,11 @@ def packets(self) -> List[bytes]:
additional_offset,
)

if self.is_query() and self._has_more_to_add(
has_more_to_add = self._has_more_to_add(
questions_offset, answer_offset, authority_offset, additional_offset
):
)

if has_more_to_add and self.is_query():
# https://datatracker.ietf.org/doc/html/rfc6762#section-7.2
if debug_enable: # pragma: no branch
log.debug("Setting TC flag")
Expand All @@ -488,7 +497,7 @@ def packets(self) -> List[bytes]:
else:
self._insert_short_at_start(self.id)

self.packets_data.append(b''.join(self.data))
packets_data.append(b''.join(self.data))

if not made_progress:
# Generating an empty packet is not a desirable outcome, but currently
Expand All @@ -498,7 +507,8 @@ def packets(self) -> List[bytes]:
log.warning("packets() made no progress adding records; returning")
break

self._reset_for_next_packet()
if has_more_to_add:
self._reset_for_next_packet()

self.state = STATE_FINISHED
return self.packets_data
return packets_data

0 comments on commit 55cf4cc

Please sign in to comment.