Skip to content

Commit

Permalink
Merge pull request #231 from tomato42/fix_equality_tests
Browse files Browse the repository at this point in the history
Fix (in)equality tests
  • Loading branch information
tomato42 committed Dec 10, 2020
2 parents e0626d0 + 7c5b3da commit 24afc18
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 31 deletions.
17 changes: 15 additions & 2 deletions src/ecdsa/ecdsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,20 @@ def __init__(self, generator, point, verify=True):
raise InvalidPointError("Generator point order is bad.")

def __eq__(self, other):
"""Return True if the keys are identical, False otherwise.
Note: for comparison, only placement on the same curve and point
equality is considered, use of the same generator point is not
considered.
"""
if isinstance(other, Public_key):
"""Return True if the points are identical, False otherwise."""
return self.curve == other.curve and self.point == other.point
return NotImplemented

def __ne__(self, other):
"""Return False if the keys are identical, True otherwise."""
return not self == other

def verifies(self, hash, signature):
"""Verify that signature is a valid signature of hash.
Return True if the signature is valid.
Expand Down Expand Up @@ -188,14 +197,18 @@ def __init__(self, public_key, secret_multiplier):
self.secret_multiplier = secret_multiplier

def __eq__(self, other):
"""Return True if the points are identical, False otherwise."""
if isinstance(other, Private_key):
"""Return True if the points are identical, False otherwise."""
return (
self.public_key == other.public_key
and self.secret_multiplier == other.secret_multiplier
)
return NotImplemented

def __ne__(self, other):
"""Return False if the points are identical, True otherwise."""
return not self == other

def sign(self, hash, random_k):
"""Return a signature for the provided hash, using the provided
random nonce. It is absolutely vital that random_k be an unpredictable
Expand Down
34 changes: 27 additions & 7 deletions src/ecdsa/ellipticcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
# Signature checking (5.4.2):
# - Verify that r and s are in [1,n-1].
#
# Version of 2008.11.25.
#
# Revision history:
# 2005.12.31 - Initial version.
# 2008.11.25 - Change CurveFp.is_on to contains_point.
#
# Written in 2005 by Peter Pearson and placed in the public domain.
# Modified extensively as part of python-ecdsa.

from __future__ import division

Expand Down Expand Up @@ -92,8 +91,14 @@ def __init__(self, p, a, b, h=None):
self.__h = h

def __eq__(self, other):
"""Return True if other is an identical curve, False otherwise.
Note: the value of the cofactor of the curve is not taken into account
when comparing curves, as it's derived from the base point and
intrinsic curve characteristic (but it's complex to compute),
only the prime and curve parameters are considered.
"""
if isinstance(other, CurveFp):
"""Return True if the curves are identical, False otherwise."""
return (
self.__p == other.__p
and self.__a == other.__a
Expand All @@ -102,7 +107,8 @@ def __eq__(self, other):
return NotImplemented

def __ne__(self, other):
return not (self == other)
"""Return False if other is an identical curve, True otherwise."""
return not self == other

def __hash__(self):
return hash((self.__p, self.__a, self.__b))
Expand Down Expand Up @@ -158,7 +164,7 @@ def __init__(self, curve, x, y, z, order=None, generator=False):
generator=True
:param bool generator: the point provided is a curve generator, as
such, it will be commonly used with scalar multiplication. This will
cause to precompute multiplication table for it
cause to precompute multiplication table generation for it
"""
self.__curve = curve
# since it's generally better (faster) to use scaled points vs unscaled
Expand Down Expand Up @@ -224,7 +230,10 @@ def __setstate__(self, state):
self._update_lock = RWLock()

def __eq__(self, other):
"""Compare two points with each-other."""
"""Compare for equality two points with each-other.
Note: only points that lie on the same curve can be equal.
"""
try:
self._update_lock.reader_acquire()
if other is INFINITY:
Expand Down Expand Up @@ -256,6 +265,10 @@ def __eq__(self, other):
y1 * zz2 * z2 - y2 * zz1 * z1
) % p == 0

def __ne__(self, other):
"""Compare for inequality two points with each-other."""
return not self == other

def order(self):
"""Return the order of the point.
Expand Down Expand Up @@ -757,7 +770,10 @@ def __init__(self, curve, x, y, order=None):
assert self * order == INFINITY

def __eq__(self, other):
"""Return True if the points are identical, False otherwise."""
"""Return True if the points are identical, False otherwise.
Note: only points that lie on the same curve can be equal.
"""
if isinstance(other, Point):
return (
self.__curve == other.__curve
Expand All @@ -766,6 +782,10 @@ def __eq__(self, other):
)
return NotImplemented

def __ne__(self, other):
"""Returns False if points are identical, True otherwise."""
return not self == other

def __neg__(self):
return Point(self.__curve, self.__x, self.__curve.p() - self.__y)

Expand Down
8 changes: 8 additions & 0 deletions src/ecdsa/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def __eq__(self, other):
return self.curve == other.curve and self.pubkey == other.pubkey
return NotImplemented

def __ne__(self, other):
"""Return False if the points are identical, True otherwise."""
return not self == other

@classmethod
def from_public_point(
cls, point, curve=NIST192p, hashfunc=sha1, validate_point=True
Expand Down Expand Up @@ -817,6 +821,10 @@ def __eq__(self, other):
)
return NotImplemented

def __ne__(self, other):
"""Return False if the points are identical, True otherwise."""
return not self == other

@classmethod
def generate(cls, curve=NIST192p, entropy=None, hashfunc=sha1):
"""
Expand Down
8 changes: 4 additions & 4 deletions src/ecdsa/test_der.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_minimal_with_high_bit_set(self):
val, rem = remove_integer(b("\x02\x02\x00\x80"))

self.assertEqual(val, 0x80)
self.assertFalse(rem)
self.assertEqual(rem, b"")

def test_two_zero_bytes_with_high_bit_set(self):
with self.assertRaises(UnexpectedDER):
Expand All @@ -60,19 +60,19 @@ def test_encoding_of_zero(self):
val, rem = remove_integer(b("\x02\x01\x00"))

self.assertEqual(val, 0)
self.assertFalse(rem)
self.assertEqual(rem, b"")

def test_encoding_of_127(self):
val, rem = remove_integer(b("\x02\x01\x7f"))

self.assertEqual(val, 127)
self.assertFalse(rem)
self.assertEqual(rem, b"")

def test_encoding_of_128(self):
val, rem = remove_integer(b("\x02\x02\x00\x80"))

self.assertEqual(val, 128)
self.assertFalse(rem)
self.assertEqual(rem, b"")

def test_wrong_tag(self):
with self.assertRaises(UnexpectedDER) as e:
Expand Down
18 changes: 6 additions & 12 deletions src/ecdsa/test_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,16 @@ def test_equality_on_verifying_keys(self):
self.assertEqual(self.vk, self.sk.get_verifying_key())

def test_inequality_on_verifying_keys(self):
# use `==` to workaround instrumental <-> unittest compat issue
self.assertFalse(self.vk == self.vk2)
self.assertNotEqual(self.vk, self.vk2)

def test_inequality_on_verifying_keys_not_implemented(self):
# use `==` to workaround instrumental <-> unittest compat issue
self.assertFalse(self.vk == None)
self.assertNotEqual(self.vk, None)

def test_VerifyingKey_inequality_on_same_curve(self):
# use `==` to workaround instrumental <-> unittest compat issue
self.assertFalse(self.vk == self.sk2.verifying_key)
self.assertNotEqual(self.vk, self.sk2.verifying_key)

def test_SigningKey_inequality_on_same_curve(self):
# use `==` to workaround instrumental <-> unittest compat issue
self.assertFalse(self.sk == self.sk2)
self.assertNotEqual(self.sk, self.sk2)


class TestSigningKey(unittest.TestCase):
Expand Down Expand Up @@ -283,12 +279,10 @@ def test_verify_with_lazy_precompute(self):
self.assertTrue(vk.verify(sig, b"other message"))

def test_inequality_on_signing_keys(self):
# use `==` to workaround instrumental <-> unittest compat issue
self.assertFalse(self.sk1 == self.sk2)
self.assertNotEqual(self.sk1, self.sk2)

def test_inequality_on_signing_keys_not_implemented(self):
# use `==` to workaround instrumental <-> unittest compat issue
self.assertFalse(self.sk1 == None)
self.assertNotEqual(self.sk1, None)


# test VerifyingKey.verify()
Expand Down
12 changes: 6 additions & 6 deletions src/ecdsa/test_pyecdsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,9 +653,9 @@ def test_public_key_recovery(self):
)

# Test if original vk is the list of recovered keys
self.assertTrue(
vk.pubkey.point
in [recovered_vk.pubkey.point for recovered_vk in recovered_vks]
self.assertIn(
vk.pubkey.point,
[recovered_vk.pubkey.point for recovered_vk in recovered_vks],
)

def test_public_key_recovery_with_custom_hash(self):
Expand Down Expand Up @@ -684,9 +684,9 @@ def test_public_key_recovery_with_custom_hash(self):
self.assertEqual(sha256, recovered_vk.default_hashfunc)

# Test if original vk is the list of recovered keys
self.assertTrue(
vk.pubkey.point
in [recovered_vk.pubkey.point for recovered_vk in recovered_vks]
self.assertIn(
vk.pubkey.point,
[recovered_vk.pubkey.point for recovered_vk in recovered_vks],
)

def test_encoding(self):
Expand Down

0 comments on commit 24afc18

Please sign in to comment.