Skip to content

Commit

Permalink
feat: expose flag to disable strict name checking in service registra…
Browse files Browse the repository at this point in the history
…tion (#1215)
  • Loading branch information
azogue committed Aug 13, 2023
1 parent aff625d commit 5df8a57
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 8 deletions.
16 changes: 11 additions & 5 deletions src/zeroconf/_core.py
Expand Up @@ -620,6 +620,7 @@ def register_service(
ttl: Optional[int] = None,
allow_name_change: bool = False,
cooperating_responders: bool = False,
strict: bool = True,
) -> None:
"""Registers service information to the network with a default TTL.
Zeroconf will then respond to requests for information for that
Expand All @@ -635,7 +636,7 @@ def register_service(
assert self.loop is not None
run_coro_with_timeout(
await_awaitable(
self.async_register_service(info, ttl, allow_name_change, cooperating_responders)
self.async_register_service(info, ttl, allow_name_change, cooperating_responders, strict)
),
self.loop,
_REGISTER_TIME * _REGISTER_BROADCASTS,
Expand All @@ -647,6 +648,7 @@ async def async_register_service(
ttl: Optional[int] = None,
allow_name_change: bool = False,
cooperating_responders: bool = False,
strict: bool = True,
) -> Awaitable:
"""Registers service information to the network with a default TTL.
Zeroconf will then respond to requests for information for that
Expand All @@ -662,7 +664,7 @@ async def async_register_service(

info.set_server_if_missing()
await self.async_wait_for_start()
await self.async_check_service(info, allow_name_change, cooperating_responders)
await self.async_check_service(info, allow_name_change, cooperating_responders, strict)
self.registry.async_add(info)
return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))

Expand Down Expand Up @@ -810,11 +812,15 @@ def unregister_all_services(self) -> None:
)

async def async_check_service(
self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False
self,
info: ServiceInfo,
allow_name_change: bool,
cooperating_responders: bool = False,
strict: bool = True,
) -> None:
"""Checks the network for a unique service name, modifying the
ServiceInfo passed in if it is not unique."""
instance_name = instance_name_from_service_info(info)
instance_name = instance_name_from_service_info(info, strict=strict)
if cooperating_responders:
return
next_instance_number = 2
Expand All @@ -829,7 +835,7 @@ async def async_check_service(
# change the name and look for a conflict
info.name = f'{instance_name}-{next_instance_number}.{info.type}'
next_instance_number += 1
service_type_name(info.name)
service_type_name(info.name, strict=strict)
next_time = now
i = 0

Expand Down
4 changes: 2 additions & 2 deletions src/zeroconf/_services/info.py
Expand Up @@ -76,11 +76,11 @@
from .._core import Zeroconf


def instance_name_from_service_info(info: "ServiceInfo") -> str:
def instance_name_from_service_info(info: "ServiceInfo", strict: bool = True) -> str:
"""Calculate the instance name from the ServiceInfo."""
# This is kind of funky because of the subtype based tests
# need to make subtypes a first class citizen
service_name = service_type_name(info.name)
service_name = service_type_name(info.name, strict=strict)
if not info.type.endswith(service_name):
raise BadTypeInNameException
return info.name[: -len(service_name) - 1]
Expand Down
3 changes: 2 additions & 1 deletion src/zeroconf/asyncio.py
Expand Up @@ -180,6 +180,7 @@ async def async_register_service(
ttl: Optional[int] = None,
allow_name_change: bool = False,
cooperating_responders: bool = False,
strict: bool = True,
) -> Awaitable:
"""Registers service information to the network with a default TTL.
Zeroconf will then respond to requests for information for that
Expand All @@ -192,7 +193,7 @@ async def async_register_service(
and therefore can be awaited if necessary.
"""
return await self.zeroconf.async_register_service(
info, ttl, allow_name_change, cooperating_responders
info, ttl, allow_name_change, cooperating_responders, strict
)

async def async_unregister_all_services(self) -> None:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_asyncio.py
Expand Up @@ -456,6 +456,41 @@ async def test_async_service_registration_name_does_not_match_type() -> None:
await aiozc.async_close()


@pytest.mark.asyncio
async def test_async_service_registration_name_strict_check() -> None:
"""Test registering services throws when the name does not comply."""
zc = Zeroconf(interfaces=['127.0.0.1'])
aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
type_ = "_ibisip_http._tcp.local."
name = "CustomerInformationService-F4D4895E9EEB"
registration_name = f"{name}.{type_}"

desc = {'path': '/~paulsm/'}
info = ServiceInfo(
type_,
registration_name,
80,
0,
0,
desc,
"ash-2.local.",
addresses=[socket.inet_aton("10.0.1.2")],
)
with pytest.raises(BadTypeInNameException):
await zc.async_check_service(info, allow_name_change=False)

with pytest.raises(BadTypeInNameException):
task = await aiozc.async_register_service(info)
await task

await zc.async_check_service(info, allow_name_change=False, strict=False)
task = await aiozc.async_register_service(info, strict=False)
await task

await aiozc.async_unregister_service(info)
await aiozc.async_close()


@pytest.mark.asyncio
async def test_async_tasks() -> None:
"""Test awaiting broadcast tasks"""
Expand Down
29 changes: 29 additions & 0 deletions tests/utils/test_name.py
Expand Up @@ -2,10 +2,12 @@


"""Unit tests for zeroconf._utils.name."""
import socket

import pytest

from zeroconf import BadTypeInNameException
from zeroconf._services.info import ServiceInfo, instance_name_from_service_info
from zeroconf._utils import name as nameutils


Expand All @@ -25,6 +27,33 @@ def test_service_type_name_overlong_full_name():
nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.", strict=False)


@pytest.mark.parametrize(
"instance_name, service_type",
(
("CustomerInformationService-F4D4885E9EEB", "_ibisip_http._tcp.local."),
("DeviceManagementService_F4D4885E9EEB", "_ibisip_http._tcp.local."),
),
)
def test_service_type_name_non_strict_compliant_names(instance_name, service_type):
"""Test service_type_name for valid names, but not strict-compliant."""
desc = {'path': '/~paulsm/'}
service_name = f'{instance_name}.{service_type}'
service_server = 'ash-1.local.'
service_address = socket.inet_aton("10.0.1.2")
info = ServiceInfo(
service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address]
)
assert info.get_name() == instance_name

with pytest.raises(BadTypeInNameException):
nameutils.service_type_name(service_name)
with pytest.raises(BadTypeInNameException):
instance_name_from_service_info(info)

nameutils.service_type_name(service_name, strict=False)
assert instance_name_from_service_info(info, strict=False) == instance_name


def test_possible_types():
"""Test possible types from name."""
assert nameutils.possible_types('.') == set()
Expand Down

0 comments on commit 5df8a57

Please sign in to comment.