Skip to content

Commit

Permalink
Merge pull request #471 from bwelling/dnssec-enum
Browse files Browse the repository at this point in the history
Improve consistency in DNSSEC code.
  • Loading branch information
rthalley committed May 16, 2020
2 parents 61989e1 + 9b64605 commit 02839ca
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 17 deletions.
46 changes: 35 additions & 11 deletions dns/dnssec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

"""Common DNSSEC-related functions and constants."""

import enum
import hashlib
import io
import struct
Expand Down Expand Up @@ -157,6 +158,13 @@ def key_id(key):
total += ((total >> 16) & 0xffff)
return total & 0xffff

class DSDigest(enum.IntEnum):
"""DNSSEC Delgation Signer Digest Algorithm"""

SHA1 = 1
SHA256 = 2
SHA384 = 4


def make_ds(name, key, algorithm, origin=None):
"""Create a DS record for a DNSSEC key.
Expand All @@ -165,7 +173,7 @@ def make_ds(name, key, algorithm, origin=None):
*key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY``, the key the DS is about.
*algorithm*, a ``str`` specifying the hash algorithm.
*algorithm*, a ``str`` or ``int`` specifying the hash algorithm.
The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case
does not matter for these strings.
Expand All @@ -177,14 +185,17 @@ def make_ds(name, key, algorithm, origin=None):
Returns a ``dns.rdtypes.ANY.DS.DS``
"""

if algorithm.upper() == 'SHA1':
dsalg = 1
try:
if isinstance(algorithm, str):
algorithm = DSDigest[algorithm.upper()]
except Exception:
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)

if algorithm == DSDigest.SHA1:
dshash = hashlib.sha1()
elif algorithm.upper() == 'SHA256':
dsalg = 2
elif algorithm == DSDigest.SHA256:
dshash = hashlib.sha256()
elif algorithm.upper() == 'SHA384':
dsalg = 4
elif algorithm == DSDigest.SHA384:
dshash = hashlib.sha384()
else:
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
Expand All @@ -195,7 +206,8 @@ def make_ds(name, key, algorithm, origin=None):
dshash.update(_to_rdata(key, origin))
digest = dshash.digest()

dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, dsalg) + digest
dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + \
digest
return dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0,
len(dsrdata))

Expand Down Expand Up @@ -524,6 +536,12 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None):
raise ValidationFailure("no RRSIGs validated")


class NSEC3Hash(enum.IntEnum):
"""NSEC3 hash algorithm"""

SHA1 = 1


def nsec3_hash(domain, salt, iterations, algorithm):
"""
Calculate the NSEC3 hash, according to
Expand All @@ -536,8 +554,8 @@ def nsec3_hash(domain, salt, iterations, algorithm):
*iterations*, an ``int``, the number of iterations.
*algorithm*, an ``int``, the hash algorithm. The only defined algorithm
is SHA1.
*algorithm*, a ``str`` or ``int``, the hash algorithm.
The only defined algorithm is SHA1.
Returns a ``str``, the encoded NSEC3 hash.
"""
Expand All @@ -546,7 +564,13 @@ def nsec3_hash(domain, salt, iterations, algorithm):
"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", "0123456789ABCDEFGHIJKLMNOPQRSTUV"
)

if algorithm != 1:
try:
if isinstance(algorithm, str):
algorithm = NSEC3Hash[algorithm.upper()]
except Exception:
raise ValueError("Wrong hash algorithm (only SHA1 is supported)")

if algorithm != NSEC3Hash.SHA1:
raise ValueError("Wrong hash algorithm (only SHA1 is supported)")

salt_encoded = salt
Expand Down
19 changes: 13 additions & 6 deletions tests/test_dnssec.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,21 +281,28 @@ def testAbsoluteED448Bad(self): # type: () -> None
class DNSSECMakeDSTestCase(unittest.TestCase):

def testMakeExampleSHA1DS(self): # type: () -> None
ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA1')
self.assertEqual(ds, example_ds_sha1)
for algorithm in ('SHA1', 'sha1', dns.dnssec.DSDigest.SHA1):
ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
self.assertEqual(ds, example_ds_sha1)

def testMakeExampleSHA256DS(self): # type: () -> None
ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA256')
self.assertEqual(ds, example_ds_sha256)
for algorithm in ('SHA256', 'sha256', dns.dnssec.DSDigest.SHA256):
ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
self.assertEqual(ds, example_ds_sha256)

def testMakeExampleSHA384DS(self): # type: () -> None
ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA384')
self.assertEqual(ds, example_ds_sha384)
for algorithm in ('SHA384', 'sha384', dns.dnssec.DSDigest.SHA384):
ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
self.assertEqual(ds, example_ds_sha384)

def testMakeSHA256DS(self): # type: () -> None
ds = dns.dnssec.make_ds(abs_dnspython_org, sep_key, 'SHA256')
self.assertEqual(ds, good_ds)

def testInvalidAlgorithm(self): # type: () -> None
for algorithm in (10, 'shax'):
with self.assertRaises(dns.dnssec.UnsupportedAlgorithm):
ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)

if __name__ == '__main__':
unittest.main()
22 changes: 22 additions & 0 deletions tests/test_nsec3_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ class NSEC3Hash(unittest.TestCase):
1,
),
("*.test-domain.dev", None, 45, "505k9g118d9sofnjhh54rr8fadgpa0ct", 1),
(
"example",
"aabbccdd",
12,
"0p9mhaveqvm6t7vbl5lop2u3t2rp3tom",
dnssec.NSEC3Hash.SHA1
),
("example", "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", "SHA1"),
("example", "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", "sha1")
]

def test_hash_function(self):
Expand All @@ -67,6 +76,19 @@ def test_hash_invalid_salt_length(self):
with self.assertRaises(ValueError):
hash = dnssec.nsec3_hash(data[0], data[1], data[2], data[4])

def test_hash_invalid_algorithm(self):
data = (
"example.com",
"9F1AB450CF71D",
0,
"qfo2sv6jaej4cm11a3npoorfrckdao2c",
1,
)
with self.assertRaises(ValueError):
dnssec.nsec3_hash(data[0], data[1], data[2], 10)
with self.assertRaises(ValueError):
dnssec.nsec3_hash(data[0], data[1], data[2], "foo")


if __name__ == "__main__":
unittest.main()

0 comments on commit 02839ca

Please sign in to comment.