Skip to content

Commit

Permalink
feat: speed up ServiceBrowsers with a cython pxd (#1270)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Sep 14, 2023
1 parent c88530b commit 4837876
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 32 deletions.
1 change: 1 addition & 0 deletions build_ext.py
Expand Up @@ -32,6 +32,7 @@ def build(setup_kwargs: Any) -> None:
"src/zeroconf/_handlers/answers.py",
"src/zeroconf/_handlers/record_manager.py",
"src/zeroconf/_handlers/query_handler.py",
"src/zeroconf/_services/browser.py",
"src/zeroconf/_services/info.py",
"src/zeroconf/_services/registry.py",
"src/zeroconf/_updates.py",
Expand Down
74 changes: 74 additions & 0 deletions src/zeroconf/_services/browser.pxd
@@ -0,0 +1,74 @@

import cython

from .._cache cimport DNSCache
from .._protocol.outgoing cimport DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord
from .._updates cimport RecordUpdateListener
from .._utils.time cimport current_time_millis, millis_to_seconds


cdef object TYPE_CHECKING
cdef object cached_possible_types
cdef cython.uint _EXPIRE_REFRESH_TIME_PERCENT
cdef object SERVICE_STATE_CHANGE_ADDED, SERVICE_STATE_CHANGE_REMOVED, SERVICE_STATE_CHANGE_UPDATED

cdef class _DNSPointerOutgoingBucket:

cdef public object now
cdef public DNSOutgoing out
cdef public cython.uint bytes

cpdef add(self, cython.uint max_compressed_size, DNSQuestion question, cython.set answers)


@cython.locals(answer=DNSPointer)
cdef _group_ptr_queries_with_known_answers(object now, object multicast, cython.dict question_with_known_answers)

cdef class QueryScheduler:

cdef cython.set _types
cdef cython.dict _next_time
cdef object _first_random_delay_interval
cdef cython.dict _delay

cpdef millis_to_wait(self, object now)

cpdef reschedule_type(self, object type_, object next_time)

cpdef process_ready_types(self, object now)

cdef class _ServiceBrowserBase(RecordUpdateListener):

cdef public cython.set types
cdef public object zc
cdef object _loop
cdef public object addr
cdef public object port
cdef public object multicast
cdef public object question_type
cdef public cython.dict _pending_handlers
cdef public object _service_state_changed
cdef public QueryScheduler query_scheduler
cdef public bint done
cdef public object _first_request
cdef public object _next_send_timer
cdef public object _query_sender_task

cpdef _generate_ready_queries(self, object first_request, object now)

cpdef _enqueue_callback(self, object state_change, object type_, object name)

@cython.locals(record=DNSRecord, cache=DNSCache, service=DNSRecord)
cpdef async_update_records(self, object zc, cython.float now, cython.list records)

cpdef _names_matching_types(self, object types)

cpdef reschedule_type(self, object type_, object now, object next_time)

cpdef _fire_service_state_changed_event(self, cython.tuple event)

cpdef _async_send_ready_queries_schedule_next(self)

cpdef _async_schedule_next(self, object now)

cpdef _async_send_ready_queries(self, object now)
65 changes: 37 additions & 28 deletions src/zeroconf/_services/browser.py
Expand Up @@ -78,9 +78,17 @@
ServiceStateChange.Updated: "update_service",
}

SERVICE_STATE_CHANGE_ADDED = ServiceStateChange.Added
SERVICE_STATE_CHANGE_REMOVED = ServiceStateChange.Removed
SERVICE_STATE_CHANGE_UPDATED = ServiceStateChange.Updated

if TYPE_CHECKING:
from .._core import Zeroconf

float_ = float
int_ = int
bool_ = bool
str_ = str

_QuestionWithKnownAnswers = Dict[DNSQuestion, Set[DNSPointer]]

Expand All @@ -96,7 +104,7 @@ def __init__(self, now: float, multicast: bool) -> None:
self.out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=multicast)
self.bytes = 0

def add(self, max_compressed_size: int, question: DNSQuestion, answers: Set[DNSPointer]) -> None:
def add(self, max_compressed_size: int_, question: DNSQuestion, answers: Set[DNSPointer]) -> None:
"""Add a new set of questions and known answers to the outgoing."""
self.out.add_question(question)
for answer in answers:
Expand All @@ -105,7 +113,7 @@ def add(self, max_compressed_size: int, question: DNSQuestion, answers: Set[DNSP


def _group_ptr_queries_with_known_answers(
now: float, multicast: bool, question_with_known_answers: _QuestionWithKnownAnswers
now: float_, multicast: bool_, question_with_known_answers: _QuestionWithKnownAnswers
) -> List[DNSOutgoing]:
"""Aggregate queries so that as many known answers as possible fit in the same packet
without having known answers spill over into the next packet unless the
Expand Down Expand Up @@ -205,26 +213,24 @@ class QueryScheduler:
"""

__slots__ = ('_schedule_changed_event', '_types', '_next_time', '_first_random_delay_interval', '_delay')
__slots__ = ('_types', '_next_time', '_first_random_delay_interval', '_delay')

def __init__(
self,
types: Set[str],
delay: int,
first_random_delay_interval: Tuple[int, int],
) -> None:
self._schedule_changed_event: Optional[asyncio.Event] = None
self._types = types
self._next_time: Dict[str, float] = {}
self._first_random_delay_interval = first_random_delay_interval
self._delay: Dict[str, float] = {check_type_: delay for check_type_ in self._types}

def start(self, now: float) -> None:
def start(self, now: float_) -> None:
"""Start the scheduler."""
self._schedule_changed_event = asyncio.Event()
self._generate_first_next_time(now)

def _generate_first_next_time(self, now: float) -> None:
def _generate_first_next_time(self, now: float_) -> None:
"""Generate the initial next query times.
https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
Expand All @@ -238,20 +244,20 @@ def _generate_first_next_time(self, now: float) -> None:
next_time = now + delay
self._next_time = {check_type_: next_time for check_type_ in self._types}

def millis_to_wait(self, now: float) -> float:
def millis_to_wait(self, now: float_) -> float:
"""Returns the number of milliseconds to wait for the next event."""
# Wait for the type has the smallest next time
next_time = min(self._next_time.values())
return 0 if next_time <= now else next_time - now

def reschedule_type(self, type_: str, next_time: float) -> bool:
def reschedule_type(self, type_: str_, next_time: float_) -> bool:
"""Reschedule the query for a type to happen sooner."""
if next_time >= self._next_time[type_]:
return False
self._next_time[type_] = next_time
return True

def process_ready_types(self, now: float) -> List[str]:
def process_ready_types(self, now: float_) -> List[str]:
"""Generate a list of ready types that is due and schedule the next time."""
if self.millis_to_wait(now):
return []
Expand All @@ -275,6 +281,7 @@ class _ServiceBrowserBase(RecordUpdateListener):
__slots__ = (
'types',
'zc',
'_loop',
'addr',
'port',
'multicast',
Expand Down Expand Up @@ -322,6 +329,8 @@ def __init__(
# Will generate BadTypeInNameException on a bad name
service_type_name(check_type_, strict=False)
self.zc = zc
assert zc.loop is not None
self._loop = zc.loop
self.addr = addr
self.port = port
self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6)
Expand Down Expand Up @@ -370,23 +379,23 @@ def _names_matching_types(self, names: Iterable[str]) -> List[Tuple[str, str]]:
def _enqueue_callback(
self,
state_change: ServiceStateChange,
type_: str,
name: str,
type_: str_,
name: str_,
) -> None:
# Code to ensure we only do a single update message
# Precedence is; Added, Remove, Update
key = (name, type_)
if (
state_change is ServiceStateChange.Added
state_change is SERVICE_STATE_CHANGE_ADDED
or (
state_change is ServiceStateChange.Removed
and self._pending_handlers.get(key) != ServiceStateChange.Added
state_change is SERVICE_STATE_CHANGE_REMOVED
and self._pending_handlers.get(key) != SERVICE_STATE_CHANGE_ADDED
)
or (state_change is ServiceStateChange.Updated and key not in self._pending_handlers)
or (state_change is SERVICE_STATE_CHANGE_UPDATED and key not in self._pending_handlers)
):
self._pending_handlers[key] = state_change

def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[RecordUpdate]) -> None:
"""Callback invoked by Zeroconf when new information arrives.
Updates information required by browser in the Zeroconf cache.
Expand All @@ -404,9 +413,9 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
record = cast(DNSPointer, record)
for type_ in self.types.intersection(cached_possible_types(record.name)):
if old_record is None:
self._enqueue_callback(ServiceStateChange.Added, type_, record.alias)
self._enqueue_callback(SERVICE_STATE_CHANGE_ADDED, type_, record.alias)
elif record.is_expired(now):
self._enqueue_callback(ServiceStateChange.Removed, type_, record.alias)
self._enqueue_callback(SERVICE_STATE_CHANGE_REMOVED, type_, record.alias)
else:
expire_time = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)
self.reschedule_type(type_, now, expire_time)
Expand All @@ -417,15 +426,16 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
continue

if record_type in _ADDRESS_RECORD_TYPES:
cache = self.zc.cache
# Iterate through the DNSCache and callback any services that use this address
for type_, name in self._names_matching_types(
{service.name for service in self.zc.cache.async_entries_with_server(record.name)}
{service.name for service in cache.async_entries_with_server(record.name)}
):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)
self._enqueue_callback(SERVICE_STATE_CHANGE_UPDATED, type_, name)
continue

for type_, name in self._names_matching_types((record.name,)):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)
self._enqueue_callback(SERVICE_STATE_CHANGE_UPDATED, type_, name)

@abstractmethod
def async_update_records_complete(self) -> None:
Expand Down Expand Up @@ -460,7 +470,7 @@ def _async_cancel(self) -> None:
assert self._query_sender_task is not None, "Attempted to cancel a browser that was not started"
self._query_sender_task.cancel()

def _generate_ready_queries(self, first_request: bool, now: float) -> List[DNSOutgoing]:
def _generate_ready_queries(self, first_request: bool_, now: float_) -> List[DNSOutgoing]:
"""Generate the service browser query for any type that is due."""
ready_types = self.query_scheduler.process_ready_types(now)
if not ready_types:
Expand All @@ -485,7 +495,7 @@ def _cancel_send_timer(self) -> None:
self._next_send_timer.cancel()
self._next_send_timer = None

def reschedule_type(self, type_: str, now: float, next_time: float) -> None:
def reschedule_type(self, type_: str_, now: float_, next_time: float_) -> None:
"""Reschedule a type to be refreshed in the future."""
if self.query_scheduler.reschedule_type(type_, next_time):
# We need to send the queries before rescheduling the next one
Expand All @@ -496,7 +506,7 @@ def reschedule_type(self, type_: str, now: float, next_time: float) -> None:
self._cancel_send_timer()
self._async_schedule_next(now)

def _async_send_ready_queries(self, now: float) -> None:
def _async_send_ready_queries(self, now: float_) -> None:
"""Send any ready queries."""
outs = self._generate_ready_queries(self._first_request, now)
if outs:
Expand All @@ -512,11 +522,10 @@ def _async_send_ready_queries_schedule_next(self) -> None:
self._async_send_ready_queries(now)
self._async_schedule_next(now)

def _async_schedule_next(self, now: float) -> None:
def _async_schedule_next(self, now: float_) -> None:
"""Scheule the next time."""
assert self.zc.loop is not None
delay = millis_to_seconds(self.query_scheduler.millis_to_wait(now))
self._next_send_timer = self.zc.loop.call_later(delay, self._async_send_ready_queries_schedule_next)
self._next_send_timer = self._loop.call_later(delay, self._async_send_ready_queries_schedule_next)


class ServiceBrowser(_ServiceBrowserBase, threading.Thread):
Expand Down
6 changes: 3 additions & 3 deletions src/zeroconf/_services/info.pxd
Expand Up @@ -56,9 +56,9 @@ cdef class ServiceInfo(RecordUpdateListener):
cdef public cython.set _get_address_and_nsec_records_cache

@cython.locals(
cache=DNSCache
cache=DNSCache,
)
cpdef async_update_records(self, object zc, object now, cython.list records)
cpdef async_update_records(self, object zc, cython.float now, cython.list records)

@cython.locals(
cache=DNSCache
Expand All @@ -73,7 +73,7 @@ cdef class ServiceInfo(RecordUpdateListener):

cdef _get_ip_addresses_from_cache_lifo(self, object zc, object now, object type)

cdef _process_record_threadsafe(self, object zc, DNSRecord record, object now)
cdef _process_record_threadsafe(self, object zc, DNSRecord record, cython.float now)

@cython.locals(
cache=DNSCache
Expand Down
2 changes: 1 addition & 1 deletion src/zeroconf/_updates.pxd
Expand Up @@ -4,6 +4,6 @@ import cython

cdef class RecordUpdateListener:

cpdef async_update_records(self, object zc, object now, cython.list records)
cpdef async_update_records(self, object zc, cython.float now, cython.list records)

cpdef async_update_records_complete(self)

0 comments on commit 4837876

Please sign in to comment.