Skip to content

Commit

Permalink
fix: performance regression with ServiceInfo IPv6Addresses (#1330)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Dec 13, 2023
1 parent 878a726 commit e2f9f81
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 51 deletions.
1 change: 1 addition & 0 deletions build_ext.py
Expand Up @@ -39,6 +39,7 @@ def build(setup_kwargs: Any) -> None:
"src/zeroconf/_services/info.py",
"src/zeroconf/_services/registry.py",
"src/zeroconf/_updates.py",
"src/zeroconf/_utils/ipaddress.py",
"src/zeroconf/_utils/time.py",
],
compiler_directives={"language_level": "3"}, # Python 3
Expand Down
14 changes: 6 additions & 8 deletions src/zeroconf/_services/info.pxd
Expand Up @@ -6,11 +6,15 @@ from .._dns cimport DNSAddress, DNSNsec, DNSPointer, DNSRecord, DNSService, DNST
from .._protocol.outgoing cimport DNSOutgoing
from .._record_update cimport RecordUpdate
from .._updates cimport RecordUpdateListener
from .._utils.ipaddress cimport (
get_ip_address_object_from_record,
ip_bytes_and_scope_to_address,
str_without_scope_id,
)
from .._utils.time cimport current_time_millis


cdef object _resolve_all_futures_to_none
cdef object _cached_ip_addresses_wrapper

cdef object _TYPE_SRV
cdef object _TYPE_TXT
Expand All @@ -33,13 +37,7 @@ cdef cython.set _ADDRESS_RECORD_TYPES

cdef bint TYPE_CHECKING
cdef bint IPADDRESS_SUPPORTS_SCOPE_ID

cdef _get_ip_address_object_from_record(DNSAddress record)

@cython.locals(address_str=str)
cdef _str_without_scope_id(object addr)

cdef _ip_bytes_and_scope_to_address(object addr, object scope_id)
cdef object cached_ip_addresses

cdef class ServiceInfo(RecordUpdateListener):

Expand Down
57 changes: 14 additions & 43 deletions src/zeroconf/_services/info.py
Expand Up @@ -23,8 +23,7 @@
import asyncio
import random
import sys
from functools import lru_cache
from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address
from ipaddress import IPv4Address, IPv6Address, _BaseAddress
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast

from .._dns import (
Expand All @@ -47,6 +46,12 @@
run_coro_with_timeout,
wait_for_future_set_or_timeout,
)
from .._utils.ipaddress import (
cached_ip_addresses,
get_ip_address_object_from_record,
ip_bytes_and_scope_to_address,
str_without_scope_id,
)
from .._utils.name import service_type_name
from .._utils.net import IPVersion, _encode_address
from .._utils.time import current_time_millis
Expand All @@ -67,6 +72,8 @@
_TYPE_TXT,
)

IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0)

_IPVersion_All_value = IPVersion.All.value
_IPVersion_V4Only_value = IPVersion.V4Only.value
# https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
Expand All @@ -86,7 +93,6 @@
DNS_QUESTION_TYPE_QU = DNSQuestionType.QU
DNS_QUESTION_TYPE_QM = DNSQuestionType.QM

IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0)

if TYPE_CHECKING:
from .._core import Zeroconf
Expand All @@ -102,41 +108,6 @@ def instance_name_from_service_info(info: "ServiceInfo", strict: bool = True) ->
return info.name[: -len(service_name) - 1]


@lru_cache(maxsize=512)
def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4Address, IPv6Address]]:
"""Cache IP addresses."""
try:
return ip_address(address)
except ValueError:
return None


_cached_ip_addresses_wrapper = _cached_ip_addresses


def _get_ip_address_object_from_record(record: DNSAddress) -> Optional[Union[IPv4Address, IPv6Address]]:
"""Get the IP address object from the record."""
if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id is not None:
return _ip_bytes_and_scope_to_address(record.address, record.scope_id)
return _cached_ip_addresses_wrapper(record.address)


def _ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Union[IPv4Address, IPv6Address]]:
"""Convert the bytes and scope to an IP address object."""
base_address = _cached_ip_addresses_wrapper(address)
if base_address is not None and base_address.is_link_local:
return _cached_ip_addresses_wrapper(f"{base_address}%{scope}")
return base_address


def _str_without_scope_id(addr: Union[IPv4Address, IPv6Address]) -> str:
"""Return the string representation of the address without the scope id."""
if IPADDRESS_SUPPORTS_SCOPE_ID and addr.version == 6:
address_str = str(addr)
return address_str.partition('%')[0]
return str(addr)


class ServiceInfo(RecordUpdateListener):
"""Service information.
Expand Down Expand Up @@ -271,9 +242,9 @@ def addresses(self, value: List[bytes]) -> None:

for address in value:
if IPADDRESS_SUPPORTS_SCOPE_ID and len(address) == 16 and self.interface_index is not None:
addr = _ip_bytes_and_scope_to_address(address, self.interface_index)
addr = ip_bytes_and_scope_to_address(address, self.interface_index)
else:
addr = _cached_ip_addresses_wrapper(address)
addr = cached_ip_addresses(address)
if addr is None:
raise TypeError(
"Addresses must either be IPv4 or IPv6 strings, bytes, or integers;"
Expand Down Expand Up @@ -369,7 +340,7 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
This means the first address will always be the most recently added
address of the given IP version.
"""
return [_str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)]
return [str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)]

def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
"""Equivalent to parsed_addresses, with the exception that IPv6 Link-Local
Expand Down Expand Up @@ -446,7 +417,7 @@ def _get_ip_addresses_from_cache_lifo(
for record in self._get_address_records_from_cache_by_type(zc, type):
if record.is_expired(now):
continue
ip_addr = _get_ip_address_object_from_record(record)
ip_addr = get_ip_address_object_from_record(record)
if ip_addr is not None and ip_addr not in address_list:
address_list.append(ip_addr)
address_list.reverse() # Reverse to get LIFO order
Expand Down Expand Up @@ -496,7 +467,7 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo
dns_address_record = record
if TYPE_CHECKING:
assert isinstance(dns_address_record, DNSAddress)
ip_addr = _get_ip_address_object_from_record(dns_address_record)
ip_addr = get_ip_address_object_from_record(dns_address_record)
if ip_addr is None:
log.warning(
"Encountered invalid address while processing %s: %s",
Expand Down
14 changes: 14 additions & 0 deletions src/zeroconf/_utils/ipaddress.pxd
@@ -0,0 +1,14 @@
cdef bint TYPE_CHECKING
cdef bint IPADDRESS_SUPPORTS_SCOPE_ID

from .._dns cimport DNSAddress


cpdef get_ip_address_object_from_record(DNSAddress record)

@cython.locals(address_str=str)
cpdef str_without_scope_id(object addr)

cpdef ip_bytes_and_scope_to_address(object addr, object scope_id)

cdef object cached_ip_addresses_wrapper
121 changes: 121 additions & 0 deletions src/zeroconf/_utils/ipaddress.py
@@ -0,0 +1,121 @@
""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
This module provides a framework for the use of DNS Service Discovery
using IP multicast.
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
USA
"""
import sys
from functools import lru_cache
from ipaddress import AddressValueError, IPv4Address, IPv6Address, NetmaskValueError
from typing import Any, Optional, Union

from .._dns import DNSAddress
from ..const import _TYPE_AAAA

bytes_ = bytes
int_ = int
IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0)


class ZeroconfIPv4Address(IPv4Address):

__slots__ = ("_str", "_is_link_local")

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize a new IPv4 address."""
super().__init__(*args, **kwargs)
self._str = super().__str__()
self._is_link_local = super().is_link_local

def __str__(self) -> str:
"""Return the string representation of the IPv4 address."""
return self._str

@property
def is_link_local(self) -> bool:
"""Return True if this is a link-local address."""
return self._is_link_local


class ZeroconfIPv6Address(IPv6Address):

__slots__ = ("_str", "_is_link_local")

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize a new IPv6 address."""
super().__init__(*args, **kwargs)
self._str = super().__str__()
self._is_link_local = super().is_link_local

def __str__(self) -> str:
"""Return the string representation of the IPv6 address."""
return self._str

@property
def is_link_local(self) -> bool:
"""Return True if this is a link-local address."""
return self._is_link_local


@lru_cache(maxsize=512)
def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4Address, IPv6Address]]:
"""Cache IP addresses."""
try:
return ZeroconfIPv4Address(address)
except (AddressValueError, NetmaskValueError):
pass

try:
return ZeroconfIPv6Address(address)
except (AddressValueError, NetmaskValueError):
return None


cached_ip_addresses_wrapper = _cached_ip_addresses
cached_ip_addresses = cached_ip_addresses_wrapper


def get_ip_address_object_from_record(record: DNSAddress) -> Optional[Union[IPv4Address, IPv6Address]]:
"""Get the IP address object from the record."""
if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id is not None:
return ip_bytes_and_scope_to_address(record.address, record.scope_id)
return cached_ip_addresses_wrapper(record.address)


def ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Union[IPv4Address, IPv6Address]]:
"""Convert the bytes and scope to an IP address object."""
base_address = cached_ip_addresses_wrapper(address)
if base_address is not None and base_address.is_link_local:
return cached_ip_addresses_wrapper(f"{base_address}%{scope}")
return base_address


def str_without_scope_id(addr: Union[IPv4Address, IPv6Address]) -> str:
"""Return the string representation of the address without the scope id."""
if IPADDRESS_SUPPORTS_SCOPE_ID and addr.version == 6:
address_str = str(addr)
return address_str.partition('%')[0]
return str(addr)


__all__ = (
"cached_ip_addresses",
"get_ip_address_object_from_record",
"ip_bytes_and_scope_to_address",
"str_without_scope_id",
)
24 changes: 24 additions & 0 deletions tests/utils/test_ipaddress.py
@@ -0,0 +1,24 @@
#!/usr/bin/env python

"""Unit tests for zeroconf._utils.ipaddress."""

from zeroconf._utils import ipaddress


def test_cached_ip_addresses_wrapper():
"""Test the cached_ip_addresses_wrapper."""
assert ipaddress.cached_ip_addresses('') is None
assert ipaddress.cached_ip_addresses('foo') is None
assert (
str(ipaddress.cached_ip_addresses(b'&\x06(\x00\x02 \x00\x01\x02H\x18\x93%\xc8\x19F'))
== '2606:2800:220:1:248:1893:25c8:1946'
)
assert ipaddress.cached_ip_addresses('::1') == ipaddress.IPv6Address('::1')

ipv4 = ipaddress.cached_ip_addresses('169.254.0.0')
assert ipv4 is not None
assert ipv4.is_link_local is True

ipv6 = ipaddress.cached_ip_addresses('fe80::1')
assert ipv6 is not None
assert ipv6.is_link_local is True

0 comments on commit e2f9f81

Please sign in to comment.