Skip to content

Commit

Permalink
fix: correct handling of IPv6 addresses with scope_id in ServiceInfo (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Dec 10, 2023
1 parent 1c2f194 commit 1682991
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 23 deletions.
10 changes: 5 additions & 5 deletions examples/browser.py
Expand Up @@ -51,18 +51,18 @@ def on_service_state_change(
parser.add_argument('--debug', action='store_true')
parser.add_argument('--find', action='store_true', help='Browse all available services')
version_group = parser.add_mutually_exclusive_group()
version_group.add_argument('--v6', action='store_true')
version_group.add_argument('--v6-only', action='store_true')
version_group.add_argument('--v4-only', action='store_true')
args = parser.parse_args()

if args.debug:
logging.getLogger('zeroconf').setLevel(logging.DEBUG)
if args.v6:
ip_version = IPVersion.All
elif args.v6_only:
if args.v6_only:
ip_version = IPVersion.V6Only
else:
elif args.v4_only:
ip_version = IPVersion.V4Only
else:
ip_version = IPVersion.All

zeroconf = Zeroconf(ip_version=ip_version)

Expand Down
8 changes: 8 additions & 0 deletions src/zeroconf/_services/info.pxd
Expand Up @@ -32,6 +32,14 @@ cdef object _IPVersion_V4Only_value
cdef cython.set _ADDRESS_RECORD_TYPES

cdef bint TYPE_CHECKING
cdef bint IPADDRESS_SUPPORTS_SCOPE_ID

cdef _get_ip_address_object_from_record(DNSAddress record)

@cython.locals(address_str=str)
cdef _str_without_scope_id(object addr)

cdef _ip_bytes_and_scope_to_address(object addr, object scope_id)

cdef class ServiceInfo(RecordUpdateListener):

Expand Down
53 changes: 39 additions & 14 deletions src/zeroconf/_services/info.py
Expand Up @@ -22,6 +22,7 @@

import asyncio
import random
import sys
from functools import lru_cache
from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast
Expand Down Expand Up @@ -78,12 +79,15 @@
# the A/AAAA/SRV records for a host.
_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)

bytes_ = bytes
float_ = float
int_ = int

DNS_QUESTION_TYPE_QU = DNSQuestionType.QU
DNS_QUESTION_TYPE_QM = DNSQuestionType.QM

IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0)

if TYPE_CHECKING:
from .._core import Zeroconf

Expand All @@ -110,6 +114,29 @@ def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4
_cached_ip_addresses_wrapper = _cached_ip_addresses


def _get_ip_address_object_from_record(record: DNSAddress) -> Optional[Union[IPv4Address, IPv6Address]]:
"""Get the IP address object from the record."""
if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id is not None:
return _ip_bytes_and_scope_to_address(record.address, record.scope_id)
return _cached_ip_addresses_wrapper(record.address)


def _ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Union[IPv4Address, IPv6Address]]:
"""Convert the bytes and scope to an IP address object."""
base_address = _cached_ip_addresses_wrapper(address)
if base_address is not None and base_address.is_link_local:
return _cached_ip_addresses_wrapper(f"{base_address}%{scope}")
return base_address


def _str_without_scope_id(addr: Union[IPv4Address, IPv6Address]) -> str:
"""Return the string representation of the address without the scope id."""
if IPADDRESS_SUPPORTS_SCOPE_ID and addr.version == 6:
address_str = str(addr)
return address_str.partition('%')[0]
return str(addr)


class ServiceInfo(RecordUpdateListener):
"""Service information.
Expand Down Expand Up @@ -177,6 +204,7 @@ def __init__(
raise TypeError("addresses and parsed_addresses cannot be provided together")
if not type_.endswith(service_type_name(name, strict=False)):
raise BadTypeInNameException
self.interface_index = interface_index
self.text = b''
self.type = type_
self._name = name
Expand All @@ -199,7 +227,6 @@ def __init__(
self._set_properties(properties)
self.host_ttl = host_ttl
self.other_ttl = other_ttl
self.interface_index = interface_index
self._new_records_futures: Optional[Set[asyncio.Future]] = None
self._dns_address_cache: Optional[List[DNSAddress]] = None
self._dns_pointer_cache: Optional[DNSPointer] = None
Expand Down Expand Up @@ -243,7 +270,10 @@ def addresses(self, value: List[bytes]) -> None:
self._get_address_and_nsec_records_cache = None

for address in value:
addr = _cached_ip_addresses_wrapper(address)
if IPADDRESS_SUPPORTS_SCOPE_ID and len(address) == 16 and self.interface_index is not None:
addr = _ip_bytes_and_scope_to_address(address, self.interface_index)
else:
addr = _cached_ip_addresses_wrapper(address)
if addr is None:
raise TypeError(
"Addresses must either be IPv4 or IPv6 strings, bytes, or integers;"
Expand Down Expand Up @@ -322,10 +352,10 @@ def ip_addresses_by_version(

def _ip_addresses_by_version_value(
self, version_value: int_
) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]:
) -> Union[List[IPv4Address], List[IPv6Address]]:
"""Backend for addresses_by_version that uses the raw value."""
if version_value == _IPVersion_All_value:
return [*self._ipv4_addresses, *self._ipv6_addresses]
return [*self._ipv4_addresses, *self._ipv6_addresses] # type: ignore[return-value]
if version_value == _IPVersion_V4Only_value:
return self._ipv4_addresses
return self._ipv6_addresses
Expand All @@ -339,7 +369,7 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
This means the first address will always be the most recently added
address of the given IP version.
"""
return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)]
return [_str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)]

def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
"""Equivalent to parsed_addresses, with the exception that IPv6 Link-Local
Expand All @@ -351,12 +381,7 @@ def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[st
This means the first address will always be the most recently added
address of the given IP version.
"""
if self.interface_index is None:
return self.parsed_addresses(version)
return [
f"{addr}%{self.interface_index}" if addr.version == 6 and addr.is_link_local else str(addr)
for addr in self._ip_addresses_by_version_value(version.value)
]
return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)]

def _set_properties(self, properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]]) -> None:
"""Sets properties and text of this info from a dictionary"""
Expand Down Expand Up @@ -421,8 +446,8 @@ def _get_ip_addresses_from_cache_lifo(
for record in self._get_address_records_from_cache_by_type(zc, type):
if record.is_expired(now):
continue
ip_addr = _cached_ip_addresses_wrapper(record.address)
if ip_addr is not None:
ip_addr = _get_ip_address_object_from_record(record)
if ip_addr is not None and ip_addr not in address_list:
address_list.append(ip_addr)
address_list.reverse() # Reverse to get LIFO order
return address_list
Expand Down Expand Up @@ -471,7 +496,7 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo
dns_address_record = record
if TYPE_CHECKING:
assert isinstance(dns_address_record, DNSAddress)
ip_addr = _cached_ip_addresses_wrapper(dns_address_record.address)
ip_addr = _get_ip_address_object_from_record(dns_address_record)
if ip_addr is None:
log.warning(
"Encountered invalid address while processing %s: %s",
Expand Down
59 changes: 55 additions & 4 deletions tests/services/test_info.py
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import socket
import sys
import threading
import unittest
from ipaddress import ip_address
Expand Down Expand Up @@ -538,6 +539,7 @@ def test_multiple_addresses():
assert info.addresses == [address, address]
assert info.parsed_addresses() == [address_parsed, address_parsed]
assert info.parsed_scoped_addresses() == [address_parsed, address_parsed]
ipaddress_supports_scope_id = sys.version_info >= (3, 9, 0)

if has_working_ipv6() and not os.environ.get('SKIP_IPV6'):
address_v6_parsed = "2001:db8::1"
Expand Down Expand Up @@ -576,30 +578,79 @@ def test_multiple_addresses():
assert info.ip_addresses_by_version(r.IPVersion.All) == [
ip_address(address),
ip_address(address_v6),
ip_address(address_v6_ll),
ip_address(address_v6_ll_scoped_parsed)
if ipaddress_supports_scope_id
else ip_address(address_v6_ll),
]
assert info.addresses_by_version(r.IPVersion.V4Only) == [address]
assert info.ip_addresses_by_version(r.IPVersion.V4Only) == [ip_address(address)]
assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6, address_v6_ll]
assert info.ip_addresses_by_version(r.IPVersion.V6Only) == [
ip_address(address_v6),
ip_address(address_v6_ll),
ip_address(address_v6_ll_scoped_parsed)
if ipaddress_supports_scope_id
else ip_address(address_v6_ll),
]
assert info.parsed_addresses() == [address_parsed, address_v6_parsed, address_v6_ll_parsed]
assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed]
assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed, address_v6_ll_parsed]
assert info.parsed_scoped_addresses() == [
address_parsed,
address_v6_parsed,
address_v6_ll_scoped_parsed,
address_v6_ll_scoped_parsed if ipaddress_supports_scope_id else address_v6_ll_parsed,
]
assert info.parsed_scoped_addresses(r.IPVersion.V4Only) == [address_parsed]
assert info.parsed_scoped_addresses(r.IPVersion.V6Only) == [
address_v6_parsed,
address_v6_ll_scoped_parsed,
address_v6_ll_scoped_parsed if ipaddress_supports_scope_id else address_v6_ll_parsed,
]


@unittest.skipIf(sys.version_info < (3, 9, 0), 'Requires newer python')
def test_scoped_addresses_from_cache():
type_ = "_http._tcp.local."
registration_name = f"scoped.{type_}"
zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
host = "scoped.local."

zeroconf.cache.async_add_records(
[
r.DNSPointer(
type_,
const._TYPE_PTR,
const._CLASS_IN | const._CLASS_UNIQUE,
120,
registration_name,
),
r.DNSService(
registration_name,
const._TYPE_SRV,
const._CLASS_IN | const._CLASS_UNIQUE,
120,
0,
0,
80,
host,
),
r.DNSAddress(
host,
const._TYPE_AAAA,
const._CLASS_IN | const._CLASS_UNIQUE,
120,
socket.inet_pton(socket.AF_INET6, "fe80::52e:c2f2:bc5f:e9c6"),
scope_id=12,
),
]
)

# New kwarg way
info = ServiceInfo(type_, registration_name)
info.load_from_cache(zeroconf)
assert info.parsed_scoped_addresses() == ["fe80::52e:c2f2:bc5f:e9c6%12"]
assert info.ip_addresses_by_version(r.IPVersion.V6Only) == [ip_address("fe80::52e:c2f2:bc5f:e9c6%12")]
zeroconf.close()


# This test uses asyncio because it needs to access the cache directly
# which is not threadsafe
@pytest.mark.asyncio
Expand Down

0 comments on commit 1682991

Please sign in to comment.