Skip to content

Commit

Permalink
Add Message.section_count(). (#1024)
Browse files Browse the repository at this point in the history
Adds a method to return a count of the number of records in each
section.
  • Loading branch information
bwelling committed Dec 20, 2023
1 parent 1fa7860 commit 63aa46c
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
28 changes: 28 additions & 0 deletions dns/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,34 @@ def get_rrset(
rrset = None
return rrset

def section_count(self, section: SectionType) -> int:
"""Returns the number of records in the specified section.
*section*, an ``int`` section number, a ``str`` section name, or one of
the section attributes of this message. This specifies the
the section of the message to count. For example::
my_message.section_count(my_message.answer)
my_message.section_count(dns.message.ANSWER)
my_message.section_count("ANSWER")
"""

if isinstance(section, int):
section_number = section
section = self.section_from_number(section_number)
elif isinstance(section, str):
section_number = MessageSection.from_text(section)
section = self.section_from_number(section_number)
else:
section_number = self.section_number(section)
count = sum(max(1, len(rrs)) for rrs in section)
if section_number == MessageSection.ADDITIONAL:
if self.opt is not None:
count += 1
if self.tsig is not None:
count += 1
return count

def _compute_opt_reserve(self) -> int:
"""Compute the size required for the OPT RR, padding excluded"""
if not self.opt:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,38 @@ def test_prefer_truncation_additional(self):
self.assertEqual(r2.flags & dns.flags.TC, 0)
self.assertEqual(len(r2.additional), 30)

def test_section_count(self):
a = dns.message.from_text(answer_text)
self.assertEqual(a.section_count(a.question), 1)
self.assertEqual(a.section_count(a.answer), 1)
self.assertEqual(a.section_count("authority"), 3)
self.assertEqual(a.section_count(dns.message.MessageSection.ADDITIONAL), 1)

a.use_edns()
a.use_tsig(dns.tsig.Key("foo.", b"abcd"))
self.assertEqual(a.section_count(dns.message.MessageSection.ADDITIONAL), 3)

def test_section_count_update(self):
update = dns.update.Update("example")
update.id = 1
# These each add 1 record to the prereq section
update.present("foo")
update.present("foo", "a")
update.present("bar", "a", "10.0.0.5")
update.absent("blaz2")
update.absent("blaz2", "a")
# This adds 3 records to the update section
update.replace("foo", 300, "a", "10.0.0.1", "10.0.0.2")
# These each add 1 record to the update section
update.add("bar", dns.rdataset.from_text(1, 1, 300, "10.0.0.3"))
update.delete("bar", "a", "10.0.0.4")
update.delete("blaz", "a")
update.delete("blaz2")

self.assertEqual(update.section_count(dns.update.UpdateSection.ZONE), 1)
self.assertEqual(update.section_count(dns.update.UpdateSection.PREREQ), 5)
self.assertEqual(update.section_count(dns.update.UpdateSection.UPDATE), 7)


if __name__ == "__main__":
unittest.main()

0 comments on commit 63aa46c

Please sign in to comment.