Skip to content

Commit

Permalink
fix: cleanup naming from previous refactoring in ServiceInfo (#1202)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Jul 24, 2023
1 parent fed3dec commit b272d75
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 31 deletions.
2 changes: 0 additions & 2 deletions src/zeroconf/_services/__init__.py
Expand Up @@ -46,7 +46,6 @@ def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None:


class Signal:

__slots__ = ('_handlers',)

def __init__(self) -> None:
Expand All @@ -62,7 +61,6 @@ def registration_interface(self) -> 'SignalRegistrationInterface':


class SignalRegistrationInterface:

__slots__ = ('_handlers',)

def __init__(self, handlers: List[Callable[..., None]]) -> None:
Expand Down
62 changes: 33 additions & 29 deletions src/zeroconf/_services/info.py
Expand Up @@ -21,9 +21,9 @@
"""

import asyncio
import ipaddress
import random
from functools import lru_cache
from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast

from .._dns import (
Expand Down Expand Up @@ -90,7 +90,7 @@ def instance_name_from_service_info(info: "ServiceInfo") -> str:
return info.name[: -len(service_name) - 1]


_cached_ip_addresses = lru_cache(maxsize=256)(ipaddress.ip_address)
_cached_ip_addresses = lru_cache(maxsize=256)(ip_address)


class ServiceInfo(RecordUpdateListener):
Expand Down Expand Up @@ -158,8 +158,8 @@ def __init__(
self.type = type_
self._name = name
self.key = name.lower()
self._ipv4_addresses: List[ipaddress.IPv4Address] = []
self._ipv6_addresses: List[ipaddress.IPv6Address] = []
self._ipv4_addresses: List[IPv4Address] = []
self._ipv6_addresses: List[IPv6Address] = []
if addresses is not None:
self.addresses = addresses
elif parsed_addresses is not None:
Expand Down Expand Up @@ -260,7 +260,7 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]:

def ip_addresses_by_version(
self, version: IPVersion
) -> Union[List[ipaddress.IPv4Address], List[ipaddress.IPv6Address], List[ipaddress._BaseAddress]]:
) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]:
"""List ip_address objects matching IP version.
Addresses are guaranteed to be returned in LIFO (last in, first out)
Expand All @@ -273,7 +273,7 @@ def ip_addresses_by_version(

def _ip_addresses_by_version_value(
self, version_value: int
) -> Union[List[ipaddress.IPv4Address], List[ipaddress.IPv6Address], List[ipaddress._BaseAddress]]:
) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]:
"""Backend for addresses_by_version that uses the raw value."""
if version_value == _IPVersion_All_value:
return [*self._ipv4_addresses, *self._ipv6_addresses]
Expand Down Expand Up @@ -366,31 +366,31 @@ def get_name(self) -> str:

def _get_ip_addresses_from_cache_lifo(
self, zc: 'Zeroconf', now: float, type: int
) -> List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]:
) -> List[Union[IPv4Address, IPv6Address]]:
"""Set IPv6 addresses from the cache."""
address_list: List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]] = []
address_list: List[Union[IPv4Address, IPv6Address]] = []
for record in self._get_address_records_from_cache_by_type(zc, type):
if record.is_expired(now):
continue
try:
ip_address = _cached_ip_addresses(record.address)
ip_addr = _cached_ip_addresses(record.address)
except ValueError:
continue
else:
address_list.append(ip_address)
address_list.append(ip_addr)
address_list.reverse() # Reverse to get LIFO order
return address_list

def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
"""Set IPv6 addresses from the cache."""
self._ipv6_addresses = cast(
"List[ipaddress.IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)
"List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)
)

def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
"""Set IPv4 addresses from the cache."""
self._ipv4_addresses = cast(
"List[ipaddress.IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)
"List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)
)

def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None:
Expand Down Expand Up @@ -431,46 +431,49 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo
if record.is_expired(now):
return False

if record.key == self.server_key and isinstance(record, DNSAddress):
record_key = record.key
if record_key == self.server_key and type(record) is DNSAddress:
try:
ip_addr = _cached_ip_addresses(record.address)
except ValueError as ex:
log.warning("Encountered invalid address while processing %s: %s", record, ex)
return False

if ip_addr.version == 4:
if not self._ipv4_addresses:
if type(ip_addr) is IPv4Address:
if self._ipv4_addresses:
self._set_ipv4_addresses_from_cache(zc, now)

if ip_addr not in self._ipv4_addresses:
self._ipv4_addresses.insert(0, ip_addr)
ipv4_addresses = self._ipv4_addresses
if ip_addr not in ipv4_addresses:
ipv4_addresses.insert(0, ip_addr)
return True
elif ip_addr != self._ipv4_addresses[0]:
self._ipv4_addresses.remove(ip_addr)
self._ipv4_addresses.insert(0, ip_addr)
elif ip_addr != ipv4_addresses[0]:
ipv4_addresses.remove(ip_addr)
ipv4_addresses.insert(0, ip_addr)

return False

if not self._ipv6_addresses:
self._set_ipv6_addresses_from_cache(zc, now)

ipv6_addresses = self._ipv6_addresses
if ip_addr not in self._ipv6_addresses:
self._ipv6_addresses.insert(0, ip_addr)
ipv6_addresses.insert(0, ip_addr)
return True
elif ip_addr != self._ipv6_addresses[0]:
self._ipv6_addresses.remove(ip_addr)
self._ipv6_addresses.insert(0, ip_addr)
ipv6_addresses.remove(ip_addr)
ipv6_addresses.insert(0, ip_addr)

return False

if record.key != self.key:
if record_key != self.key:
return False

if record.type == _TYPE_TXT and isinstance(record, DNSText):
if record.type == _TYPE_TXT and type(record) is DNSText:
self._set_text(record.text)
return True

if record.type == _TYPE_SRV and isinstance(record, DNSService):
if record.type == _TYPE_SRV and type(record) is DNSService:
old_server_key = self.server_key
self.name = record.name
self.server = record.server
Expand All @@ -495,16 +498,17 @@ def dns_addresses(
name = self.server or self.name
ttl = override_ttl if override_ttl is not None else self.host_ttl
class_ = _CLASS_IN | _CLASS_UNIQUE
version_value = version.value
return [
DNSAddress(
name,
_TYPE_AAAA if address.version == 6 else _TYPE_A,
_TYPE_AAAA if type(ip_addr) is IPv6Address else _TYPE_A,
class_,
ttl,
address.packed,
ip_addr.packed,
created=created,
)
for address in self._ip_addresses_by_version_value(version.value)
for ip_addr in self._ip_addresses_by_version_value(version_value)
]

def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer:
Expand Down

0 comments on commit b272d75

Please sign in to comment.