Skip to content

Commit

Permalink
feat: improve performance of constructing outgoing queries (#1267)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Sep 11, 2023
1 parent aed6391 commit 00c439a
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 20 deletions.
4 changes: 3 additions & 1 deletion src/zeroconf/_handlers/answers.pxd
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import cython

from .._dns cimport DNSRecord
from .._protocol.outgoing cimport DNSOutgoing


Expand All @@ -10,7 +11,8 @@ cdef object NAME_GETTER
cpdef construct_outgoing_multicast_answers(cython.dict answers)

cpdef construct_outgoing_unicast_answers(
cython.dict answers, object ucast_source, cython.list questions, object id_
cython.dict answers, bint ucast_source, cython.list questions, object id_
)

@cython.locals(answer=DNSRecord, additionals=cython.set, additional=DNSRecord)
cdef _add_answers_additionals(DNSOutgoing out, cython.dict answers)
3 changes: 2 additions & 1 deletion src/zeroconf/_handlers/answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsTy
# overall size of the outgoing response via name compression
for answer in sorted(answers, key=NAME_GETTER):
out.add_answer_at_time(answer, 0)
for additional in answers[answer]:
additionals = answers[answer]
for additional in additionals:
if additional not in sending:
out.add_additional_answer(additional)
sending.add(additional)
25 changes: 20 additions & 5 deletions src/zeroconf/_protocol/outgoing.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,25 @@ cdef object PACK_BYTE
cdef object PACK_SHORT
cdef object PACK_LONG

cdef object STATE_INIT
cdef object STATE_FINISHED

cdef object LOGGING_IS_ENABLED_FOR
cdef object LOGGING_DEBUG

cdef cython.tuple BYTE_TABLE

cdef class DNSOutgoing:

cdef public unsigned int flags
cdef public object finished
cdef public bint finished
cdef public object id
cdef public bint multicast
cdef public cython.list packets_data
cdef public cython.dict names
cdef public cython.list data
cdef public unsigned int size
cdef public object allow_long
cdef public bint allow_long
cdef public object state
cdef public cython.list questions
cdef public cython.list answers
Expand All @@ -48,18 +56,21 @@ cdef class DNSOutgoing:

cdef _write_int(self, object value)

cdef _write_question(self, DNSQuestion question)
cdef cython.bint _write_question(self, DNSQuestion question)

@cython.locals(
d=cython.bytes,
data_view=cython.list,
length=cython.uint
)
cdef _write_record(self, DNSRecord record, object now)
cdef cython.bint _write_record(self, DNSRecord record, object now)

cdef _write_record_class(self, DNSEntry record)

cdef _check_data_limit_or_rollback(self, object start_data_length, object start_size)
@cython.locals(
start_size_int=object
)
cdef cython.bint _check_data_limit_or_rollback(self, cython.uint start_data_length, cython.uint start_size)

cdef _write_questions_from_offset(self, object questions_offset)

Expand All @@ -74,6 +85,9 @@ cdef class DNSOutgoing:
@cython.locals(
labels=cython.list,
label=cython.str,
index=cython.uint,
start_size=cython.uint,
name_length=cython.uint,
)
cpdef write_name(self, cython.str name)

Expand Down Expand Up @@ -103,6 +117,7 @@ cdef class DNSOutgoing:

cpdef add_answer(self, DNSIncoming inp, DNSRecord record)

@cython.locals(now_float=cython.float)
cpdef add_answer_at_time(self, DNSRecord record, object now)

cpdef add_authorative_answer(self, DNSPointer record)
Expand Down
40 changes: 27 additions & 13 deletions src/zeroconf/_protocol/outgoing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,21 @@
PACK_SHORT = Struct('>H').pack
PACK_LONG = Struct('>L').pack

BYTE_TABLE = tuple(PACK_BYTE(i) for i in range(256))


class State(enum.Enum):
init = 0
finished = 1


STATE_INIT = State.init
STATE_FINISHED = State.finished

LOGGING_IS_ENABLED_FOR = log.isEnabledFor
LOGGING_DEBUG = logging.DEBUG


class DNSOutgoing:

"""Object representation of an outgoing packet"""
Expand Down Expand Up @@ -93,7 +102,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None:
self.size: int = _DNS_PACKET_HEADER_LEN
self.allow_long: bool = True

self.state = State.init
self.state = STATE_INIT

self.questions: List[DNSQuestion] = []
self.answers: List[Tuple[DNSRecord, float]] = []
Expand Down Expand Up @@ -137,7 +146,8 @@ def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None:

def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None:
"""Adds an answer if it does not expire by a certain time"""
if record is not None and (now == 0 or not record.is_expired(now)):
now_float = now
if record is not None and (now_float == 0 or not record.is_expired(now_float)):
self.answers.append((record, now))

def add_authorative_answer(self, record: DNSPointer) -> None:
Expand Down Expand Up @@ -207,7 +217,7 @@ def add_question_or_all_cache(

def _write_byte(self, value: int_) -> None:
"""Writes a single byte to the packet"""
self.data.append(PACK_BYTE(value))
self.data.append(BYTE_TABLE[value])
self.size += 1

def _insert_short_at_start(self, value: int_) -> None:
Expand Down Expand Up @@ -267,7 +277,7 @@ def write_name(self, name: str_) -> None:
"""

# split name into each label
name_length = None
name_length = 0
if name.endswith('.'):
name = name[: len(name) - 1]
labels = name.split('.')
Expand All @@ -276,14 +286,14 @@ def write_name(self, name: str_) -> None:
start_size = self.size
for count in range(len(labels)):
label = name if count == 0 else '.'.join(labels[count:])
index = self.names.get(label)
index = self.names.get(label, 0)
if index:
# If part of the name already exists in the packet,
# create a pointer to it
self._write_byte((index >> 8) | 0xC0)
self._write_byte(index & 0xFF)
return
if name_length is None:
if name_length == 0:
name_length = len(name.encode('utf-8'))
self.names[label] = start_size + name_length - len(label.encode('utf-8'))
self._write_utf(labels[count])
Expand All @@ -293,7 +303,8 @@ def write_name(self, name: str_) -> None:

def _write_question(self, question: DNSQuestion_) -> bool:
"""Writes a question to the packet"""
start_data_length, start_size = len(self.data), self.size
start_data_length = len(self.data)
start_size = self.size
self.write_name(question.name)
self.write_short(question.type)
self._write_record_class(question)
Expand All @@ -314,7 +325,8 @@ def _write_record(self, record: DNSRecord_, now: float_) -> bool:
"""Writes a record (answer, authoritative answer, additional) to
the packet. Returns True on success, or False if we did not
because the packet because the record does not fit."""
start_data_length, start_size = len(self.data), self.size
start_data_length = len(self.data)
start_size = self.size
self.write_name(record.name)
self.write_short(record.type)
self._write_record_class(record)
Expand All @@ -339,11 +351,13 @@ def _check_data_limit_or_rollback(self, start_data_length: int_, start_size: int
if self.size <= len_limit:
return True

log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit)
if LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG): # pragma: no branch
log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit)
del self.data[start_data_length:]
self.size = start_size

rollback_names = [name for name, idx in self.names.items() if idx >= start_size]
start_size_int = start_size
rollback_names = [name for name, idx in self.names.items() if idx >= start_size_int]
for name in rollback_names:
del self.names[name]
return False
Expand Down Expand Up @@ -395,7 +409,7 @@ def packets(self) -> List[bytes]:
return self._packets()

def _packets(self) -> List[bytes]:
if self.state == State.finished:
if self.state == STATE_FINISHED:
return self.packets_data

questions_offset = 0
Expand All @@ -404,7 +418,7 @@ def _packets(self) -> List[bytes]:
additional_offset = 0
# we have to at least write out the question
first_time = True
debug_enable = log.isEnabledFor(logging.DEBUG)
debug_enable = LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG)

while first_time or self._has_more_to_add(
questions_offset, answer_offset, authority_offset, additional_offset
Expand Down Expand Up @@ -476,5 +490,5 @@ def _packets(self) -> List[bytes]:
):
log.warning("packets() made no progress adding records; returning")
break
self.state = State.finished
self.state = STATE_FINISHED
return self.packets_data

0 comments on commit 00c439a

Please sign in to comment.