diff --git a/build_ext.py b/build_ext.py index 3746fb25..430ad9f1 100644 --- a/build_ext.py +++ b/build_ext.py @@ -23,6 +23,7 @@ def build(setup_kwargs: Any) -> None: dict( ext_modules=cythonize( [ + "src/zeroconf/_cache.py", "src/zeroconf/_dns.py", "src/zeroconf/_protocol/incoming.py", "src/zeroconf/_protocol/outgoing.py", diff --git a/src/zeroconf/_cache.pxd b/src/zeroconf/_cache.pxd new file mode 100644 index 00000000..acf4029e --- /dev/null +++ b/src/zeroconf/_cache.pxd @@ -0,0 +1,28 @@ +import cython +from ._dns cimport ( + DNSAddress, + DNSEntry, + DNSHinfo, + DNSPointer, + DNSRecord, + DNSService, + DNSText, +) + + +cdef object _TYPE_PTR + +cdef _remove_key(cython.dict cache, object key, DNSRecord record) + + +cdef class DNSCache: + + cdef public cython.dict cache + cdef public cython.dict service_cache + + cdef _async_add(self, DNSRecord record) + + cdef _async_remove(self, DNSRecord record) + + +cdef _dns_record_matches(DNSRecord record, object key, object type_, object class_) diff --git a/src/zeroconf/_cache.py b/src/zeroconf/_cache.py index cb485eaf..f022c2cb 100644 --- a/src/zeroconf/_cache.py +++ b/src/zeroconf/_cache.py @@ -32,7 +32,6 @@ DNSRecord, DNSService, DNSText, - dns_entry_matches, ) from ._utils.time import current_time_millis from .const import _TYPE_PTR @@ -40,14 +39,16 @@ _UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService) _UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService] _DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] +_DNSRecord = DNSRecord +_str = str -def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None: +def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None: """Remove a key from a DNSRecord cache This function must be run in from event loop. """ - del cache[key][entry] + del cache[key][record] if not cache[key]: del cache[key] @@ -62,7 +63,7 @@ def __init__(self) -> None: # Functions prefixed with async_ are NOT threadsafe and must # be run in the event loop. - def _async_add(self, entry: DNSRecord) -> bool: + def _async_add(self, record: _DNSRecord) -> bool: """Adds an entry. Returns true if the entry was not already in the cache. @@ -75,11 +76,11 @@ def _async_add(self, entry: DNSRecord) -> bool: # replaces any existing records that are __eq__ to each other which # removes the risk that accessing the cache from the wrong # direction would return the old incorrect entry. - store = self.cache.setdefault(entry.key, {}) - new = entry not in store and not isinstance(entry, DNSNsec) - store[entry] = entry - if isinstance(entry, DNSService): - self.service_cache.setdefault(entry.server_key, {})[entry] = entry + store = self.cache.setdefault(record.key, {}) + new = record not in store and not isinstance(record, DNSNsec) + store[record] = record + if isinstance(record, DNSService): + self.service_cache.setdefault(record.server_key, {})[record] = record return new def async_add_records(self, entries: Iterable[DNSRecord]) -> bool: @@ -95,14 +96,14 @@ def async_add_records(self, entries: Iterable[DNSRecord]) -> bool: new = True return new - def _async_remove(self, entry: DNSRecord) -> None: + def _async_remove(self, record: _DNSRecord) -> None: """Removes an entry. This function must be run in from event loop. """ - if isinstance(entry, DNSService): - _remove_key(self.service_cache, entry.server_key, entry) - _remove_key(self.cache, entry.key, entry) + if isinstance(record, DNSService): + _remove_key(self.service_cache, record.server_key, record) + _remove_key(self.cache, record.key, record) def async_remove_records(self, entries: Iterable[DNSRecord]) -> None: """Remove multiple records. @@ -128,7 +129,10 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: This function is not threadsafe and must be called from the event loop. """ - return self.cache.get(entry.key, {}).get(entry) + store = self.cache.get(entry.key) + if store is None: + return None + return store.get(entry) def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]: """Gets all matching entries by details. @@ -138,7 +142,7 @@ def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[D """ key = name.lower() for entry in self.cache.get(key, []): - if dns_entry_matches(entry, key, type_, class_): + if _dns_record_matches(entry, key, type_, class_): yield entry def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]: @@ -185,7 +189,7 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco """ key = name.lower() for cached_entry in reversed(list(self.cache.get(key, []))): - if dns_entry_matches(cached_entry, key, type_, class_): + if _dns_record_matches(cached_entry, key, type_, class_): return cached_entry return None @@ -193,7 +197,7 @@ def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSReco """Gets all matching entries by details.""" key = name.lower() return [ - entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_) + entry for entry in list(self.cache.get(key, [])) if _dns_record_matches(entry, key, type_, class_) ] def entries_with_server(self, server: str) -> List[DNSRecord]: @@ -218,3 +222,7 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D def names(self) -> List[str]: """Return a copy of the list of current cache names.""" return list(self.cache) + + +def _dns_record_matches(record: _DNSRecord, key: _str, type_: int, class_: int) -> bool: + return key == record.key and type_ == record.type and class_ == record.class_ diff --git a/src/zeroconf/_dns.pxd b/src/zeroconf/_dns.pxd index 8246326b..c83269ef 100644 --- a/src/zeroconf/_dns.pxd +++ b/src/zeroconf/_dns.pxd @@ -74,3 +74,5 @@ cdef class DNSRRSet: cdef _records cdef _lookup + +cdef _dns_entry_matches(DNSEntry entry, object key, object type_, object class_) diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index 72b56377..4d46263e 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -59,10 +59,6 @@ class DNSQuestionType(enum.Enum): QM = 2 -def dns_entry_matches(record: 'DNSEntry', key: str, type_: int, class_: int) -> bool: - return key == record.key and type_ == record.type and class_ == record.class_ - - class DNSEntry: """A DNS entry""" @@ -78,7 +74,7 @@ def __init__(self, name: str, type_: int, class_: int) -> None: def __eq__(self, other: Any) -> bool: """Equality test on key (lowercase name), type, and class""" - return dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry) + return _dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry) @staticmethod def get_class_(class_: int) -> str: @@ -121,7 +117,7 @@ def __hash__(self) -> int: def __eq__(self, other: Any) -> bool: """Tests equality on dns question.""" - return isinstance(other, DNSQuestion) and dns_entry_matches(other, self.key, self.type, self.class_) + return isinstance(other, DNSQuestion) and _dns_entry_matches(other, self.key, self.type, self.class_) @property def max_size(self) -> int: @@ -254,7 +250,7 @@ def __eq__(self, other: Any) -> bool: isinstance(other, DNSAddress) and self.address == other.address and self.scope_id == other.scope_id - and dns_entry_matches(other, self.key, self.type, self.class_) + and _dns_entry_matches(other, self.key, self.type, self.class_) ) def __hash__(self) -> int: @@ -298,7 +294,7 @@ def __eq__(self, other: Any) -> bool: isinstance(other, DNSHinfo) and self.cpu == other.cpu and self.os == other.os - and dns_entry_matches(other, self.key, self.type, self.class_) + and _dns_entry_matches(other, self.key, self.type, self.class_) ) def __hash__(self) -> int: @@ -342,7 +338,7 @@ def __eq__(self, other: Any) -> bool: return ( isinstance(other, DNSPointer) and self.alias == other.alias - and dns_entry_matches(other, self.key, self.type, self.class_) + and _dns_entry_matches(other, self.key, self.type, self.class_) ) def __hash__(self) -> int: @@ -381,7 +377,7 @@ def __eq__(self, other: Any) -> bool: return ( isinstance(other, DNSText) and self.text == other.text - and dns_entry_matches(other, self.key, self.type, self.class_) + and _dns_entry_matches(other, self.key, self.type, self.class_) ) def __repr__(self) -> str: @@ -432,7 +428,7 @@ def __eq__(self, other: Any) -> bool: and self.weight == other.weight and self.port == other.port and self.server == other.server - and dns_entry_matches(other, self.key, self.type, self.class_) + and _dns_entry_matches(other, self.key, self.type, self.class_) ) def __hash__(self) -> int: @@ -487,7 +483,7 @@ def __eq__(self, other: Any) -> bool: isinstance(other, DNSNsec) and self.next_name == other.next_name and self.rdtypes == other.rdtypes - and dns_entry_matches(other, self.key, self.type, self.class_) + and _dns_entry_matches(other, self.key, self.type, self.class_) ) def __hash__(self) -> int: @@ -527,3 +523,11 @@ def suppresses(self, record: DNSRecord) -> bool: def __contains__(self, record: DNSRecord) -> bool: """Returns true if the rrset contains the record.""" return record in self.lookup + + +_DNSEntry = DNSEntry +_str = str + + +def _dns_entry_matches(entry: _DNSEntry, key: _str, type_: int, class_: int) -> bool: + return key == entry.key and type_ == entry.type and class_ == entry.class_ diff --git a/src/zeroconf/_protocol/incoming.py b/src/zeroconf/_protocol/incoming.py index c9e04e12..aaf0340b 100644 --- a/src/zeroconf/_protocol/incoming.py +++ b/src/zeroconf/_protocol/incoming.py @@ -66,6 +66,7 @@ UNPACK_HHiH = struct.Struct(b'!HHiH').unpack_from _seen_logs: Dict[str, Union[int, tuple]] = {} +_str = str class DNSIncoming: @@ -250,7 +251,9 @@ def _read_others(self) -> None: if rec is not None: self._answers.append(rec) - def _read_record(self, domain, type_: int, class_: int, ttl: int, length: int) -> Optional[DNSRecord]: # type: ignore[no-untyped-def] + def _read_record( + self, domain: _str, type_: int, class_: int, ttl: int, length: int + ) -> Optional[DNSRecord]: """Read known records types and skip unknown ones.""" if type_ == _TYPE_A: return DNSAddress(domain, type_, class_, ttl, self._read_string(4), created=self.now)