diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index ab8e72e5..6a9c2c8a 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -620,6 +620,7 @@ def register_service( ttl: Optional[int] = None, allow_name_change: bool = False, cooperating_responders: bool = False, + strict: bool = True, ) -> None: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that @@ -635,7 +636,7 @@ def register_service( assert self.loop is not None run_coro_with_timeout( await_awaitable( - self.async_register_service(info, ttl, allow_name_change, cooperating_responders) + self.async_register_service(info, ttl, allow_name_change, cooperating_responders, strict) ), self.loop, _REGISTER_TIME * _REGISTER_BROADCASTS, @@ -647,6 +648,7 @@ async def async_register_service( ttl: Optional[int] = None, allow_name_change: bool = False, cooperating_responders: bool = False, + strict: bool = True, ) -> Awaitable: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that @@ -662,7 +664,7 @@ async def async_register_service( info.set_server_if_missing() await self.async_wait_for_start() - await self.async_check_service(info, allow_name_change, cooperating_responders) + await self.async_check_service(info, allow_name_change, cooperating_responders, strict) self.registry.async_add(info) return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) @@ -810,11 +812,15 @@ def unregister_all_services(self) -> None: ) async def async_check_service( - self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False + self, + info: ServiceInfo, + allow_name_change: bool, + cooperating_responders: bool = False, + strict: bool = True, ) -> None: """Checks the network for a unique service name, modifying the ServiceInfo passed in if it is not unique.""" - instance_name = instance_name_from_service_info(info) + instance_name = instance_name_from_service_info(info, strict=strict) if cooperating_responders: return next_instance_number = 2 @@ -829,7 +835,7 @@ async def async_check_service( # change the name and look for a conflict info.name = f'{instance_name}-{next_instance_number}.{info.type}' next_instance_number += 1 - service_type_name(info.name) + service_type_name(info.name, strict=strict) next_time = now i = 0 diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index 02b7137a..29ddb9a0 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -76,11 +76,11 @@ from .._core import Zeroconf -def instance_name_from_service_info(info: "ServiceInfo") -> str: +def instance_name_from_service_info(info: "ServiceInfo", strict: bool = True) -> str: """Calculate the instance name from the ServiceInfo.""" # This is kind of funky because of the subtype based tests # need to make subtypes a first class citizen - service_name = service_type_name(info.name) + service_name = service_type_name(info.name, strict=strict) if not info.type.endswith(service_name): raise BadTypeInNameException return info.name[: -len(service_name) - 1] diff --git a/src/zeroconf/asyncio.py b/src/zeroconf/asyncio.py index 7ded0ecb..755757d7 100644 --- a/src/zeroconf/asyncio.py +++ b/src/zeroconf/asyncio.py @@ -180,6 +180,7 @@ async def async_register_service( ttl: Optional[int] = None, allow_name_change: bool = False, cooperating_responders: bool = False, + strict: bool = True, ) -> Awaitable: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that @@ -192,7 +193,7 @@ async def async_register_service( and therefore can be awaited if necessary. """ return await self.zeroconf.async_register_service( - info, ttl, allow_name_change, cooperating_responders + info, ttl, allow_name_change, cooperating_responders, strict ) async def async_unregister_all_services(self) -> None: diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 66c81e00..cd067ae1 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -456,6 +456,41 @@ async def test_async_service_registration_name_does_not_match_type() -> None: await aiozc.async_close() +@pytest.mark.asyncio +async def test_async_service_registration_name_strict_check() -> None: + """Test registering services throws when the name does not comply.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_ibisip_http._tcp.local." + name = "CustomerInformationService-F4D4895E9EEB" + registration_name = f"{name}.{type_}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + with pytest.raises(BadTypeInNameException): + await zc.async_check_service(info, allow_name_change=False) + + with pytest.raises(BadTypeInNameException): + task = await aiozc.async_register_service(info) + await task + + await zc.async_check_service(info, allow_name_change=False, strict=False) + task = await aiozc.async_register_service(info, strict=False) + await task + + await aiozc.async_unregister_service(info) + await aiozc.async_close() + + @pytest.mark.asyncio async def test_async_tasks() -> None: """Test awaiting broadcast tasks""" diff --git a/tests/utils/test_name.py b/tests/utils/test_name.py index 3df73f5a..9604b775 100644 --- a/tests/utils/test_name.py +++ b/tests/utils/test_name.py @@ -2,10 +2,12 @@ """Unit tests for zeroconf._utils.name.""" +import socket import pytest from zeroconf import BadTypeInNameException +from zeroconf._services.info import ServiceInfo, instance_name_from_service_info from zeroconf._utils import name as nameutils @@ -25,6 +27,33 @@ def test_service_type_name_overlong_full_name(): nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.", strict=False) +@pytest.mark.parametrize( + "instance_name, service_type", + ( + ("CustomerInformationService-F4D4885E9EEB", "_ibisip_http._tcp.local."), + ("DeviceManagementService_F4D4885E9EEB", "_ibisip_http._tcp.local."), + ), +) +def test_service_type_name_non_strict_compliant_names(instance_name, service_type): + """Test service_type_name for valid names, but not strict-compliant.""" + desc = {'path': '/~paulsm/'} + service_name = f'{instance_name}.{service_type}' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + assert info.get_name() == instance_name + + with pytest.raises(BadTypeInNameException): + nameutils.service_type_name(service_name) + with pytest.raises(BadTypeInNameException): + instance_name_from_service_info(info) + + nameutils.service_type_name(service_name, strict=False) + assert instance_name_from_service_info(info, strict=False) == instance_name + + def test_possible_types(): """Test possible types from name.""" assert nameutils.possible_types('.') == set()