Skip to content

Commit

Permalink
feat: optimize cache implementation (#1252)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Sep 3, 2023
1 parent 235d528 commit 8d3ec79
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 26 deletions.
25 changes: 20 additions & 5 deletions src/zeroconf/_cache.pxd
Expand Up @@ -4,6 +4,7 @@ from ._dns cimport (
DNSAddress,
DNSEntry,
DNSHinfo,
DNSNsec,
DNSPointer,
DNSRecord,
DNSService,
Expand All @@ -13,7 +14,7 @@ from ._dns cimport (

cdef object _UNIQUE_RECORD_TYPES
cdef object _TYPE_PTR
cdef object _ONE_SECOND
cdef cython.uint _ONE_SECOND

cdef _remove_key(cython.dict cache, object key, DNSRecord record)

Expand All @@ -27,23 +28,37 @@ cdef class DNSCache:

cpdef async_remove_records(self, object entries)

@cython.locals(
store=cython.dict,
)
cpdef async_get_unique(self, DNSRecord entry)

@cython.locals(
record=DNSRecord,
)
cpdef async_expire(self, float now)

@cython.locals(
records=cython.dict,
record=DNSRecord,
)
cdef _async_all_by_details(self, object name, object type_, object class_)
cpdef async_all_by_details(self, str name, object type_, object class_)

cpdef async_entries_with_name(self, str name)

cpdef async_entries_with_server(self, str name)

@cython.locals(
store=cython.dict,
)
cdef _async_add(self, DNSRecord record)

cdef _async_remove(self, DNSRecord record)

cpdef async_mark_unique_records_older_than_1s_to_expire(self, object unique_types, object answers, object now)

@cython.locals(
record=DNSRecord,
created_float=cython.float,
)
cdef _async_mark_unique_records_older_than_1s_to_expire(self, object unique_types, object answers, object now)
cpdef async_mark_unique_records_older_than_1s_to_expire(self, cython.set unique_types, object answers, float now)

cdef _dns_record_matches(DNSRecord record, object key, object type_, object class_)
25 changes: 6 additions & 19 deletions src/zeroconf/_cache.py
Expand Up @@ -20,7 +20,6 @@
USA
"""

import itertools
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast

from ._dns import (
Expand Down Expand Up @@ -115,12 +114,12 @@ def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
for entry in entries:
self._async_remove(entry)

def async_expire(self, now: float) -> List[DNSRecord]:
def async_expire(self, now: _float) -> List[DNSRecord]:
"""Purge expired entries from the cache.
This function must be run in from event loop.
"""
expired = [record for record in itertools.chain(*self.cache.values()) if record.is_expired(now)]
expired = [record for records in self.cache.values() for record in records if record.is_expired(now)]
self.async_remove_records(expired)
return expired

Expand All @@ -136,15 +135,7 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
return None
return store.get(entry)

def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterable[DNSRecord]:
"""Gets all matching entries by details.
This function is not thread-safe and must be called from
the event loop.
"""
return self._async_all_by_details(name, type_, class_)

def _async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DNSRecord]:
def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DNSRecord]:
"""Gets all matching entries by details.
This function is not thread-safe and must be called from
Expand Down Expand Up @@ -240,20 +231,16 @@ def names(self) -> List[str]:

def async_mark_unique_records_older_than_1s_to_expire(
self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float
) -> None:
self._async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now)

def _async_mark_unique_records_older_than_1s_to_expire(
self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float
) -> None:
# rfc6762#section-10.2 para 2
# Since unique is set, all old records with that name, rrtype,
# and rrclass that were received more than one second ago are declared
# invalid, and marked to expire from the cache in one second.
answers_rrset = set(answers)
for name, type_, class_ in unique_types:
for record in self._async_all_by_details(name, type_, class_):
if (now - record.created > _ONE_SECOND) and record not in answers_rrset:
for record in self.async_all_by_details(name, type_, class_):
created_float = record.created
if (now - created_float > _ONE_SECOND) and record not in answers_rrset:
# Expire in 1s
record.set_created_ttl(now, 1)

Expand Down
5 changes: 3 additions & 2 deletions src/zeroconf/_handlers/record_manager.py
Expand Up @@ -86,8 +86,9 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
now_float = now
unique_types: Set[Tuple[str, int, int]] = set()
cache = self.cache
answers = msg.answers

for record in msg.answers:
for record in answers:
# Protect zeroconf from records that can cause denial of service.
#
# We enforce a minimum TTL for PTR records to avoid
Expand Down Expand Up @@ -127,7 +128,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
removes.add(record)

if unique_types:
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now)
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now)

if updates:
self.async_updates(now, updates)
Expand Down

0 comments on commit 8d3ec79

Please sign in to comment.