diff --git a/src/ecdsa/der.py b/src/ecdsa/der.py index 3ed7f6c3..2d2a7a25 100644 --- a/src/ecdsa/der.py +++ b/src/ecdsa/der.py @@ -2,6 +2,7 @@ import binascii import base64 +import warnings from six import int2byte, b, integer_types, text_type @@ -29,8 +30,61 @@ def encode_integer(r): return b("\x02") + int2byte(len(s)+1) + b("\x00") + s -def encode_bitstring(s): - return b("\x03") + encode_length(len(s)) + s +# sentry object to check if an argument was specified (used to detect +# deprecated calling convention) +_sentry = object() + + +def encode_bitstring(s, unused=_sentry): + """ + Encode a binary string as a BIT STRING using :term:`DER` encoding. + + Note, because there is no native Python object that can encode an actual + bit string, this function only accepts byte strings as the `s` argument. + The byte string is the actual bit string that will be encoded, padded + on the right (least significant bits, looking from big endian perspective) + to the first full byte. If the bit string has a bit length that is multiple + of 8, then the padding should not be included. For correct DER encoding + the padding bits MUST be set to 0. + + Number of bits of padding need to be provided as the `unused` parameter. + In case they are specified as None, it means the number of unused bits + is already encoded in the string as the first byte. + + The deprecated call convention specifies just the `s` parameters and + encodes the number of unused bits as first parameter (same convention + as with None). + + Empty string must be encoded with `unused` specified as 0. + + Future version of python-ecdsa will make specifying the `unused` argument + mandatory. + + :param s: bytes to encode + :type s: bytes like object + :param unused: number of bits at the end of `s` that are unused, must be + between 0 and 7 (inclusive) + :type unused: int or None + + :raises ValueError: when `unused` is too large or too small + + :return: `s` encoded using DER + :rtype: bytes + """ + encoded_unused = b'' + len_extra = 0 + if unused is _sentry: + warnings.warn("Legacy call convention used, unused= needs to be " + "specified", + DeprecationWarning) + elif unused is not None: + if not 0 <= unused <= 7: + raise ValueError("unused must be integer between 0 and 7") + if unused and not s: + raise ValueError("unused is non-zero but s is empty") + encoded_unused = int2byte(unused) + len_extra = 1 + return b("\x03") + encode_length(len(s) + len_extra) + encoded_unused + s def encode_octet_string(s): @@ -198,13 +252,77 @@ def read_length(string): return int(binascii.hexlify(string[1:1+llen]), 16), 1+llen -def remove_bitstring(string): +def remove_bitstring(string, expect_unused=_sentry): + """ + Remove a BIT STRING object from `string` following :term:`DER`. + + The `expect_unused` can be used to specify if the bit string should + have the amount of unused bits decoded or not. If it's an integer, any + read BIT STRING that has number of unused bits different from specified + value will cause UnexpectedDER exception to be raised (this is especially + useful when decoding BIT STRINGS that have DER encoded object in them; + DER encoding is byte oriented, so the unused bits will always equal 0). + + If the `expect_unused` is specified as None, the first element returned + will be a tuple, with the first value being the extracted bit string + while the second value will be the decoded number of unused bits. + + If the `expect_unused` is unspecified, the decoding of byte with + number of unused bits will not be attempted and the bit string will be + returned as-is, the callee will be required to decode it and verify its + correctness. + + Future version of python will require the `expected_unused` parameter + to be specified. + + :param string: string of bytes to extract the BIT STRING from + :type string: bytes like object + :param expect_unused: number of bits that should be unused in the BIT + STRING, or None, to return it to caller + :type expect_unused: int or None + + :raises UnexpectedDER: when the encoding does not follow DER. + + :return: a tuple with first element being the extracted bit string and + the second being the remaining bytes in the string (if any); if the + `expect_unused` is specified as None, the first element of the returned + tuple will be a tuple itself, with first element being the bit string + as bytes and the second element bing the number of unused bits at the + end of the byte array as an integer + :rtype: tuple + """ + if not string: + raise UnexpectedDER("Empty string does not encode a bitstring") + if expect_unused is _sentry: + warnings.warn("Legacy call convention used, expect_unused= needs to be" + " specified", + DeprecationWarning) num = string[0] if isinstance(string[0], integer_types) else ord(string[0]) if not string.startswith(b("\x03")): raise UnexpectedDER("wanted bitstring (0x03), got 0x%02x" % num) length, llen = read_length(string[1:]) + if not length: + raise UnexpectedDER("Invalid length of bit string, can't be 0") body = string[1+llen:1+llen+length] rest = string[1+llen+length:] + if expect_unused is not _sentry: + unused = body[0] if isinstance(body[0], integer_types) \ + else ord(body[0]) + if not 0 <= unused <= 7: + raise UnexpectedDER("Invalid encoding of unused bits") + if expect_unused is not None and expect_unused != unused: + raise UnexpectedDER("Unexpected number of unused bits") + body = body[1:] + if unused: + if not body: + raise UnexpectedDER("Invalid encoding of empty bit string") + last = body[-1] if isinstance(body[-1], integer_types) else \ + ord(body[-1]) + # verify that all the unused bits are set to zero (DER requirement) + if last & (2 ** unused - 1): + raise UnexpectedDER("Non zero padding bits in bit string") + if expect_unused is None: + body = (body, unused) return body, rest # SEQUENCE([1, STRING(secexp), cont[0], OBJECT(curvename), cont[1], BINTSTRING) diff --git a/src/ecdsa/keys.py b/src/ecdsa/keys.py index 8c006b51..5a979d5a 100644 --- a/src/ecdsa/keys.py +++ b/src/ecdsa/keys.py @@ -149,16 +149,14 @@ def from_der(klass, string): raise der.UnexpectedDER("Unexpected object identifier in DER " "encoding: {0!r}".format(oid_pk)) curve = find_curve(oid_curve) - point_str, empty = der.remove_bitstring(point_str_bitstring) + point_str, empty = der.remove_bitstring(point_str_bitstring, 0) if empty != b(""): raise der.UnexpectedDER("trailing junk after pubkey pointstring: %s" % binascii.hexlify(empty)) - # the point encoding is padded with a zero byte # raw encoding of point is invalid in DER files - if not point_str.startswith(b("\x00")) or \ - len(point_str[1:]) == curve.verifying_key_length: + if len(point_str) == curve.verifying_key_length: raise der.UnexpectedDER("Malformed encoding of public point") - return klass.from_string(point_str[1:], curve) + return klass.from_string(point_str, curve) @classmethod def from_public_key_recovery(cls, signature, data, curve, hashfunc=sha1, @@ -226,10 +224,12 @@ def to_pem(self): return der.topem(self.to_der(), "PUBLIC KEY") def to_der(self, point_encoding="uncompressed"): - point_str = b("\x00") + self.to_string(point_encoding) + point_str = self.to_string(point_encoding) return der.encode_sequence(der.encode_sequence(encoded_oid_ecPublicKey, self.curve.encoded_oid), - der.encode_bitstring(point_str)) + # 0 is the number of unused bits in the + # bit string + der.encode_bitstring(point_str, 0)) def verify(self, signature, data, hashfunc=None, sigdecode=sigdecode_string): hashfunc = hashfunc or self.default_hashfunc @@ -334,7 +334,7 @@ def from_der(klass, string, hashfunc=sha1): # if tag != 1: # raise der.UnexpectedDER("expected tag 1 in DER privkey, got %d" # % tag) - # pubkey_str = der.remove_bitstring(pubkey_bitstring) + # pubkey_str = der.remove_bitstring(pubkey_bitstring, 0) # if empty != "": # raise der.UnexpectedDER("trailing junk after DER privkey " # "pubkeystr: %s" % binascii.hexlify(empty)) @@ -356,13 +356,15 @@ def to_pem(self): def to_der(self, point_encoding="uncompressed"): # SEQ([int(1), octetstring(privkey),cont[0], oid(secp224r1), # cont[1],bitstring]) - encoded_vk = b("\x00") + \ - self.get_verifying_key().to_string(point_encoding) - return der.encode_sequence(der.encode_integer(1), - der.encode_octet_string(self.to_string()), - der.encode_constructed(0, self.curve.encoded_oid), - der.encode_constructed(1, der.encode_bitstring(encoded_vk)), - ) + encoded_vk = self.get_verifying_key().to_string(point_encoding) + # the 0 in encode_bitstring specifies the number of unused bits + # in the `encoded_vk` string + return der.encode_sequence( + der.encode_integer(1), + der.encode_octet_string(self.to_string()), + der.encode_constructed(0, self.curve.encoded_oid), + der.encode_constructed(1, der.encode_bitstring(encoded_vk, 0)), + ) def get_verifying_key(self): return self.verifying_key diff --git a/src/ecdsa/test_der.py b/src/ecdsa/test_der.py index 76bdfd46..cae04e0b 100644 --- a/src/ecdsa/test_der.py +++ b/src/ecdsa/test_der.py @@ -5,8 +5,11 @@ import unittest2 as unittest except ImportError: import unittest -from .der import remove_integer, UnexpectedDER, read_length +from .der import remove_integer, UnexpectedDER, read_length, encode_bitstring,\ + remove_bitstring from six import b +import pytest +import warnings class TestRemoveInteger(unittest.TestCase): # DER requires the integers to be 0-padded only if they would be @@ -92,3 +95,128 @@ def test_empty_string(self): def test_length_overflow(self): with self.assertRaises(UnexpectedDER): read_length(b('\x83\x01\x00')) + + +class TestEncodeBitstring(unittest.TestCase): + # DER requires BIT STRINGS to include a number of padding bits in the + # encoded byte string, that padding must be between 0 and 7 + + def test_old_call_convention(self): + """This is the old way to use the function.""" + warnings.simplefilter('always') + with pytest.warns(DeprecationWarning) as warns: + der = encode_bitstring(b'\x00\xff') + + self.assertEqual(len(warns), 1) + self.assertIn("unused= needs to be specified", + warns[0].message.args[0]) + + self.assertEqual(der, b'\x03\x02\x00\xff') + + def test_new_call_convention(self): + """This is how it should be called now.""" + warnings.simplefilter('always') + with pytest.warns(None) as warns: + der = encode_bitstring(b'\xff', 0) + + # verify that new call convention doesn't raise Warnings + self.assertEqual(len(warns), 0) + + self.assertEqual(der, b'\x03\x02\x00\xff') + + def test_implicit_unused_bits(self): + """ + Writing bit string with already included the number of unused bits. + """ + warnings.simplefilter('always') + with pytest.warns(None) as warns: + der = encode_bitstring(b'\x00\xff', None) + + # verify that new call convention doesn't raise Warnings + self.assertEqual(len(warns), 0) + + self.assertEqual(der, b'\x03\x02\x00\xff') + + def test_empty_string(self): + self.assertEqual(encode_bitstring(b'', 0), b'\x03\x01\x00') + + def test_invalid_unused_count(self): + with self.assertRaises(ValueError): + encode_bitstring(b'\xff\x00', 8) + + def test_invalid_unused_with_empty_string(self): + with self.assertRaises(ValueError): + encode_bitstring(b'', 1) + + +class TestRemoveBitstring(unittest.TestCase): + def test_old_call_convention(self): + """This is the old way to call the function.""" + warnings.simplefilter('always') + with pytest.warns(DeprecationWarning) as warns: + bits, rest = remove_bitstring(b'\x03\x02\x00\xff') + + self.assertEqual(len(warns), 1) + self.assertIn("expect_unused= needs to be specified", + warns[0].message.args[0]) + + self.assertEqual(bits, b'\x00\xff') + self.assertEqual(rest, b'') + + def test_new_call_convention(self): + warnings.simplefilter('always') + with pytest.warns(None) as warns: + bits, rest = remove_bitstring(b'\x03\x02\x00\xff', 0) + + self.assertEqual(len(warns), 0) + + self.assertEqual(bits, b'\xff') + self.assertEqual(rest, b'') + + def test_implicit_unexpected_unused(self): + warnings.simplefilter('always') + with pytest.warns(None) as warns: + bits, rest = remove_bitstring(b'\x03\x02\x00\xff', None) + + self.assertEqual(len(warns), 0) + + self.assertEqual(bits, (b'\xff', 0)) + self.assertEqual(rest, b'') + + def test_with_padding(self): + ret, rest = remove_bitstring(b'\x03\x02\x04\xf0', None) + + self.assertEqual(ret, (b'\xf0', 4)) + self.assertEqual(rest, b'') + + def test_not_a_bitstring(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x02\x02\x00\xff', None) + + def test_empty_encoding(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03\x00', None) + + def test_empty_string(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'', None) + + def test_no_length(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03', None) + + def test_unexpected_number_of_unused_bits(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03\x02\x00\xff', 1) + + def test_invalid_encoding_of_unused_bits(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03\x03\x08\xff\x00', None) + + def test_invalid_encoding_of_empty_string(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03\x01\x01', None) + + def test_invalid_padding_bits(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03\x02\x01\xff', None) diff --git a/src/ecdsa/test_pyecdsa.py b/src/ecdsa/test_pyecdsa.py index 36b989df..8644e3dc 100644 --- a/src/ecdsa/test_pyecdsa.py +++ b/src/ecdsa/test_pyecdsa.py @@ -282,7 +282,7 @@ def test_vk_from_der_garbage_after_curve_oid(self): curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + \ b('garbage') enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) - point_der = der.encode_bitstring(b'\x00\xff') + point_der = der.encode_bitstring(b'\x00\xff', None) to_decode = der.encode_sequence(enc_type_der, point_der) with self.assertRaises(der.UnexpectedDER): @@ -292,7 +292,7 @@ def test_vk_from_der_invalid_key_type(self): type_oid_der = der.encode_oid(*(1, 2, 3)) curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) - point_der = der.encode_bitstring(b'\x00\xff') + point_der = der.encode_bitstring(b'\x00\xff', None) to_decode = der.encode_sequence(enc_type_der, point_der) with self.assertRaises(der.UnexpectedDER): @@ -302,7 +302,7 @@ def test_vk_from_der_garbage_after_point_string(self): type_oid_der = encoded_oid_ecPublicKey curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) - point_der = der.encode_bitstring(b'\x00\xff') + b('garbage') + point_der = der.encode_bitstring(b'\x00\xff', None) + b('garbage') to_decode = der.encode_sequence(enc_type_der, point_der) with self.assertRaises(der.UnexpectedDER): @@ -312,7 +312,27 @@ def test_vk_from_der_invalid_bitstring(self): type_oid_der = encoded_oid_ecPublicKey curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) - point_der = der.encode_bitstring(b'\x08\xff') + point_der = der.encode_bitstring(b'\x08\xff', None) + to_decode = der.encode_sequence(enc_type_der, point_der) + + with self.assertRaises(der.UnexpectedDER): + VerifyingKey.from_der(to_decode) + + def test_vk_from_der_with_invalid_length_of_encoding(self): + type_oid_der = encoded_oid_ecPublicKey + curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) + point_der = der.encode_bitstring(b'\xff'*64, 0) + to_decode = der.encode_sequence(enc_type_der, point_der) + + with self.assertRaises(MalformedPointError): + VerifyingKey.from_der(to_decode) + + def test_vk_from_der_with_raw_encoding(self): + type_oid_der = encoded_oid_ecPublicKey + curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) + point_der = der.encode_bitstring(b'\xff'*48, 0) to_decode = der.encode_sequence(enc_type_der, point_der) with self.assertRaises(der.UnexpectedDER):