Skip to content

Commit

Permalink
feat: speed up question and answer history with a cython pxd (#1234)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Aug 26, 2023
1 parent 84054ce commit 703ecb2
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 25 deletions.
1 change: 1 addition & 0 deletions build_ext.py
Expand Up @@ -25,6 +25,7 @@ def build(setup_kwargs: Any) -> None:
[
"src/zeroconf/_dns.py",
"src/zeroconf/_cache.py",
"src/zeroconf/_history.py",
"src/zeroconf/_listener.py",
"src/zeroconf/_protocol/incoming.py",
"src/zeroconf/_protocol/outgoing.py",
Expand Down
16 changes: 16 additions & 0 deletions src/zeroconf/_history.pxd
@@ -0,0 +1,16 @@
import cython


cdef cython.double _DUPLICATE_QUESTION_INTERVAL

cdef class QuestionHistory:

cdef cython.dict _history


@cython.locals(than=cython.double, previous_question=cython.tuple, previous_known_answers=cython.set)
cpdef suppresses(self, object question, cython.double now, cython.set known_answers)


@cython.locals(than=cython.double, now_known_answers=cython.tuple)
cpdef async_expire(self, cython.double now)
27 changes: 18 additions & 9 deletions src/zeroconf/_history.py
Expand Up @@ -20,24 +20,29 @@
USA
"""

from typing import Dict, Set, Tuple
from typing import Dict, List, Set, Tuple

from ._dns import DNSQuestion, DNSRecord
from .const import _DUPLICATE_QUESTION_INTERVAL

# The QuestionHistory is used to implement Duplicate Question Suppression
# https://datatracker.ietf.org/doc/html/rfc6762#section-7.3

_float = float


class QuestionHistory:
"""Remember questions and known answers."""

def __init__(self) -> None:
"""Init a new QuestionHistory."""
self._history: Dict[DNSQuestion, Tuple[float, Set[DNSRecord]]] = {}

def add_question_at_time(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> None:
def add_question_at_time(self, question: DNSQuestion, now: _float, known_answers: Set[DNSRecord]) -> None:
"""Remember a question with known answers."""
self._history[question] = (now, known_answers)

def suppresses(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> bool:
def suppresses(self, question: DNSQuestion, now: _float, known_answers: Set[DNSRecord]) -> bool:
"""Check to see if a question should be suppressed.
https://datatracker.ietf.org/doc/html/rfc6762#section-7.3
Expand All @@ -59,12 +64,16 @@ def suppresses(self, question: DNSQuestion, now: float, known_answers: Set[DNSRe
return False
return True

def async_expire(self, now: float) -> None:
def async_expire(self, now: _float) -> None:
"""Expire the history of old questions."""
removes = [
question
for question, now_known_answers in self._history.items()
if now - now_known_answers[0] > _DUPLICATE_QUESTION_INTERVAL
]
removes: List[DNSQuestion] = []
for question, now_known_answers in self._history.items():
than, _ = now_known_answers
if now - than > _DUPLICATE_QUESTION_INTERVAL:
removes.append(question)
for question in removes:
del self._history[question]

def clear(self) -> None:
"""Clear the history."""
self._history.clear()
12 changes: 9 additions & 3 deletions tests/__init__.py
Expand Up @@ -23,11 +23,17 @@
import asyncio
import socket
from functools import lru_cache
from typing import List
from typing import List, Set

import ifaddr

from zeroconf import DNSIncoming, Zeroconf
from zeroconf import DNSIncoming, DNSQuestion, DNSRecord, Zeroconf
from zeroconf._history import QuestionHistory


class QuestionHistoryWithoutSuppression(QuestionHistory):
def suppresses(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> bool:
return False


def _inject_responses(zc: Zeroconf, msgs: List[DNSIncoming]) -> None:
Expand Down Expand Up @@ -77,4 +83,4 @@ def has_working_ipv6():

def _clear_cache(zc):
zc.cache.cache.clear()
zc.question_history._history.clear()
zc.question_history.clear()
14 changes: 9 additions & 5 deletions tests/services/test_browser.py
Expand Up @@ -31,7 +31,12 @@
from zeroconf._services.info import ServiceInfo
from zeroconf.asyncio import AsyncZeroconf

from .. import _inject_response, _wait_for_start, has_working_ipv6
from .. import (
QuestionHistoryWithoutSuppression,
_inject_response,
_wait_for_start,
has_working_ipv6,
)

log = logging.getLogger('zeroconf')
original_logging_level = logging.NOTSET
Expand Down Expand Up @@ -444,6 +449,7 @@ def test_backoff():
type_ = "_http._tcp.local."
zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
_wait_for_start(zeroconf_browser)
zeroconf_browser.question_history = QuestionHistoryWithoutSuppression()

# we are going to patch the zeroconf send to check query transmission
old_send = zeroconf_browser.async_send
Expand All @@ -465,10 +471,8 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
# patch the zeroconf current_time_millis
# patch the backoff limit to prevent test running forever
with patch.object(zeroconf_browser, "async_send", send), patch.object(
zeroconf_browser.question_history, "suppresses", return_value=False
), patch.object(_services_browser, "current_time_millis", current_time_millis), patch.object(
_services_browser, "_BROWSER_BACKOFF_LIMIT", 10
), patch.object(
_services_browser, "current_time_millis", current_time_millis
), patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", 10), patch.object(
_services_browser, "_FIRST_QUERY_DELAY_RANDOM_INTERVAL", (0, 0)
):
# dummy service callback
Expand Down
11 changes: 5 additions & 6 deletions tests/test_asyncio.py
Expand Up @@ -43,7 +43,7 @@
)
from zeroconf.const import _LISTENER_TIME

from . import _clear_cache, has_working_ipv6
from . import QuestionHistoryWithoutSuppression, _clear_cache, has_working_ipv6

log = logging.getLogger('zeroconf')
original_logging_level = logging.NOTSET
Expand Down Expand Up @@ -951,6 +951,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name):

aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
zeroconf_browser = aiozc.zeroconf
zeroconf_browser.question_history = QuestionHistoryWithoutSuppression()
await zeroconf_browser.async_wait_for_start()

# we are going to patch the zeroconf send to check packet sizes
Expand Down Expand Up @@ -990,11 +991,9 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
# patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL
# Disable duplicate question suppression and duplicate packet suppression for this test as it works
# by asking the same question over and over
with patch.object(zeroconf_browser.question_history, "suppresses", return_value=False), patch.object(
zeroconf_browser, "async_send", send
), patch("zeroconf._services.browser.current_time_millis", _new_current_time_millis), patch.object(
_services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)
):
with patch.object(zeroconf_browser, "async_send", send), patch(
"zeroconf._services.browser.current_time_millis", _new_current_time_millis
), patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)):
service_added = asyncio.Event()
service_removed = asyncio.Event()

Expand Down
4 changes: 2 additions & 2 deletions tests/test_handlers.py
Expand Up @@ -1131,7 +1131,7 @@ async def test_cache_flush_bit():
for record in new_records:
assert zc.cache.async_get_unique(record) is not None

original_a_record.created = current_time_millis() - 1001
original_a_record.created = current_time_millis() - 1500

# Do the run within 1s to verify the original record is not going to be expired
out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True)
Expand All @@ -1146,7 +1146,7 @@ async def test_cache_flush_bit():
cached_records = [zc.cache.async_get_unique(record) for record in new_records]
for cached_record in cached_records:
assert cached_record is not None
cached_record.created = current_time_millis() - 1001
cached_record.created = current_time_millis() - 1500

fresh_address = socket.inet_aton("4.4.4.4")
info.addresses = [fresh_address]
Expand Down
3 changes: 3 additions & 0 deletions tests/test_listener.py
Expand Up @@ -14,6 +14,8 @@
from zeroconf._protocol import outgoing
from zeroconf._protocol.incoming import DNSIncoming

from . import QuestionHistoryWithoutSuppression

log = logging.getLogger('zeroconf')
original_logging_level = logging.NOTSET

Expand Down Expand Up @@ -123,6 +125,7 @@ def test_guard_against_duplicate_packets():
These packets can quickly overwhelm the system.
"""
zc = Zeroconf(interfaces=['127.0.0.1'])
zc.question_history = QuestionHistoryWithoutSuppression()

class SubListener(_listener.AsyncListener):
def handle_query_or_defer(
Expand Down

0 comments on commit 703ecb2

Please sign in to comment.