Skip to content

Commit

Permalink
feat: speed up responding to queries (#1275)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Sep 25, 2023
1 parent aa8fd1a commit 3c6b18c
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 81 deletions.
2 changes: 1 addition & 1 deletion src/zeroconf/_dns.pxd
Expand Up @@ -125,7 +125,7 @@ cdef class DNSNsec(DNSRecord):

cdef class DNSRRSet:

cdef cython.list _record_sets
cdef cython.list _records
cdef cython.dict _lookup

@cython.locals(other=DNSRecord)
Expand Down
25 changes: 11 additions & 14 deletions src/zeroconf/_dns.py
Expand Up @@ -174,7 +174,7 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use
def suppressed_by(self, msg: 'DNSIncoming') -> bool:
"""Returns true if any answer in a message can suffice for the
information held in this record."""
answers = msg.answers
answers = msg.answers()
for record in answers:
if self._suppressed_by_answer(record):
return True
Expand Down Expand Up @@ -521,37 +521,34 @@ def __repr__(self) -> str:
class DNSRRSet:
"""A set of dns records with a lookup to get the ttl."""

__slots__ = ('_record_sets', '_lookup')
__slots__ = ('_records', '_lookup')

def __init__(self, record_sets: List[List[DNSRecord]]) -> None:
def __init__(self, records: List[DNSRecord]) -> None:
"""Create an RRset from records sets."""
self._record_sets = record_sets
self._lookup: Optional[Dict[DNSRecord, float]] = None
self._records = records
self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None

@property
def lookup(self) -> Dict[DNSRecord, float]:
def lookup(self) -> Dict[DNSRecord, DNSRecord]:
"""Return the lookup table."""
return self._get_lookup()

def lookup_set(self) -> Set[DNSRecord]:
"""Return the lookup table as aset."""
return set(self._get_lookup())

def _get_lookup(self) -> Dict[DNSRecord, float]:
def _get_lookup(self) -> Dict[DNSRecord, DNSRecord]:
"""Return the lookup table, building it if needed."""
if self._lookup is None:
# Build the hash table so we can lookup the record ttl
self._lookup = {}
for record_sets in self._record_sets:
for record in record_sets:
self._lookup[record] = record.ttl
self._lookup = {record: record for record in self._records}
return self._lookup

def suppresses(self, record: _DNSRecord) -> bool:
"""Returns true if any answer in the rrset can suffice for the
information held in this record."""
lookup = self._get_lookup()
other_ttl = lookup.get(record)
if other_ttl is None:
other = lookup.get(record)
if other is None:
return False
return other_ttl > (record.ttl / 2)
return other.ttl > (record.ttl / 2)
14 changes: 8 additions & 6 deletions src/zeroconf/_handlers/query_handler.pxd
Expand Up @@ -21,7 +21,7 @@ cdef object _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL
cdef class _QueryResponse:

cdef bint _is_probe
cdef DNSIncoming _msg
cdef cython.list _questions
cdef float _now
cdef DNSCache _cache
cdef cython.dict _additionals
Expand All @@ -31,20 +31,20 @@ cdef class _QueryResponse:
cdef cython.set _mcast_aggregate_last_second

@cython.locals(record=DNSRecord)
cpdef add_qu_question_response(self, cython.dict answers)
cdef add_qu_question_response(self, cython.dict answers)

cpdef add_ucast_question_response(self, cython.dict answers)
cdef add_ucast_question_response(self, cython.dict answers)

@cython.locals(answer=DNSRecord)
cpdef add_mcast_question_response(self, cython.dict answers)
@cython.locals(answer=DNSRecord, question=DNSQuestion)
cdef add_mcast_question_response(self, cython.dict answers)

@cython.locals(maybe_entry=DNSRecord)
cdef bint _has_mcast_within_one_quarter_ttl(self, DNSRecord record)

@cython.locals(maybe_entry=DNSRecord)
cdef bint _has_mcast_record_in_last_second(self, DNSRecord record)

cpdef answers(self)
cdef QuestionAnswers answers(self)

cdef class QueryHandler:

Expand All @@ -70,5 +70,7 @@ cdef class QueryHandler:
answer_set=cython.dict,
known_answers=DNSRRSet,
known_answers_set=cython.set,
is_probe=object,
now=object
)
cpdef async_response(self, cython.list msgs, cython.bint unicast_source)
43 changes: 27 additions & 16 deletions src/zeroconf/_handlers/query_handler.py
Expand Up @@ -55,7 +55,7 @@ class _QueryResponse:

__slots__ = (
"_is_probe",
"_msg",
"_questions",
"_now",
"_cache",
"_additionals",
Expand All @@ -65,15 +65,11 @@ class _QueryResponse:
"_mcast_aggregate_last_second",
)

def __init__(self, cache: DNSCache, msgs: List[DNSIncoming]) -> None:
def __init__(self, cache: DNSCache, questions: List[DNSQuestion], is_probe: bool, now: float) -> None:
"""Build a query response."""
self._is_probe = False
for msg in msgs:
if msg.is_probe:
self._is_probe = True
break
self._msg = msgs[0]
self._now = self._msg.now
self._is_probe = is_probe
self._questions = questions
self._now = now
self._cache = cache
self._additionals: _AnswerWithAdditionalsType = {}
self._ucast: Set[DNSRecord] = set()
Expand Down Expand Up @@ -107,10 +103,15 @@ def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> No

if self._has_mcast_record_in_last_second(answer):
self._mcast_aggregate_last_second.add(answer)
elif len(self._msg.questions) == 1 and self._msg.questions[0].type in _RESPOND_IMMEDIATE_TYPES:
self._mcast_now.add(answer)
else:
self._mcast_aggregate.add(answer)
continue

if len(self._questions) == 1:
question = self._questions[0]
if question.type in _RESPOND_IMMEDIATE_TYPES:
self._mcast_now.add(answer)
continue

self._mcast_aggregate.add(answer)

def answers(
self,
Expand Down Expand Up @@ -262,16 +263,26 @@ def async_response( # pylint: disable=unused-argument
This function must be run in the event loop as it is not
threadsafe.
"""
known_answers = DNSRRSet([msg.answers for msg in msgs if not msg.is_probe])
query_res = _QueryResponse(self.cache, msgs)
answers: List[DNSRecord] = []
is_probe = False
msg = msgs[0]
questions = msg.questions
now = msg.now
for msg in msgs:
if not msg.is_probe():
answers.extend(msg.answers())
else:
is_probe = True
known_answers = DNSRRSet(answers)
query_res = _QueryResponse(self.cache, questions, is_probe, now)
known_answers_set: Optional[Set[DNSRecord]] = None

for msg in msgs:
for question in msg.questions:
if not question.unique: # unique and unicast are the same flag
if not known_answers_set: # pragma: no branch
known_answers_set = known_answers.lookup_set()
self.question_history.add_question_at_time(question, msg.now, known_answers_set)
self.question_history.add_question_at_time(question, now, known_answers_set)
answer_set = self._answer_question(question, known_answers)
if not ucast_source and question.unique: # unique and unicast are the same flag
query_res.add_qu_question_response(answer_set)
Expand Down
2 changes: 1 addition & 1 deletion src/zeroconf/_handlers/record_manager.py
Expand Up @@ -87,7 +87,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
now_float = now
unique_types: Set[Tuple[str, int, int]] = set()
cache = self.cache
answers = msg.answers
answers = msg.answers()

for record in answers:
# Protect zeroconf from records that can cause denial of service.
Expand Down
4 changes: 4 additions & 0 deletions src/zeroconf/_protocol/incoming.pxd
Expand Up @@ -72,6 +72,10 @@ cdef class DNSIncoming:

cpdef is_query(self)

cpdef is_probe(self)

cpdef answers(self)

cpdef is_response(self)

@cython.locals(
Expand Down
4 changes: 1 addition & 3 deletions src/zeroconf/_protocol/incoming.py
Expand Up @@ -172,7 +172,6 @@ def _log_exception_debug(cls, *logger_data: Any) -> None:
log_exc_info = True
log.debug(*(logger_data or ['Exception occurred']), exc_info=log_exc_info)

@property
def answers(self) -> List[DNSRecord]:
"""Answers in the packet."""
if not self._did_read_others:
Expand All @@ -187,7 +186,6 @@ def answers(self) -> List[DNSRecord]:
)
return self._answers

@property
def is_probe(self) -> bool:
"""Returns true if this is a probe."""
return self.num_authorities > 0
Expand All @@ -203,7 +201,7 @@ def __repr__(self) -> str:
'n_auth=%s' % self.num_authorities,
'n_add=%s' % self.num_additionals,
'questions=%s' % self.questions,
'answers=%s' % self.answers,
'answers=%s' % self.answers(),
]
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_asyncio.py
Expand Up @@ -997,7 +997,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
"""Sends an outgoing packet."""
pout = DNSIncoming(out.packets()[0])
nonlocal nbr_answers
for answer in pout.answers:
for answer in pout.answers():
nbr_answers += 1
if not answer.ttl > expected_ttl / 2:
unexpected_ttl.set()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dns.py
Expand Up @@ -392,7 +392,7 @@ def test_rrset_does_not_consider_ttl():
longaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 100, b'same')
shortaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 10, b'same')

rrset = DNSRRSet([[longarec, shortaaaarec]])
rrset = DNSRRSet([longarec, shortaaaarec])

assert rrset.suppresses(longarec)
assert rrset.suppresses(shortarec)
Expand All @@ -404,7 +404,7 @@ def test_rrset_does_not_consider_ttl():
mediumarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 60, b'same')
shortarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 10, b'same')

rrset2 = DNSRRSet([[mediumarec]])
rrset2 = DNSRRSet([mediumarec])
assert not rrset2.suppresses(verylongarec)
assert rrset2.suppresses(longarec)
assert rrset2.suppresses(mediumarec)
Expand Down
14 changes: 7 additions & 7 deletions tests/test_handlers.py
Expand Up @@ -1425,8 +1425,8 @@ async def test_response_aggregation_timings(run_isolated):
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
zc.record_manager.async_updates_from_response(incoming)
assert info.dns_pointer() in incoming.answers
assert info2.dns_pointer() in incoming.answers
assert info.dns_pointer() in incoming.answers()
assert info2.dns_pointer() in incoming.answers()
send_mock.reset_mock()

protocol.datagram_received(query3.packets()[0], ('127.0.0.1', const._MDNS_PORT))
Expand All @@ -1439,7 +1439,7 @@ async def test_response_aggregation_timings(run_isolated):
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
zc.record_manager.async_updates_from_response(incoming)
assert info3.dns_pointer() in incoming.answers
assert info3.dns_pointer() in incoming.answers()
send_mock.reset_mock()

# Because the response was sent in the last second we need to make
Expand All @@ -1461,7 +1461,7 @@ async def test_response_aggregation_timings(run_isolated):
assert len(calls) == 1
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
assert info.dns_pointer() in incoming.answers
assert info.dns_pointer() in incoming.answers()

await aiozc.async_close()

Expand Down Expand Up @@ -1501,7 +1501,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
zc.record_manager.async_updates_from_response(incoming)
assert info2.dns_pointer() in incoming.answers
assert info2.dns_pointer() in incoming.answers()

send_mock.reset_mock()
protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT))
Expand All @@ -1511,7 +1511,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
zc.record_manager.async_updates_from_response(incoming)
assert info2.dns_pointer() in incoming.answers
assert info2.dns_pointer() in incoming.answers()

send_mock.reset_mock()
protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT))
Expand All @@ -1534,7 +1534,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
zc.record_manager.async_updates_from_response(incoming)
assert info2.dns_pointer() in incoming.answers
assert info2.dns_pointer() in incoming.answers()


@pytest.mark.asyncio
Expand Down

0 comments on commit 3c6b18c

Please sign in to comment.