Skip to content

Commit

Permalink
feat: speed up the service registry (#1174)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed May 25, 2023
1 parent bb496a1 commit 360ceb2
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/zeroconf/_core.py
Expand Up @@ -745,7 +745,7 @@ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable:
# goodbye packets for the address records

assert info.server is not None
entries = self.registry.async_get_infos_server(info.server)
entries = self.registry.async_get_infos_server(info.server.lower())
broadcast_addresses = not bool(entries)
return asyncio.ensure_future(
self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses)
Expand Down
17 changes: 9 additions & 8 deletions src/zeroconf/_handlers.py
Expand Up @@ -255,10 +255,10 @@ def _add_service_type_enumeration_query_answers(
answer_set[dns_pointer] = set()

def _add_pointer_answers(
self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
self, lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
) -> None:
"""Answer PTR/ANY question."""
for service in self.registry.async_get_infos_type(name):
for service in self.registry.async_get_infos_type(lower_name):
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.1.
dns_pointer = service.dns_pointer(created=now)
Expand All @@ -270,14 +270,14 @@ def _add_pointer_answers(

def _add_address_answers(
self,
name: str,
lower_name: str,
answer_set: _AnswerWithAdditionalsType,
known_answers: DNSRRSet,
now: float,
type_: int,
) -> None:
"""Answer A/AAAA/ANY question."""
for service in self.registry.async_get_infos_server(name):
for service in self.registry.async_get_infos_server(lower_name):
answers: List[DNSAddress] = []
additionals: Set[DNSRecord] = set()
seen_types: Set[int] = set()
Expand Down Expand Up @@ -305,21 +305,22 @@ def _answer_question(
now: float,
) -> _AnswerWithAdditionalsType:
answer_set: _AnswerWithAdditionalsType = {}
question_lower_name = question.name.lower()

if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME:
if question.type == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME:
self._add_service_type_enumeration_query_answers(answer_set, known_answers, now)
return answer_set

type_ = question.type

if type_ in (_TYPE_PTR, _TYPE_ANY):
self._add_pointer_answers(question.name, answer_set, known_answers, now)
self._add_pointer_answers(question_lower_name, answer_set, known_answers, now)

if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY):
self._add_address_answers(question.name, answer_set, known_answers, now, type_)
self._add_address_answers(question_lower_name, answer_set, known_answers, now, type_)

if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY):
service = self.registry.async_get_info_name(question.name)
service = self.registry.async_get_info_name(question_lower_name)
if service is not None:
if type_ in (_TYPE_SRV, _TYPE_ANY):
# Add recommended additional answers according to
Expand Down
4 changes: 2 additions & 2 deletions src/zeroconf/_services/registry.py
Expand Up @@ -60,7 +60,7 @@ def async_get_service_infos(self) -> List[ServiceInfo]:

def async_get_info_name(self, name: str) -> Optional[ServiceInfo]:
"""Return all ServiceInfo for the name."""
return self._services.get(name.lower())
return self._services.get(name)

def async_get_types(self) -> List[str]:
"""Return all types."""
Expand All @@ -76,7 +76,7 @@ def async_get_infos_server(self, server: str) -> List[ServiceInfo]:

def _async_get_by_index(self, records: Dict[str, List], key: str) -> List[ServiceInfo]:
"""Return all ServiceInfo matching the index."""
return [self._services[name] for name in records.get(key.lower(), [])]
return [self._services[name] for name in records.get(key, [])]

def _add(self, info: ServiceInfo) -> None:
"""Add a new service under the lock."""
Expand Down
19 changes: 0 additions & 19 deletions tests/services/test_registry.py
Expand Up @@ -110,22 +110,3 @@ def test_lookups_upper_case_by_lower_case(self):
assert registry.async_get_infos_type(type_.lower()) == [info]
assert registry.async_get_infos_server("ash-2.local.") == [info]
assert registry.async_get_types() == [type_.lower()]

def test_lookups_lower_case_by_upper_case(self):
type_ = "_test-srvc-type._tcp.local."
name = "xxxyyy"
registration_name = f"{name}.{type_}"

desc = {'path': '/~paulsm/'}
info = ServiceInfo(
type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
)

registry = r.ServiceRegistry()
registry.async_add(info)

assert registry.async_get_service_infos() == [info]
assert registry.async_get_info_name(registration_name.upper()) == info
assert registry.async_get_infos_type(type_.upper()) == [info]
assert registry.async_get_infos_server("ASH-2.local.") == [info]
assert registry.async_get_types() == [type_]
107 changes: 107 additions & 0 deletions tests/test_handlers.py
Expand Up @@ -340,6 +340,32 @@ def test_aaaa_query():
zc.close()


@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
def test_aaaa_query_upper_case():
"""Test that queries for AAAA records work and should respond right away with an upper case name."""
zc = Zeroconf(interfaces=['127.0.0.1'])
type_ = "_knownaaaservice._tcp.local."
name = "knownname"
registration_name = f"{name}.{type_}"
desc = {'path': '/~paulsm/'}
server_name = "ash-2.local."
ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1")
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address])
zc.registry.async_add(info)

generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(server_name.upper(), const._TYPE_AAAA, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
mcast_answers = list(question_answers.mcast_now)
assert mcast_answers[0].address == ipv6_address # type: ignore[attr-defined]
# unregister
zc.registry.async_remove(info)
zc.close()


@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
def test_a_and_aaaa_record_fate_sharing():
Expand Down Expand Up @@ -481,6 +507,48 @@ async def test_probe_answered_immediately():
zc.close()


@pytest.mark.asyncio
async def test_probe_answered_immediately_with_uppercase_name():
"""Verify probes are responded to immediately with an uppercase name."""
# instantiate a zeroconf instance
zc = Zeroconf(interfaces=['127.0.0.1'])

# service definition
type_ = "_test-srvc-type._tcp.local."
name = "xxxyyy"
registration_name = f"{name}.{type_}"
desc = {'path': '/~paulsm/'}
info = ServiceInfo(
type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
)
zc.registry.async_add(info)
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(info.type.upper(), const._TYPE_PTR, const._CLASS_IN)
query.add_question(question)
query.add_authorative_answer(info.dns_pointer())
question_answers = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], False
)
assert not question_answers.ucast
assert not question_answers.mcast_aggregate
assert not question_answers.mcast_aggregate_last_second
assert question_answers.mcast_now

query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
question.unicast = True
query.add_question(question)
query.add_authorative_answer(info.dns_pointer())
question_answers = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], False
)
assert question_answers.ucast
assert question_answers.mcast_now
assert not question_answers.mcast_aggregate
assert not question_answers.mcast_aggregate_last_second
zc.close()


def test_qu_response():
"""Handle multicast incoming with the QU bit set."""
# instantiate a zeroconf instance
Expand Down Expand Up @@ -842,6 +910,45 @@ def test_known_answer_supression_service_type_enumeration_query():
zc.close()


def test_upper_case_enumeration_query():
zc = Zeroconf(interfaces=['127.0.0.1'])
type_ = "_otherknown._tcp.local."
name = "knownname"
registration_name = f"{name}.{type_}"
desc = {'path': '/~paulsm/'}
server_name = "ash-2.local."
info = ServiceInfo(
type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
)
zc.registry.async_add(info)

type_2 = "_otherknown2._tcp.local."
name = "knownname"
registration_name2 = f"{name}.{type_2}"
desc = {'path': '/~paulsm/'}
server_name2 = "ash-3.local."
info2 = ServiceInfo(
type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
)
zc.registry.async_add(info2)
_clear_cache(zc)

# Test PTR supression
generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME.upper(), const._TYPE_PTR, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
assert not question_answers.ucast
assert not question_answers.mcast_now
assert question_answers.mcast_aggregate
assert not question_answers.mcast_aggregate_last_second
# unregister
zc.registry.async_remove(info)
zc.registry.async_remove(info2)
zc.close()


# This test uses asyncio because it needs to access the cache directly
# which is not threadsafe
@pytest.mark.asyncio
Expand Down

0 comments on commit 360ceb2

Please sign in to comment.