Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: speed up the service registry #1174

Merged
merged 2 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable:
# goodbye packets for the address records

assert info.server is not None
entries = self.registry.async_get_infos_server(info.server)
entries = self.registry.async_get_infos_server(info.server.lower())
broadcast_addresses = not bool(entries)
return asyncio.ensure_future(
self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses)
Expand Down
17 changes: 9 additions & 8 deletions src/zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,10 @@ def _add_service_type_enumeration_query_answers(
answer_set[dns_pointer] = set()

def _add_pointer_answers(
self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
self, lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
) -> None:
"""Answer PTR/ANY question."""
for service in self.registry.async_get_infos_type(name):
for service in self.registry.async_get_infos_type(lower_name):
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.1.
dns_pointer = service.dns_pointer(created=now)
Expand All @@ -270,14 +270,14 @@ def _add_pointer_answers(

def _add_address_answers(
self,
name: str,
lower_name: str,
answer_set: _AnswerWithAdditionalsType,
known_answers: DNSRRSet,
now: float,
type_: int,
) -> None:
"""Answer A/AAAA/ANY question."""
for service in self.registry.async_get_infos_server(name):
for service in self.registry.async_get_infos_server(lower_name):
answers: List[DNSAddress] = []
additionals: Set[DNSRecord] = set()
seen_types: Set[int] = set()
Expand Down Expand Up @@ -305,21 +305,22 @@ def _answer_question(
now: float,
) -> _AnswerWithAdditionalsType:
answer_set: _AnswerWithAdditionalsType = {}
question_lower_name = question.name.lower()

if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME:
if question.type == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME:
self._add_service_type_enumeration_query_answers(answer_set, known_answers, now)
return answer_set

type_ = question.type

if type_ in (_TYPE_PTR, _TYPE_ANY):
self._add_pointer_answers(question.name, answer_set, known_answers, now)
self._add_pointer_answers(question_lower_name, answer_set, known_answers, now)

if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY):
self._add_address_answers(question.name, answer_set, known_answers, now, type_)
self._add_address_answers(question_lower_name, answer_set, known_answers, now, type_)

if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY):
service = self.registry.async_get_info_name(question.name)
service = self.registry.async_get_info_name(question_lower_name)
if service is not None:
if type_ in (_TYPE_SRV, _TYPE_ANY):
# Add recommended additional answers according to
Expand Down
4 changes: 2 additions & 2 deletions src/zeroconf/_services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def async_get_service_infos(self) -> List[ServiceInfo]:

def async_get_info_name(self, name: str) -> Optional[ServiceInfo]:
"""Return all ServiceInfo for the name."""
return self._services.get(name.lower())
return self._services.get(name)

def async_get_types(self) -> List[str]:
"""Return all types."""
Expand All @@ -76,7 +76,7 @@ def async_get_infos_server(self, server: str) -> List[ServiceInfo]:

def _async_get_by_index(self, records: Dict[str, List], key: str) -> List[ServiceInfo]:
"""Return all ServiceInfo matching the index."""
return [self._services[name] for name in records.get(key.lower(), [])]
return [self._services[name] for name in records.get(key, [])]

def _add(self, info: ServiceInfo) -> None:
"""Add a new service under the lock."""
Expand Down
19 changes: 0 additions & 19 deletions tests/services/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,3 @@ def test_lookups_upper_case_by_lower_case(self):
assert registry.async_get_infos_type(type_.lower()) == [info]
assert registry.async_get_infos_server("ash-2.local.") == [info]
assert registry.async_get_types() == [type_.lower()]

def test_lookups_lower_case_by_upper_case(self):
type_ = "_test-srvc-type._tcp.local."
name = "xxxyyy"
registration_name = f"{name}.{type_}"

desc = {'path': '/~paulsm/'}
info = ServiceInfo(
type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
)

registry = r.ServiceRegistry()
registry.async_add(info)

assert registry.async_get_service_infos() == [info]
assert registry.async_get_info_name(registration_name.upper()) == info
assert registry.async_get_infos_type(type_.upper()) == [info]
assert registry.async_get_infos_server("ASH-2.local.") == [info]
assert registry.async_get_types() == [type_]
107 changes: 107 additions & 0 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,32 @@ def test_aaaa_query():
zc.close()


@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
def test_aaaa_query_upper_case():
"""Test that queries for AAAA records work and should respond right away with an upper case name."""
zc = Zeroconf(interfaces=['127.0.0.1'])
type_ = "_knownaaaservice._tcp.local."
name = "knownname"
registration_name = f"{name}.{type_}"
desc = {'path': '/~paulsm/'}
server_name = "ash-2.local."
ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1")
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address])
zc.registry.async_add(info)

generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(server_name.upper(), const._TYPE_AAAA, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
mcast_answers = list(question_answers.mcast_now)
assert mcast_answers[0].address == ipv6_address # type: ignore[attr-defined]
# unregister
zc.registry.async_remove(info)
zc.close()


@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
def test_a_and_aaaa_record_fate_sharing():
Expand Down Expand Up @@ -481,6 +507,48 @@ async def test_probe_answered_immediately():
zc.close()


@pytest.mark.asyncio
async def test_probe_answered_immediately_with_uppercase_name():
"""Verify probes are responded to immediately with an uppercase name."""
# instantiate a zeroconf instance
zc = Zeroconf(interfaces=['127.0.0.1'])

# service definition
type_ = "_test-srvc-type._tcp.local."
name = "xxxyyy"
registration_name = f"{name}.{type_}"
desc = {'path': '/~paulsm/'}
info = ServiceInfo(
type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
)
zc.registry.async_add(info)
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(info.type.upper(), const._TYPE_PTR, const._CLASS_IN)
query.add_question(question)
query.add_authorative_answer(info.dns_pointer())
question_answers = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], False
)
assert not question_answers.ucast
assert not question_answers.mcast_aggregate
assert not question_answers.mcast_aggregate_last_second
assert question_answers.mcast_now

query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
question.unicast = True
query.add_question(question)
query.add_authorative_answer(info.dns_pointer())
question_answers = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], False
)
assert question_answers.ucast
assert question_answers.mcast_now
assert not question_answers.mcast_aggregate
assert not question_answers.mcast_aggregate_last_second
zc.close()


def test_qu_response():
"""Handle multicast incoming with the QU bit set."""
# instantiate a zeroconf instance
Expand Down Expand Up @@ -842,6 +910,45 @@ def test_known_answer_supression_service_type_enumeration_query():
zc.close()


def test_upper_case_enumeration_query():
zc = Zeroconf(interfaces=['127.0.0.1'])
type_ = "_otherknown._tcp.local."
name = "knownname"
registration_name = f"{name}.{type_}"
desc = {'path': '/~paulsm/'}
server_name = "ash-2.local."
info = ServiceInfo(
type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
)
zc.registry.async_add(info)

type_2 = "_otherknown2._tcp.local."
name = "knownname"
registration_name2 = f"{name}.{type_2}"
desc = {'path': '/~paulsm/'}
server_name2 = "ash-3.local."
info2 = ServiceInfo(
type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
)
zc.registry.async_add(info2)
_clear_cache(zc)

# Test PTR supression
generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME.upper(), const._TYPE_PTR, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
assert not question_answers.ucast
assert not question_answers.mcast_now
assert question_answers.mcast_aggregate
assert not question_answers.mcast_aggregate_last_second
# unregister
zc.registry.async_remove(info)
zc.registry.async_remove(info2)
zc.close()


# This test uses asyncio because it needs to access the cache directly
# which is not threadsafe
@pytest.mark.asyncio
Expand Down