From 360ceb2548c4c4974ff798aac43a6fff9803ea0e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 25 May 2023 07:45:13 -0500 Subject: [PATCH] feat: speed up the service registry (#1174) --- src/zeroconf/_core.py | 2 +- src/zeroconf/_handlers.py | 17 ++--- src/zeroconf/_services/registry.py | 4 +- tests/services/test_registry.py | 19 ----- tests/test_handlers.py | 107 +++++++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 30 deletions(-) diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 18823ef2..a55f55e8 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -745,7 +745,7 @@ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: # goodbye packets for the address records assert info.server is not None - entries = self.registry.async_get_infos_server(info.server) + entries = self.registry.async_get_infos_server(info.server.lower()) broadcast_addresses = not bool(entries) return asyncio.ensure_future( self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses) diff --git a/src/zeroconf/_handlers.py b/src/zeroconf/_handlers.py index 240deb47..159fd0d5 100644 --- a/src/zeroconf/_handlers.py +++ b/src/zeroconf/_handlers.py @@ -255,10 +255,10 @@ def _add_service_type_enumeration_query_answers( answer_set[dns_pointer] = set() def _add_pointer_answers( - self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float + self, lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float ) -> None: """Answer PTR/ANY question.""" - for service in self.registry.async_get_infos_type(name): + for service in self.registry.async_get_infos_type(lower_name): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. dns_pointer = service.dns_pointer(created=now) @@ -270,14 +270,14 @@ def _add_pointer_answers( def _add_address_answers( self, - name: str, + lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float, type_: int, ) -> None: """Answer A/AAAA/ANY question.""" - for service in self.registry.async_get_infos_server(name): + for service in self.registry.async_get_infos_server(lower_name): answers: List[DNSAddress] = [] additionals: Set[DNSRecord] = set() seen_types: Set[int] = set() @@ -305,21 +305,22 @@ def _answer_question( now: float, ) -> _AnswerWithAdditionalsType: answer_set: _AnswerWithAdditionalsType = {} + question_lower_name = question.name.lower() - if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: + if question.type == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME: self._add_service_type_enumeration_query_answers(answer_set, known_answers, now) return answer_set type_ = question.type if type_ in (_TYPE_PTR, _TYPE_ANY): - self._add_pointer_answers(question.name, answer_set, known_answers, now) + self._add_pointer_answers(question_lower_name, answer_set, known_answers, now) if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): - self._add_address_answers(question.name, answer_set, known_answers, now, type_) + self._add_address_answers(question_lower_name, answer_set, known_answers, now, type_) if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): - service = self.registry.async_get_info_name(question.name) + service = self.registry.async_get_info_name(question_lower_name) if service is not None: if type_ in (_TYPE_SRV, _TYPE_ANY): # Add recommended additional answers according to diff --git a/src/zeroconf/_services/registry.py b/src/zeroconf/_services/registry.py index b3dba674..1c4ad085 100644 --- a/src/zeroconf/_services/registry.py +++ b/src/zeroconf/_services/registry.py @@ -60,7 +60,7 @@ def async_get_service_infos(self) -> List[ServiceInfo]: def async_get_info_name(self, name: str) -> Optional[ServiceInfo]: """Return all ServiceInfo for the name.""" - return self._services.get(name.lower()) + return self._services.get(name) def async_get_types(self) -> List[str]: """Return all types.""" @@ -76,7 +76,7 @@ def async_get_infos_server(self, server: str) -> List[ServiceInfo]: def _async_get_by_index(self, records: Dict[str, List], key: str) -> List[ServiceInfo]: """Return all ServiceInfo matching the index.""" - return [self._services[name] for name in records.get(key.lower(), [])] + return [self._services[name] for name in records.get(key, [])] def _add(self, info: ServiceInfo) -> None: """Add a new service under the lock.""" diff --git a/tests/services/test_registry.py b/tests/services/test_registry.py index 3207b14e..f8656e2f 100644 --- a/tests/services/test_registry.py +++ b/tests/services/test_registry.py @@ -110,22 +110,3 @@ def test_lookups_upper_case_by_lower_case(self): assert registry.async_get_infos_type(type_.lower()) == [info] assert registry.async_get_infos_server("ash-2.local.") == [info] assert registry.async_get_types() == [type_.lower()] - - def test_lookups_lower_case_by_upper_case(self): - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = f"{name}.{type_}" - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] - ) - - registry = r.ServiceRegistry() - registry.async_add(info) - - assert registry.async_get_service_infos() == [info] - assert registry.async_get_info_name(registration_name.upper()) == info - assert registry.async_get_infos_type(type_.upper()) == [info] - assert registry.async_get_infos_server("ASH-2.local.") == [info] - assert registry.async_get_types() == [type_] diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 0a976d3d..c1c0a9a7 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -340,6 +340,32 @@ def test_aaaa_query(): zc.close() +@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') +@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') +def test_aaaa_query_upper_case(): + """Test that queries for AAAA records work and should respond right away with an upper case name.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_knownaaaservice._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) + zc.registry.async_add(info) + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name.upper(), const._TYPE_AAAA, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + mcast_answers = list(question_answers.mcast_now) + assert mcast_answers[0].address == ipv6_address # type: ignore[attr-defined] + # unregister + zc.registry.async_remove(info) + zc.close() + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_a_and_aaaa_record_fate_sharing(): @@ -481,6 +507,48 @@ async def test_probe_answered_immediately(): zc.close() +@pytest.mark.asyncio +async def test_probe_answered_immediately_with_uppercase_name(): + """Verify probes are responded to immediately with an uppercase name.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type.upper(), const._TYPE_PTR, const._CLASS_IN) + query.add_question(question) + query.add_authorative_answer(info.dns_pointer()) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.ucast + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + assert question_answers.mcast_now + + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True + query.add_question(question) + query.add_authorative_answer(info.dns_pointer()) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert question_answers.ucast + assert question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + zc.close() + + def test_qu_response(): """Handle multicast incoming with the QU bit set.""" # instantiate a zeroconf instance @@ -842,6 +910,45 @@ def test_known_answer_supression_service_type_enumeration_query(): zc.close() +def test_upper_case_enumeration_query(): + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_otherknown._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + + type_2 = "_otherknown2._tcp.local." + name = "knownname" + registration_name2 = f"{name}.{type_2}" + desc = {'path': '/~paulsm/'} + server_name2 = "ash-3.local." + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info2) + _clear_cache(zc) + + # Test PTR supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME.upper(), const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + # unregister + zc.registry.async_remove(info) + zc.registry.async_remove(info2) + zc.close() + + # This test uses asyncio because it needs to access the cache directly # which is not threadsafe @pytest.mark.asyncio