Skip to content

Commit

Permalink
Call UpdateService on SRV & A/AAAA updates as well as TXT (#239)
Browse files Browse the repository at this point in the history
Fix #235

Contains:

* Add lock around handlers list
* Reverse DNSCache order to ensure newest records take precedence

  When there are multiple records in the cache, the behaviour was
  inconsistent. Whilst the DNSCache.get() method returned the newest,
  any function which iterated over the entire cache suffered from
  a last write winds issue. This change makes this behaviour consistent
  and allows the removal of an (incorrect) wait from one of the unit tests.
  • Loading branch information
mattsaxon authored and jstasiak committed Apr 25, 2020
1 parent f8fe400 commit 552a030
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 84 deletions.
132 changes: 85 additions & 47 deletions zeroconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import threading
import time
import warnings
from collections import OrderedDict
from typing import Dict, List, Optional, Sequence, Union, cast
from typing import Any, Callable, Set, Tuple # noqa # used in type hints

Expand Down Expand Up @@ -1121,8 +1122,9 @@ def __init__(self) -> None:

def add(self, entry: DNSRecord) -> None:
"""Adds an entry"""
# Insert first in list so get returns newest entry
self.cache.setdefault(entry.key, []).insert(0, entry)
# Insert last in list, get will return newest entry
# iteration will result in last update winning
self.cache.setdefault(entry.key, []).append(entry)

def remove(self, entry: DNSRecord) -> None:
"""Removes an entry"""
Expand All @@ -1142,7 +1144,7 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
matching entry."""
try:
list_ = self.cache[entry.key]
for cached_entry in list_:
for cached_entry in reversed(list_):
if entry.__eq__(cached_entry):
return cached_entry
return None
Expand All @@ -1164,7 +1166,7 @@ def entries_with_name(self, name: str) -> List[DNSRecord]:

def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]:
now = current_time_millis()
for record in self.entries_with_name(name):
for record in reversed(self.entries_with_name(name)):
if (
record.type == _TYPE_PTR
and not record.is_expired(now)
Expand Down Expand Up @@ -1400,7 +1402,7 @@ def __init__(
self.services = {} # type: Dict[str, DNSRecord]
self.next_time = current_time_millis()
self.delay = delay
self._handlers_to_call = [] # type: List[Callable[[Zeroconf], None]]
self._handlers_to_call = OrderedDict() # type: OrderedDict[str, ServiceStateChange]

self._service_state_changed = Signal()

Expand Down Expand Up @@ -1445,14 +1447,30 @@ def service_state_changed(self) -> SignalRegistrationInterface:
def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
"""Callback invoked by Zeroconf when new information arrives.
Updates information required by browser in the Zeroconf cache."""
Updates information required by browser in the Zeroconf cache.
Ensures that there is are no unecessary duplicates in the list
"""

def enqueue_callback(state_change: ServiceStateChange, name: str) -> None:
self._handlers_to_call.append(
lambda zeroconf: self._service_state_changed.fire(
zeroconf=zeroconf, service_type=self.type, name=name, state_change=state_change

# Code to ensure we only do a single update message
# Precedence is; Added, Remove, Update

if (
state_change is ServiceStateChange.Added
or (
state_change is ServiceStateChange.Removed
and (
self._handlers_to_call.get(name) is ServiceStateChange.Updated
or self._handlers_to_call.get(name) is ServiceStateChange.Added
or self._handlers_to_call.get(name) is None
)
)
)
or (state_change is ServiceStateChange.Updated and name not in self._handlers_to_call)
):
self._handlers_to_call[name] = state_change

if record.type == _TYPE_PTR and record.name == self.type:
assert isinstance(record, DNSPointer)
Expand All @@ -1476,8 +1494,20 @@ def enqueue_callback(state_change: ServiceStateChange, name: str) -> None:
if expires < self.next_time:
self.next_time = expires

elif record.type == _TYPE_TXT and record.name.endswith(self.type):
assert isinstance(record, DNSText)
elif record.type == _TYPE_A or record.type == _TYPE_AAAA:
assert isinstance(record, DNSAddress)

# Iterate through the DNSCache and callback any services that use this address
for service in zc.cache.entries():
if (
isinstance(service, DNSService)
and service.name.endswith(self.type)
and service.server == record.name
and not record.is_expired(now)
):
enqueue_callback(ServiceStateChange.Updated, service.name)

elif record.name.endswith(self.type):
expired = record.is_expired(now)
if not expired:
enqueue_callback(ServiceStateChange.Updated, record.name)
Expand Down Expand Up @@ -1509,8 +1539,11 @@ def run(self) -> None:
self.delay = min(_BROWSER_BACKOFF_LIMIT * 1000, self.delay * 2)

if len(self._handlers_to_call) > 0 and not self.zc.done:
handler = self._handlers_to_call.pop(0)
handler(self.zc)
with self.zc._handlers_lock:
handler = self._handlers_to_call.popitem(False)
self._service_state_changed.fire(
zeroconf=self.zc, service_type=self.type, name=handler[0], state_change=handler[1]
)


class ServiceInfo(RecordUpdateListener):
Expand Down Expand Up @@ -2173,6 +2206,8 @@ def __init__(

self.debug = None # type: Optional[DNSOutgoing]

self._handlers_lock = threading.Lock() # ensure we process a full message in one go

@property
def done(self) -> bool:
return self._GLOBAL_DONE
Expand Down Expand Up @@ -2449,42 +2484,45 @@ def update_record(self, now: float, rec: DNSRecord) -> None:
def handle_response(self, msg: DNSIncoming) -> None:
"""Deal with incoming response packets. All answers
are held in the cache, and listeners are notified."""
now = current_time_millis()
for record in msg.answers:

updated = True

if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2
# Since the cache format is keyed on the lower case record name
# we can avoid iterating everything in the cache and
# only look though entries for the specific name.
# entries_with_name will take care of converting to lowercase
#
# We make a copy of the list that entries_with_name returns
# since we cannot iterate over something we might remove
for entry in self.cache.entries_with_name(record.name).copy():

if entry == record:
updated = False
with self._handlers_lock:

# Check the time first because it is far cheaper
# than the __eq__
if (record.created - entry.created > 1000) and DNSEntry.__eq__(entry, record):
self.cache.remove(entry)

expired = record.is_expired(now)
maybe_entry = self.cache.get(record)
if not expired:
if maybe_entry is not None:
maybe_entry.reset_ttl(record)
now = current_time_millis()
for record in msg.answers:

updated = True

if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2
# Since the cache format is keyed on the lower case record name
# we can avoid iterating everything in the cache and
# only look though entries for the specific name.
# entries_with_name will take care of converting to lowercase
#
# We make a copy of the list that entries_with_name returns
# since we cannot iterate over something we might remove
for entry in self.cache.entries_with_name(record.name).copy():

if entry == record:
updated = False

# Check the time first because it is far cheaper
# than the __eq__
if (record.created - entry.created > 1000) and DNSEntry.__eq__(entry, record):
self.cache.remove(entry)

expired = record.is_expired(now)
maybe_entry = self.cache.get(record)
if not expired:
if maybe_entry is not None:
maybe_entry.reset_ttl(record)
else:
self.cache.add(record)
if updated:
self.update_record(now, record)
else:
self.cache.add(record)
if updated:
self.update_record(now, record)
else:
if maybe_entry is not None:
self.update_record(now, record)
self.cache.remove(maybe_entry)
if maybe_entry is not None:
self.update_record(now, record)
self.cache.remove(maybe_entry)

def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None:
"""Deal with incoming query packets. Provides a response if
Expand Down
102 changes: 65 additions & 37 deletions zeroconf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def test_numbers(self):
def test_numbers_questions(self):
generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN)
for i in range(10):
for i in range(10): # pylint: disable=unused-variable

This comment has been minimized.

Copy link
@PhilippSelenium

PhilippSelenium Apr 30, 2020

@mattsaxon I think one could get rid of # pylint: disable=unused-variable by using _

This comment has been minimized.

Copy link
@jstasiak

jstasiak Apr 30, 2020

Collaborator

@PhilippSelenium That's on me, I resolved one merge conflict incorrectly, this wasn't in the final pull request I took this from.

generated.add_question(question)
bytes = generated.packet()
(num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12])
Expand Down Expand Up @@ -756,7 +756,7 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT):
"""Sends an outgoing packet."""
nonlocal nbr_answers, nbr_additionals, nbr_authorities

for answer, time_ in out.answers:
for answer, time_ in out.answers: # pylint: disable=unused-variable
nbr_answers += 1
assert answer.ttl == get_ttl(answer.type)
for answer in out.additionals:
Expand Down Expand Up @@ -1053,62 +1053,57 @@ def test_update_record(self):

service_name = 'name._type._tcp.local.'
service_type = '_type._tcp.local.'
service_server = 'ash-2.local.'
service_text = b'path=/~paulsm/'
service_server = 'ash-1.local.'
service_text = b'path=/~matt1/'
service_address = '10.0.1.2'

service_added = False
service_removed = False
service_added_count = 0
service_removed_count = 0
service_updated_count = 0
service_add_event = Event()
service_removed_event = Event()
service_updated_event = Event()

class MyServiceListener(r.ServiceListener):
def add_service(self, zc, type_, name) -> None:
nonlocal service_added
service_added = True
nonlocal service_added_count
service_added_count += 1
service_add_event.set()

def remove_service(self, zc, type_, name) -> None:
nonlocal service_added, service_removed
service_added = False
service_removed = True
nonlocal service_removed_count
service_removed_count += 1
service_removed_event.set()

def update_service(self, zc, type_, name) -> None:
nonlocal service_updated_count
service_updated_count += 1

service_info = zc.get_service_info(type_, name)
assert service_info.addresses[0] == socket.inet_aton(service_address)
assert service_info.text == service_text
assert service_info.server == service_server
service_updated_event.set()

def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
ttl = 120
generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)

if service_state_change == r.ServiceStateChange.Updated:
generated.add_answer_at_time(
r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0
)
return r.DNSIncoming(generated.packet())
generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)

if service_state_change == r.ServiceStateChange.Removed:
ttl = 0
else:
ttl = 120

generated.add_answer_at_time(
r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0
r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0
)

generated.add_answer_at_time(
r.DNSService(
service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server
),
0,
)
generated.add_answer_at_time(
r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0
)

generated.add_answer_at_time(
r.DNSAddress(
service_server,
Expand All @@ -1120,36 +1115,69 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi
0,
)

generated.add_answer_at_time(
r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0
)

return r.DNSIncoming(generated.packet())

zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener())

try:
wait_time = 3

# service added
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Added))
service_add_event.wait(1)
service_updated_event.wait(1)
assert service_added is True
assert service_updated_count == 1
assert service_removed is False
service_add_event.wait(wait_time)
assert service_added_count == 1
assert service_updated_count == 0
assert service_removed_count == 0

# service updated. currently only text record can be updated
# service SRV updated
service_updated_event.clear()
service_text = b'path=/~humingchun/'
service_server = 'ash-2.local.'
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated))
service_updated_event.wait(wait_time)
assert service_added_count == 1
assert service_updated_count == 1
assert service_removed_count == 0

# service TXT updated
service_updated_event.clear()
service_text = b'path=/~matt2/'
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated))
service_updated_event.wait(1)
assert service_added is True
service_updated_event.wait(wait_time)
assert service_added_count == 1
assert service_updated_count == 2
assert service_removed is False
assert service_removed_count == 0

# service A updated
service_updated_event.clear()
service_address = '10.0.1.3'
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated))
service_updated_event.wait(wait_time)
assert service_added_count == 1
assert service_updated_count == 3
assert service_removed_count == 0

# service all updated
service_updated_event.clear()
service_server = 'ash-3.local.'
service_text = b'path=/~matt3/'
service_address = '10.0.1.3'
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated))
service_updated_event.wait(wait_time)
assert service_added_count == 1
assert service_updated_count == 4
assert service_removed_count == 0

# service removed
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Removed))
service_removed_event.wait(1)
assert service_added is False
assert service_updated_count == 2
assert service_removed is True
service_removed_event.wait(wait_time)
assert service_added_count == 1
assert service_updated_count == 4
assert service_removed_count == 1

finally:
service_browser.cancel()
Expand Down

0 comments on commit 552a030

Please sign in to comment.