Skip to content

Commit

Permalink
feat: optimize the dns cache (#1119)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Dec 21, 2022
1 parent f57d9f1 commit e80fcef
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 30 deletions.
1 change: 1 addition & 0 deletions build_ext.py
Expand Up @@ -23,6 +23,7 @@ def build(setup_kwargs: Any) -> None:
dict(
ext_modules=cythonize(
[
"src/zeroconf/_cache.py",
"src/zeroconf/_dns.py",
"src/zeroconf/_protocol/incoming.py",
"src/zeroconf/_protocol/outgoing.py",
Expand Down
28 changes: 28 additions & 0 deletions src/zeroconf/_cache.pxd
@@ -0,0 +1,28 @@
import cython
from ._dns cimport (
DNSAddress,
DNSEntry,
DNSHinfo,
DNSPointer,
DNSRecord,
DNSService,
DNSText,
)


cdef object _TYPE_PTR

cdef _remove_key(cython.dict cache, object key, DNSRecord record)


cdef class DNSCache:

cdef public cython.dict cache
cdef public cython.dict service_cache

cdef _async_add(self, DNSRecord record)

cdef _async_remove(self, DNSRecord record)


cdef _dns_record_matches(DNSRecord record, object key, object type_, object class_)
42 changes: 25 additions & 17 deletions src/zeroconf/_cache.py
Expand Up @@ -32,22 +32,23 @@
DNSRecord,
DNSService,
DNSText,
dns_entry_matches,
)
from ._utils.time import current_time_millis
from .const import _TYPE_PTR

_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService)
_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService]
_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]]
_DNSRecord = DNSRecord
_str = str


def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None:
def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None:
"""Remove a key from a DNSRecord cache
This function must be run in from event loop.
"""
del cache[key][entry]
del cache[key][record]
if not cache[key]:
del cache[key]

Expand All @@ -62,7 +63,7 @@ def __init__(self) -> None:
# Functions prefixed with async_ are NOT threadsafe and must
# be run in the event loop.

def _async_add(self, entry: DNSRecord) -> bool:
def _async_add(self, record: _DNSRecord) -> bool:
"""Adds an entry.
Returns true if the entry was not already in the cache.
Expand All @@ -75,11 +76,11 @@ def _async_add(self, entry: DNSRecord) -> bool:
# replaces any existing records that are __eq__ to each other which
# removes the risk that accessing the cache from the wrong
# direction would return the old incorrect entry.
store = self.cache.setdefault(entry.key, {})
new = entry not in store and not isinstance(entry, DNSNsec)
store[entry] = entry
if isinstance(entry, DNSService):
self.service_cache.setdefault(entry.server_key, {})[entry] = entry
store = self.cache.setdefault(record.key, {})
new = record not in store and not isinstance(record, DNSNsec)
store[record] = record
if isinstance(record, DNSService):
self.service_cache.setdefault(record.server_key, {})[record] = record
return new

def async_add_records(self, entries: Iterable[DNSRecord]) -> bool:
Expand All @@ -95,14 +96,14 @@ def async_add_records(self, entries: Iterable[DNSRecord]) -> bool:
new = True
return new

def _async_remove(self, entry: DNSRecord) -> None:
def _async_remove(self, record: _DNSRecord) -> None:
"""Removes an entry.
This function must be run in from event loop.
"""
if isinstance(entry, DNSService):
_remove_key(self.service_cache, entry.server_key, entry)
_remove_key(self.cache, entry.key, entry)
if isinstance(record, DNSService):
_remove_key(self.service_cache, record.server_key, record)
_remove_key(self.cache, record.key, record)

def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
"""Remove multiple records.
Expand All @@ -128,7 +129,10 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
This function is not threadsafe and must be called from
the event loop.
"""
return self.cache.get(entry.key, {}).get(entry)
store = self.cache.get(entry.key)
if store is None:
return None
return store.get(entry)

def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]:
"""Gets all matching entries by details.
Expand All @@ -138,7 +142,7 @@ def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[D
"""
key = name.lower()
for entry in self.cache.get(key, []):
if dns_entry_matches(entry, key, type_, class_):
if _dns_record_matches(entry, key, type_, class_):
yield entry

def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]:
Expand Down Expand Up @@ -185,15 +189,15 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco
"""
key = name.lower()
for cached_entry in reversed(list(self.cache.get(key, []))):
if dns_entry_matches(cached_entry, key, type_, class_):
if _dns_record_matches(cached_entry, key, type_, class_):
return cached_entry
return None

def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]:
"""Gets all matching entries by details."""
key = name.lower()
return [
entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_)
entry for entry in list(self.cache.get(key, [])) if _dns_record_matches(entry, key, type_, class_)
]

def entries_with_server(self, server: str) -> List[DNSRecord]:
Expand All @@ -218,3 +222,7 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D
def names(self) -> List[str]:
"""Return a copy of the list of current cache names."""
return list(self.cache)


def _dns_record_matches(record: _DNSRecord, key: _str, type_: int, class_: int) -> bool:
return key == record.key and type_ == record.type and class_ == record.class_
2 changes: 2 additions & 0 deletions src/zeroconf/_dns.pxd
Expand Up @@ -74,3 +74,5 @@ cdef class DNSRRSet:

cdef _records
cdef _lookup

cdef _dns_entry_matches(DNSEntry entry, object key, object type_, object class_)
28 changes: 16 additions & 12 deletions src/zeroconf/_dns.py
Expand Up @@ -59,10 +59,6 @@ class DNSQuestionType(enum.Enum):
QM = 2


def dns_entry_matches(record: 'DNSEntry', key: str, type_: int, class_: int) -> bool:
return key == record.key and type_ == record.type and class_ == record.class_


class DNSEntry:

"""A DNS entry"""
Expand All @@ -78,7 +74,7 @@ def __init__(self, name: str, type_: int, class_: int) -> None:

def __eq__(self, other: Any) -> bool:
"""Equality test on key (lowercase name), type, and class"""
return dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry)
return _dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry)

@staticmethod
def get_class_(class_: int) -> str:
Expand Down Expand Up @@ -121,7 +117,7 @@ def __hash__(self) -> int:

def __eq__(self, other: Any) -> bool:
"""Tests equality on dns question."""
return isinstance(other, DNSQuestion) and dns_entry_matches(other, self.key, self.type, self.class_)
return isinstance(other, DNSQuestion) and _dns_entry_matches(other, self.key, self.type, self.class_)

@property
def max_size(self) -> int:
Expand Down Expand Up @@ -254,7 +250,7 @@ def __eq__(self, other: Any) -> bool:
isinstance(other, DNSAddress)
and self.address == other.address
and self.scope_id == other.scope_id
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -298,7 +294,7 @@ def __eq__(self, other: Any) -> bool:
isinstance(other, DNSHinfo)
and self.cpu == other.cpu
and self.os == other.os
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -342,7 +338,7 @@ def __eq__(self, other: Any) -> bool:
return (
isinstance(other, DNSPointer)
and self.alias == other.alias
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -381,7 +377,7 @@ def __eq__(self, other: Any) -> bool:
return (
isinstance(other, DNSText)
and self.text == other.text
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __repr__(self) -> str:
Expand Down Expand Up @@ -432,7 +428,7 @@ def __eq__(self, other: Any) -> bool:
and self.weight == other.weight
and self.port == other.port
and self.server == other.server
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -487,7 +483,7 @@ def __eq__(self, other: Any) -> bool:
isinstance(other, DNSNsec)
and self.next_name == other.next_name
and self.rdtypes == other.rdtypes
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -527,3 +523,11 @@ def suppresses(self, record: DNSRecord) -> bool:
def __contains__(self, record: DNSRecord) -> bool:
"""Returns true if the rrset contains the record."""
return record in self.lookup


_DNSEntry = DNSEntry
_str = str


def _dns_entry_matches(entry: _DNSEntry, key: _str, type_: int, class_: int) -> bool:
return key == entry.key and type_ == entry.type and class_ == entry.class_
5 changes: 4 additions & 1 deletion src/zeroconf/_protocol/incoming.py
Expand Up @@ -66,6 +66,7 @@
UNPACK_HHiH = struct.Struct(b'!HHiH').unpack_from

_seen_logs: Dict[str, Union[int, tuple]] = {}
_str = str


class DNSIncoming:
Expand Down Expand Up @@ -250,7 +251,9 @@ def _read_others(self) -> None:
if rec is not None:
self._answers.append(rec)

def _read_record(self, domain, type_: int, class_: int, ttl: int, length: int) -> Optional[DNSRecord]: # type: ignore[no-untyped-def]
def _read_record(
self, domain: _str, type_: int, class_: int, ttl: int, length: int
) -> Optional[DNSRecord]:
"""Read known records types and skip unknown ones."""
if type_ == _TYPE_A:
return DNSAddress(domain, type_, class_, ttl, self._read_string(4), created=self.now)
Expand Down

0 comments on commit e80fcef

Please sign in to comment.