From 1682991b985b1f7b2bf0cff1a7eb7793070e7cb1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 10 Dec 2023 10:38:27 -1000 Subject: [PATCH] fix: correct handling of IPv6 addresses with scope_id in ServiceInfo (#1322) --- examples/browser.py | 10 +++--- src/zeroconf/_services/info.pxd | 8 +++++ src/zeroconf/_services/info.py | 53 +++++++++++++++++++++-------- tests/services/test_info.py | 59 ++++++++++++++++++++++++++++++--- 4 files changed, 107 insertions(+), 23 deletions(-) diff --git a/examples/browser.py b/examples/browser.py index 60933e2a..a456a9eb 100755 --- a/examples/browser.py +++ b/examples/browser.py @@ -51,18 +51,18 @@ def on_service_state_change( parser.add_argument('--debug', action='store_true') parser.add_argument('--find', action='store_true', help='Browse all available services') version_group = parser.add_mutually_exclusive_group() - version_group.add_argument('--v6', action='store_true') version_group.add_argument('--v6-only', action='store_true') + version_group.add_argument('--v4-only', action='store_true') args = parser.parse_args() if args.debug: logging.getLogger('zeroconf').setLevel(logging.DEBUG) - if args.v6: - ip_version = IPVersion.All - elif args.v6_only: + if args.v6_only: ip_version = IPVersion.V6Only - else: + elif args.v4_only: ip_version = IPVersion.V4Only + else: + ip_version = IPVersion.All zeroconf = Zeroconf(ip_version=ip_version) diff --git a/src/zeroconf/_services/info.pxd b/src/zeroconf/_services/info.pxd index 3506c3a9..b7a2ee30 100644 --- a/src/zeroconf/_services/info.pxd +++ b/src/zeroconf/_services/info.pxd @@ -32,6 +32,14 @@ cdef object _IPVersion_V4Only_value cdef cython.set _ADDRESS_RECORD_TYPES cdef bint TYPE_CHECKING +cdef bint IPADDRESS_SUPPORTS_SCOPE_ID + +cdef _get_ip_address_object_from_record(DNSAddress record) + +@cython.locals(address_str=str) +cdef _str_without_scope_id(object addr) + +cdef _ip_bytes_and_scope_to_address(object addr, object scope_id) cdef class ServiceInfo(RecordUpdateListener): diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index f363b55b..e9e25763 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -22,6 +22,7 @@ import asyncio import random +import sys from functools import lru_cache from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast @@ -78,12 +79,15 @@ # the A/AAAA/SRV records for a host. _AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120) +bytes_ = bytes float_ = float int_ = int DNS_QUESTION_TYPE_QU = DNSQuestionType.QU DNS_QUESTION_TYPE_QM = DNSQuestionType.QM +IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0) + if TYPE_CHECKING: from .._core import Zeroconf @@ -110,6 +114,29 @@ def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4 _cached_ip_addresses_wrapper = _cached_ip_addresses +def _get_ip_address_object_from_record(record: DNSAddress) -> Optional[Union[IPv4Address, IPv6Address]]: + """Get the IP address object from the record.""" + if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id is not None: + return _ip_bytes_and_scope_to_address(record.address, record.scope_id) + return _cached_ip_addresses_wrapper(record.address) + + +def _ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Union[IPv4Address, IPv6Address]]: + """Convert the bytes and scope to an IP address object.""" + base_address = _cached_ip_addresses_wrapper(address) + if base_address is not None and base_address.is_link_local: + return _cached_ip_addresses_wrapper(f"{base_address}%{scope}") + return base_address + + +def _str_without_scope_id(addr: Union[IPv4Address, IPv6Address]) -> str: + """Return the string representation of the address without the scope id.""" + if IPADDRESS_SUPPORTS_SCOPE_ID and addr.version == 6: + address_str = str(addr) + return address_str.partition('%')[0] + return str(addr) + + class ServiceInfo(RecordUpdateListener): """Service information. @@ -177,6 +204,7 @@ def __init__( raise TypeError("addresses and parsed_addresses cannot be provided together") if not type_.endswith(service_type_name(name, strict=False)): raise BadTypeInNameException + self.interface_index = interface_index self.text = b'' self.type = type_ self._name = name @@ -199,7 +227,6 @@ def __init__( self._set_properties(properties) self.host_ttl = host_ttl self.other_ttl = other_ttl - self.interface_index = interface_index self._new_records_futures: Optional[Set[asyncio.Future]] = None self._dns_address_cache: Optional[List[DNSAddress]] = None self._dns_pointer_cache: Optional[DNSPointer] = None @@ -243,7 +270,10 @@ def addresses(self, value: List[bytes]) -> None: self._get_address_and_nsec_records_cache = None for address in value: - addr = _cached_ip_addresses_wrapper(address) + if IPADDRESS_SUPPORTS_SCOPE_ID and len(address) == 16 and self.interface_index is not None: + addr = _ip_bytes_and_scope_to_address(address, self.interface_index) + else: + addr = _cached_ip_addresses_wrapper(address) if addr is None: raise TypeError( "Addresses must either be IPv4 or IPv6 strings, bytes, or integers;" @@ -322,10 +352,10 @@ def ip_addresses_by_version( def _ip_addresses_by_version_value( self, version_value: int_ - ) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]: + ) -> Union[List[IPv4Address], List[IPv6Address]]: """Backend for addresses_by_version that uses the raw value.""" if version_value == _IPVersion_All_value: - return [*self._ipv4_addresses, *self._ipv6_addresses] + return [*self._ipv4_addresses, *self._ipv6_addresses] # type: ignore[return-value] if version_value == _IPVersion_V4Only_value: return self._ipv4_addresses return self._ipv6_addresses @@ -339,7 +369,7 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: This means the first address will always be the most recently added address of the given IP version. """ - return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)] + return [_str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)] def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: """Equivalent to parsed_addresses, with the exception that IPv6 Link-Local @@ -351,12 +381,7 @@ def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[st This means the first address will always be the most recently added address of the given IP version. """ - if self.interface_index is None: - return self.parsed_addresses(version) - return [ - f"{addr}%{self.interface_index}" if addr.version == 6 and addr.is_link_local else str(addr) - for addr in self._ip_addresses_by_version_value(version.value) - ] + return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)] def _set_properties(self, properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]]) -> None: """Sets properties and text of this info from a dictionary""" @@ -421,8 +446,8 @@ def _get_ip_addresses_from_cache_lifo( for record in self._get_address_records_from_cache_by_type(zc, type): if record.is_expired(now): continue - ip_addr = _cached_ip_addresses_wrapper(record.address) - if ip_addr is not None: + ip_addr = _get_ip_address_object_from_record(record) + if ip_addr is not None and ip_addr not in address_list: address_list.append(ip_addr) address_list.reverse() # Reverse to get LIFO order return address_list @@ -471,7 +496,7 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo dns_address_record = record if TYPE_CHECKING: assert isinstance(dns_address_record, DNSAddress) - ip_addr = _cached_ip_addresses_wrapper(dns_address_record.address) + ip_addr = _get_ip_address_object_from_record(dns_address_record) if ip_addr is None: log.warning( "Encountered invalid address while processing %s: %s", diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 7d437d23..482b3b0c 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -7,6 +7,7 @@ import logging import os import socket +import sys import threading import unittest from ipaddress import ip_address @@ -538,6 +539,7 @@ def test_multiple_addresses(): assert info.addresses == [address, address] assert info.parsed_addresses() == [address_parsed, address_parsed] assert info.parsed_scoped_addresses() == [address_parsed, address_parsed] + ipaddress_supports_scope_id = sys.version_info >= (3, 9, 0) if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): address_v6_parsed = "2001:db8::1" @@ -576,14 +578,18 @@ def test_multiple_addresses(): assert info.ip_addresses_by_version(r.IPVersion.All) == [ ip_address(address), ip_address(address_v6), - ip_address(address_v6_ll), + ip_address(address_v6_ll_scoped_parsed) + if ipaddress_supports_scope_id + else ip_address(address_v6_ll), ] assert info.addresses_by_version(r.IPVersion.V4Only) == [address] assert info.ip_addresses_by_version(r.IPVersion.V4Only) == [ip_address(address)] assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6, address_v6_ll] assert info.ip_addresses_by_version(r.IPVersion.V6Only) == [ ip_address(address_v6), - ip_address(address_v6_ll), + ip_address(address_v6_ll_scoped_parsed) + if ipaddress_supports_scope_id + else ip_address(address_v6_ll), ] assert info.parsed_addresses() == [address_parsed, address_v6_parsed, address_v6_ll_parsed] assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] @@ -591,15 +597,60 @@ def test_multiple_addresses(): assert info.parsed_scoped_addresses() == [ address_parsed, address_v6_parsed, - address_v6_ll_scoped_parsed, + address_v6_ll_scoped_parsed if ipaddress_supports_scope_id else address_v6_ll_parsed, ] assert info.parsed_scoped_addresses(r.IPVersion.V4Only) == [address_parsed] assert info.parsed_scoped_addresses(r.IPVersion.V6Only) == [ address_v6_parsed, - address_v6_ll_scoped_parsed, + address_v6_ll_scoped_parsed if ipaddress_supports_scope_id else address_v6_ll_parsed, ] +@unittest.skipIf(sys.version_info < (3, 9, 0), 'Requires newer python') +def test_scoped_addresses_from_cache(): + type_ = "_http._tcp.local." + registration_name = f"scoped.{type_}" + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + host = "scoped.local." + + zeroconf.cache.async_add_records( + [ + r.DNSPointer( + type_, + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + 120, + registration_name, + ), + r.DNSService( + registration_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + 120, + 0, + 0, + 80, + host, + ), + r.DNSAddress( + host, + const._TYPE_AAAA, + const._CLASS_IN | const._CLASS_UNIQUE, + 120, + socket.inet_pton(socket.AF_INET6, "fe80::52e:c2f2:bc5f:e9c6"), + scope_id=12, + ), + ] + ) + + # New kwarg way + info = ServiceInfo(type_, registration_name) + info.load_from_cache(zeroconf) + assert info.parsed_scoped_addresses() == ["fe80::52e:c2f2:bc5f:e9c6%12"] + assert info.ip_addresses_by_version(r.IPVersion.V6Only) == [ip_address("fe80::52e:c2f2:bc5f:e9c6%12")] + zeroconf.close() + + # This test uses asyncio because it needs to access the cache directly # which is not threadsafe @pytest.mark.asyncio