From a13b1d066f2553b116ff5a07dfb2bcc45549ecfa Mon Sep 17 00:00:00 2001 From: Hubert Kario Date: Sat, 24 Apr 2021 20:31:55 +0200 Subject: [PATCH] allow limiting point formats, don't accept malformed PEM public files Allow specifying what point formats are supported when loading public keys. Limit the point formats when loading PEM and DER public files to the formats actually allowed in them: uncompressed, compressed, and hybrid. Previous code would allow raw encoding too. --- src/ecdsa/ecdh.py | 15 +++++++-- src/ecdsa/keys.py | 71 ++++++++++++++++++++++++++++++++------- src/ecdsa/test_keys.py | 8 ++++- src/ecdsa/test_pyecdsa.py | 59 ++++++++++++++++++++++++++++++++ 4 files changed, 137 insertions(+), 16 deletions(-) diff --git a/src/ecdsa/ecdh.py b/src/ecdsa/ecdh.py index a12e94ee..824a09b4 100644 --- a/src/ecdsa/ecdh.py +++ b/src/ecdsa/ecdh.py @@ -216,7 +216,7 @@ def get_public_key(self): :return: public (verifying) key from local private key. :rtype: VerifyingKey object - """ + """ return self.private_key.get_verifying_key() def load_received_public_key(self, public_key): @@ -237,7 +237,9 @@ def load_received_public_key(self, public_key): raise InvalidCurveError("Curve mismatch.") self.public_key = public_key - def load_received_public_key_bytes(self, public_key_str): + def load_received_public_key_bytes( + self, public_key_str, valid_encodings=None + ): """ Load public key from byte string. @@ -247,9 +249,16 @@ def load_received_public_key_bytes(self, public_key_str): :param public_key_str: public key in bytes string format :type public_key_str: :term:`bytes-like object` + :param valid_encodings: list of acceptable point encoding formats, + supported ones are: :term:`uncompressed`, :term:`compressed`, + :term:`hybrid`, and :term:`raw encoding` (specified with ``raw`` + name). All formats by default (specified with ``None``). + :type valid_encodings: :term:`set-like object` """ return self.load_received_public_key( - VerifyingKey.from_string(public_key_str, self.curve) + VerifyingKey.from_string( + public_key_str, self.curve, valid_encodings + ) ) def load_received_public_key_der(self, public_key_der): diff --git a/src/ecdsa/keys.py b/src/ecdsa/keys.py index ded7cfb0..e0723540 100644 --- a/src/ecdsa/keys.py +++ b/src/ecdsa/keys.py @@ -64,6 +64,10 @@ string) is endianess dependant! Signature computed over ``array.array`` of integers on a big-endian system will not be verified on a little-endian system and vice-versa. + + set-like object + All the types that support the ``in`` operator, like ``list``, + ``tuple``, ``set``, ``frozenset``, etc. """ import binascii @@ -332,7 +336,12 @@ def _from_hybrid(cls, string, curve, validate_point): @classmethod def from_string( - cls, string, curve=NIST192p, hashfunc=sha1, validate_point=True + cls, + string, + curve=NIST192p, + hashfunc=sha1, + validate_point=True, + valid_encodings=None, ): """ Initialise the object from byte encoding of public key. @@ -355,6 +364,11 @@ def from_string( :param validate_point: whether to verify that the point lays on the provided curve or not, defaults to True :type validate_point: bool + :param valid_encodings: list of acceptable point encoding formats, + supported ones are: :term:`uncompressed`, :term:`compressed`, + :term:`hybrid`, and :term:`raw encoding` (specified with ``raw`` + name). All formats by default (specified with ``None``). + :type valid_encodings: :term:`set-like object` :raises MalformedPointError: if the public point does not lay on the curve or the encoding is invalid @@ -362,31 +376,43 @@ def from_string( :return: Initialised VerifyingKey object :rtype: VerifyingKey """ + if valid_encodings is None: + valid_encodings = set( + ["uncompressed", "compressed", "hybrid", "raw"] + ) string = normalise_bytes(string) sig_len = len(string) - if sig_len == curve.verifying_key_length: + if sig_len == curve.verifying_key_length and "raw" in valid_encodings: point = cls._from_raw_encoding(string, curve) - elif sig_len == curve.verifying_key_length + 1: - if string[:1] in (b("\x06"), b("\x07")): + elif sig_len == curve.verifying_key_length + 1 and ( + "hybrid" in valid_encodings or "uncompressed" in valid_encodings + ): + if ( + string[:1] in (b("\x06"), b("\x07")) + and "hybrid" in valid_encodings + ): point = cls._from_hybrid(string, curve, validate_point) - elif string[:1] == b("\x04"): + elif string[:1] == b("\x04") and "uncompressed" in valid_encodings: point = cls._from_raw_encoding(string[1:], curve) else: raise MalformedPointError( "Invalid X9.62 encoding of the public point" ) - elif sig_len == curve.verifying_key_length // 2 + 1: + elif ( + sig_len == curve.verifying_key_length // 2 + 1 + and "compressed" in valid_encodings + ): point = cls._from_compressed(string, curve) else: raise MalformedPointError( "Length of string does not match lengths of " - "any of the supported encodings of {0} " - "curve.".format(curve.name) + "any of the enabled ({1}) encodings of {0} " + "curve.".format(curve.name, ", ".join(valid_encodings)) ) return cls.from_public_point(point, curve, hashfunc, validate_point) @classmethod - def from_pem(cls, string, hashfunc=sha1): + def from_pem(cls, string, hashfunc=sha1, valid_encodings=None): """ Initialise from public key stored in :term:`PEM` format. @@ -400,14 +426,23 @@ def from_pem(cls, string, hashfunc=sha1): :param string: text with PEM-encoded public ECDSA key :type string: str + :param valid_encodings: list of allowed point encodings. + By default :term:`uncompressed`, :term:`compressed`, and + :term:`hybrid`. To read malformed files, include + :term:`raw encoding` with ``raw`` in the list. + :type valid_encodings: :term:`set-like object :return: Initialised VerifyingKey object :rtype: VerifyingKey """ - return cls.from_der(der.unpem(string), hashfunc=hashfunc) + return cls.from_der( + der.unpem(string), + hashfunc=hashfunc, + valid_encodings=valid_encodings, + ) @classmethod - def from_der(cls, string, hashfunc=sha1): + def from_der(cls, string, hashfunc=sha1, valid_encodings=None): """ Initialise the key stored in :term:`DER` format. @@ -432,10 +467,17 @@ def from_der(cls, string, hashfunc=sha1): :param string: binary string with the DER encoding of public ECDSA key :type string: bytes-like object + :param valid_encodings: list of allowed point encodings. + By default :term:`uncompressed`, :term:`compressed`, and + :term:`hybrid`. To read malformed files, include + :term:`raw encoding` with ``raw`` in the list. + :type valid_encodings: :term:`set-like object :return: Initialised VerifyingKey object :rtype: VerifyingKey """ + if valid_encodings is None: + valid_encodings = set(["uncompressed", "compressed", "hybrid"]) string = normalise_bytes(string) # [[oid_ecPublicKey,oid_curve], point_str_bitstring] s1, empty = der.remove_sequence(string) @@ -467,7 +509,12 @@ def from_der(cls, string, hashfunc=sha1): # raw encoding of point is invalid in DER files if len(point_str) == curve.verifying_key_length: raise der.UnexpectedDER("Malformed encoding of public point") - return cls.from_string(point_str, curve, hashfunc=hashfunc) + return cls.from_string( + point_str, + curve, + hashfunc=hashfunc, + valid_encodings=valid_encodings, + ) @classmethod def from_public_key_recovery( diff --git a/src/ecdsa/test_keys.py b/src/ecdsa/test_keys.py index 58cf9613..31353f45 100644 --- a/src/ecdsa/test_keys.py +++ b/src/ecdsa/test_keys.py @@ -13,7 +13,7 @@ import pytest import hashlib -from .keys import VerifyingKey, SigningKey +from .keys import VerifyingKey, SigningKey, MalformedPointError from .der import unpem from .util import ( sigencode_string, @@ -153,6 +153,12 @@ def setUpClass(cls): cls.sk2 = SigningKey.generate(vk.curve) + def test_load_key_with_disabled_format(self): + with self.assertRaises(MalformedPointError) as e: + VerifyingKey.from_der(self.key_bytes, valid_encodings=["raw"]) + + self.assertIn("enabled (raw) encodings", str(e.exception)) + def test_custom_hashfunc(self): vk = VerifyingKey.from_der(self.key_bytes, hashlib.sha256) diff --git a/src/ecdsa/test_pyecdsa.py b/src/ecdsa/test_pyecdsa.py index f61981ff..d1418c01 100644 --- a/src/ecdsa/test_pyecdsa.py +++ b/src/ecdsa/test_pyecdsa.py @@ -722,6 +722,65 @@ def test_decoding(self): from_uncompressed = VerifyingKey.from_string(b("\x06") + enc) self.assertEqual(from_uncompressed.pubkey.point, vk.pubkey.point) + def test_uncompressed_decoding_as_only_alowed(self): + enc = b( + "\x04" + "\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3" + "\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4" + "z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*" + ) + vk = VerifyingKey.from_string(enc, valid_encodings=("uncompressed",)) + sk = SigningKey.from_secret_exponent(123456789) + + self.assertEqual(vk, sk.verifying_key) + + def test_raw_decoding_with_blocked_format(self): + enc = b( + "\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3" + "\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4" + "z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*" + ) + with self.assertRaises(MalformedPointError) as exp: + VerifyingKey.from_string(enc, valid_encodings=("hybrid",)) + + self.assertIn("hybrid", str(exp.exception)) + + def test_uncompressed_decoding_with_blocked_format(self): + enc = b( + "\x04" + "\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3" + "\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4" + "z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*" + ) + with self.assertRaises(MalformedPointError) as exp: + VerifyingKey.from_string(enc, valid_encodings=("hybrid",)) + + self.assertIn("Invalid X9.62 encoding", str(exp.exception)) + + def test_hybrid_decoding_with_blocked_format(self): + enc = b( + "\x06" + "\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3" + "\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4" + "z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*" + ) + with self.assertRaises(MalformedPointError) as exp: + VerifyingKey.from_string(enc, valid_encodings=("uncompressed",)) + + self.assertIn("Invalid X9.62 encoding", str(exp.exception)) + + def test_compressed_decoding_with_blocked_format(self): + enc = b( + "\x02" + "\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3" + "\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4" + "z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*" + )[:25] + with self.assertRaises(MalformedPointError) as exp: + VerifyingKey.from_string(enc, valid_encodings=("hybrid", "raw")) + + self.assertIn("(hybrid, raw)", str(exp.exception)) + def test_decoding_with_malformed_uncompressed(self): enc = b( "\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3"