Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: speed up question and answer history with a cython pxd #1234

Merged
merged 5 commits into from
Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def build(setup_kwargs: Any) -> None:
[
"src/zeroconf/_dns.py",
"src/zeroconf/_cache.py",
"src/zeroconf/_history.py",
"src/zeroconf/_listener.py",
"src/zeroconf/_protocol/incoming.py",
"src/zeroconf/_protocol/outgoing.py",
Expand Down
16 changes: 16 additions & 0 deletions src/zeroconf/_history.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import cython


cdef cython.double _DUPLICATE_QUESTION_INTERVAL

cdef class QuestionHistory:

cdef cython.dict _history


@cython.locals(than=cython.double, previous_question=cython.tuple, previous_known_answers=cython.set)
cpdef suppresses(self, object question, cython.double now, cython.set known_answers)


@cython.locals(than=cython.double, now_known_answers=cython.tuple)
cpdef async_expire(self, cython.double now)
27 changes: 18 additions & 9 deletions src/zeroconf/_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,29 @@
USA
"""

from typing import Dict, Set, Tuple
from typing import Dict, List, Set, Tuple

from ._dns import DNSQuestion, DNSRecord
from .const import _DUPLICATE_QUESTION_INTERVAL

# The QuestionHistory is used to implement Duplicate Question Suppression
# https://datatracker.ietf.org/doc/html/rfc6762#section-7.3

_float = float


class QuestionHistory:
"""Remember questions and known answers."""

def __init__(self) -> None:
"""Init a new QuestionHistory."""
self._history: Dict[DNSQuestion, Tuple[float, Set[DNSRecord]]] = {}

def add_question_at_time(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> None:
def add_question_at_time(self, question: DNSQuestion, now: _float, known_answers: Set[DNSRecord]) -> None:
"""Remember a question with known answers."""
self._history[question] = (now, known_answers)

def suppresses(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> bool:
def suppresses(self, question: DNSQuestion, now: _float, known_answers: Set[DNSRecord]) -> bool:
"""Check to see if a question should be suppressed.

https://datatracker.ietf.org/doc/html/rfc6762#section-7.3
Expand All @@ -59,12 +64,16 @@ def suppresses(self, question: DNSQuestion, now: float, known_answers: Set[DNSRe
return False
return True

def async_expire(self, now: float) -> None:
def async_expire(self, now: _float) -> None:
"""Expire the history of old questions."""
removes = [
question
for question, now_known_answers in self._history.items()
if now - now_known_answers[0] > _DUPLICATE_QUESTION_INTERVAL
]
removes: List[DNSQuestion] = []
for question, now_known_answers in self._history.items():
than, _ = now_known_answers
if now - than > _DUPLICATE_QUESTION_INTERVAL:
removes.append(question)
for question in removes:
del self._history[question]

def clear(self) -> None:
"""Clear the history."""
self._history.clear()
12 changes: 9 additions & 3 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@
import asyncio
import socket
from functools import lru_cache
from typing import List
from typing import List, Set

import ifaddr

from zeroconf import DNSIncoming, Zeroconf
from zeroconf import DNSIncoming, DNSQuestion, DNSRecord, Zeroconf
from zeroconf._history import QuestionHistory


class QuestionHistoryWithoutSuppression(QuestionHistory):
def suppresses(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> bool:
return False


def _inject_responses(zc: Zeroconf, msgs: List[DNSIncoming]) -> None:
Expand Down Expand Up @@ -77,4 +83,4 @@ def has_working_ipv6():

def _clear_cache(zc):
zc.cache.cache.clear()
zc.question_history._history.clear()
zc.question_history.clear()
14 changes: 9 additions & 5 deletions tests/services/test_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
from zeroconf._services.info import ServiceInfo
from zeroconf.asyncio import AsyncZeroconf

from .. import _inject_response, _wait_for_start, has_working_ipv6
from .. import (
QuestionHistoryWithoutSuppression,
_inject_response,
_wait_for_start,
has_working_ipv6,
)

log = logging.getLogger('zeroconf')
original_logging_level = logging.NOTSET
Expand Down Expand Up @@ -444,6 +449,7 @@ def test_backoff():
type_ = "_http._tcp.local."
zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
_wait_for_start(zeroconf_browser)
zeroconf_browser.question_history = QuestionHistoryWithoutSuppression()

# we are going to patch the zeroconf send to check query transmission
old_send = zeroconf_browser.async_send
Expand All @@ -465,10 +471,8 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
# patch the zeroconf current_time_millis
# patch the backoff limit to prevent test running forever
with patch.object(zeroconf_browser, "async_send", send), patch.object(
zeroconf_browser.question_history, "suppresses", return_value=False
), patch.object(_services_browser, "current_time_millis", current_time_millis), patch.object(
_services_browser, "_BROWSER_BACKOFF_LIMIT", 10
), patch.object(
_services_browser, "current_time_millis", current_time_millis
), patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", 10), patch.object(
_services_browser, "_FIRST_QUERY_DELAY_RANDOM_INTERVAL", (0, 0)
):
# dummy service callback
Expand Down
11 changes: 5 additions & 6 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from zeroconf.const import _LISTENER_TIME

from . import _clear_cache, has_working_ipv6
from . import QuestionHistoryWithoutSuppression, _clear_cache, has_working_ipv6

log = logging.getLogger('zeroconf')
original_logging_level = logging.NOTSET
Expand Down Expand Up @@ -951,6 +951,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name):

aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
zeroconf_browser = aiozc.zeroconf
zeroconf_browser.question_history = QuestionHistoryWithoutSuppression()
await zeroconf_browser.async_wait_for_start()

# we are going to patch the zeroconf send to check packet sizes
Expand Down Expand Up @@ -990,11 +991,9 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
# patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL
# Disable duplicate question suppression and duplicate packet suppression for this test as it works
# by asking the same question over and over
with patch.object(zeroconf_browser.question_history, "suppresses", return_value=False), patch.object(
zeroconf_browser, "async_send", send
), patch("zeroconf._services.browser.current_time_millis", _new_current_time_millis), patch.object(
_services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)
):
with patch.object(zeroconf_browser, "async_send", send), patch(
"zeroconf._services.browser.current_time_millis", _new_current_time_millis
), patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)):
service_added = asyncio.Event()
service_removed = asyncio.Event()

Expand Down
4 changes: 2 additions & 2 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ async def test_cache_flush_bit():
for record in new_records:
assert zc.cache.async_get_unique(record) is not None

original_a_record.created = current_time_millis() - 1001
bdraco marked this conversation as resolved.
Show resolved Hide resolved
original_a_record.created = current_time_millis() - 1500

# Do the run within 1s to verify the original record is not going to be expired
out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True)
Expand All @@ -1146,7 +1146,7 @@ async def test_cache_flush_bit():
cached_records = [zc.cache.async_get_unique(record) for record in new_records]
for cached_record in cached_records:
assert cached_record is not None
cached_record.created = current_time_millis() - 1001
cached_record.created = current_time_millis() - 1500

fresh_address = socket.inet_aton("4.4.4.4")
info.addresses = [fresh_address]
Expand Down
3 changes: 3 additions & 0 deletions tests/test_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from zeroconf._protocol import outgoing
from zeroconf._protocol.incoming import DNSIncoming

from . import QuestionHistoryWithoutSuppression

log = logging.getLogger('zeroconf')
original_logging_level = logging.NOTSET

Expand Down Expand Up @@ -123,6 +125,7 @@ def test_guard_against_duplicate_packets():
These packets can quickly overwhelm the system.
"""
zc = Zeroconf(interfaces=['127.0.0.1'])
zc.question_history = QuestionHistoryWithoutSuppression()

class SubListener(_listener.AsyncListener):
def handle_query_or_defer(
Expand Down