diff --git a/dns/message.py b/dns/message.py index 44cacbd9..8513db95 100644 --- a/dns/message.py +++ b/dns/message.py @@ -20,7 +20,7 @@ import contextlib import io import time -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast import dns.edns import dns.entropy @@ -912,6 +912,14 @@ def set_opcode(self, opcode: dns.opcode.Opcode) -> None: self.flags &= 0x87FF self.flags |= dns.opcode.to_flags(opcode) + def get_options(self, otype: dns.edns.OptionType) -> List[dns.edns.Option]: + """Return the list of options of the specified type.""" + return [option for option in self.options if option.otype == otype] + + def extended_errors(self) -> List[dns.edns.EDEOption]: + """Return the list of Extended DNS Error (EDE) options in the message""" + return cast(List[dns.edns.EDEOption], self.get_options(dns.edns.OptionType.EDE)) + def _get_one_rr_per_rrset(self, value): # What the caller picked is fine. return value diff --git a/tests/test_message.py b/tests/test_message.py index 93c8aafd..bbd45718 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -991,6 +991,15 @@ def test_section_count_update(self): self.assertEqual(update.section_count(dns.update.UpdateSection.PREREQ), 5) self.assertEqual(update.section_count(dns.update.UpdateSection.UPDATE), 7) + def test_extended_errors(self): + options = [ + dns.edns.EDEOption(dns.edns.EDECode.NETWORK_ERROR, "tubes not tubing"), + dns.edns.EDEOption(dns.edns.EDECode.OTHER, "catch all code"), + ] + r = dns.message.make_query("example", "A", use_edns=0, options=options) + r.flags |= dns.flags.QR + self.assertEqual(r.extended_errors(), options) + if __name__ == "__main__": unittest.main()