diff --git a/src/zeroconf/_dns.pxd b/src/zeroconf/_dns.pxd index fa73d692..ccdcc34f 100644 --- a/src/zeroconf/_dns.pxd +++ b/src/zeroconf/_dns.pxd @@ -125,7 +125,7 @@ cdef class DNSNsec(DNSRecord): cdef class DNSRRSet: - cdef cython.list _record_sets + cdef cython.list _records cdef cython.dict _lookup @cython.locals(other=DNSRecord) diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index b546d273..0b43f410 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -174,7 +174,7 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use def suppressed_by(self, msg: 'DNSIncoming') -> bool: """Returns true if any answer in a message can suffice for the information held in this record.""" - answers = msg.answers + answers = msg.answers() for record in answers: if self._suppressed_by_answer(record): return True @@ -521,15 +521,15 @@ def __repr__(self) -> str: class DNSRRSet: """A set of dns records with a lookup to get the ttl.""" - __slots__ = ('_record_sets', '_lookup') + __slots__ = ('_records', '_lookup') - def __init__(self, record_sets: List[List[DNSRecord]]) -> None: + def __init__(self, records: List[DNSRecord]) -> None: """Create an RRset from records sets.""" - self._record_sets = record_sets - self._lookup: Optional[Dict[DNSRecord, float]] = None + self._records = records + self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None @property - def lookup(self) -> Dict[DNSRecord, float]: + def lookup(self) -> Dict[DNSRecord, DNSRecord]: """Return the lookup table.""" return self._get_lookup() @@ -537,21 +537,18 @@ def lookup_set(self) -> Set[DNSRecord]: """Return the lookup table as aset.""" return set(self._get_lookup()) - def _get_lookup(self) -> Dict[DNSRecord, float]: + def _get_lookup(self) -> Dict[DNSRecord, DNSRecord]: """Return the lookup table, building it if needed.""" if self._lookup is None: # Build the hash table so we can lookup the record ttl - self._lookup = {} - for record_sets in self._record_sets: - for record in record_sets: - self._lookup[record] = record.ttl + self._lookup = {record: record for record in self._records} return self._lookup def suppresses(self, record: _DNSRecord) -> bool: """Returns true if any answer in the rrset can suffice for the information held in this record.""" lookup = self._get_lookup() - other_ttl = lookup.get(record) - if other_ttl is None: + other = lookup.get(record) + if other is None: return False - return other_ttl > (record.ttl / 2) + return other.ttl > (record.ttl / 2) diff --git a/src/zeroconf/_handlers/query_handler.pxd b/src/zeroconf/_handlers/query_handler.pxd index 31261a69..365e3a27 100644 --- a/src/zeroconf/_handlers/query_handler.pxd +++ b/src/zeroconf/_handlers/query_handler.pxd @@ -21,7 +21,7 @@ cdef object _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL cdef class _QueryResponse: cdef bint _is_probe - cdef DNSIncoming _msg + cdef cython.list _questions cdef float _now cdef DNSCache _cache cdef cython.dict _additionals @@ -31,12 +31,12 @@ cdef class _QueryResponse: cdef cython.set _mcast_aggregate_last_second @cython.locals(record=DNSRecord) - cpdef add_qu_question_response(self, cython.dict answers) + cdef add_qu_question_response(self, cython.dict answers) - cpdef add_ucast_question_response(self, cython.dict answers) + cdef add_ucast_question_response(self, cython.dict answers) - @cython.locals(answer=DNSRecord) - cpdef add_mcast_question_response(self, cython.dict answers) + @cython.locals(answer=DNSRecord, question=DNSQuestion) + cdef add_mcast_question_response(self, cython.dict answers) @cython.locals(maybe_entry=DNSRecord) cdef bint _has_mcast_within_one_quarter_ttl(self, DNSRecord record) @@ -44,7 +44,7 @@ cdef class _QueryResponse: @cython.locals(maybe_entry=DNSRecord) cdef bint _has_mcast_record_in_last_second(self, DNSRecord record) - cpdef answers(self) + cdef QuestionAnswers answers(self) cdef class QueryHandler: @@ -70,5 +70,7 @@ cdef class QueryHandler: answer_set=cython.dict, known_answers=DNSRRSet, known_answers_set=cython.set, + is_probe=object, + now=object ) cpdef async_response(self, cython.list msgs, cython.bint unicast_source) diff --git a/src/zeroconf/_handlers/query_handler.py b/src/zeroconf/_handlers/query_handler.py index f4243021..776d6a3f 100644 --- a/src/zeroconf/_handlers/query_handler.py +++ b/src/zeroconf/_handlers/query_handler.py @@ -55,7 +55,7 @@ class _QueryResponse: __slots__ = ( "_is_probe", - "_msg", + "_questions", "_now", "_cache", "_additionals", @@ -65,15 +65,11 @@ class _QueryResponse: "_mcast_aggregate_last_second", ) - def __init__(self, cache: DNSCache, msgs: List[DNSIncoming]) -> None: + def __init__(self, cache: DNSCache, questions: List[DNSQuestion], is_probe: bool, now: float) -> None: """Build a query response.""" - self._is_probe = False - for msg in msgs: - if msg.is_probe: - self._is_probe = True - break - self._msg = msgs[0] - self._now = self._msg.now + self._is_probe = is_probe + self._questions = questions + self._now = now self._cache = cache self._additionals: _AnswerWithAdditionalsType = {} self._ucast: Set[DNSRecord] = set() @@ -107,10 +103,15 @@ def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> No if self._has_mcast_record_in_last_second(answer): self._mcast_aggregate_last_second.add(answer) - elif len(self._msg.questions) == 1 and self._msg.questions[0].type in _RESPOND_IMMEDIATE_TYPES: - self._mcast_now.add(answer) - else: - self._mcast_aggregate.add(answer) + continue + + if len(self._questions) == 1: + question = self._questions[0] + if question.type in _RESPOND_IMMEDIATE_TYPES: + self._mcast_now.add(answer) + continue + + self._mcast_aggregate.add(answer) def answers( self, @@ -262,8 +263,18 @@ def async_response( # pylint: disable=unused-argument This function must be run in the event loop as it is not threadsafe. """ - known_answers = DNSRRSet([msg.answers for msg in msgs if not msg.is_probe]) - query_res = _QueryResponse(self.cache, msgs) + answers: List[DNSRecord] = [] + is_probe = False + msg = msgs[0] + questions = msg.questions + now = msg.now + for msg in msgs: + if not msg.is_probe(): + answers.extend(msg.answers()) + else: + is_probe = True + known_answers = DNSRRSet(answers) + query_res = _QueryResponse(self.cache, questions, is_probe, now) known_answers_set: Optional[Set[DNSRecord]] = None for msg in msgs: @@ -271,7 +282,7 @@ def async_response( # pylint: disable=unused-argument if not question.unique: # unique and unicast are the same flag if not known_answers_set: # pragma: no branch known_answers_set = known_answers.lookup_set() - self.question_history.add_question_at_time(question, msg.now, known_answers_set) + self.question_history.add_question_at_time(question, now, known_answers_set) answer_set = self._answer_question(question, known_answers) if not ucast_source and question.unique: # unique and unicast are the same flag query_res.add_qu_question_response(answer_set) diff --git a/src/zeroconf/_handlers/record_manager.py b/src/zeroconf/_handlers/record_manager.py index 396bad45..63572c1e 100644 --- a/src/zeroconf/_handlers/record_manager.py +++ b/src/zeroconf/_handlers/record_manager.py @@ -87,7 +87,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: now_float = now unique_types: Set[Tuple[str, int, int]] = set() cache = self.cache - answers = msg.answers + answers = msg.answers() for record in answers: # Protect zeroconf from records that can cause denial of service. diff --git a/src/zeroconf/_protocol/incoming.pxd b/src/zeroconf/_protocol/incoming.pxd index ebd09a0e..37fc91e7 100644 --- a/src/zeroconf/_protocol/incoming.pxd +++ b/src/zeroconf/_protocol/incoming.pxd @@ -72,6 +72,10 @@ cdef class DNSIncoming: cpdef is_query(self) + cpdef is_probe(self) + + cpdef answers(self) + cpdef is_response(self) @cython.locals( diff --git a/src/zeroconf/_protocol/incoming.py b/src/zeroconf/_protocol/incoming.py index 87d25816..5838657a 100644 --- a/src/zeroconf/_protocol/incoming.py +++ b/src/zeroconf/_protocol/incoming.py @@ -172,7 +172,6 @@ def _log_exception_debug(cls, *logger_data: Any) -> None: log_exc_info = True log.debug(*(logger_data or ['Exception occurred']), exc_info=log_exc_info) - @property def answers(self) -> List[DNSRecord]: """Answers in the packet.""" if not self._did_read_others: @@ -187,7 +186,6 @@ def answers(self) -> List[DNSRecord]: ) return self._answers - @property def is_probe(self) -> bool: """Returns true if this is a probe.""" return self.num_authorities > 0 @@ -203,7 +201,7 @@ def __repr__(self) -> str: 'n_auth=%s' % self.num_authorities, 'n_add=%s' % self.num_additionals, 'questions=%s' % self.questions, - 'answers=%s' % self.answers, + 'answers=%s' % self.answers(), ] ) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 18e8c8e0..d77e7e83 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -997,7 +997,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): """Sends an outgoing packet.""" pout = DNSIncoming(out.packets()[0]) nonlocal nbr_answers - for answer in pout.answers: + for answer in pout.answers(): nbr_answers += 1 if not answer.ttl > expected_ttl / 2: unexpected_ttl.set() diff --git a/tests/test_dns.py b/tests/test_dns.py index b82f5d81..08f805f0 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -392,7 +392,7 @@ def test_rrset_does_not_consider_ttl(): longaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 100, b'same') shortaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 10, b'same') - rrset = DNSRRSet([[longarec, shortaaaarec]]) + rrset = DNSRRSet([longarec, shortaaaarec]) assert rrset.suppresses(longarec) assert rrset.suppresses(shortarec) @@ -404,7 +404,7 @@ def test_rrset_does_not_consider_ttl(): mediumarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 60, b'same') shortarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 10, b'same') - rrset2 = DNSRRSet([[mediumarec]]) + rrset2 = DNSRRSet([mediumarec]) assert not rrset2.suppresses(verylongarec) assert rrset2.suppresses(longarec) assert rrset2.suppresses(mediumarec) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 6266ad91..11b58292 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1425,8 +1425,8 @@ async def test_response_aggregation_timings(run_isolated): outgoing = send_mock.call_args[0][0] incoming = r.DNSIncoming(outgoing.packets()[0]) zc.record_manager.async_updates_from_response(incoming) - assert info.dns_pointer() in incoming.answers - assert info2.dns_pointer() in incoming.answers + assert info.dns_pointer() in incoming.answers() + assert info2.dns_pointer() in incoming.answers() send_mock.reset_mock() protocol.datagram_received(query3.packets()[0], ('127.0.0.1', const._MDNS_PORT)) @@ -1439,7 +1439,7 @@ async def test_response_aggregation_timings(run_isolated): outgoing = send_mock.call_args[0][0] incoming = r.DNSIncoming(outgoing.packets()[0]) zc.record_manager.async_updates_from_response(incoming) - assert info3.dns_pointer() in incoming.answers + assert info3.dns_pointer() in incoming.answers() send_mock.reset_mock() # Because the response was sent in the last second we need to make @@ -1461,7 +1461,7 @@ async def test_response_aggregation_timings(run_isolated): assert len(calls) == 1 outgoing = send_mock.call_args[0][0] incoming = r.DNSIncoming(outgoing.packets()[0]) - assert info.dns_pointer() in incoming.answers + assert info.dns_pointer() in incoming.answers() await aiozc.async_close() @@ -1501,7 +1501,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli outgoing = send_mock.call_args[0][0] incoming = r.DNSIncoming(outgoing.packets()[0]) zc.record_manager.async_updates_from_response(incoming) - assert info2.dns_pointer() in incoming.answers + assert info2.dns_pointer() in incoming.answers() send_mock.reset_mock() protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) @@ -1511,7 +1511,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli outgoing = send_mock.call_args[0][0] incoming = r.DNSIncoming(outgoing.packets()[0]) zc.record_manager.async_updates_from_response(incoming) - assert info2.dns_pointer() in incoming.answers + assert info2.dns_pointer() in incoming.answers() send_mock.reset_mock() protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) @@ -1534,7 +1534,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli outgoing = send_mock.call_args[0][0] incoming = r.DNSIncoming(outgoing.packets()[0]) zc.record_manager.async_updates_from_response(incoming) - assert info2.dns_pointer() in incoming.answers + assert info2.dns_pointer() in incoming.answers() @pytest.mark.asyncio diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 79f32755..a8593850 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -63,7 +63,7 @@ def test_parse_own_packet_nsec(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) generated.add_answer_at_time(answer, 0) parsed = r.DNSIncoming(generated.packets()[0]) - assert answer in parsed.answers + assert answer in parsed.answers() # Types > 255 should be ignored answer_invalid_types = r.DNSNsec( @@ -77,7 +77,7 @@ def test_parse_own_packet_nsec(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) generated.add_answer_at_time(answer_invalid_types, 0) parsed = r.DNSIncoming(generated.packets()[0]) - assert answer in parsed.answers + assert answer in parsed.answers() def test_parse_own_packet_response(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) @@ -96,7 +96,7 @@ def test_parse_own_packet_response(self): ) parsed = r.DNSIncoming(generated.packets()[0]) assert len(generated.answers) == 1 - assert len(generated.answers) == len(parsed.answers) + assert len(generated.answers) == len(parsed.answers()) def test_adding_empty_answer(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) @@ -119,7 +119,7 @@ def test_adding_empty_answer(self): ) parsed = r.DNSIncoming(generated.packets()[0]) assert len(generated.answers) == 1 - assert len(generated.answers) == len(parsed.answers) + assert len(generated.answers) == len(parsed.answers()) def test_adding_expired_answer(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) @@ -138,7 +138,7 @@ def test_adding_expired_answer(self): ) parsed = r.DNSIncoming(generated.packets()[0]) assert len(generated.answers) == 0 - assert len(generated.answers) == len(parsed.answers) + assert len(generated.answers) == len(parsed.answers()) def test_match_question(self): generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) @@ -221,7 +221,7 @@ def test_dns_hinfo(self): generated = r.DNSOutgoing(0) generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'os')) parsed = r.DNSIncoming(generated.packets()[0]) - answer = cast(r.DNSHinfo, parsed.answers[0]) + answer = cast(r.DNSHinfo, parsed.answers()[0]) assert answer.cpu == 'cpu' assert answer.os == 'os' @@ -276,15 +276,15 @@ def test_many_questions_with_many_known_answers(self): parsed1 = r.DNSIncoming(packets[0]) assert len(parsed1.questions) == 30 - assert len(parsed1.answers) == 88 + assert len(parsed1.answers()) == 88 assert parsed1.truncated parsed2 = r.DNSIncoming(packets[1]) assert len(parsed2.questions) == 0 - assert len(parsed2.answers) == 101 + assert len(parsed2.answers()) == 101 assert parsed2.truncated parsed3 = r.DNSIncoming(packets[2]) assert len(parsed3.questions) == 0 - assert len(parsed3.answers) == 11 + assert len(parsed3.answers()) == 11 assert not parsed3.truncated def test_massive_probe_packet_split(self): @@ -375,7 +375,7 @@ def test_only_one_answer_can_by_large(self): for packet in packets: parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 1 + assert len(parsed.answers()) == 1 def test_questions_do_not_end_up_every_packet(self): """Test that questions are not sent again when multiple packets are needed. @@ -413,11 +413,11 @@ def test_questions_do_not_end_up_every_packet(self): parsed1 = r.DNSIncoming(packets[0]) assert len(parsed1.questions) == 35 - assert len(parsed1.answers) == 33 + assert len(parsed1.answers()) == 33 parsed2 = r.DNSIncoming(packets[1]) assert len(parsed2.questions) == 0 - assert len(parsed2.answers) == 2 + assert len(parsed2.answers()) == 2 class PacketForm(unittest.TestCase): @@ -482,7 +482,7 @@ def test_incoming_unknown_type(self): generated.add_additional_answer(answer) packet = generated.packets()[0] parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 0 + assert len(parsed.answers()) == 0 assert parsed.is_query() != parsed.is_response() def test_incoming_circular_reference(self): @@ -505,7 +505,7 @@ def test_incoming_ipv6(self): generated.add_additional_answer(answer) packet = generated.packets()[0] parsed = r.DNSIncoming(packet) - record = parsed.answers[0] + record = parsed.answers()[0] assert isinstance(record, r.DNSAddress) assert record.address == packed @@ -662,7 +662,7 @@ def test_dns_compression_rollback_for_corruption(): incoming = r.DNSIncoming(packet) assert incoming.valid is True assert ( - len(incoming.answers) + len(incoming.answers()) == incoming.num_answers + incoming.num_authorities + incoming.num_additionals ) @@ -767,7 +767,7 @@ def test_parse_packet_with_nsec_record(): b"\x00\x00\x80\x00@" ) parsed = DNSIncoming(nsec_packet) - nsec_record = cast(r.DNSNsec, parsed.answers[3]) + nsec_record = cast(r.DNSNsec, parsed.answers()[3]) assert "nsec," in str(nsec_record) assert nsec_record.rdtypes == [16, 33] assert nsec_record.next_name == "MyHome54 (2)._meshcop._udp.local." @@ -794,8 +794,8 @@ def test_records_same_packet_share_fate(): for packet in out.packets(): dnsin = DNSIncoming(packet) - first_time = dnsin.answers[0].created - for answer in dnsin.answers: + first_time = dnsin.answers()[0].created + for answer in dnsin.answers(): assert answer.created == first_time @@ -828,7 +828,7 @@ def test_dns_compression_all_invalid(caplog): ) parsed = r.DNSIncoming(packet, ("2.4.5.4", 5353)) assert len(parsed.questions) == 0 - assert len(parsed.answers) == 0 + assert len(parsed.answers()) == 0 assert " Unable to parse; skipping record" in caplog.text @@ -845,7 +845,7 @@ def test_invalid_next_name_ignored(): ) parsed = r.DNSIncoming(packet) assert len(parsed.questions) == 1 - assert len(parsed.answers) == 2 + assert len(parsed.answers()) == 2 def test_dns_compression_invalid_skips_record(): @@ -868,7 +868,7 @@ def test_dns_compression_invalid_skips_record(): 'eufy HomeBase2-2464._hap._tcp.local.', [const._TYPE_TXT, const._TYPE_SRV], ) - assert answer in parsed.answers + assert answer in parsed.answers() def test_dns_compression_points_forward(): @@ -893,7 +893,7 @@ def test_dns_compression_points_forward(): 'TV Beneden (2)._androidtvremote._tcp.local.', [const._TYPE_TXT, const._TYPE_SRV], ) - assert answer in parsed.answers + assert answer in parsed.answers() def test_dns_compression_points_to_itself(): @@ -904,7 +904,7 @@ def test_dns_compression_points_to_itself(): b"\x01\x00\x04\xc0\xa8\xd0\x06" ) parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 1 + assert len(parsed.answers()) == 1 def test_dns_compression_points_beyond_packet(): @@ -915,7 +915,7 @@ def test_dns_compression_points_beyond_packet(): b'\x00\x01\x00\x04\xc0\xa8\xd0\x06' ) parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 1 + assert len(parsed.answers()) == 1 def test_dns_compression_generic_failure(caplog): @@ -926,7 +926,7 @@ def test_dns_compression_generic_failure(caplog): b'\x00\x01\x00\x04\xc0\xa8\xd0\x06' ) parsed = r.DNSIncoming(packet, ("1.2.3.4", 5353)) - assert len(parsed.answers) == 1 + assert len(parsed.answers()) == 1 assert "Received invalid packet from ('1.2.3.4', 5353)" in caplog.text @@ -946,7 +946,7 @@ def test_label_length_attack(): b'\x01\x00\x04\xc0\xa8\xd0\x06' ) parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 0 + assert len(parsed.answers()) == 0 def test_label_compression_attack(): @@ -976,7 +976,7 @@ def test_label_compression_attack(): b'\x0c\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x06' ) parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 1 + assert len(parsed.answers()) == 1 def test_dns_compression_loop_attack(): @@ -993,7 +993,7 @@ def test_dns_compression_loop_attack(): b'\x04\xc0\xa8\xd0\x05' ) parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 0 + assert len(parsed.answers()) == 0 def test_txt_after_invalid_nsec_name_still_usable(): @@ -1013,7 +1013,7 @@ def test_txt_after_invalid_nsec_name_still_usable(): b'ce=0' ) parsed = r.DNSIncoming(packet) - txt_record = cast(r.DNSText, parsed.answers[4]) + txt_record = cast(r.DNSText, parsed.answers()[4]) # The NSEC record with the invalid name compression should be skipped assert txt_record.text == ( b'2info=/api/v1/players/RINCON_542A1BC9220E01400/info\x06vers=3\x10protovers' @@ -1022,4 +1022,4 @@ def test_txt_after_invalid_nsec_name_still_usable(): b'00/xml/device_description.xml\x0csslport=1443\x0ehhsslport=1843\tvarian' b't=2\x0emdnssequence=0' ) - assert len(parsed.answers) == 5 + assert len(parsed.answers()) == 5