diff --git a/src/zeroconf/_handlers/query_handler.py b/src/zeroconf/_handlers/query_handler.py index cbb18eee..b232ea49 100644 --- a/src/zeroconf/_handlers/query_handler.py +++ b/src/zeroconf/_handlers/query_handler.py @@ -21,7 +21,7 @@ """ -from typing import TYPE_CHECKING, List, Set, cast +from typing import TYPE_CHECKING, List, Optional, Set, cast from .._cache import DNSCache, _UniqueRecordsType from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet @@ -109,19 +109,20 @@ def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> No else: self._mcast_aggregate.add(answer) - def _generate_answers_with_additionals(self, rrset: Set[DNSRecord]) -> _AnswerWithAdditionalsType: - """Create answers with additionals from an rrset.""" - return {record: self._additionals[record] for record in rrset} - def answers( self, ) -> QuestionAnswers: """Return answer sets that will be queued.""" return QuestionAnswers( - self._generate_answers_with_additionals(self._ucast), - self._generate_answers_with_additionals(self._mcast_now), - self._generate_answers_with_additionals(self._mcast_aggregate), - self._generate_answers_with_additionals(self._mcast_aggregate_last_second), + *( + {record: self._additionals[record] for record in rrset} + for rrset in ( + self._ucast, + self._mcast_now, + self._mcast_aggregate, + self._mcast_aggregate_last_second, + ) + ) ) def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: @@ -224,17 +225,16 @@ def _answer_question( self, question: DNSQuestion, known_answers: DNSRRSet, - now: float, ) -> _AnswerWithAdditionalsType: + """Answer a question.""" answer_set: _AnswerWithAdditionalsType = {} question_lower_name = question.name.lower() + type_ = question.type - if question.type == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME: + if type_ == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME: self._add_service_type_enumeration_query_answers(answer_set, known_answers) return answer_set - type_ = question.type - if type_ in (_TYPE_PTR, _TYPE_ANY): self._add_pointer_answers(question_lower_name, answer_set, known_answers) @@ -267,12 +267,15 @@ def async_response( # pylint: disable=unused-argument """ known_answers = DNSRRSet([msg.answers for msg in msgs if not msg.is_probe]) query_res = _QueryResponse(self.cache, msgs) + known_answers_set: Optional[Set[DNSRecord]] = None for msg in msgs: for question in msg.questions: if not question.unicast: - self.question_history.add_question_at_time(question, msg.now, set(known_answers.lookup)) - answer_set = self._answer_question(question, known_answers, msg.now) + if not known_answers_set: # pragma: no branch + known_answers_set = set(known_answers.lookup) + self.question_history.add_question_at_time(question, msg.now, known_answers_set) + answer_set = self._answer_question(question, known_answers) if not ucast_source and question.unicast: query_res.add_qu_question_response(answer_set) continue