Skip to content

Commit

Permalink
feat: make ServiceInfo aware of question history (#1348)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Dec 16, 2023
1 parent cf40470 commit b9aae1d
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 66 deletions.
4 changes: 2 additions & 2 deletions src/zeroconf/_history.pxd
Expand Up @@ -9,10 +9,10 @@ cdef class QuestionHistory:

cdef cython.dict _history

cpdef add_question_at_time(self, DNSQuestion question, double now, cython.set known_answers)
cpdef void add_question_at_time(self, DNSQuestion question, double now, cython.set known_answers)

@cython.locals(than=double, previous_question=cython.tuple, previous_known_answers=cython.set)
cpdef bint suppresses(self, DNSQuestion question, double now, cython.set known_answers)

@cython.locals(than=double, now_known_answers=cython.tuple)
cpdef async_expire(self, double now)
cpdef void async_expire(self, double now)
15 changes: 5 additions & 10 deletions src/zeroconf/_protocol/outgoing.pxd
@@ -1,7 +1,6 @@

import cython

from .._cache cimport DNSCache
from .._dns cimport DNSEntry, DNSPointer, DNSQuestion, DNSRecord
from .incoming cimport DNSIncoming

Expand Down Expand Up @@ -127,20 +126,16 @@ cdef class DNSOutgoing:
)
cpdef packets(self)

cpdef add_question_or_all_cache(self, DNSCache cache, double now, str name, object type_, object class_)
cpdef void add_question(self, DNSQuestion question)

cpdef add_question_or_one_cache(self, DNSCache cache, double now, str name, object type_, object class_)

cpdef add_question(self, DNSQuestion question)

cpdef add_answer(self, DNSIncoming inp, DNSRecord record)
cpdef void add_answer(self, DNSIncoming inp, DNSRecord record)

@cython.locals(now_double=double)
cpdef add_answer_at_time(self, DNSRecord record, double now)
cpdef void add_answer_at_time(self, DNSRecord record, double now)

cpdef add_authorative_answer(self, DNSPointer record)
cpdef void add_authorative_answer(self, DNSPointer record)

cpdef add_additional_answer(self, DNSRecord record)
cpdef void add_additional_answer(self, DNSRecord record)

cpdef bint is_query(self)

Expand Down
24 changes: 0 additions & 24 deletions src/zeroconf/_protocol/outgoing.py
Expand Up @@ -25,7 +25,6 @@
from struct import Struct
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union

from .._cache import DNSCache
from .._dns import DNSPointer, DNSQuestion, DNSRecord
from .._exceptions import NamePartTooLongException
from .._logger import log
Expand Down Expand Up @@ -198,29 +197,6 @@ def add_additional_answer(self, record: DNSRecord) -> None:
"""
self.additionals.append(record)

def add_question_or_one_cache(
self, cache: DNSCache, now: float_, name: str_, type_: int_, class_: int_
) -> None:
"""Add a question if it is not already cached."""
cached_entry = cache.get_by_details(name, type_, class_)
if not cached_entry:
self.add_question(DNSQuestion(name, type_, class_))
else:
self.add_answer_at_time(cached_entry, now)

def add_question_or_all_cache(
self, cache: DNSCache, now: float_, name: str_, type_: int_, class_: int_
) -> None:
"""Add a question if it is not already cached.
This is currently only used for IPv6 addresses.
"""
cached_entries = cache.get_all_by_details(name, type_, class_)
if not cached_entries:
self.add_question(DNSQuestion(name, type_, class_))
return
for cached_entry in cached_entries:
self.add_answer_at_time(cached_entry, now)

def _write_byte(self, value: int_) -> None:
"""Writes a single byte to the packet"""
self.data.append(BYTE_TABLE[value])
Expand Down
39 changes: 35 additions & 4 deletions src/zeroconf/_services/info.pxd
Expand Up @@ -2,7 +2,16 @@
import cython

from .._cache cimport DNSCache
from .._dns cimport DNSAddress, DNSNsec, DNSPointer, DNSRecord, DNSService, DNSText
from .._dns cimport (
DNSAddress,
DNSNsec,
DNSPointer,
DNSQuestion,
DNSRecord,
DNSService,
DNSText,
)
from .._history cimport QuestionHistory
from .._protocol.outgoing cimport DNSOutgoing
from .._record_update cimport RecordUpdate
from .._updates cimport RecordUpdateListener
Expand All @@ -27,18 +36,22 @@ cdef object _FLAGS_QR_QUERY

cdef object service_type_name

cdef object DNS_QUESTION_TYPE_QU
cdef object DNS_QUESTION_TYPE_QM
cdef object QU_QUESTION
cdef object QM_QUESTION

cdef object _IPVersion_All_value
cdef object _IPVersion_V4Only_value

cdef cython.set _ADDRESS_RECORD_TYPES

cdef unsigned int _DUPLICATE_QUESTION_INTERVAL

cdef bint TYPE_CHECKING
cdef bint IPADDRESS_SUPPORTS_SCOPE_ID
cdef object cached_ip_addresses

cdef object randint

cdef class ServiceInfo(RecordUpdateListener):

cdef public cython.bytes text
Expand Down Expand Up @@ -123,5 +136,23 @@ cdef class ServiceInfo(RecordUpdateListener):

cpdef void async_clear_cache(self)

@cython.locals(cache=DNSCache)
@cython.locals(cache=DNSCache, history=QuestionHistory, out=DNSOutgoing, qu_question=bint)
cdef DNSOutgoing _generate_request_query(self, object zc, double now, object question_type)

@cython.locals(question=DNSQuestion, answer=DNSRecord)
cdef void _add_question_with_known_answers(
self,
DNSOutgoing out,
bint qu_question,
QuestionHistory question_history,
DNSCache cache,
double now,
str name,
object type_,
object class_,
bint skip_if_known_answers
)

cdef double _get_initial_delay(self)

cdef double _get_random_delay(self)
97 changes: 76 additions & 21 deletions src/zeroconf/_services/info.py
Expand Up @@ -26,16 +26,19 @@
from ipaddress import IPv4Address, IPv6Address, _BaseAddress
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast

from .._cache import DNSCache
from .._dns import (
DNSAddress,
DNSNsec,
DNSPointer,
DNSQuestion,
DNSQuestionType,
DNSRecord,
DNSService,
DNSText,
)
from .._exceptions import BadTypeInNameException
from .._history import QuestionHistory
from .._logger import log
from .._protocol.outgoing import DNSOutgoing
from .._record_update import RecordUpdate
Expand All @@ -61,6 +64,7 @@
_CLASS_IN_UNIQUE,
_DNS_HOST_TTL,
_DNS_OTHER_TTL,
_DUPLICATE_QUESTION_INTERVAL,
_FLAGS_QR_QUERY,
_LISTENER_TIME,
_MDNS_PORT,
Expand Down Expand Up @@ -89,10 +93,12 @@
bytes_ = bytes
float_ = float
int_ = int
str_ = str

DNS_QUESTION_TYPE_QU = DNSQuestionType.QU
DNS_QUESTION_TYPE_QM = DNSQuestionType.QM
QU_QUESTION = DNSQuestionType.QU
QM_QUESTION = DNSQuestionType.QM

randint = random.randint

if TYPE_CHECKING:
from .._core import Zeroconf
Expand Down Expand Up @@ -774,6 +780,12 @@ def request(
)
)

def _get_initial_delay(self) -> float_:
return _LISTENER_TIME

def _get_random_delay(self) -> int_:
return randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL)

async def async_request(
self,
zc: 'Zeroconf',
Expand Down Expand Up @@ -804,7 +816,7 @@ async def async_request(
assert zc.loop is not None

first_request = True
delay = _LISTENER_TIME
delay = self._get_initial_delay()
next_ = now
last = now + timeout
try:
Expand All @@ -813,18 +825,25 @@ async def async_request(
if last <= now:
return False
if next_ <= now:
out = self._generate_request_query(
zc,
now,
question_type or DNS_QUESTION_TYPE_QU if first_request else DNS_QUESTION_TYPE_QM,
)
this_question_type = question_type or QU_QUESTION if first_request else QM_QUESTION
out = self._generate_request_query(zc, now, this_question_type)
first_request = False
if not out.questions:
return self._load_from_cache(zc, now)
zc.async_send(out, addr, port)
if out.questions:
# All questions may have been suppressed
# by the question history, so nothing to send,
# but keep waiting for answers in case another
# client on the network is asking the same
# question or they have not arrived yet.
zc.async_send(out, addr, port)
next_ = now + delay
delay *= 2
next_ += random.randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL)
next_ += self._get_random_delay()
if this_question_type is QM_QUESTION and delay < _DUPLICATE_QUESTION_INTERVAL:
# If we just asked a QM question, we need to
# wait at least the duplicate question interval
# before asking another QM question otherwise
# its likely to be suppressed by the question
# history of the remote responder.
delay = _DUPLICATE_QUESTION_INTERVAL

await self.async_wait(min(next_, last) - now, zc.loop)
now = current_time_millis()
Expand All @@ -833,21 +852,57 @@ async def async_request(

return True

def _add_question_with_known_answers(
self,
out: DNSOutgoing,
qu_question: bool,
question_history: QuestionHistory,
cache: DNSCache,
now: float_,
name: str_,
type_: int_,
class_: int_,
skip_if_known_answers: bool,
) -> None:
"""Add a question with known answers if its not suppressed."""
known_answers = {
answer for answer in cache.get_all_by_details(name, type_, class_) if not answer.is_stale(now)
}
if skip_if_known_answers and known_answers:
return
question = DNSQuestion(name, type_, class_)
if qu_question:
question.unicast = True
elif question_history.suppresses(question, now, known_answers):
return
else:
question_history.add_question_at_time(question, now, known_answers)
out.add_question(question)
for answer in known_answers:
out.add_answer_at_time(answer, now)

def _generate_request_query(
self, zc: 'Zeroconf', now: float_, question_type: DNSQuestionType
) -> DNSOutgoing:
"""Generate the request query."""
out = DNSOutgoing(_FLAGS_QR_QUERY)
name = self._name
server_or_name = self.server or name
server = self.server or name
cache = zc.cache
out.add_question_or_one_cache(cache, now, name, _TYPE_SRV, _CLASS_IN)
out.add_question_or_one_cache(cache, now, name, _TYPE_TXT, _CLASS_IN)
out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_A, _CLASS_IN)
out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_AAAA, _CLASS_IN)
if question_type is DNS_QUESTION_TYPE_QU:
for question in out.questions:
question.unicast = True
history = zc.question_history
qu_question = question_type is QU_QUESTION
self._add_question_with_known_answers(
out, qu_question, history, cache, now, name, _TYPE_SRV, _CLASS_IN, True
)
self._add_question_with_known_answers(
out, qu_question, history, cache, now, name, _TYPE_TXT, _CLASS_IN, True
)
self._add_question_with_known_answers(
out, qu_question, history, cache, now, server, _TYPE_A, _CLASS_IN, False
)
self._add_question_with_known_answers(
out, qu_question, history, cache, now, server, _TYPE_AAAA, _CLASS_IN, False
)
return out

def __repr__(self) -> str:
Expand Down

0 comments on commit b9aae1d

Please sign in to comment.