Skip to content

Commit

Permalink
Merge pull request #118 from tomato42/point-compression
Browse files Browse the repository at this point in the history
Support for X9.62 formatted public keys
  • Loading branch information
tomato42 committed Oct 1, 2019
2 parents bcf6afe + cb15e5f commit ba0a5c6
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 28 deletions.
136 changes: 118 additions & 18 deletions src/ecdsa/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from . import ecdsa
from . import der
from . import rfc6979
from . import ellipticcurve
from .curves import NIST192p, find_curve
from .numbertheory import square_root_mod_prime, SquareRootError
from .ecdsa import RSZeroError
from .util import string_to_number, number_to_string, randrange
from .util import sigencode_string, sigdecode_string
Expand All @@ -23,6 +25,10 @@ class BadDigestError(Exception):
pass


class MalformedPointError(AssertionError):
pass


class VerifyingKey:
def __init__(self, _error__please_use_generate=None):
if not _error__please_use_generate:
Expand All @@ -38,9 +44,8 @@ def from_public_point(klass, point, curve=NIST192p, hashfunc=sha1):
self.pubkey.order = curve.order
return self

@classmethod
def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
validate_point=True):
@staticmethod
def _from_raw_encoding(string, curve, validate_point):
order = curve.order
assert (len(string) == curve.verifying_key_length), \
(len(string), curve.verifying_key_length)
Expand All @@ -50,10 +55,72 @@ def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
assert len(ys) == curve.baselen, (len(ys), curve.baselen)
x = string_to_number(xs)
y = string_to_number(ys)
if validate_point:
assert ecdsa.point_is_valid(curve.generator, x, y)
from . import ellipticcurve
point = ellipticcurve.Point(curve.curve, x, y, order)
if validate_point and not ecdsa.point_is_valid(curve.generator, x, y):
raise MalformedPointError("Point does not lie on the curve")

return ellipticcurve.Point(curve.curve, x, y, order)

@staticmethod
def _from_compressed(string, curve, validate_point):
if string[:1] not in (b('\x02'), b('\x03')):
raise MalformedPointError("Malformed compressed point encoding")

is_even = string[:1] == b('\x02')
x = string_to_number(string[1:])
order = curve.order
p = curve.curve.p()
alpha = (pow(x, 3, p) + (curve.curve.a() * x) + curve.curve.b()) % p
try:
beta = square_root_mod_prime(alpha, p)
except SquareRootError as e:
raise MalformedPointError(
"Encoding does not correspond to a point on curve", e)
if is_even == bool(beta & 1):
y = p - beta
else:
y = beta
if validate_point and not ecdsa.point_is_valid(curve.generator, x, y):
raise MalformedPointError("Point does not lie on curve")
return ellipticcurve.Point(curve.curve, x, y, order)

@classmethod
def _from_hybrid(cls, string, curve, validate_point):
assert string[:1] in (b('\x06'), b('\x07'))

# primarily use the uncompressed as it's easiest to handle
point = cls._from_raw_encoding(string[1:], curve, validate_point)

# but validate if it's self-consistent if we're asked to do that
if validate_point and \
(point.y() & 1 and string[:1] != b('\x07') or
(not point.y() & 1) and string[:1] != b('\x06')):
raise MalformedPointError("Inconsistent hybrid point encoding")

return point

@classmethod
def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
validate_point=True):
sig_len = len(string)
if sig_len == curve.verifying_key_length:
point = klass._from_raw_encoding(string, curve, validate_point)
elif sig_len == curve.verifying_key_length + 1:
if string[:1] in (b('\x06'), b('\x07')):
point = klass._from_hybrid(string, curve, validate_point)
elif string[:1] == b('\x04'):
point = klass._from_raw_encoding(string[1:], curve,
validate_point)
else:
raise MalformedPointError(
"Invalid X9.62 encoding of the public point")
elif sig_len == curve.baselen + 1:
point = klass._from_compressed(string, curve, validate_point)
else:
raise MalformedPointError(
"Length of string does not match lengths of "
"any of the supported encodings of {0} "
"curve.".format(curve.name))

return klass.from_public_point(point, curve, hashfunc)

@classmethod
Expand All @@ -74,14 +141,20 @@ def from_der(klass, string):
if empty != b(""):
raise der.UnexpectedDER("trailing junk after DER pubkey objects: %s" %
binascii.hexlify(empty))
assert oid_pk == oid_ecPublicKey, (oid_pk, oid_ecPublicKey)
if not oid_pk == oid_ecPublicKey:
raise der.UnexpectedDER("Unexpected object identifier in DER "
"encoding: {0!r}".format(oid_pk))
curve = find_curve(oid_curve)
point_str, empty = der.remove_bitstring(point_str_bitstring)
if empty != b(""):
raise der.UnexpectedDER("trailing junk after pubkey pointstring: %s" %
binascii.hexlify(empty))
assert point_str.startswith(b("\x00\x04"))
return klass.from_string(point_str[2:], curve)
# the point encoding is padded with a zero byte
# raw encoding of point is invalid in DER files
if not point_str.startswith(b("\x00")) or \
len(point_str[1:]) == curve.verifying_key_length:
raise der.UnexpectedDER("Malformed encoding of public point")
return klass.from_string(point_str[1:], curve)

@classmethod
def from_public_key_recovery(cls, signature, data, curve, hashfunc=sha1,
Expand Down Expand Up @@ -110,23 +183,49 @@ def from_public_key_recovery_with_digest(klass, signature, digest, curve, hashfu
verifying_keys = [klass.from_public_point(pk.point, curve, hashfunc) for pk in pks]
return verifying_keys

def to_string(self):
# VerifyingKey.from_string(vk.to_string()) == vk as long as the
# curves are the same: the curve itself is not included in the
# serialized form
def _raw_encode(self):
order = self.pubkey.order
x_str = number_to_string(self.pubkey.point.x(), order)
y_str = number_to_string(self.pubkey.point.y(), order)
return x_str + y_str

def _compressed_encode(self):
order = self.pubkey.order
x_str = number_to_string(self.pubkey.point.x(), order)
if self.pubkey.point.y() & 1:
return b('\x03') + x_str
else:
return b('\x02') + x_str

def _hybrid_encode(self):
raw_enc = self._raw_encode()
if self.pubkey.point.y() & 1:
return b('\x07') + raw_enc
else:
return b('\x06') + raw_enc

def to_string(self, encoding="raw"):
# VerifyingKey.from_string(vk.to_string()) == vk as long as the
# curves are the same: the curve itself is not included in the
# serialized form
assert encoding in ("raw", "uncompressed", "compressed", "hybrid")
if encoding == "raw":
return self._raw_encode()
elif encoding == "uncompressed":
return b('\x04') + self._raw_encode()
elif encoding == "hybrid":
return self._hybrid_encode()
else:
return self._compressed_encode()

def to_pem(self):
return der.topem(self.to_der(), "PUBLIC KEY")

def to_der(self):
def to_der(self, point_encoding="uncompressed"):
order = self.pubkey.order
x_str = number_to_string(self.pubkey.point.x(), order)
y_str = number_to_string(self.pubkey.point.y(), order)
point_str = b("\x00\x04") + x_str + y_str
point_str = b("\x00") + self.to_string(point_encoding)
return der.encode_sequence(der.encode_sequence(encoded_oid_ecPublicKey,
self.curve.encoded_oid),
der.encode_bitstring(point_str))
Expand Down Expand Up @@ -247,10 +346,11 @@ def to_pem(self):
# TODO: "BEGIN ECPARAMETERS"
return der.topem(self.to_der(), "EC PRIVATE KEY")

def to_der(self):
def to_der(self, point_encoding="uncompressed"):
# SEQ([int(1), octetstring(privkey),cont[0], oid(secp224r1),
# cont[1],bitstring])
encoded_vk = b("\x00\x04") + self.get_verifying_key().to_string()
encoded_vk = b("\x00") + \
self.get_verifying_key().to_string(point_encoding)
return der.encode_sequence(der.encode_integer(1),
der.encode_octet_string(self.to_string()),
der.encode_constructed(0, self.curve.encoded_oid),
Expand Down
23 changes: 16 additions & 7 deletions src/ecdsa/numbertheory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@

from __future__ import division

from six import integer_types
from six import integer_types, PY3
from six.moves import reduce
try:
xrange
except NameError:
xrange = range

import math

Expand Down Expand Up @@ -62,7 +66,7 @@ def polynomial_reduce_mod(poly, polymod, p):

while len(poly) >= len(polymod):
if poly[-1] != 0:
for i in range(2, len(polymod) + 1):
for i in xrange(2, len(polymod) + 1):
poly[-i] = (poly[-i] - poly[-1] * polymod[-i]) % p
poly = poly[0:-1]

Expand All @@ -86,8 +90,8 @@ def polynomial_multiply_mod(m1, m2, polymod, p):

# Add together all the cross-terms:

for i in range(len(m1)):
for j in range(len(m2)):
for i in xrange(len(m1)):
for j in xrange(len(m2)):
prod[i + j] = (prod[i + j] + m1[i] * m2[j]) % p

return polynomial_reduce_mod(prod, polymod, p)
Expand Down Expand Up @@ -187,7 +191,12 @@ def square_root_mod_prime(a, p):
return (2 * a * modular_exp(4 * a, (p - 5) // 8, p)) % p
raise RuntimeError("Shouldn't get here.")

for b in range(2, p):
if PY3:
range_top = p
else:
# xrange on python2 can take integers representable as C long only
range_top = min(0x7fffffff, p)
for b in xrange(2, range_top):
if jacobi(b * b - 4 * a, p) == -1:
f = (a, -b, 1)
ff = polynomial_exp_mod((0, 1), (p + 1) // 2, f, p)
Expand Down Expand Up @@ -355,7 +364,7 @@ def carmichael_of_factorized(f_list):
return 1

result = carmichael_of_ppower(f_list[0])
for i in range(1, len(f_list)):
for i in xrange(1, len(f_list)):
result = lcm(result, carmichael_of_ppower(f_list[i]))

return result
Expand Down Expand Up @@ -477,7 +486,7 @@ def is_prime(n):
while (r % 2) == 0:
s = s + 1
r = r // 2
for i in range(t):
for i in xrange(t):
a = smallprimes[i]
y = modular_exp(a, r, n)
if y != 1 and y != n - 1:
Expand Down
Loading

0 comments on commit ba0a5c6

Please sign in to comment.