Skip to content

Commit

Permalink
feat: speed up the query handler (#1350)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Dec 19, 2023
1 parent 7ffbed8 commit 9eac0a1
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 73 deletions.
49 changes: 3 additions & 46 deletions src/zeroconf/_core.py
Expand Up @@ -31,10 +31,6 @@
from ._dns import DNSQuestion, DNSQuestionType
from ._engine import AsyncEngine
from ._exceptions import NonUniqueNameException, NotRunningException
from ._handlers.answers import (
construct_outgoing_multicast_answers,
construct_outgoing_unicast_answers,
)
from ._handlers.multicast_outgoing_queue import MulticastOutgoingQueue
from ._handlers.query_handler import QueryHandler
from ._handlers.record_manager import RecordManager
Expand Down Expand Up @@ -187,15 +183,15 @@ def __init__(
self.registry = ServiceRegistry()
self.cache = DNSCache()
self.question_history = QuestionHistory()
self.query_handler = QueryHandler(self.registry, self.cache, self.question_history)
self.query_handler = QueryHandler(self)
self.record_manager = RecordManager(self)

self._notify_futures: Set[asyncio.Future] = set()
self.loop: Optional[asyncio.AbstractEventLoop] = None
self._loop_thread: Optional[threading.Thread] = None

self._out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY)
self._out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY)
self.out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY)
self.out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY)

self.start()

Expand Down Expand Up @@ -567,45 +563,6 @@ def handle_response(self, msg: DNSIncoming) -> None:
self.log_warning_once("handle_response is deprecated, use record_manager.async_updates_from_response")
self.record_manager.async_updates_from_response(msg)

def handle_assembled_query(
self,
packets: List[DNSIncoming],
addr: str,
port: int,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
) -> None:
"""Respond to a (re)assembled query.
If the protocol received packets with the TC bit set, it will
wait a bit for the rest of the packets and only call
handle_assembled_query once it has a complete set of packets
or the timer expires. If the TC bit is not set, a single
packet will be in packets.
"""
ucast_source = port != _MDNS_PORT
question_answers = self.query_handler.async_response(packets, ucast_source)
if not question_answers:
return
now = packets[0].now
if question_answers.ucast:
questions = packets[0].questions
id_ = packets[0].id
out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_)
# When sending unicast, only send back the reply
# via the same socket that it was recieved from
# as we know its reachable from that socket
self.async_send(out, addr, port, v6_flow_scope, transport)
if question_answers.mcast_now:
self.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now))
if question_answers.mcast_aggregate:
self._out_queue.async_add(now, question_answers.mcast_aggregate)
if question_answers.mcast_aggregate_last_second:
# https://datatracker.ietf.org/doc/html/rfc6762#section-14
# If we broadcast it in the last second, we have to delay
# at least a second before we send it again
self._out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second)

def send(
self,
out: DNSOutgoing,
Expand Down
25 changes: 23 additions & 2 deletions src/zeroconf/_handlers/query_handler.pxd
Expand Up @@ -7,7 +7,12 @@ from .._history cimport QuestionHistory
from .._protocol.incoming cimport DNSIncoming
from .._services.info cimport ServiceInfo
from .._services.registry cimport ServiceRegistry
from .answers cimport QuestionAnswers
from .answers cimport (
QuestionAnswers,
construct_outgoing_multicast_answers,
construct_outgoing_unicast_answers,
)
from .multicast_outgoing_queue cimport MulticastOutgoingQueue


cdef bint TYPE_CHECKING
Expand Down Expand Up @@ -65,6 +70,7 @@ cdef class _QueryResponse:

cdef class QueryHandler:

cdef object zc
cdef ServiceRegistry registry
cdef DNSCache cache
cdef QuestionHistory question_history
Expand Down Expand Up @@ -93,7 +99,22 @@ cdef class QueryHandler:
is_probe=object,
now=double
)
cpdef async_response(self, cython.list msgs, cython.bint unicast_source)
cpdef QuestionAnswers async_response(self, cython.list msgs, cython.bint unicast_source)

@cython.locals(name=str, question_lower_name=str)
cdef _get_answer_strategies(self, DNSQuestion question)

@cython.locals(
first_packet=DNSIncoming,
ucast_source=bint,
out_queue=MulticastOutgoingQueue,
out_delay_queue=MulticastOutgoingQueue
)
cpdef void handle_assembled_query(
self,
list packets,
object addr,
object port,
object transport,
tuple v6_flow_scope
)
71 changes: 61 additions & 10 deletions src/zeroconf/_handlers/query_handler.py
Expand Up @@ -20,19 +20,19 @@
USA
"""

from typing import TYPE_CHECKING, List, Optional, Set, cast
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast

from .._cache import DNSCache, _UniqueRecordsType
from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet
from .._history import QuestionHistory
from .._protocol.incoming import DNSIncoming
from .._services.info import ServiceInfo
from .._services.registry import ServiceRegistry
from .._transport import _WrappedTransport
from .._utils.net import IPVersion
from ..const import (
_ADDRESS_RECORD_TYPES,
_CLASS_IN,
_DNS_OTHER_TTL,
_MDNS_PORT,
_ONE_SECOND,
_SERVICE_TYPE_ENUMERATION_NAME,
_TYPE_A,
Expand All @@ -43,7 +43,12 @@
_TYPE_SRV,
_TYPE_TXT,
)
from .answers import QuestionAnswers, _AnswerWithAdditionalsType
from .answers import (
QuestionAnswers,
_AnswerWithAdditionalsType,
construct_outgoing_multicast_answers,
construct_outgoing_unicast_answers,
)

_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES}

Expand All @@ -53,14 +58,17 @@
_IPVersion_ALL = IPVersion.All

_int = int

_str = str

_ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION = 0
_ANSWER_STRATEGY_POINTER = 1
_ANSWER_STRATEGY_ADDRESS = 2
_ANSWER_STRATEGY_SERVICE = 3
_ANSWER_STRATEGY_TEXT = 4

if TYPE_CHECKING:
from .._core import Zeroconf


class _AnswerStrategy:

Expand Down Expand Up @@ -183,13 +191,14 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:
class QueryHandler:
"""Query the ServiceRegistry."""

__slots__ = ("registry", "cache", "question_history")
__slots__ = ("zc", "registry", "cache", "question_history")

def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history: QuestionHistory) -> None:
def __init__(self, zc: 'Zeroconf') -> None:
"""Init the query handler."""
self.registry = registry
self.cache = cache
self.question_history = question_history
self.zc = zc
self.registry = zc.registry
self.cache = zc.cache
self.question_history = zc.question_history

def _add_service_type_enumeration_query_answers(
self, types: List[str], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet
Expand Down Expand Up @@ -385,3 +394,45 @@ def _get_answer_strategies(
)

return strategies

def handle_assembled_query(
self,
packets: List[DNSIncoming],
addr: _str,
port: _int,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
) -> None:
"""Respond to a (re)assembled query.
If the protocol recieved packets with the TC bit set, it will
wait a bit for the rest of the packets and only call
handle_assembled_query once it has a complete set of packets
or the timer expires. If the TC bit is not set, a single
packet will be in packets.
"""
first_packet = packets[0]
now = first_packet.now
ucast_source = port != _MDNS_PORT
question_answers = self.async_response(packets, ucast_source)
if not question_answers:
return
if question_answers.ucast:
questions = first_packet.questions
id_ = first_packet.id
out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_)
# When sending unicast, only send back the reply
# via the same socket that it was recieved from
# as we know its reachable from that socket
self.zc.async_send(out, addr, port, v6_flow_scope, transport)
if question_answers.mcast_now:
self.zc.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now))
if question_answers.mcast_aggregate:
out_queue = self.zc.out_queue
out_queue.async_add(now, question_answers.mcast_aggregate)
if question_answers.mcast_aggregate_last_second:
# https://datatracker.ietf.org/doc/html/rfc6762#section-14
# If we broadcast it in the last second, we have to delay
# at least a second before we send it again
out_delay_queue = self.zc.out_delay_queue
out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second)
13 changes: 6 additions & 7 deletions src/zeroconf/_handlers/record_manager.pxd
Expand Up @@ -22,22 +22,21 @@ cdef class RecordManager:
cdef public DNSCache cache
cdef public cython.set listeners

cpdef async_updates(self, object now, object records)
cpdef void async_updates(self, object now, object records)

cpdef async_updates_complete(self, object notify)
cpdef void async_updates_complete(self, bint notify)

@cython.locals(
cache=DNSCache,
record=DNSRecord,
answers=cython.list,
maybe_entry=DNSRecord,
now_double=double
)
cpdef async_updates_from_response(self, DNSIncoming msg)
cpdef void async_updates_from_response(self, DNSIncoming msg)

cpdef async_add_listener(self, RecordUpdateListener listener, object question)
cpdef void async_add_listener(self, RecordUpdateListener listener, object question)

cpdef async_remove_listener(self, RecordUpdateListener listener)
cpdef void async_remove_listener(self, RecordUpdateListener listener)

@cython.locals(question=DNSQuestion, record=DNSRecord)
cdef _async_update_matching_records(self, RecordUpdateListener listener, cython.list questions)
cdef void _async_update_matching_records(self, RecordUpdateListener listener, cython.list questions)
5 changes: 2 additions & 3 deletions src/zeroconf/_handlers/record_manager.py
Expand Up @@ -84,7 +84,6 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
other_adds: List[DNSRecord] = []
removes: Set[DNSRecord] = set()
now = msg.now
now_double = now
unique_types: Set[Tuple[str, int, int]] = set()
cache = self.cache
answers = msg.answers()
Expand Down Expand Up @@ -113,7 +112,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
record = cast(_UniqueRecordsType, record)

maybe_entry = cache.async_get_unique(record)
if not record.is_expired(now_double):
if not record.is_expired(now):
if maybe_entry is not None:
maybe_entry.reset_ttl(record)
else:
Expand All @@ -129,7 +128,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
removes.add(record)

if unique_types:
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now_double)
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now)

if updates:
self.async_updates(now, updates)
Expand Down
2 changes: 2 additions & 0 deletions src/zeroconf/_listener.pxd
@@ -1,6 +1,7 @@

import cython

from ._handlers.query_handler cimport QueryHandler
from ._handlers.record_manager cimport RecordManager
from ._protocol.incoming cimport DNSIncoming
from ._services.registry cimport ServiceRegistry
Expand All @@ -21,6 +22,7 @@ cdef class AsyncListener:
cdef public object zc
cdef ServiceRegistry _registry
cdef RecordManager _record_manager
cdef QueryHandler _query_handler
cdef public cython.bytes data
cdef public double last_time
cdef public DNSIncoming last_message
Expand Down
4 changes: 3 additions & 1 deletion src/zeroconf/_listener.py
Expand Up @@ -59,6 +59,7 @@ class AsyncListener:
'zc',
'_registry',
'_record_manager',
"_query_handler",
'data',
'last_time',
'last_message',
Expand All @@ -72,6 +73,7 @@ def __init__(self, zc: 'Zeroconf') -> None:
self.zc = zc
self._registry = zc.registry
self._record_manager = zc.record_manager
self._query_handler = zc.query_handler
self.data: Optional[bytes] = None
self.last_time: float = 0
self.last_message: Optional[DNSIncoming] = None
Expand Down Expand Up @@ -228,7 +230,7 @@ def _respond_query(
if msg:
packets.append(msg)

self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope)
self._query_handler.handle_assembled_query(packets, addr, port, transport, v6_flow_scope)

def error_received(self, exc: Exception) -> None:
"""Likely socket closed or IPv6."""
Expand Down
2 changes: 1 addition & 1 deletion src/zeroconf/_protocol/incoming.pxd
Expand Up @@ -56,7 +56,7 @@ cdef class DNSIncoming:
cdef cython.uint _num_authorities
cdef cython.uint _num_additionals
cdef public bint valid
cdef public object now
cdef public double now
cdef public object scope_id
cdef public object source
cdef bint _has_qu_question
Expand Down
4 changes: 2 additions & 2 deletions src/zeroconf/_transport.py
Expand Up @@ -22,7 +22,7 @@

import asyncio
import socket
from typing import Any
from typing import Tuple


class _WrappedTransport:
Expand All @@ -42,7 +42,7 @@ def __init__(
is_ipv6: bool,
sock: socket.socket,
fileno: int,
sock_name: Any,
sock_name: Tuple,
) -> None:
"""Initialize the wrapped transport.
Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Expand Up @@ -9,6 +9,7 @@
import pytest

from zeroconf import _core, const
from zeroconf._handlers import query_handler


@pytest.fixture(autouse=True)
Expand All @@ -23,7 +24,9 @@ def verify_threads_ended():
@pytest.fixture
def run_isolated():
"""Change the mDNS port to run the test in isolation."""
with patch.object(_core, "_MDNS_PORT", 5454), patch.object(const, "_MDNS_PORT", 5454):
with patch.object(query_handler, "_MDNS_PORT", 5454), patch.object(
_core, "_MDNS_PORT", 5454
), patch.object(const, "_MDNS_PORT", 5454):
yield


Expand Down

0 comments on commit 9eac0a1

Please sign in to comment.