Skip to content

Commit

Permalink
make our code follow the new calling convention for bistring
Browse files Browse the repository at this point in the history
  • Loading branch information
tomato42 committed Oct 17, 2019
1 parent 627a7b2 commit ba3d483
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
32 changes: 17 additions & 15 deletions src/ecdsa/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,14 @@ def from_der(klass, string):
raise der.UnexpectedDER("Unexpected object identifier in DER "
"encoding: {0!r}".format(oid_pk))
curve = find_curve(oid_curve)
point_str, empty = der.remove_bitstring(point_str_bitstring)
point_str, empty = der.remove_bitstring(point_str_bitstring, 0)
if empty != b(""):
raise der.UnexpectedDER("trailing junk after pubkey pointstring: %s" %
binascii.hexlify(empty))
# the point encoding is padded with a zero byte
# raw encoding of point is invalid in DER files
if not point_str.startswith(b("\x00")) or \
len(point_str[1:]) == curve.verifying_key_length:
if len(point_str) == curve.verifying_key_length:
raise der.UnexpectedDER("Malformed encoding of public point")
return klass.from_string(point_str[1:], curve)
return klass.from_string(point_str, curve)

@classmethod
def from_public_key_recovery(cls, signature, data, curve, hashfunc=sha1,
Expand Down Expand Up @@ -228,10 +226,12 @@ def to_pem(self):
def to_der(self, point_encoding="uncompressed"):
if point_encoding == "raw":
raise ValueError("raw point_encoding not allowed in DER")
point_str = b("\x00") + self.to_string(point_encoding)
point_str = self.to_string(point_encoding)
return der.encode_sequence(der.encode_sequence(encoded_oid_ecPublicKey,
self.curve.encoded_oid),
der.encode_bitstring(point_str))
# 0 is the number of unused bits in the
# bit string
der.encode_bitstring(point_str, 0))

def verify(self, signature, data, hashfunc=None, sigdecode=sigdecode_string):
hashfunc = hashfunc or self.default_hashfunc
Expand Down Expand Up @@ -336,7 +336,7 @@ def from_der(klass, string, hashfunc=sha1):
# if tag != 1:
# raise der.UnexpectedDER("expected tag 1 in DER privkey, got %d"
# % tag)
# pubkey_str = der.remove_bitstring(pubkey_bitstring)
# pubkey_str = der.remove_bitstring(pubkey_bitstring, 0)
# if empty != "":
# raise der.UnexpectedDER("trailing junk after DER privkey "
# "pubkeystr: %s" % binascii.hexlify(empty))
Expand All @@ -360,13 +360,15 @@ def to_der(self, point_encoding="uncompressed"):
# cont[1],bitstring])
if point_encoding == "raw":
raise ValueError("raw encoding not allowed in DER")
encoded_vk = b("\x00") + \
self.get_verifying_key().to_string(point_encoding)
return der.encode_sequence(der.encode_integer(1),
der.encode_octet_string(self.to_string()),
der.encode_constructed(0, self.curve.encoded_oid),
der.encode_constructed(1, der.encode_bitstring(encoded_vk)),
)
encoded_vk = self.get_verifying_key().to_string(point_encoding)
# the 0 in encode_bitstring specifies the number of unused bits
# in the `encoded_vk` string
return der.encode_sequence(
der.encode_integer(1),
der.encode_octet_string(self.to_string()),
der.encode_constructed(0, self.curve.encoded_oid),
der.encode_constructed(1, der.encode_bitstring(encoded_vk, 0)),
)

def get_verifying_key(self):
return self.verifying_key
Expand Down
28 changes: 24 additions & 4 deletions src/ecdsa/test_pyecdsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def test_vk_from_der_garbage_after_curve_oid(self):
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + \
b('garbage')
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x00\xff')
point_der = der.encode_bitstring(b'\x00\xff', None)
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
Expand All @@ -305,7 +305,7 @@ def test_vk_from_der_invalid_key_type(self):
type_oid_der = der.encode_oid(*(1, 2, 3))
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x00\xff')
point_der = der.encode_bitstring(b'\x00\xff', None)
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
Expand All @@ -315,7 +315,7 @@ def test_vk_from_der_garbage_after_point_string(self):
type_oid_der = encoded_oid_ecPublicKey
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x00\xff') + b('garbage')
point_der = der.encode_bitstring(b'\x00\xff', None) + b('garbage')
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
Expand All @@ -325,7 +325,27 @@ def test_vk_from_der_invalid_bitstring(self):
type_oid_der = encoded_oid_ecPublicKey
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x08\xff')
point_der = der.encode_bitstring(b'\x08\xff', None)
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
VerifyingKey.from_der(to_decode)

def test_vk_from_der_with_invalid_length_of_encoding(self):
type_oid_der = encoded_oid_ecPublicKey
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\xff'*64, 0)
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(MalformedPointError):
VerifyingKey.from_der(to_decode)

def test_vk_from_der_with_raw_encoding(self):
type_oid_der = encoded_oid_ecPublicKey
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\xff'*48, 0)
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
Expand Down

0 comments on commit ba3d483

Please sign in to comment.