From 85ba9900d09cbd103ea434f7b2e3a54ebe1a9b1a Mon Sep 17 00:00:00 2001 From: Hubert Kario Date: Wed, 16 Oct 2019 01:10:00 +0200 Subject: [PATCH 1/2] make encode_bitstring and decode_bitstring follow ASN.1 the functions required the callees to handle the encoded length of unused bits, handle it inside the functions now also add test coverage for that deprecate the old calling convention --- src/ecdsa/der.py | 124 +++++++++++++++++++++++++++++++++++++++- src/ecdsa/test_der.py | 130 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 250 insertions(+), 4 deletions(-) diff --git a/src/ecdsa/der.py b/src/ecdsa/der.py index 3ed7f6c3..2d2a7a25 100644 --- a/src/ecdsa/der.py +++ b/src/ecdsa/der.py @@ -2,6 +2,7 @@ import binascii import base64 +import warnings from six import int2byte, b, integer_types, text_type @@ -29,8 +30,61 @@ def encode_integer(r): return b("\x02") + int2byte(len(s)+1) + b("\x00") + s -def encode_bitstring(s): - return b("\x03") + encode_length(len(s)) + s +# sentry object to check if an argument was specified (used to detect +# deprecated calling convention) +_sentry = object() + + +def encode_bitstring(s, unused=_sentry): + """ + Encode a binary string as a BIT STRING using :term:`DER` encoding. + + Note, because there is no native Python object that can encode an actual + bit string, this function only accepts byte strings as the `s` argument. + The byte string is the actual bit string that will be encoded, padded + on the right (least significant bits, looking from big endian perspective) + to the first full byte. If the bit string has a bit length that is multiple + of 8, then the padding should not be included. For correct DER encoding + the padding bits MUST be set to 0. + + Number of bits of padding need to be provided as the `unused` parameter. + In case they are specified as None, it means the number of unused bits + is already encoded in the string as the first byte. + + The deprecated call convention specifies just the `s` parameters and + encodes the number of unused bits as first parameter (same convention + as with None). + + Empty string must be encoded with `unused` specified as 0. + + Future version of python-ecdsa will make specifying the `unused` argument + mandatory. + + :param s: bytes to encode + :type s: bytes like object + :param unused: number of bits at the end of `s` that are unused, must be + between 0 and 7 (inclusive) + :type unused: int or None + + :raises ValueError: when `unused` is too large or too small + + :return: `s` encoded using DER + :rtype: bytes + """ + encoded_unused = b'' + len_extra = 0 + if unused is _sentry: + warnings.warn("Legacy call convention used, unused= needs to be " + "specified", + DeprecationWarning) + elif unused is not None: + if not 0 <= unused <= 7: + raise ValueError("unused must be integer between 0 and 7") + if unused and not s: + raise ValueError("unused is non-zero but s is empty") + encoded_unused = int2byte(unused) + len_extra = 1 + return b("\x03") + encode_length(len(s) + len_extra) + encoded_unused + s def encode_octet_string(s): @@ -198,13 +252,77 @@ def read_length(string): return int(binascii.hexlify(string[1:1+llen]), 16), 1+llen -def remove_bitstring(string): +def remove_bitstring(string, expect_unused=_sentry): + """ + Remove a BIT STRING object from `string` following :term:`DER`. + + The `expect_unused` can be used to specify if the bit string should + have the amount of unused bits decoded or not. If it's an integer, any + read BIT STRING that has number of unused bits different from specified + value will cause UnexpectedDER exception to be raised (this is especially + useful when decoding BIT STRINGS that have DER encoded object in them; + DER encoding is byte oriented, so the unused bits will always equal 0). + + If the `expect_unused` is specified as None, the first element returned + will be a tuple, with the first value being the extracted bit string + while the second value will be the decoded number of unused bits. + + If the `expect_unused` is unspecified, the decoding of byte with + number of unused bits will not be attempted and the bit string will be + returned as-is, the callee will be required to decode it and verify its + correctness. + + Future version of python will require the `expected_unused` parameter + to be specified. + + :param string: string of bytes to extract the BIT STRING from + :type string: bytes like object + :param expect_unused: number of bits that should be unused in the BIT + STRING, or None, to return it to caller + :type expect_unused: int or None + + :raises UnexpectedDER: when the encoding does not follow DER. + + :return: a tuple with first element being the extracted bit string and + the second being the remaining bytes in the string (if any); if the + `expect_unused` is specified as None, the first element of the returned + tuple will be a tuple itself, with first element being the bit string + as bytes and the second element bing the number of unused bits at the + end of the byte array as an integer + :rtype: tuple + """ + if not string: + raise UnexpectedDER("Empty string does not encode a bitstring") + if expect_unused is _sentry: + warnings.warn("Legacy call convention used, expect_unused= needs to be" + " specified", + DeprecationWarning) num = string[0] if isinstance(string[0], integer_types) else ord(string[0]) if not string.startswith(b("\x03")): raise UnexpectedDER("wanted bitstring (0x03), got 0x%02x" % num) length, llen = read_length(string[1:]) + if not length: + raise UnexpectedDER("Invalid length of bit string, can't be 0") body = string[1+llen:1+llen+length] rest = string[1+llen+length:] + if expect_unused is not _sentry: + unused = body[0] if isinstance(body[0], integer_types) \ + else ord(body[0]) + if not 0 <= unused <= 7: + raise UnexpectedDER("Invalid encoding of unused bits") + if expect_unused is not None and expect_unused != unused: + raise UnexpectedDER("Unexpected number of unused bits") + body = body[1:] + if unused: + if not body: + raise UnexpectedDER("Invalid encoding of empty bit string") + last = body[-1] if isinstance(body[-1], integer_types) else \ + ord(body[-1]) + # verify that all the unused bits are set to zero (DER requirement) + if last & (2 ** unused - 1): + raise UnexpectedDER("Non zero padding bits in bit string") + if expect_unused is None: + body = (body, unused) return body, rest # SEQUENCE([1, STRING(secexp), cont[0], OBJECT(curvename), cont[1], BINTSTRING) diff --git a/src/ecdsa/test_der.py b/src/ecdsa/test_der.py index 76bdfd46..cae04e0b 100644 --- a/src/ecdsa/test_der.py +++ b/src/ecdsa/test_der.py @@ -5,8 +5,11 @@ import unittest2 as unittest except ImportError: import unittest -from .der import remove_integer, UnexpectedDER, read_length +from .der import remove_integer, UnexpectedDER, read_length, encode_bitstring,\ + remove_bitstring from six import b +import pytest +import warnings class TestRemoveInteger(unittest.TestCase): # DER requires the integers to be 0-padded only if they would be @@ -92,3 +95,128 @@ def test_empty_string(self): def test_length_overflow(self): with self.assertRaises(UnexpectedDER): read_length(b('\x83\x01\x00')) + + +class TestEncodeBitstring(unittest.TestCase): + # DER requires BIT STRINGS to include a number of padding bits in the + # encoded byte string, that padding must be between 0 and 7 + + def test_old_call_convention(self): + """This is the old way to use the function.""" + warnings.simplefilter('always') + with pytest.warns(DeprecationWarning) as warns: + der = encode_bitstring(b'\x00\xff') + + self.assertEqual(len(warns), 1) + self.assertIn("unused= needs to be specified", + warns[0].message.args[0]) + + self.assertEqual(der, b'\x03\x02\x00\xff') + + def test_new_call_convention(self): + """This is how it should be called now.""" + warnings.simplefilter('always') + with pytest.warns(None) as warns: + der = encode_bitstring(b'\xff', 0) + + # verify that new call convention doesn't raise Warnings + self.assertEqual(len(warns), 0) + + self.assertEqual(der, b'\x03\x02\x00\xff') + + def test_implicit_unused_bits(self): + """ + Writing bit string with already included the number of unused bits. + """ + warnings.simplefilter('always') + with pytest.warns(None) as warns: + der = encode_bitstring(b'\x00\xff', None) + + # verify that new call convention doesn't raise Warnings + self.assertEqual(len(warns), 0) + + self.assertEqual(der, b'\x03\x02\x00\xff') + + def test_empty_string(self): + self.assertEqual(encode_bitstring(b'', 0), b'\x03\x01\x00') + + def test_invalid_unused_count(self): + with self.assertRaises(ValueError): + encode_bitstring(b'\xff\x00', 8) + + def test_invalid_unused_with_empty_string(self): + with self.assertRaises(ValueError): + encode_bitstring(b'', 1) + + +class TestRemoveBitstring(unittest.TestCase): + def test_old_call_convention(self): + """This is the old way to call the function.""" + warnings.simplefilter('always') + with pytest.warns(DeprecationWarning) as warns: + bits, rest = remove_bitstring(b'\x03\x02\x00\xff') + + self.assertEqual(len(warns), 1) + self.assertIn("expect_unused= needs to be specified", + warns[0].message.args[0]) + + self.assertEqual(bits, b'\x00\xff') + self.assertEqual(rest, b'') + + def test_new_call_convention(self): + warnings.simplefilter('always') + with pytest.warns(None) as warns: + bits, rest = remove_bitstring(b'\x03\x02\x00\xff', 0) + + self.assertEqual(len(warns), 0) + + self.assertEqual(bits, b'\xff') + self.assertEqual(rest, b'') + + def test_implicit_unexpected_unused(self): + warnings.simplefilter('always') + with pytest.warns(None) as warns: + bits, rest = remove_bitstring(b'\x03\x02\x00\xff', None) + + self.assertEqual(len(warns), 0) + + self.assertEqual(bits, (b'\xff', 0)) + self.assertEqual(rest, b'') + + def test_with_padding(self): + ret, rest = remove_bitstring(b'\x03\x02\x04\xf0', None) + + self.assertEqual(ret, (b'\xf0', 4)) + self.assertEqual(rest, b'') + + def test_not_a_bitstring(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x02\x02\x00\xff', None) + + def test_empty_encoding(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03\x00', None) + + def test_empty_string(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'', None) + + def test_no_length(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03', None) + + def test_unexpected_number_of_unused_bits(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03\x02\x00\xff', 1) + + def test_invalid_encoding_of_unused_bits(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03\x03\x08\xff\x00', None) + + def test_invalid_encoding_of_empty_string(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03\x01\x01', None) + + def test_invalid_padding_bits(self): + with self.assertRaises(UnexpectedDER): + remove_bitstring(b'\x03\x02\x01\xff', None) From 650a2a8e6dccfc188ab2a377021609f8e5fccbaf Mon Sep 17 00:00:00 2001 From: Hubert Kario Date: Wed, 16 Oct 2019 01:18:24 +0200 Subject: [PATCH 2/2] make our code follow the new calling convention for bistring --- src/ecdsa/keys.py | 32 +++++++++++++++++--------------- src/ecdsa/test_pyecdsa.py | 28 ++++++++++++++++++++++++---- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/src/ecdsa/keys.py b/src/ecdsa/keys.py index 8c006b51..5a979d5a 100644 --- a/src/ecdsa/keys.py +++ b/src/ecdsa/keys.py @@ -149,16 +149,14 @@ def from_der(klass, string): raise der.UnexpectedDER("Unexpected object identifier in DER " "encoding: {0!r}".format(oid_pk)) curve = find_curve(oid_curve) - point_str, empty = der.remove_bitstring(point_str_bitstring) + point_str, empty = der.remove_bitstring(point_str_bitstring, 0) if empty != b(""): raise der.UnexpectedDER("trailing junk after pubkey pointstring: %s" % binascii.hexlify(empty)) - # the point encoding is padded with a zero byte # raw encoding of point is invalid in DER files - if not point_str.startswith(b("\x00")) or \ - len(point_str[1:]) == curve.verifying_key_length: + if len(point_str) == curve.verifying_key_length: raise der.UnexpectedDER("Malformed encoding of public point") - return klass.from_string(point_str[1:], curve) + return klass.from_string(point_str, curve) @classmethod def from_public_key_recovery(cls, signature, data, curve, hashfunc=sha1, @@ -226,10 +224,12 @@ def to_pem(self): return der.topem(self.to_der(), "PUBLIC KEY") def to_der(self, point_encoding="uncompressed"): - point_str = b("\x00") + self.to_string(point_encoding) + point_str = self.to_string(point_encoding) return der.encode_sequence(der.encode_sequence(encoded_oid_ecPublicKey, self.curve.encoded_oid), - der.encode_bitstring(point_str)) + # 0 is the number of unused bits in the + # bit string + der.encode_bitstring(point_str, 0)) def verify(self, signature, data, hashfunc=None, sigdecode=sigdecode_string): hashfunc = hashfunc or self.default_hashfunc @@ -334,7 +334,7 @@ def from_der(klass, string, hashfunc=sha1): # if tag != 1: # raise der.UnexpectedDER("expected tag 1 in DER privkey, got %d" # % tag) - # pubkey_str = der.remove_bitstring(pubkey_bitstring) + # pubkey_str = der.remove_bitstring(pubkey_bitstring, 0) # if empty != "": # raise der.UnexpectedDER("trailing junk after DER privkey " # "pubkeystr: %s" % binascii.hexlify(empty)) @@ -356,13 +356,15 @@ def to_pem(self): def to_der(self, point_encoding="uncompressed"): # SEQ([int(1), octetstring(privkey),cont[0], oid(secp224r1), # cont[1],bitstring]) - encoded_vk = b("\x00") + \ - self.get_verifying_key().to_string(point_encoding) - return der.encode_sequence(der.encode_integer(1), - der.encode_octet_string(self.to_string()), - der.encode_constructed(0, self.curve.encoded_oid), - der.encode_constructed(1, der.encode_bitstring(encoded_vk)), - ) + encoded_vk = self.get_verifying_key().to_string(point_encoding) + # the 0 in encode_bitstring specifies the number of unused bits + # in the `encoded_vk` string + return der.encode_sequence( + der.encode_integer(1), + der.encode_octet_string(self.to_string()), + der.encode_constructed(0, self.curve.encoded_oid), + der.encode_constructed(1, der.encode_bitstring(encoded_vk, 0)), + ) def get_verifying_key(self): return self.verifying_key diff --git a/src/ecdsa/test_pyecdsa.py b/src/ecdsa/test_pyecdsa.py index 36b989df..8644e3dc 100644 --- a/src/ecdsa/test_pyecdsa.py +++ b/src/ecdsa/test_pyecdsa.py @@ -282,7 +282,7 @@ def test_vk_from_der_garbage_after_curve_oid(self): curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + \ b('garbage') enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) - point_der = der.encode_bitstring(b'\x00\xff') + point_der = der.encode_bitstring(b'\x00\xff', None) to_decode = der.encode_sequence(enc_type_der, point_der) with self.assertRaises(der.UnexpectedDER): @@ -292,7 +292,7 @@ def test_vk_from_der_invalid_key_type(self): type_oid_der = der.encode_oid(*(1, 2, 3)) curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) - point_der = der.encode_bitstring(b'\x00\xff') + point_der = der.encode_bitstring(b'\x00\xff', None) to_decode = der.encode_sequence(enc_type_der, point_der) with self.assertRaises(der.UnexpectedDER): @@ -302,7 +302,7 @@ def test_vk_from_der_garbage_after_point_string(self): type_oid_der = encoded_oid_ecPublicKey curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) - point_der = der.encode_bitstring(b'\x00\xff') + b('garbage') + point_der = der.encode_bitstring(b'\x00\xff', None) + b('garbage') to_decode = der.encode_sequence(enc_type_der, point_der) with self.assertRaises(der.UnexpectedDER): @@ -312,7 +312,27 @@ def test_vk_from_der_invalid_bitstring(self): type_oid_der = encoded_oid_ecPublicKey curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) - point_der = der.encode_bitstring(b'\x08\xff') + point_der = der.encode_bitstring(b'\x08\xff', None) + to_decode = der.encode_sequence(enc_type_der, point_der) + + with self.assertRaises(der.UnexpectedDER): + VerifyingKey.from_der(to_decode) + + def test_vk_from_der_with_invalid_length_of_encoding(self): + type_oid_der = encoded_oid_ecPublicKey + curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) + point_der = der.encode_bitstring(b'\xff'*64, 0) + to_decode = der.encode_sequence(enc_type_der, point_der) + + with self.assertRaises(MalformedPointError): + VerifyingKey.from_der(to_decode) + + def test_vk_from_der_with_raw_encoding(self): + type_oid_der = encoded_oid_ecPublicKey + curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) + point_der = der.encode_bitstring(b'\xff'*48, 0) to_decode = der.encode_sequence(enc_type_der, point_der) with self.assertRaises(der.UnexpectedDER):