diff --git a/src/ecdsa/keys.py b/src/ecdsa/keys.py index 7dd2c111..87adc401 100644 --- a/src/ecdsa/keys.py +++ b/src/ecdsa/keys.py @@ -139,7 +139,7 @@ def __repr__(self): pub_key = self.to_string("compressed") return "VerifyingKey.from_string({0!r}, {1!r}, {2})".format( pub_key, self.curve, self.default_hashfunc().name) - + def __eq__(self, other): """Return True if the points are identical, False otherwise.""" if isinstance(other, VerifyingKey): @@ -290,7 +290,7 @@ def from_string(cls, string, curve=NIST192p, hashfunc=sha1, return cls.from_public_point(point, curve, hashfunc) @classmethod - def from_pem(cls, string): + def from_pem(cls, string, hashfunc=sha1): """ Initialise from public key stored in :term:`PEM` format. @@ -308,10 +308,10 @@ def from_pem(cls, string): :return: Initialised VerifyingKey object :rtype: VerifyingKey """ - return cls.from_der(der.unpem(string)) + return cls.from_der(der.unpem(string), hashfunc=hashfunc) @classmethod - def from_der(cls, string): + def from_der(cls, string, hashfunc=sha1): """ Initialise the key stored in :term:`DER` format. @@ -364,7 +364,7 @@ def from_der(cls, string): # raw encoding of point is invalid in DER files if len(point_str) == curve.verifying_key_length: raise der.UnexpectedDER("Malformed encoding of public point") - return cls.from_string(point_str, curve) + return cls.from_string(point_str, curve, hashfunc=hashfunc) @classmethod def from_public_key_recovery(cls, signature, data, curve, hashfunc=sha1, diff --git a/src/ecdsa/test_keys.py b/src/ecdsa/test_keys.py index a6f23389..1ca66d2b 100644 --- a/src/ecdsa/test_keys.py +++ b/src/ecdsa/test_keys.py @@ -120,13 +120,14 @@ def setUpClass(cls): "-----BEGIN PUBLIC KEY-----\n" "MEkwEwYHKoZIzj0CAQYIKoZIzj0DAQEDMgAEuIF30ITvF/XkVjlAgCg2D59ZtKTX\n" "Jk5i2gZR3OR6NaTFtFz1FZNCOotVe5wgmfNs\n" - "-----END PUBLIC KEY-----\n") - + "-----END PUBLIC KEY-----\n") + cls.key_pem = key_str + cls.key_bytes = unpem(key_str) assert isinstance(cls.key_bytes, bytes) cls.vk = VerifyingKey.from_pem(key_str) cls.sk = SigningKey.from_pem(prv_key_str) - + key_str = ( "-----BEGIN PUBLIC KEY-----\n" "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4H3iRbG4TSrsSRb/gusPQB/4YcN8\n" @@ -135,6 +136,16 @@ def setUpClass(cls): ) cls.vk2 = VerifyingKey.from_pem(key_str) + def test_custom_hashfunc(self): + vk = VerifyingKey.from_der(self.key_bytes, hashlib.sha256) + + self.assertIs(vk.default_hashfunc, hashlib.sha256) + + def test_from_pem_with_custom_hashfunc(self): + vk = VerifyingKey.from_pem(self.key_pem, hashlib.sha256) + + self.assertIs(vk.default_hashfunc, hashlib.sha256) + def test_bytes(self): vk = VerifyingKey.from_der(self.key_bytes) @@ -166,14 +177,14 @@ def test_array_array_of_bytes_memoryview(self): vk = VerifyingKey.from_der(buffer(arr)) self.assertEqual(self.vk.to_string(), vk.to_string()) - - def test_equality_on_verifying_keys(self): + + def test_equality_on_verifying_keys(self): self.assertEqual(self.vk, self.sk.get_verifying_key()) - - def test_inequality_on_verifying_keys(self): + + def test_inequality_on_verifying_keys(self): self.assertNotEqual(self.vk, self.vk2) - - def test_inequality_on_verifying_keys_not_implemented(self): + + def test_inequality_on_verifying_keys_not_implemented(self): self.assertNotEqual(self.vk, None)