Skip to content

Commit

Permalink
feat: speed up instances only used to lookup answers (#1307)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Nov 12, 2023
1 parent 9ca9a57 commit 0701b8a
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 22 deletions.
11 changes: 11 additions & 0 deletions src/zeroconf/_listener.pxd
Expand Up @@ -3,6 +3,7 @@ import cython

from ._handlers.record_manager cimport RecordManager
from ._protocol.incoming cimport DNSIncoming
from ._services.registry cimport ServiceRegistry
from ._utils.time cimport current_time_millis, millis_to_seconds


Expand All @@ -18,6 +19,7 @@ cdef cython.uint _DUPLICATE_PACKET_SUPPRESSION_INTERVAL
cdef class AsyncListener:

cdef public object zc
cdef ServiceRegistry _registry
cdef RecordManager _record_manager
cdef public cython.bytes data
cdef public cython.float last_time
Expand All @@ -34,3 +36,12 @@ cdef class AsyncListener:
cpdef _process_datagram_at_time(self, bint debug, cython.uint data_len, cython.float now, bytes data, cython.tuple addrs)

cdef _cancel_any_timers_for_addr(self, object addr)

cpdef handle_query_or_defer(
self,
DNSIncoming msg,
object addr,
object port,
object transport,
tuple v6_flow_scope
)
12 changes: 9 additions & 3 deletions src/zeroconf/_listener.py
Expand Up @@ -57,6 +57,7 @@ class AsyncListener:

__slots__ = (
'zc',
'_registry',
'_record_manager',
'data',
'last_time',
Expand All @@ -69,6 +70,7 @@ class AsyncListener:

def __init__(self, zc: 'Zeroconf') -> None:
self.zc = zc
self._registry = zc.registry
self._record_manager = zc.record_manager
self.data: Optional[bytes] = None
self.last_time: float = 0
Expand Down Expand Up @@ -171,17 +173,21 @@ def _process_datagram_at_time(
self._record_manager.async_updates_from_response(msg)
return

if not self._registry.has_entries:
# If the registry is empty, we have no answers to give.
return

if TYPE_CHECKING:
assert self.transport is not None
self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope)

def handle_query_or_defer(
self,
msg: DNSIncoming,
addr: str,
port: int,
addr: _str,
port: _int,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
) -> None:
"""Deal with incoming query packets. Provides a response if
possible."""
Expand Down
5 changes: 5 additions & 0 deletions src/zeroconf/_services/registry.pxd
Expand Up @@ -9,6 +9,7 @@ cdef class ServiceRegistry:
cdef cython.dict _services
cdef public cython.dict types
cdef public cython.dict servers
cdef public bint has_entries

@cython.locals(
record_list=cython.list,
Expand All @@ -17,6 +18,10 @@ cdef class ServiceRegistry:

cdef _add(self, ServiceInfo info)

@cython.locals(
info=ServiceInfo,
old_service_info=ServiceInfo
)
cdef _remove(self, cython.list infos)

cpdef ServiceInfo async_get_info_name(self, str name)
Expand Down
10 changes: 7 additions & 3 deletions src/zeroconf/_services/registry.py
Expand Up @@ -35,7 +35,7 @@ class ServiceRegistry:
the event loop as it is not thread safe.
"""

__slots__ = ("_services", "types", "servers")
__slots__ = ("_services", "types", "servers", "has_entries")

def __init__(
self,
Expand All @@ -44,6 +44,7 @@ def __init__(
self._services: Dict[str, ServiceInfo] = {}
self.types: Dict[str, List] = {}
self.servers: Dict[str, List] = {}
self.has_entries: bool = False

def async_add(self, info: ServiceInfo) -> None:
"""Add a new service to the registry."""
Expand Down Expand Up @@ -95,14 +96,17 @@ def _add(self, info: ServiceInfo) -> None:
self._services[info.key] = info
self.types.setdefault(info.type.lower(), []).append(info.key)
self.servers.setdefault(info.server_key, []).append(info.key)
self.has_entries = True

def _remove(self, infos: List[ServiceInfo]) -> None:
"""Remove a services under the lock."""
for info in infos:
if info.key not in self._services:
old_service_info = self._services.get(info.key)
if old_service_info is None:
continue
old_service_info = self._services[info.key]
assert old_service_info.server_key is not None
self.types[old_service_info.type.lower()].remove(info.key)
self.servers[old_service_info.server_key].remove(info.key)
del self._services[info.key]

self.has_entries = bool(self._services)
39 changes: 24 additions & 15 deletions tests/test_core.py
Expand Up @@ -12,12 +12,10 @@
import time
import unittest
import unittest.mock
from typing import cast
from unittest.mock import patch
from typing import Tuple, Union, cast
from unittest.mock import Mock, patch

if sys.version_info[:3][1] < 8:
from unittest.mock import Mock

AsyncMock = Mock
else:
from unittest.mock import AsyncMock
Expand All @@ -26,6 +24,8 @@

import zeroconf as r
from zeroconf import NotRunningException, Zeroconf, const, current_time_millis
from zeroconf._listener import AsyncListener, _WrappedTransport
from zeroconf._protocol.incoming import DNSIncoming
from zeroconf.asyncio import AsyncZeroconf

from . import _clear_cache, _inject_response, _wait_for_start, has_working_ipv6
Expand All @@ -45,10 +45,19 @@ def teardown_module():
log.setLevel(original_logging_level)


def threadsafe_query(zc, protocol, *args):
def threadsafe_query(
zc: 'Zeroconf',
protocol: 'AsyncListener',
msg: DNSIncoming,
addr: str,
port: int,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
) -> None:
async def make_query():
protocol.handle_query_or_defer(*args)
protocol.handle_query_or_defer(msg, addr, port, transport, v6_flow_scope)

assert zc.loop is not None
asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result()


Expand Down Expand Up @@ -476,28 +485,28 @@ def test_tc_bit_defers():

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert source_ip not in protocol._deferred
assert source_ip not in protocol._timers

Expand Down Expand Up @@ -555,20 +564,20 @@ def test_tc_bit_defers_last_response_missing():

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
timer1 = protocol._timers[source_ip]

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
timer2 = protocol._timers[source_ip]
assert timer1.cancelled()
assert timer2 != timer1

# Send the same packet again to similar multi interfaces
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
timer3 = protocol._timers[source_ip]
Expand All @@ -577,7 +586,7 @@ def test_tc_bit_defers_last_response_missing():

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
timer4 = protocol._timers[source_ip]
Expand Down
12 changes: 11 additions & 1 deletion tests/test_listener.py
Expand Up @@ -10,7 +10,14 @@
from unittest.mock import MagicMock, patch

import zeroconf as r
from zeroconf import Zeroconf, _engine, _listener, const, current_time_millis
from zeroconf import (
ServiceInfo,
Zeroconf,
_engine,
_listener,
const,
current_time_millis,
)
from zeroconf._protocol import outgoing
from zeroconf._protocol.incoming import DNSIncoming

Expand Down Expand Up @@ -125,6 +132,9 @@ def test_guard_against_duplicate_packets():
These packets can quickly overwhelm the system.
"""
zc = Zeroconf(interfaces=['127.0.0.1'])
zc.registry.async_add(
ServiceInfo("_http._tcp.local.", "Test._http._tcp.local.", server="Test._http._tcp.local.", port=4)
)
zc.question_history = QuestionHistoryWithoutSuppression()

class SubListener(_listener.AsyncListener):
Expand Down

0 comments on commit 0701b8a

Please sign in to comment.